mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-29 19:15:39 +00:00
enhance: fix task nonstopping error if model key is expired PR1017
This commit is contained in:
parent
2480ea0145
commit
e1e8d0ad8a
3 changed files with 3 additions and 241 deletions
|
|
@ -52,7 +52,6 @@ from app.service.task import (
|
|||
TaskLock,
|
||||
delete_task_lock,
|
||||
set_current_task_id,
|
||||
validate_model_before_task,
|
||||
)
|
||||
from app.utils.event_loop_utils import set_main_event_loop
|
||||
from app.utils.file_utils import get_working_directory
|
||||
|
|
@ -321,9 +320,8 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
"""Main task execution loop. Called when POST /chat endpoint
|
||||
is hit to start a new chat session.
|
||||
|
||||
Validates model configuration, processes task queue, manages
|
||||
workforce lifecycle, and streams responses back to the client
|
||||
via SSE.
|
||||
Processes task queue, manages workforce lifecycle, and streams
|
||||
responses back to the client via SSE.
|
||||
|
||||
Args:
|
||||
options (Chat): Chat configuration containing task details and
|
||||
|
|
@ -335,15 +333,6 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
Yields:
|
||||
SSE formatted responses for task progress, errors, and results
|
||||
"""
|
||||
# Validate model configuration before starting task
|
||||
is_valid, error_msg = await validate_model_before_task(options)
|
||||
if not is_valid:
|
||||
yield sse_json(
|
||||
"error", {"message": f"Model validation failed: {error_msg}"}
|
||||
)
|
||||
task_lock.status = Status.done
|
||||
return
|
||||
|
||||
start_event_loop = True
|
||||
|
||||
# Initialize task_lock attributes
|
||||
|
|
@ -2028,7 +2017,7 @@ Is this a complex task? (yes/no):"""
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in question_confirm: {e}")
|
||||
return True
|
||||
raise
|
||||
|
||||
|
||||
async def summary_task(agent: ListenChatAgent, task: Task) -> str:
|
||||
|
|
|
|||
|
|
@ -25,11 +25,9 @@ from camel.tasks import Task
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.component.model_validation import create_agent
|
||||
from app.exception.exception import ProgramException
|
||||
from app.model.chat import (
|
||||
AgentModelConfig,
|
||||
Chat,
|
||||
McpServers,
|
||||
SupplementChat,
|
||||
UpdateData,
|
||||
|
|
@ -676,52 +674,3 @@ def set_process_task(process_task_id: str):
|
|||
yield
|
||||
finally:
|
||||
process_task.reset(origin)
|
||||
|
||||
|
||||
async def validate_model_before_task(options: Chat) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate model configuration before starting a task.
|
||||
Makes a simple test request to ensure the API key and model are valid.
|
||||
|
||||
Args:
|
||||
options (Chat): Chat options containing model configuration.
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
- is_valid: True if validation passed
|
||||
- error_message: Raw error message if validation failed,
|
||||
None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Validating model configuration for task {options.task_id}"
|
||||
)
|
||||
|
||||
# Create test agent with same config as task will use
|
||||
agent = create_agent(
|
||||
model_platform=options.model_platform,
|
||||
model_type=options.model_type,
|
||||
api_key=options.api_key,
|
||||
url=options.api_url,
|
||||
model_config_dict=options.model_config,
|
||||
)
|
||||
|
||||
# Make a simple test call in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: agent.step("test"))
|
||||
|
||||
logger.info(f"Model validation passed for task {options.task_id}")
|
||||
return True, None
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(
|
||||
f"Model validation failed for task {options.task_id}: {error_msg}",
|
||||
extra={
|
||||
"project_id": options.project_id,
|
||||
"task_id": options.task_id,
|
||||
"error": error_msg,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
return False, error_msg
|
||||
|
|
|
|||
|
|
@ -1,176 +0,0 @@
|
|||
"""
|
||||
Unit tests for validate_model_before_task function.
|
||||
|
||||
TODO: Rename this file to test_task.py after fixing errors
|
||||
in backend/tests/unit/service/test_task.py
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from camel.types import ModelPlatformType
|
||||
|
||||
from app.model.chat import Chat
|
||||
from app.service.task import validate_model_before_task
|
||||
|
||||
# Test data constants
|
||||
TEST_PROJECT_ID = "test_project"
|
||||
TEST_TASK_ID = "test_task_123"
|
||||
TEST_QUESTION = "Test question"
|
||||
TEST_EMAIL = "test@example.com"
|
||||
TEST_MODEL_PLATFORM = ModelPlatformType.OPENAI
|
||||
TEST_MODEL_TYPE = "gpt-4o"
|
||||
TEST_API_URL = "https://api.openai.com/v1"
|
||||
TEST_VALID_API_KEY = "sk-valid-key"
|
||||
TEST_INVALID_API_KEY = "sk-invalid-key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_model_success():
|
||||
"""Test successful model validation."""
|
||||
options = Chat(
|
||||
project_id=TEST_PROJECT_ID,
|
||||
task_id=TEST_TASK_ID,
|
||||
question=TEST_QUESTION,
|
||||
email=TEST_EMAIL,
|
||||
model_platform=TEST_MODEL_PLATFORM,
|
||||
model_type=TEST_MODEL_TYPE,
|
||||
api_key=TEST_VALID_API_KEY,
|
||||
api_url=TEST_API_URL,
|
||||
model_config={},
|
||||
)
|
||||
|
||||
# Mock the create_agent and agent.step
|
||||
mock_agent = Mock()
|
||||
mock_agent.step = Mock(return_value="test response")
|
||||
|
||||
with patch("app.service.task.create_agent", return_value=mock_agent):
|
||||
is_valid, error_msg = await validate_model_before_task(options)
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_model_invalid_api_key():
|
||||
"""Test model validation with invalid API key."""
|
||||
options = Chat(
|
||||
project_id=TEST_PROJECT_ID,
|
||||
task_id=TEST_TASK_ID,
|
||||
question=TEST_QUESTION,
|
||||
email=TEST_EMAIL,
|
||||
model_platform=TEST_MODEL_PLATFORM,
|
||||
model_type=TEST_MODEL_TYPE,
|
||||
api_key=TEST_INVALID_API_KEY,
|
||||
api_url=TEST_API_URL,
|
||||
model_config={},
|
||||
)
|
||||
|
||||
# Mock the create_agent to raise authentication error
|
||||
with patch("app.service.task.create_agent") as mock_create:
|
||||
mock_agent = Mock()
|
||||
mock_agent.step = Mock(
|
||||
side_effect=Exception("Error code: 401 - Invalid API key")
|
||||
)
|
||||
mock_create.return_value = mock_agent
|
||||
|
||||
is_valid, error_msg = await validate_model_before_task(options)
|
||||
|
||||
assert is_valid is False
|
||||
assert error_msg is not None
|
||||
assert "401" in error_msg or "Invalid API key" in error_msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_model_network_error():
|
||||
"""Test model validation with network error."""
|
||||
options = Chat(
|
||||
project_id=TEST_PROJECT_ID,
|
||||
task_id=TEST_TASK_ID,
|
||||
question=TEST_QUESTION,
|
||||
email=TEST_EMAIL,
|
||||
model_platform=TEST_MODEL_PLATFORM,
|
||||
model_type=TEST_MODEL_TYPE,
|
||||
api_key=TEST_VALID_API_KEY,
|
||||
api_url="https://invalid-url.com",
|
||||
model_config={},
|
||||
)
|
||||
|
||||
# Mock the create_agent to raise network error
|
||||
with patch("app.service.task.create_agent") as mock_create:
|
||||
mock_agent = Mock()
|
||||
mock_agent.step = Mock(side_effect=Exception("Connection error"))
|
||||
mock_create.return_value = mock_agent
|
||||
|
||||
is_valid, error_msg = await validate_model_before_task(options)
|
||||
|
||||
assert is_valid is False
|
||||
assert error_msg is not None
|
||||
assert "Connection error" in error_msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_model_with_custom_config():
|
||||
"""Test model validation with custom model configuration."""
|
||||
custom_config = {"temperature": 0.7, "max_tokens": 1000}
|
||||
|
||||
options = Chat(
|
||||
project_id=TEST_PROJECT_ID,
|
||||
task_id=TEST_TASK_ID,
|
||||
question=TEST_QUESTION,
|
||||
email=TEST_EMAIL,
|
||||
model_platform=TEST_MODEL_PLATFORM,
|
||||
model_type=TEST_MODEL_TYPE,
|
||||
api_key=TEST_VALID_API_KEY,
|
||||
api_url=TEST_API_URL,
|
||||
model_config=custom_config,
|
||||
)
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.step = Mock(return_value="test response")
|
||||
|
||||
with patch(
|
||||
"app.service.task.create_agent", return_value=mock_agent
|
||||
) as mock_create:
|
||||
is_valid, error_msg = await validate_model_before_task(options)
|
||||
|
||||
# Verify create_agent was called
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
assert call_args.kwargs["model_platform"] == options.model_platform
|
||||
assert call_args.kwargs["model_type"] == options.model_type
|
||||
assert call_args.kwargs["api_key"] == options.api_key
|
||||
assert call_args.kwargs["url"] == options.api_url
|
||||
|
||||
assert is_valid is True
|
||||
assert error_msg is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_model_rate_limit_error():
|
||||
"""Test model validation with rate limit error."""
|
||||
options = Chat(
|
||||
project_id=TEST_PROJECT_ID,
|
||||
task_id=TEST_TASK_ID,
|
||||
question=TEST_QUESTION,
|
||||
email=TEST_EMAIL,
|
||||
model_platform=TEST_MODEL_PLATFORM,
|
||||
model_type=TEST_MODEL_TYPE,
|
||||
api_key=TEST_VALID_API_KEY,
|
||||
api_url=TEST_API_URL,
|
||||
model_config={},
|
||||
)
|
||||
|
||||
# Mock the create_agent to raise rate limit error
|
||||
with patch("app.service.task.create_agent") as mock_create:
|
||||
mock_agent = Mock()
|
||||
mock_agent.step = Mock(
|
||||
side_effect=Exception("Error code: 429 - Rate limit exceeded")
|
||||
)
|
||||
mock_create.return_value = mock_agent
|
||||
|
||||
is_valid, error_msg = await validate_model_before_task(options)
|
||||
|
||||
assert is_valid is False
|
||||
assert error_msg is not None
|
||||
assert "429" in error_msg or "Rate limit" in error_msg
|
||||
Loading…
Add table
Add a link
Reference in a new issue