mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-28 17:55:55 +00:00
fix: raise RuntimeError if for_failure in _analyze_task
This commit is contained in:
parent
ce782762e8
commit
3d9dfafabf
2 changed files with 884 additions and 45 deletions
|
|
@ -57,6 +57,8 @@ from app.utils.telemetry.workforce_metrics import WorkforceMetricsCallback
|
|||
|
||||
logger = logging.getLogger("workforce")
|
||||
|
||||
_ANALYZE_TASK_MAX_RETRIES = 3
|
||||
|
||||
|
||||
class Workforce(BaseWorkforce):
|
||||
def __init__(
|
||||
|
|
@ -110,56 +112,64 @@ class Workforce(BaseWorkforce):
|
|||
for_failure: bool,
|
||||
error_message: str | None = None,
|
||||
) -> TaskAnalysisResult:
|
||||
"""Override to add debugging for None return issue.
|
||||
"""Override to retry when the base class returns None.
|
||||
|
||||
The base class should never return None, but we're seeing it happen.
|
||||
This override adds logging to help diagnose the root cause.
|
||||
The base class can return None when the LLM fails to produce
|
||||
valid structured output. We retry up to _ANALYZE_TASK_MAX_RETRIES
|
||||
times before falling back.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[WF-DEBUG] _analyze_task called: task_id={task.id}, "
|
||||
f"for_failure={for_failure}, "
|
||||
f"use_structured_output_handler={self.use_structured_output_handler}"
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(1, _ANALYZE_TASK_MAX_RETRIES + 1):
|
||||
try:
|
||||
result = super()._analyze_task(
|
||||
task,
|
||||
for_failure=for_failure,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
logger.warning(
|
||||
f"[WF-RETRY] _analyze_task returned None "
|
||||
f"(attempt {attempt}/{_ANALYZE_TASK_MAX_RETRIES}), "
|
||||
f"task_id={task.id}, for_failure={for_failure}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(
|
||||
f"[WF-RETRY] _analyze_task raised "
|
||||
f"{type(e).__name__}: {e} "
|
||||
f"(attempt {attempt}/{_ANALYZE_TASK_MAX_RETRIES}), "
|
||||
f"task_id={task.id}, for_failure={for_failure}"
|
||||
)
|
||||
|
||||
# All retries exhausted
|
||||
logger.error(
|
||||
f"[WF-BUG] _analyze_task failed after "
|
||||
f"{_ANALYZE_TASK_MAX_RETRIES} retries, "
|
||||
f"task_id={task.id}, for_failure={for_failure}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = super()._analyze_task(
|
||||
task, for_failure=for_failure, error_message=error_message
|
||||
)
|
||||
if for_failure:
|
||||
# Task already failed + analysis failed — raise to halt
|
||||
raise RuntimeError(
|
||||
f"_analyze_task returned None after "
|
||||
f"{_ANALYZE_TASK_MAX_RETRIES} retries for "
|
||||
f"failed task {task.id}"
|
||||
) from last_exception
|
||||
|
||||
logger.debug(
|
||||
f"[WF-DEBUG] _analyze_task result: type={type(result)}, "
|
||||
f"value={result}"
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# This should never happen - log detailed info
|
||||
logger.error(
|
||||
f"[WF-BUG] _analyze_task returned None unexpectedly! "
|
||||
f"task_id={task.id}, for_failure={for_failure}, "
|
||||
f"use_structured_output_handler="
|
||||
f"{self.use_structured_output_handler}"
|
||||
)
|
||||
# Return fallback to prevent crash
|
||||
if for_failure:
|
||||
return TaskAnalysisResult(
|
||||
reasoning="BUG: _analyze_task returned None",
|
||||
recovery_strategy="retry",
|
||||
issues=[error_message] if error_message else [],
|
||||
)
|
||||
else:
|
||||
return TaskAnalysisResult(
|
||||
reasoning="BUG: _analyze_task returned None",
|
||||
quality_score=80,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[WF-DEBUG] _analyze_task exception: {type(e).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
# Quality evaluation failed — accept the task result as-is
|
||||
return TaskAnalysisResult(
|
||||
reasoning=(
|
||||
f"_analyze_task returned None after "
|
||||
f"{_ANALYZE_TASK_MAX_RETRIES} retries, "
|
||||
f"accepting task result"
|
||||
),
|
||||
quality_score=80,
|
||||
)
|
||||
|
||||
def eigent_make_sub_tasks(
|
||||
self,
|
||||
|
|
|
|||
829
backend/tests/app/utils/test_workforce.py
Normal file
829
backend/tests/app/utils/test_workforce.py
Normal file
|
|
@ -0,0 +1,829 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from camel.societies.workforce.utils import (
|
||||
TaskAnalysisResult,
|
||||
TaskAssignment,
|
||||
TaskAssignResult,
|
||||
)
|
||||
from camel.societies.workforce.workforce import (
|
||||
Workforce as BaseWorkforce,
|
||||
WorkforceState,
|
||||
)
|
||||
from camel.tasks import Task
|
||||
from camel.tasks.task import TaskState
|
||||
|
||||
from app.agent.listen_chat_agent import ListenChatAgent
|
||||
from app.exception.exception import UserException
|
||||
from app.service.task import (
|
||||
ActionAssignTaskData,
|
||||
ActionTaskStateData,
|
||||
create_task_lock,
|
||||
)
|
||||
from app.utils.workforce import _ANALYZE_TASK_MAX_RETRIES, Workforce
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_model_factory():
|
||||
"""Prevent ChatAgent from requiring real API keys during Workforce init."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_type = MagicMock()
|
||||
mock_model.model_config_dict = {}
|
||||
with patch("camel.models.ModelFactory.create", return_value=mock_model):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_workforce_initialization():
|
||||
"""Test Workforce initialization with default settings."""
|
||||
api_task_id = "test_api_task_123"
|
||||
description = "Test workforce"
|
||||
|
||||
workforce = Workforce(api_task_id=api_task_id, description=description)
|
||||
|
||||
assert workforce.api_task_id == api_task_id
|
||||
assert workforce.description == description
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eigent_make_sub_tasks_success():
|
||||
"""Test eigent_make_sub_tasks successfully decomposes task."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
task = Task(content="Create a web application", id="main_task")
|
||||
|
||||
subtask1 = Task(content="Setup project structure", id="subtask_1")
|
||||
subtask2 = Task(content="Implement authentication", id="subtask_2")
|
||||
mock_subtasks = [subtask1, subtask2]
|
||||
|
||||
async def fake_decompose(*args, **kwargs):
|
||||
return mock_subtasks
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
workforce,
|
||||
"handle_decompose_append_task",
|
||||
side_effect=fake_decompose,
|
||||
),
|
||||
patch("app.utils.workforce.validate_task_content", return_value=True),
|
||||
):
|
||||
result = workforce.eigent_make_sub_tasks(task)
|
||||
|
||||
assert result == mock_subtasks
|
||||
assert workforce._task is task
|
||||
assert workforce._state == WorkforceState.RUNNING
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eigent_make_sub_tasks_with_streaming_decomposition():
|
||||
"""Test eigent_make_sub_tasks with streaming decomposition result."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
task = Task(content="Complex project task", id="main_task")
|
||||
|
||||
def mock_streaming_decomposition():
|
||||
yield [Task(content="Phase 1", id="phase_1")]
|
||||
yield [Task(content="Phase 2", id="phase_2")]
|
||||
yield [Task(content="Phase 3", id="phase_3")]
|
||||
|
||||
with (
|
||||
patch.object(workforce, "reset"),
|
||||
patch.object(workforce, "set_channel"),
|
||||
patch.object(
|
||||
workforce,
|
||||
"_decompose_task",
|
||||
return_value=mock_streaming_decomposition(),
|
||||
),
|
||||
patch("app.utils.workforce.validate_task_content", return_value=True),
|
||||
):
|
||||
result = workforce.eigent_make_sub_tasks(task)
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(subtask, Task) for subtask in result)
|
||||
assert result[0].content == "Phase 1"
|
||||
assert result[1].content == "Phase 2"
|
||||
assert result[2].content == "Phase 3"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eigent_make_sub_tasks_invalid_content():
|
||||
"""Test eigent_make_sub_tasks with invalid task content."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
task = Task(content="", id="invalid_task")
|
||||
|
||||
with patch(
|
||||
"app.utils.workforce.validate_task_content", return_value=False
|
||||
):
|
||||
with pytest.raises(UserException):
|
||||
workforce.eigent_make_sub_tasks(task)
|
||||
|
||||
assert task.state == TaskState.FAILED
|
||||
assert "Invalid or empty content" in task.result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_eigent_start_success():
|
||||
"""Test eigent_start successfully starts workforce."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
subtasks = [
|
||||
Task(content="Subtask 1", id="sub_1"),
|
||||
Task(content="Subtask 2", id="sub_2"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(workforce, "start", new_callable=AsyncMock) as mock_start,
|
||||
patch.object(workforce, "save_snapshot") as mock_save_snapshot,
|
||||
):
|
||||
await workforce.eigent_start(subtasks)
|
||||
|
||||
assert len(workforce._pending_tasks) >= len(subtasks)
|
||||
|
||||
mock_save_snapshot.assert_called_once_with(
|
||||
"Initial task decomposition"
|
||||
)
|
||||
mock_start.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_eigent_start_with_exception():
|
||||
"""Test eigent_start handles exceptions properly."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
subtasks = [Task(content="Subtask 1", id="sub_1")]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
workforce,
|
||||
"start",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Workforce start failed"),
|
||||
),
|
||||
patch.object(workforce, "save_snapshot"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Workforce start failed"):
|
||||
await workforce.eigent_start(subtasks)
|
||||
|
||||
assert workforce._state == WorkforceState.STOPPED
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_assignee_with_notifications(mock_task_lock):
|
||||
"""Test _find_assignee sends proper task assignment notifications."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
main_task = Task(content="Main task", id="main")
|
||||
subtask1 = Task(content="Subtask 1", id="sub_1")
|
||||
workforce._task = main_task
|
||||
|
||||
tasks = [main_task, subtask1]
|
||||
|
||||
assignments = [
|
||||
TaskAssignment(
|
||||
task_id="main", assignee_id="coordinator", dependencies=[]
|
||||
),
|
||||
TaskAssignment(
|
||||
task_id="sub_1", assignee_id="worker_node_1", dependencies=[]
|
||||
),
|
||||
]
|
||||
mock_assign_result = TaskAssignResult(assignments=assignments)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.utils.workforce.get_task_lock",
|
||||
return_value=mock_task_lock,
|
||||
),
|
||||
patch(
|
||||
"app.utils.workforce.get_camel_task",
|
||||
side_effect=lambda task_id, task_list: next(
|
||||
(t for t in task_list if t.id == task_id), None
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_find_assignee",
|
||||
return_value=mock_assign_result,
|
||||
),
|
||||
patch.object(
|
||||
workforce,
|
||||
"_get_agent_id_from_node_id",
|
||||
return_value="agent_1",
|
||||
),
|
||||
):
|
||||
result = await workforce._find_assignee(tasks)
|
||||
|
||||
assert result is mock_assign_result
|
||||
# put_queue is called via asyncio.create_task, so check
|
||||
# add_background_task was called (tracks the async task)
|
||||
assert mock_task_lock.add_background_task.call_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_task_notification(mock_task_lock):
|
||||
"""Test _post_task sends running state notification."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
main_task = Task(content="Main task", id="main")
|
||||
subtask = Task(content="Subtask", id="sub_1")
|
||||
workforce._task = main_task
|
||||
|
||||
assignee_id = "worker_node_1"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.utils.workforce.get_task_lock",
|
||||
return_value=mock_task_lock,
|
||||
),
|
||||
patch.object(
|
||||
workforce,
|
||||
"_get_agent_id_from_node_id",
|
||||
return_value="agent_1",
|
||||
),
|
||||
patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_post_task",
|
||||
return_value=None,
|
||||
) as mock_super_post,
|
||||
):
|
||||
await workforce._post_task(subtask, assignee_id)
|
||||
|
||||
mock_task_lock.put_queue.assert_called_once()
|
||||
call_args = mock_task_lock.put_queue.call_args[0][0]
|
||||
assert isinstance(call_args, ActionAssignTaskData)
|
||||
assert call_args.data["assignee_id"] == "agent_1"
|
||||
assert call_args.data["task_id"] == "sub_1"
|
||||
assert call_args.data["state"] == "running"
|
||||
|
||||
mock_super_post.assert_called_once_with(subtask, assignee_id)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_single_agent_worker_success():
|
||||
"""Test add_single_agent_worker successfully adds worker."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
mock_worker = MagicMock(spec=ListenChatAgent)
|
||||
mock_worker.agent_id = "test_worker_123"
|
||||
mock_worker.agent_name = "test_worker"
|
||||
description = "Test worker description"
|
||||
|
||||
with (
|
||||
patch.object(workforce, "_validate_agent_compatibility"),
|
||||
patch.object(workforce, "_attach_pause_event_to_agent"),
|
||||
patch.object(workforce, "_start_child_node_when_paused"),
|
||||
):
|
||||
result = workforce.add_single_agent_worker(
|
||||
description, mock_worker, pool_max_size=5
|
||||
)
|
||||
|
||||
assert result is workforce
|
||||
assert len(workforce._children) == 1
|
||||
|
||||
added_worker = workforce._children[0]
|
||||
assert hasattr(added_worker, "worker")
|
||||
assert added_worker.worker is mock_worker
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_single_agent_worker_while_running():
|
||||
"""Test add_single_agent_worker raises error when workforce is running."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
workforce._state = WorkforceState.RUNNING
|
||||
|
||||
mock_worker = MagicMock(spec=ListenChatAgent)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Cannot add workers while workforce is running"
|
||||
):
|
||||
workforce.add_single_agent_worker("Test worker", mock_worker)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_completed_task(mock_task_lock):
|
||||
"""Test _handle_completed_task sends completion notification."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
task = Task(content="Completed task", id="completed_123")
|
||||
task.state = TaskState.DONE
|
||||
task.result = "Task completed successfully"
|
||||
task.failure_count = 0
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.utils.workforce.get_task_lock",
|
||||
return_value=mock_task_lock,
|
||||
),
|
||||
patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_handle_completed_task",
|
||||
return_value=None,
|
||||
) as mock_super_handle,
|
||||
):
|
||||
await workforce._handle_completed_task(task)
|
||||
|
||||
mock_task_lock.put_queue.assert_called_once()
|
||||
call_args = mock_task_lock.put_queue.call_args[0][0]
|
||||
assert isinstance(call_args, ActionTaskStateData)
|
||||
assert call_args.data["task_id"] == "completed_123"
|
||||
assert call_args.data["state"] == TaskState.DONE
|
||||
assert call_args.data["result"] == "Task completed successfully"
|
||||
|
||||
mock_super_handle.assert_called_once_with(task)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_failed_task(mock_task_lock):
|
||||
"""Test _handle_failed_task sends failure notification after max retries."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
# Clear callbacks to isolate notification test from metrics logging
|
||||
workforce._callbacks = []
|
||||
|
||||
task = Task(content="Failed task", id="failed_123")
|
||||
task.state = TaskState.FAILED
|
||||
# failure_count must be >= max_retries (default 3) for notification
|
||||
task.failure_count = 3
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.utils.workforce.get_task_lock",
|
||||
return_value=mock_task_lock,
|
||||
),
|
||||
patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_handle_failed_task",
|
||||
return_value=True,
|
||||
) as mock_super_handle,
|
||||
):
|
||||
result = await workforce._handle_failed_task(task)
|
||||
|
||||
assert result is True
|
||||
|
||||
mock_task_lock.put_queue.assert_called_once()
|
||||
call_args = mock_task_lock.put_queue.call_args[0][0]
|
||||
assert isinstance(call_args, ActionTaskStateData)
|
||||
assert call_args.data["task_id"] == "failed_123"
|
||||
assert call_args.data["state"] == TaskState.FAILED
|
||||
assert call_args.data["failure_count"] == 3
|
||||
|
||||
mock_super_handle.assert_called_once_with(task)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sends_end_notification(mock_task_lock):
|
||||
"""Test stop method sends end notification."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.utils.workforce.get_task_lock",
|
||||
return_value=mock_task_lock,
|
||||
),
|
||||
patch.object(
|
||||
workforce.__class__.__bases__[0], "stop"
|
||||
) as mock_super_stop,
|
||||
):
|
||||
workforce.stop()
|
||||
|
||||
mock_super_stop.assert_called_once()
|
||||
assert mock_task_lock.add_background_task.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_deletes_task_lock():
|
||||
"""Test cleanup method deletes task lock."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
with patch("app.service.task.delete_task_lock") as mock_delete:
|
||||
await workforce.cleanup()
|
||||
|
||||
mock_delete.assert_called_once_with(api_task_id)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_handles_exception():
|
||||
"""Test cleanup handles exceptions gracefully."""
|
||||
api_task_id = "test_api_task_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Test workforce"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.service.task.delete_task_lock",
|
||||
side_effect=Exception("Delete failed"),
|
||||
):
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eigent_make_sub_tasks_with_none_task():
|
||||
"""Test eigent_make_sub_tasks with None task."""
|
||||
api_task_id = "error_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Error test workforce"
|
||||
)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
workforce.eigent_make_sub_tasks(None)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_eigent_make_sub_tasks_with_malformed_task():
|
||||
"""Test eigent_make_sub_tasks with malformed task object."""
|
||||
api_task_id = "error_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Error test workforce"
|
||||
)
|
||||
|
||||
fake_task = MagicMock()
|
||||
fake_task.content = "Fake task content"
|
||||
fake_task.id = "fake_task"
|
||||
|
||||
with patch(
|
||||
"app.utils.workforce.validate_task_content", return_value=False
|
||||
):
|
||||
with pytest.raises(UserException):
|
||||
workforce.eigent_make_sub_tasks(fake_task)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_eigent_start_with_empty_subtasks():
|
||||
"""Test eigent_start with empty subtasks list."""
|
||||
api_task_id = "empty_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Empty test workforce"
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workforce, "start", new_callable=AsyncMock),
|
||||
patch.object(workforce, "save_snapshot"),
|
||||
):
|
||||
await workforce.eigent_start([])
|
||||
|
||||
workforce.start.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_single_agent_worker_with_invalid_worker():
|
||||
"""Test add_single_agent_worker with invalid worker object."""
|
||||
api_task_id = "invalid_worker_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id,
|
||||
description="Invalid worker test workforce",
|
||||
)
|
||||
|
||||
invalid_worker = "not_an_agent"
|
||||
|
||||
with patch.object(
|
||||
workforce,
|
||||
"_validate_agent_compatibility",
|
||||
side_effect=ValueError("Invalid agent"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid agent"):
|
||||
workforce.add_single_agent_worker("Invalid worker", invalid_worker)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_assignee_with_get_task_lock_failure():
|
||||
"""Test _find_assignee when get_task_lock fails after parent method succeeds."""
|
||||
api_task_id = "lock_fail_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Lock fail test workforce"
|
||||
)
|
||||
|
||||
tasks = [Task(content="Test task", id="test")]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_find_assignee",
|
||||
return_value=TaskAssignResult(assignments=[]),
|
||||
) as mock_super_find,
|
||||
patch(
|
||||
"app.utils.workforce.get_task_lock",
|
||||
side_effect=Exception("Task lock not found"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match="Task lock not found"):
|
||||
await workforce._find_assignee(tasks)
|
||||
|
||||
mock_super_find.assert_called_once_with(tasks)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_with_nonexistent_task_lock():
|
||||
"""Test cleanup when task lock doesn't exist."""
|
||||
api_task_id = "nonexistent_lock_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id,
|
||||
description="Nonexistent lock test workforce",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.service.task.delete_task_lock",
|
||||
side_effect=Exception("Task lock not found"),
|
||||
):
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workforce tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_workforce_inheritance():
|
||||
"""Test that Workforce properly inherits from BaseWorkforce."""
|
||||
api_task_id = "inheritance_test_123"
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Inheritance test workforce"
|
||||
)
|
||||
|
||||
assert isinstance(workforce, BaseWorkforce)
|
||||
assert hasattr(workforce, "api_task_id")
|
||||
assert workforce.api_task_id == api_task_id
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_workforce_lifecycle():
|
||||
"""Test complete workforce lifecycle from creation to cleanup."""
|
||||
api_task_id = "integration_test_123"
|
||||
|
||||
create_task_lock(api_task_id)
|
||||
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Integration test workforce"
|
||||
)
|
||||
|
||||
main_task = Task(content="Integration test task", id="main_task")
|
||||
|
||||
subtasks = [
|
||||
Task(content="Setup", id="setup_task"),
|
||||
Task(content="Implementation", id="impl_task"),
|
||||
Task(content="Testing", id="test_task"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.utils.workforce.validate_task_content", return_value=True),
|
||||
# Mock asyncio.run to return subtasks directly (can't nest event loops)
|
||||
patch("app.utils.workforce.asyncio.run", return_value=subtasks),
|
||||
patch.object(workforce, "start", new_callable=AsyncMock),
|
||||
):
|
||||
result_subtasks = workforce.eigent_make_sub_tasks(main_task)
|
||||
assert len(result_subtasks) == 3
|
||||
|
||||
await workforce.eigent_start(result_subtasks)
|
||||
|
||||
mock_worker = MagicMock(spec=ListenChatAgent)
|
||||
mock_worker.agent_id = "integration_worker_123"
|
||||
mock_worker.agent_name = "integration_worker"
|
||||
with (
|
||||
patch.object(workforce, "_validate_agent_compatibility"),
|
||||
patch.object(workforce, "_attach_pause_event_to_agent"),
|
||||
patch.object(workforce, "_start_child_node_when_paused"),
|
||||
):
|
||||
workforce.add_single_agent_worker(
|
||||
"Integration worker", mock_worker
|
||||
)
|
||||
|
||||
assert len(workforce._children) == 1
|
||||
|
||||
with patch.object(workforce.__class__.__bases__[0], "stop"):
|
||||
workforce.stop()
|
||||
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_workforce_with_multiple_workers():
|
||||
"""Test workforce with multiple workers."""
|
||||
api_task_id = "multi_worker_test_123"
|
||||
|
||||
create_task_lock(api_task_id)
|
||||
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Multi-worker test workforce"
|
||||
)
|
||||
|
||||
workers = []
|
||||
for i in range(3):
|
||||
mock_worker = MagicMock(spec=ListenChatAgent)
|
||||
mock_worker.role_name = f"worker_{i}"
|
||||
mock_worker.agent_id = f"worker_{i}_123"
|
||||
mock_worker.agent_name = f"worker_{i}"
|
||||
workers.append(mock_worker)
|
||||
|
||||
with (
|
||||
patch.object(workforce, "_validate_agent_compatibility"),
|
||||
patch.object(workforce, "_attach_pause_event_to_agent"),
|
||||
patch.object(workforce, "_start_child_node_when_paused"),
|
||||
):
|
||||
for i, worker in enumerate(workers):
|
||||
workforce.add_single_agent_worker(f"Worker {i}", worker)
|
||||
|
||||
assert len(workforce._children) == 3
|
||||
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_workforce_task_state_tracking():
|
||||
"""Test workforce properly tracks task state changes."""
|
||||
api_task_id = "task_tracking_test_123"
|
||||
|
||||
create_task_lock(api_task_id)
|
||||
|
||||
workforce = Workforce(
|
||||
api_task_id=api_task_id, description="Task tracking test workforce"
|
||||
)
|
||||
|
||||
completed_task = Task(content="Completed task", id="completed")
|
||||
completed_task.state = TaskState.DONE
|
||||
completed_task.result = "Success"
|
||||
|
||||
with patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_handle_completed_task",
|
||||
return_value=None,
|
||||
):
|
||||
await workforce._handle_completed_task(completed_task)
|
||||
|
||||
failed_task = Task(content="Failed task", id="failed")
|
||||
failed_task.state = TaskState.FAILED
|
||||
failed_task.failure_count = 1
|
||||
|
||||
with patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_handle_failed_task",
|
||||
return_value=True,
|
||||
):
|
||||
result = await workforce._handle_failed_task(failed_task)
|
||||
assert result is True
|
||||
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _analyze_task retry logic tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_analyze_task_returns_valid_result_on_first_attempt():
|
||||
"""Test _analyze_task returns immediately when base class succeeds."""
|
||||
workforce = Workforce(api_task_id="test_123", description="Test workforce")
|
||||
task = Task(content="Test task", id="task_1")
|
||||
|
||||
expected = TaskAnalysisResult(reasoning="looks good", quality_score=90)
|
||||
|
||||
with patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_analyze_task",
|
||||
return_value=expected,
|
||||
) as mock_super:
|
||||
result = workforce._analyze_task(task, for_failure=False)
|
||||
|
||||
assert result is expected
|
||||
assert mock_super.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize(
|
||||
"side_effect, expected_calls",
|
||||
[
|
||||
([None, "success"], 2),
|
||||
([ValueError("err"), "success"], 2),
|
||||
([None, ValueError("err"), "success"], 3),
|
||||
],
|
||||
ids=["none_then_success", "exception_then_success", "mixed_then_success"],
|
||||
)
|
||||
def test_analyze_task_retries_and_succeeds(side_effect, expected_calls):
|
||||
"""Test _analyze_task retries on None/exception and succeeds."""
|
||||
workforce = Workforce(api_task_id="test_123", description="Test workforce")
|
||||
task = Task(content="Test task", id="task_1")
|
||||
|
||||
expected = TaskAnalysisResult(reasoning="recovered", quality_score=85)
|
||||
# Replace "success" sentinel with actual result
|
||||
resolved = [expected if s == "success" else s for s in side_effect]
|
||||
|
||||
with patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_analyze_task",
|
||||
side_effect=resolved,
|
||||
) as mock_super:
|
||||
result = workforce._analyze_task(task, for_failure=False)
|
||||
|
||||
assert result is expected
|
||||
assert mock_super.call_count == expected_calls
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_analyze_task_quality_eval_fallback_after_exhausted_retries():
|
||||
"""Test _analyze_task returns quality_score=80 fallback when all retries
|
||||
fail for quality evaluation (for_failure=False)."""
|
||||
workforce = Workforce(api_task_id="test_123", description="Test workforce")
|
||||
task = Task(content="Test task", id="task_1")
|
||||
|
||||
with patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_analyze_task",
|
||||
return_value=None,
|
||||
) as mock_super:
|
||||
result = workforce._analyze_task(task, for_failure=False)
|
||||
|
||||
assert mock_super.call_count == _ANALYZE_TASK_MAX_RETRIES
|
||||
assert result.quality_score == 80
|
||||
assert "retries" in result.reasoning
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_analyze_task_failure_raises_after_exhausted_retries():
|
||||
"""Test _analyze_task raises RuntimeError when all retries fail
|
||||
for failure analysis (for_failure=True)."""
|
||||
workforce = Workforce(api_task_id="test_123", description="Test workforce")
|
||||
task = Task(content="Test task", id="task_1")
|
||||
|
||||
with patch.object(
|
||||
workforce.__class__.__bases__[0],
|
||||
"_analyze_task",
|
||||
return_value=None,
|
||||
) as mock_super:
|
||||
with pytest.raises(RuntimeError, match="returned None after"):
|
||||
workforce._analyze_task(
|
||||
task, for_failure=True, error_message="task crashed"
|
||||
)
|
||||
|
||||
assert mock_super.call_count == _ANALYZE_TASK_MAX_RETRIES
|
||||
Loading…
Add table
Add a link
Reference in a new issue