eigent/backend/app/service/task.py
Tong Chen 09200a8cf6
Some checks are pending
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (javascript-typescript) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Pre-commit / pre-commit (push) Waiting to run
Test / Run Python Tests (push) Waiting to run
fix: use current attachments for follow-up questions instead of first… (#1167)
Co-authored-by: Wendong-Fan <w3ndong.fan@gmail.com>
Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
2026-02-06 20:48:36 +08:00

676 lines
20 KiB
Python

# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
import asyncio
import logging
import weakref
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Literal
from camel.tasks import Task
from pydantic import BaseModel
from typing_extensions import TypedDict
from app.exception.exception import ProgramException
from app.model.chat import (
AgentModelConfig,
McpServers,
SupplementChat,
UpdateData,
)
from app.model.enums import Status
logger = logging.getLogger("task_service")
class Action(str, Enum):
improve = "improve" # user -> backend
update_task = "update_task" # user -> backend
task_state = "task_state" # backend -> user
new_task_state = "new_task_state" # backend -> user
# backend -> user (streaming decomposition)
decompose_progress = "decompose_progress"
decompose_text = "decompose_text" # backend -> user (raw streaming text)
start = "start" # user -> backend
create_agent = "create_agent" # backend -> user
activate_agent = "activate_agent" # backend -> user
deactivate_agent = "deactivate_agent" # backend -> user
assign_task = "assign_task" # backend -> user
activate_toolkit = "activate_toolkit" # backend -> user
deactivate_toolkit = "deactivate_toolkit" # backend -> user
write_file = "write_file" # backend -> user
ask = "ask" # backend -> user
notice = "notice" # backend -> user
search_mcp = "search_mcp" # backend -> user
install_mcp = "install_mcp" # backend -> user
terminal = "terminal" # backend -> user
end = "end" # backend -> user
stop = "stop" # user -> backend
supplement = "supplement" # user -> backend
pause = "pause" # user -> backend user take control
resume = "resume" # user -> backend user take control
new_agent = "new_agent" # user -> backend
budget_not_enough = "budget_not_enough" # backend -> user
add_task = "add_task" # user -> backend
remove_task = "remove_task" # user -> backend
skip_task = "skip_task" # user -> backend
timeout = "timeout" # backend -> user (task timeout error)
class ImprovePayload(BaseModel):
"""User input payload for an improve action."""
question: str
attaches: list[str] = []
class ActionImproveData(BaseModel):
action: Literal[Action.improve] = Action.improve
data: ImprovePayload
new_task_id: str | None = None
class ActionStartData(BaseModel):
action: Literal[Action.start] = Action.start
class ActionUpdateTaskData(BaseModel):
action: Literal[Action.update_task] = Action.update_task
data: UpdateData
class ActionTaskStateData(BaseModel):
action: Literal[Action.task_state] = Action.task_state
data: dict[
Literal["task_id", "content", "state", "result", "failure_count"],
str | int,
]
class ActionDecomposeProgressData(BaseModel):
action: Literal[Action.decompose_progress] = Action.decompose_progress
data: dict
class ActionDecomposeTextData(BaseModel):
action: Literal[Action.decompose_text] = Action.decompose_text
data: dict
class ActionNewTaskStateData(BaseModel):
action: Literal[Action.new_task_state] = Action.new_task_state
data: dict[
Literal["task_id", "content", "state", "result", "failure_count"],
str | int,
]
class ActionAskData(BaseModel):
action: Literal[Action.ask] = Action.ask
data: dict[Literal["question", "agent"], str]
class AgentDataDict(TypedDict):
agent_name: str
agent_id: str
tools: list[str]
class ActionCreateAgentData(BaseModel):
action: Literal[Action.create_agent] = Action.create_agent
data: AgentDataDict
class ActionActivateAgentData(BaseModel):
action: Literal[Action.activate_agent] = Action.activate_agent
data: dict[
Literal["agent_name", "process_task_id", "agent_id", "message"], str
]
class DataDict(TypedDict):
agent_name: str
agent_id: str
process_task_id: str
message: str
tokens: int
class ActionDeactivateAgentData(BaseModel):
action: Literal[Action.deactivate_agent] = Action.deactivate_agent
data: DataDict
class ActionAssignTaskData(BaseModel):
action: Literal[Action.assign_task] = Action.assign_task
data: dict[
Literal["assignee_id", "task_id", "content", "state", "failure_count"],
str | int,
]
class ActionActivateToolkitData(BaseModel):
action: Literal[Action.activate_toolkit] = Action.activate_toolkit
data: dict[
Literal[
"agent_name",
"toolkit_name",
"process_task_id",
"method_name",
"message",
],
str,
]
class ActionDeactivateToolkitData(BaseModel):
action: Literal[Action.deactivate_toolkit] = Action.deactivate_toolkit
data: dict[
Literal[
"agent_name",
"toolkit_name",
"process_task_id",
"method_name",
"message",
],
str,
]
class ActionWriteFileData(BaseModel):
action: Literal[Action.write_file] = Action.write_file
process_task_id: str
data: str
class ActionNoticeData(BaseModel):
action: Literal[Action.notice] = Action.notice
process_task_id: str
data: str
class ActionSearchMcpData(BaseModel):
action: Literal[Action.search_mcp] = Action.search_mcp
data: Any
class ActionInstallMcpData(BaseModel):
action: Literal[Action.install_mcp] = Action.install_mcp
data: McpServers
class ActionTerminalData(BaseModel):
action: Literal[Action.terminal] = Action.terminal
process_task_id: str
data: str
class ActionStopData(BaseModel):
action: Literal[Action.stop] = Action.stop
class ActionEndData(BaseModel):
action: Literal[Action.end] = Action.end
class ActionTimeoutData(BaseModel):
action: Literal[Action.timeout] = Action.timeout
data: dict[
Literal[
"message", "in_flight_tasks", "pending_tasks", "timeout_seconds"
],
str | int,
]
class ActionSupplementData(BaseModel):
action: Literal[Action.supplement] = Action.supplement
data: SupplementChat
class ActionTakeControl(BaseModel):
action: Literal[Action.pause, Action.resume]
class ActionNewAgent(BaseModel):
action: Literal[Action.new_agent] = Action.new_agent
name: str
description: str
tools: list[str]
mcp_tools: McpServers | None
custom_model_config: "AgentModelConfig | None" = None
class ActionBudgetNotEnough(BaseModel):
action: Literal[Action.budget_not_enough] = Action.budget_not_enough
class ActionAddTaskData(BaseModel):
action: Literal[Action.add_task] = Action.add_task
content: str
project_id: str | None = None
task_id: str | None = None
additional_info: dict | None = None
insert_position: int = -1
class ActionRemoveTaskData(BaseModel):
action: Literal[Action.remove_task] = Action.remove_task
task_id: str
project_id: str
class ActionSkipTaskData(BaseModel):
action: Literal[Action.skip_task] = Action.skip_task
project_id: str
ActionData = (
ActionImproveData
| ActionStartData
| ActionUpdateTaskData
| ActionTaskStateData
| ActionAskData
| ActionCreateAgentData
| ActionActivateAgentData
| ActionDeactivateAgentData
| ActionAssignTaskData
| ActionActivateToolkitData
| ActionDeactivateToolkitData
| ActionWriteFileData
| ActionNoticeData
| ActionSearchMcpData
| ActionInstallMcpData
| ActionTerminalData
| ActionStopData
| ActionEndData
| ActionTimeoutData
| ActionSupplementData
| ActionTakeControl
| ActionNewAgent
| ActionBudgetNotEnough
| ActionAddTaskData
| ActionRemoveTaskData
| ActionSkipTaskData
| ActionDecomposeTextData
| ActionDecomposeProgressData
)
class Agents(str, Enum):
task_agent = "task_agent"
coordinator_agent = "coordinator_agent"
new_worker_agent = "new_worker_agent"
developer_agent = "developer_agent"
browser_agent = "browser_agent"
document_agent = "document_agent"
multi_modal_agent = "multi_modal_agent"
social_media_agent = "social_media_agent"
mcp_agent = "mcp_agent"
class TaskLock:
id: str
status: Status = Status.confirming
active_agent: str = ""
mcp: list[str]
queue: asyncio.Queue[ActionData]
"""Queue monitoring for SSE response"""
human_input: dict[str, asyncio.Queue[str]]
"""After receiving user's reply, put the reply into the
corresponding agent's queue"""
created_at: datetime
last_accessed: datetime
background_tasks: set[asyncio.Task]
"""Track all background tasks for cleanup"""
registered_toolkits: list[Any]
"""Track toolkits for cleanup (e.g., TerminalToolkit venvs)"""
# Context management fields
conversation_history: list[dict[str, Any]]
"""Store conversation history for context"""
last_task_result: str
"""Store the last task execution result"""
question_agent: Any | None
"""Persistent question confirmation agent"""
summary_generated: bool
"""Track if summary has been generated for this project"""
current_task_id: str | None
"""Current task ID to be used in SSE responses"""
def __init__(
self, id: str, queue: asyncio.Queue, human_input: dict
) -> None:
self.id = id
self.queue = queue
self.human_input = human_input
self.created_at = datetime.now()
self.last_accessed = datetime.now()
self.background_tasks = set()
self.registered_toolkits = []
# Initialize context management fields
self.conversation_history = []
self.last_task_result = ""
self.last_task_summary = ""
self.question_agent = None
self.current_task_id = None
logger.info(
"Task lock initialized",
extra={"task_id": id, "created_at": self.created_at.isoformat()},
)
async def put_queue(self, data: ActionData):
self.last_accessed = datetime.now()
logger.debug(
"Adding item to task queue",
extra={"task_id": self.id, "action": data.action},
)
await self.queue.put(data)
async def get_queue(self):
self.last_accessed = datetime.now()
logger.debug(
"Getting item from task queue", extra={"task_id": self.id}
)
return await self.queue.get()
async def put_human_input(self, agent: str, data: Any = None):
logger.debug(
"Adding human input",
extra={
"task_id": self.id,
"agent": agent,
"has_data": data is not None,
},
)
await self.human_input[agent].put(data)
async def get_human_input(self, agent: str):
logger.debug(
"Getting human input", extra={"task_id": self.id, "agent": agent}
)
return await self.human_input[agent].get()
def add_human_input_listen(self, agent: str):
logger.debug(
"Adding human input listener",
extra={"task_id": self.id, "agent": agent},
)
self.human_input[agent] = asyncio.Queue(1)
def add_background_task(self, task: asyncio.Task) -> None:
r"""Add a task to track and clean up weak references"""
logger.debug(
"Adding background task",
extra={
"task_id": self.id,
"background_tasks_count": len(self.background_tasks),
},
)
self.background_tasks.add(task)
task.add_done_callback(lambda t: self.background_tasks.discard(t))
async def cleanup(self):
r"""Cancel all background tasks and clean up resources"""
logger.info(
"Starting task lock cleanup",
extra={
"task_id": self.id,
"background_tasks_count": len(self.background_tasks),
},
)
for task in list(self.background_tasks):
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.background_tasks.clear()
# Clean up registered toolkits (e.g., remove TerminalToolkit venvs)
for toolkit in self.registered_toolkits:
try:
if hasattr(toolkit, "cleanup"):
toolkit.cleanup()
logger.info(
"Toolkit cleanup completed",
extra={
"task_id": self.id,
"toolkit": type(toolkit).__name__,
},
)
except Exception as e:
logger.warning(
f"Failed to cleanup toolkit: {e}",
extra={
"task_id": self.id,
"toolkit": type(toolkit).__name__,
},
)
self.registered_toolkits.clear()
logger.info("Task lock cleanup completed", extra={"task_id": self.id})
def register_toolkit(self, toolkit: Any) -> None:
"""Register a toolkit for cleanup when task ends.
This is used to track toolkits that create resources (like venvs) that
should be cleaned up when the task is complete.
Note: Duplicate registrations of the same toolkit instance are ignored.
"""
# Prevent duplicate registration of the same toolkit instance
if any(t is toolkit for t in self.registered_toolkits):
logger.debug(
"Toolkit already registered, skipping",
extra={"task_id": self.id, "toolkit": type(toolkit).__name__},
)
return
self.registered_toolkits.append(toolkit)
logger.debug(
"Toolkit registered for cleanup",
extra={
"task_id": self.id,
"toolkit": type(toolkit).__name__,
"total_registered": len(self.registered_toolkits),
},
)
def add_conversation(self, role: str, content: str | dict):
"""Add a conversation entry to history"""
logger.debug(
"Adding conversation entry",
extra={
"task_id": self.id,
"role": role,
"content_length": len(str(content)),
},
)
self.conversation_history.append(
{
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
}
)
def get_recent_context(self, max_entries: int = None) -> str:
"""Get recent conversation context as a formatted string"""
if not self.conversation_history:
return ""
context = "=== Recent Conversation ===\n"
if max_entries is None:
history_to_use = self.conversation_history
else:
history_to_use = self.conversation_history[-max_entries:]
for entry in history_to_use:
context += f"{entry['role']}: {entry['content']}\n"
return context
task_locks = dict[str, TaskLock]()
# Cleanup task for removing stale task locks
_cleanup_task: asyncio.Task | None = None
task_index: dict[str, weakref.ref[Task]] = {}
def get_task_lock(id: str) -> TaskLock:
if id not in task_locks:
logger.error("Task lock not found", extra={"task_id": id})
raise ProgramException("Task not found")
logger.debug("Task lock retrieved", extra={"task_id": id})
return task_locks[id]
def get_task_lock_if_exists(id: str) -> TaskLock | None:
"""Get task lock if it exists, otherwise return None"""
return task_locks.get(id)
def set_current_task_id(project_id: str, task_id: str) -> None:
"""Set the current task ID for a project's task lock"""
task_lock = get_task_lock(project_id)
task_lock.current_task_id = task_id
logger.info(
"Updated current task ID",
extra={"project_id": project_id, "task_id": task_id},
)
def create_task_lock(id: str) -> TaskLock:
if id in task_locks:
logger.warning(
"Attempting to create task lock that already exists",
extra={"task_id": id},
)
raise ProgramException("Task already exists")
logger.info("Creating new task lock", extra={"task_id": id})
task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={})
# Start cleanup task if not running
# global _cleanup_task
# if _cleanup_task is None or _cleanup_task.done():
# _cleanup_task = asyncio.create_task(_periodic_cleanup())
logger.info(
"Task lock created successfully",
extra={"task_id": id, "total_task_locks": len(task_locks)},
)
return task_locks[id]
def get_or_create_task_lock(id: str) -> TaskLock:
"""Get existing task lock or create a new one if it doesn't exist"""
if id in task_locks:
logger.debug("Using existing task lock", extra={"task_id": id})
return task_locks[id]
logger.info("Task lock not found, creating new one", extra={"task_id": id})
return create_task_lock(id)
async def delete_task_lock(id: str):
if id not in task_locks:
logger.warning(
"Attempting to delete non-existent task lock",
extra={"task_id": id},
)
raise ProgramException("Task not found")
# Clean up background tasks before deletion
task_lock = task_locks[id]
logger.info(
"Cleaning up task lock",
extra={
"task_id": id,
"background_tasks": len(task_lock.background_tasks),
},
)
await task_lock.cleanup()
del task_locks[id]
logger.info(
"Task lock deleted successfully",
extra={"task_id": id, "remaining_task_locks": len(task_locks)},
)
def get_camel_task(id: str, tasks: list[Task]) -> None | Task:
if id in task_index:
task_ref = task_index[id]
task = task_ref()
if task is not None:
return task
else:
# Weak reference died, remove from index
del task_index[id]
# Fallback to search and rebuild index
for item in tasks:
# Add to index
task_index[item.id] = weakref.ref(item)
if item.id == id:
return item
else:
task = get_camel_task(id, item.subtasks)
if task is not None:
return task
return None
async def _periodic_cleanup():
r"""Periodically clean up stale task locks"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
current_time = datetime.now()
stale_timeout = timedelta(
hours=4
) # Consider tasks stale after 4 hours
stale_ids = []
for task_id, task_lock in task_locks.items():
if current_time - task_lock.last_accessed > stale_timeout:
stale_ids.append(task_id)
for task_id in stale_ids:
logger.warning(f"Cleaning up stale task lock: {task_id}")
await delete_task_lock(task_id)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in periodic cleanup: {e}")
process_task = ContextVar[str]("id")
@contextmanager
def set_process_task(process_task_id: str):
origin = process_task.set(process_task_id)
try:
yield
finally:
process_task.reset(origin)