mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-19 16:31:36 +00:00
refactor: format python backend code (#1132)
Co-authored-by: bytecii <bytecii@users.noreply.github.com> Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
This commit is contained in:
parent
649bdc6822
commit
42ce1d96be
45 changed files with 1639 additions and 826 deletions
42
.github/workflows/pre-commit.yml
vendored
42
.github/workflows/pre-commit.yml
vendored
|
|
@ -30,21 +30,31 @@ jobs:
|
|||
run: uv sync --group dev
|
||||
|
||||
- name: Run pre-commit
|
||||
run: >-
|
||||
uv run pre-commit run --files
|
||||
app/agent/README.md
|
||||
app/agent/__init__.py
|
||||
app/agent/factory/__init__.py
|
||||
app/agent/factory/social_media.py
|
||||
app/service/chat_service.py
|
||||
app/service/task.py
|
||||
app/utils/toolkit/google_calendar_toolkit.py
|
||||
app/utils/toolkit/google_gmail_mcp_toolkit.py
|
||||
app/utils/toolkit/linkedin_toolkit.py
|
||||
app/utils/toolkit/reddit_toolkit.py
|
||||
app/utils/toolkit/slack_toolkit.py
|
||||
app/utils/toolkit/twitter_toolkit.py
|
||||
app/utils/toolkit/whatsapp_toolkit.py
|
||||
tests/app/agent/factory/test_social_media.py
|
||||
run: |
|
||||
uv run pre-commit run --files \
|
||||
$(find \
|
||||
app/agent \
|
||||
app/controller \
|
||||
app/exception \
|
||||
app/middleware \
|
||||
app/model \
|
||||
app/service \
|
||||
tests/app \
|
||||
-type f ! -path '*__pycache__*') \
|
||||
app/__init__.py \
|
||||
app/router.py \
|
||||
app/component/__init__.py \
|
||||
app/component/pydantic/__init__.py \
|
||||
app/utils/listen/__init__.py \
|
||||
app/utils/server/__init__.py \
|
||||
app/utils/toolkit/__init__.py \
|
||||
app/utils/toolkit/google_calendar_toolkit.py \
|
||||
app/utils/toolkit/google_gmail_mcp_toolkit.py \
|
||||
app/utils/toolkit/linkedin_toolkit.py \
|
||||
app/utils/toolkit/reddit_toolkit.py \
|
||||
app/utils/toolkit/slack_toolkit.py \
|
||||
app/utils/toolkit/twitter_toolkit.py \
|
||||
app/utils/toolkit/whatsapp_toolkit.py \
|
||||
tests/conftest.py
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
|
|
|||
39
.github/workflows/test.yml
vendored
Normal file
39
.github/workflows/test.yml
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
name: Test
|
||||
|
||||
'on':
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: Run Python Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.10
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd backend
|
||||
uv sync
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
cd backend
|
||||
uv run pytest tests/app -v
|
||||
|
|
@ -20,5 +20,9 @@ api = FastAPI(title="Eigent Multi-Agent System API")
|
|||
|
||||
# Add CORS middleware
|
||||
api.add_middleware(
|
||||
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,21 +19,23 @@ import uuid
|
|||
from threading import Lock
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.agent.listen_chat_agent import ListenChatAgent, logger
|
||||
from app.model.chat import AgentModelConfig, Chat
|
||||
from app.service.task import ActionCreateAgentData, Agents, get_task_lock
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import ModelFactory
|
||||
from camel.toolkits import FunctionTool, RegisteredAgentToolkit
|
||||
from camel.types import ModelPlatformType
|
||||
|
||||
from app.agent.listen_chat_agent import ListenChatAgent, logger
|
||||
from app.model.chat import AgentModelConfig, Chat
|
||||
from app.service.task import ActionCreateAgentData, Agents, get_task_lock
|
||||
|
||||
# Thread-safe reference to main event loop using contextvars
|
||||
# This ensures each request has its own event loop reference,
|
||||
# avoiding race conditions
|
||||
_main_event_loop_var: contextvars.ContextVar[asyncio.AbstractEventLoop
|
||||
| None] = contextvars.ContextVar(
|
||||
"_main_event_loop",
|
||||
default=None)
|
||||
default=None
|
||||
)
|
||||
|
||||
# Global fallback for main event loop reference
|
||||
# Used when contextvars don't propagate to worker threads
|
||||
|
|
@ -77,10 +79,12 @@ def _schedule_async_task(coro):
|
|||
asyncio.run_coroutine_threadsafe(coro, main_loop)
|
||||
else:
|
||||
# This should not happen in normal operation - log error and skip
|
||||
logging.error("No event loop available for async task "
|
||||
"scheduling, task skipped. Ensure "
|
||||
"set_main_event_loop() is called "
|
||||
"before parallel agent creation.")
|
||||
logging.error(
|
||||
"No event loop available for async task "
|
||||
"scheduling, task skipped. Ensure "
|
||||
"set_main_event_loop() is called "
|
||||
"before parallel agent creation."
|
||||
)
|
||||
|
||||
|
||||
def agent_model(
|
||||
|
|
@ -96,8 +100,10 @@ def agent_model(
|
|||
):
|
||||
task_lock = get_task_lock(options.project_id)
|
||||
agent_id = str(uuid.uuid4())
|
||||
logger.info(f"Creating agent: {agent_name} with id: {agent_id} "
|
||||
f"for project: {options.project_id}")
|
||||
logger.info(
|
||||
f"Creating agent: {agent_name} with id: {agent_id} "
|
||||
f"for project: {options.project_id}"
|
||||
)
|
||||
# Use thread-safe scheduling to support parallel agent creation
|
||||
_schedule_async_task(
|
||||
task_lock.put_queue(
|
||||
|
|
@ -106,7 +112,10 @@ def agent_model(
|
|||
"agent_name": agent_name,
|
||||
"agent_id": agent_id,
|
||||
"tools": tool_names or [],
|
||||
})))
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Determine model configuration - use custom config if provided,
|
||||
# otherwise use task defaults
|
||||
|
|
@ -117,11 +126,14 @@ def agent_model(
|
|||
for attr in config_attrs:
|
||||
effective_config[attr] = getattr(custom_model_config, attr,
|
||||
None) or getattr(options, attr)
|
||||
extra_params = (custom_model_config.extra_params
|
||||
or options.extra_params or {})
|
||||
logger.info(f"Agent {agent_name} using custom model config: "
|
||||
f"platform={effective_config['model_platform']}, "
|
||||
f"type={effective_config['model_type']}")
|
||||
extra_params = (
|
||||
custom_model_config.extra_params or options.extra_params or {}
|
||||
)
|
||||
logger.info(
|
||||
f"Agent {agent_name} using custom model config: "
|
||||
f"platform={effective_config['model_platform']}, "
|
||||
f"type={effective_config['model_type']}"
|
||||
)
|
||||
else:
|
||||
for attr in config_attrs:
|
||||
effective_config[attr] = getattr(options, attr)
|
||||
|
|
@ -163,13 +175,14 @@ def agent_model(
|
|||
if agent_name == Agents.browser_agent:
|
||||
try:
|
||||
model_platform_enum = ModelPlatformType(
|
||||
effective_config["model_platform"].lower())
|
||||
effective_config["model_platform"].lower()
|
||||
)
|
||||
if model_platform_enum in {
|
||||
ModelPlatformType.OPENAI,
|
||||
ModelPlatformType.AZURE,
|
||||
ModelPlatformType.OPENAI_COMPATIBLE_MODEL,
|
||||
ModelPlatformType.LITELLM,
|
||||
ModelPlatformType.OPENROUTER,
|
||||
ModelPlatformType.OPENAI,
|
||||
ModelPlatformType.AZURE,
|
||||
ModelPlatformType.OPENAI_COMPATIBLE_MODEL,
|
||||
ModelPlatformType.LITELLM,
|
||||
ModelPlatformType.OPENROUTER,
|
||||
}:
|
||||
model_config["parallel_tool_calls"] = False
|
||||
except (ValueError, AttributeError):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@
|
|||
import platform
|
||||
import uuid
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
|
||||
from app.agent.agent_model import agent_model
|
||||
from app.agent.listen_chat_agent import logger
|
||||
from app.agent.prompt import BROWSER_SYS_PROMPT
|
||||
|
|
@ -25,21 +28,23 @@ from app.service.task import Agents
|
|||
from app.utils.file_utils import get_working_directory
|
||||
from app.utils.toolkit.human_toolkit import HumanToolkit
|
||||
from app.utils.toolkit.hybrid_browser_toolkit import HybridBrowserToolkit
|
||||
|
||||
# TODO: Remove NoteTakingToolkit and use TerminalToolkit instead
|
||||
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
|
||||
from app.utils.toolkit.search_toolkit import SearchToolkit
|
||||
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
|
||||
from camel.messages import BaseMessage
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
|
||||
|
||||
def browser_agent(options: Chat):
|
||||
working_directory = get_working_directory(options)
|
||||
logger.info(f"Creating browser agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}")
|
||||
logger.info(
|
||||
f"Creating browser agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}"
|
||||
)
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(
|
||||
options.project_id, Agents.browser_agent).send_message_to_user)
|
||||
message_handler=HumanToolkit(options.project_id, Agents.browser_agent
|
||||
).send_message_to_user
|
||||
)
|
||||
|
||||
web_toolkit_custom = HybridBrowserToolkit(
|
||||
options.project_id,
|
||||
|
|
@ -70,7 +75,8 @@ def browser_agent(options: Chat):
|
|||
# Save reference before registering for toolkits_to_register_agent
|
||||
web_toolkit_for_agent_registration = web_toolkit_custom
|
||||
web_toolkit_custom = message_integration.register_toolkits(
|
||||
web_toolkit_custom)
|
||||
web_toolkit_custom
|
||||
)
|
||||
|
||||
terminal_toolkit = TerminalToolkit(
|
||||
options.project_id,
|
||||
|
|
@ -80,11 +86,14 @@ def browser_agent(options: Chat):
|
|||
clone_current_env=True,
|
||||
)
|
||||
terminal_toolkit = message_integration.register_functions(
|
||||
[terminal_toolkit.shell_exec])
|
||||
[terminal_toolkit.shell_exec]
|
||||
)
|
||||
|
||||
note_toolkit = NoteTakingToolkit(options.project_id,
|
||||
Agents.browser_agent,
|
||||
working_directory=working_directory)
|
||||
note_toolkit = NoteTakingToolkit(
|
||||
options.project_id,
|
||||
Agents.browser_agent,
|
||||
working_directory=working_directory
|
||||
)
|
||||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
|
||||
search_tools = SearchToolkit.get_can_use_tools(options.project_id)
|
||||
|
|
@ -94,8 +103,8 @@ def browser_agent(options: Chat):
|
|||
search_tools = []
|
||||
|
||||
tools = [
|
||||
*HumanToolkit.get_can_use_tools(options.project_id,
|
||||
Agents.browser_agent),
|
||||
*HumanToolkit.
|
||||
get_can_use_tools(options.project_id, Agents.browser_agent),
|
||||
*web_toolkit_custom.get_tools(),
|
||||
*terminal_toolkit,
|
||||
*note_toolkit.get_tools(),
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@
|
|||
|
||||
import platform
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
|
||||
from app.agent.agent_model import agent_model
|
||||
from app.agent.listen_chat_agent import logger
|
||||
from app.agent.prompt import DEVELOPER_SYS_PROMPT
|
||||
|
|
@ -22,22 +25,25 @@ from app.model.chat import Chat
|
|||
from app.service.task import Agents
|
||||
from app.utils.file_utils import get_working_directory
|
||||
from app.utils.toolkit.human_toolkit import HumanToolkit
|
||||
|
||||
# TODO: Remove NoteTakingToolkit and use TerminalToolkit instead
|
||||
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
|
||||
from app.utils.toolkit.screenshot_toolkit import ScreenshotToolkit
|
||||
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
|
||||
from app.utils.toolkit.web_deploy_toolkit import WebDeployToolkit
|
||||
from camel.messages import BaseMessage
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
|
||||
|
||||
async def developer_agent(options: Chat):
|
||||
working_directory = get_working_directory(options)
|
||||
logger.info(f"Creating developer agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}")
|
||||
logger.info(
|
||||
f"Creating developer agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}"
|
||||
)
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(
|
||||
options.project_id, Agents.developer_agent).send_message_to_user)
|
||||
options.project_id, Agents.developer_agent
|
||||
).send_message_to_user
|
||||
)
|
||||
note_toolkit = NoteTakingToolkit(
|
||||
api_task_id=options.project_id,
|
||||
agent_name=Agents.developer_agent,
|
||||
|
|
@ -46,11 +52,14 @@ async def developer_agent(options: Chat):
|
|||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
web_deploy_toolkit = WebDeployToolkit(api_task_id=options.project_id)
|
||||
web_deploy_toolkit = message_integration.register_toolkits(
|
||||
web_deploy_toolkit)
|
||||
screenshot_toolkit = ScreenshotToolkit(options.project_id,
|
||||
working_directory=working_directory)
|
||||
web_deploy_toolkit
|
||||
)
|
||||
screenshot_toolkit = ScreenshotToolkit(
|
||||
options.project_id, working_directory=working_directory
|
||||
)
|
||||
screenshot_toolkit = message_integration.register_toolkits(
|
||||
screenshot_toolkit)
|
||||
screenshot_toolkit
|
||||
)
|
||||
|
||||
terminal_toolkit = TerminalToolkit(
|
||||
options.project_id,
|
||||
|
|
@ -62,8 +71,8 @@ async def developer_agent(options: Chat):
|
|||
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
|
||||
|
||||
tools = [
|
||||
*HumanToolkit.get_can_use_tools(options.project_id,
|
||||
Agents.developer_agent),
|
||||
*HumanToolkit.
|
||||
get_can_use_tools(options.project_id, Agents.developer_agent),
|
||||
*note_toolkit.get_tools(),
|
||||
*web_deploy_toolkit.get_tools(),
|
||||
*terminal_toolkit.get_tools(),
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
import platform
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
|
||||
from app.agent.agent_model import agent_model
|
||||
from app.agent.listen_chat_agent import logger
|
||||
from app.agent.prompt import DOCUMENT_SYS_PROMPT
|
||||
|
|
@ -25,36 +28,44 @@ from app.utils.toolkit.file_write_toolkit import FileToolkit
|
|||
from app.utils.toolkit.google_drive_mcp_toolkit import GoogleDriveMCPToolkit
|
||||
from app.utils.toolkit.human_toolkit import HumanToolkit
|
||||
from app.utils.toolkit.markitdown_toolkit import MarkItDownToolkit
|
||||
|
||||
# TODO: Remove NoteTakingToolkit and use TerminalToolkit instead
|
||||
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
|
||||
from app.utils.toolkit.pptx_toolkit import PPTXToolkit
|
||||
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
|
||||
from camel.messages import BaseMessage
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
|
||||
|
||||
async def document_agent(options: Chat):
|
||||
working_directory = get_working_directory(options)
|
||||
logger.info(f"Creating document agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}")
|
||||
logger.info(
|
||||
f"Creating document agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}"
|
||||
)
|
||||
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.project_id,
|
||||
Agents.task_agent).send_message_to_user)
|
||||
file_write_toolkit = FileToolkit(options.project_id,
|
||||
working_directory=working_directory)
|
||||
pptx_toolkit = PPTXToolkit(options.project_id,
|
||||
working_directory=working_directory)
|
||||
message_handler=HumanToolkit(options.project_id, Agents.task_agent
|
||||
).send_message_to_user
|
||||
)
|
||||
file_write_toolkit = FileToolkit(
|
||||
options.project_id, working_directory=working_directory
|
||||
)
|
||||
pptx_toolkit = PPTXToolkit(
|
||||
options.project_id, working_directory=working_directory
|
||||
)
|
||||
pptx_toolkit = message_integration.register_toolkits(pptx_toolkit)
|
||||
mark_it_down_toolkit = MarkItDownToolkit(options.project_id)
|
||||
mark_it_down_toolkit = message_integration.register_toolkits(
|
||||
mark_it_down_toolkit)
|
||||
excel_toolkit = ExcelToolkit(options.project_id,
|
||||
working_directory=working_directory)
|
||||
mark_it_down_toolkit
|
||||
)
|
||||
excel_toolkit = ExcelToolkit(
|
||||
options.project_id, working_directory=working_directory
|
||||
)
|
||||
excel_toolkit = message_integration.register_toolkits(excel_toolkit)
|
||||
note_toolkit = NoteTakingToolkit(options.project_id,
|
||||
Agents.document_agent,
|
||||
working_directory=working_directory)
|
||||
note_toolkit = NoteTakingToolkit(
|
||||
options.project_id,
|
||||
Agents.document_agent,
|
||||
working_directory=working_directory
|
||||
)
|
||||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
|
||||
terminal_toolkit = TerminalToolkit(
|
||||
|
|
@ -67,13 +78,14 @@ async def document_agent(options: Chat):
|
|||
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
|
||||
|
||||
google_drive_tools = await GoogleDriveMCPToolkit.get_can_use_tools(
|
||||
options.project_id, options.get_bun_env())
|
||||
options.project_id, options.get_bun_env()
|
||||
)
|
||||
|
||||
tools = [
|
||||
*file_write_toolkit.get_tools(),
|
||||
*pptx_toolkit.get_tools(),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id,
|
||||
Agents.document_agent),
|
||||
*HumanToolkit.
|
||||
get_can_use_tools(options.project_id, Agents.document_agent),
|
||||
*mark_it_down_toolkit.get_tools(),
|
||||
*excel_toolkit.get_tools(),
|
||||
*note_toolkit.get_tools(),
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from app.agent.listen_chat_agent import ListenChatAgent, logger
|
|||
from app.agent.prompt import MCP_SYS_PROMPT
|
||||
from app.agent.tools import get_mcp_tools
|
||||
from app.model.chat import Chat
|
||||
from app.service.task import Agents, ActionCreateAgentData, get_task_lock
|
||||
from app.service.task import ActionCreateAgentData, Agents, get_task_lock
|
||||
from app.utils.toolkit.mcp_search_toolkit import McpSearchToolkit
|
||||
|
||||
|
||||
|
|
@ -30,19 +30,22 @@ async def mcp_agent(options: Chat):
|
|||
f"with {len(options.installed_mcp['mcpServers'])} MCP servers"
|
||||
)
|
||||
tools = [
|
||||
# *HumanToolkit.get_can_use_tools(options.project_id, Agents.mcp_agent),
|
||||
*McpSearchToolkit(options.project_id).get_tools(),
|
||||
]
|
||||
if len(options.installed_mcp["mcpServers"]) > 0:
|
||||
try:
|
||||
mcp_tools = await get_mcp_tools(options.installed_mcp)
|
||||
logger.info(
|
||||
f"Retrieved {len(mcp_tools)} MCP tools for task {options.project_id}"
|
||||
f"Retrieved {len(mcp_tools)} MCP tools "
|
||||
f"for task {options.project_id}"
|
||||
)
|
||||
if mcp_tools:
|
||||
tool_names = [(tool.get_function_name() if hasattr(
|
||||
tool, "get_function_name") else str(tool))
|
||||
for tool in mcp_tools]
|
||||
tool_names = [
|
||||
(
|
||||
tool.get_function_name()
|
||||
if hasattr(tool, "get_function_name") else str(tool)
|
||||
) for tool in mcp_tools
|
||||
]
|
||||
logger.debug(f"MCP tools: {tool_names}")
|
||||
tools = [*tools, *mcp_tools]
|
||||
except Exception as e:
|
||||
|
|
@ -51,7 +54,8 @@ async def mcp_agent(options: Chat):
|
|||
task_lock = get_task_lock(options.project_id)
|
||||
agent_id = str(uuid.uuid4())
|
||||
logger.info(
|
||||
f"Creating MCP agent: {Agents.mcp_agent} with id: {agent_id} for task: {options.project_id}"
|
||||
f"Creating MCP agent: {Agents.mcp_agent} with id: "
|
||||
f"{agent_id} for task: {options.project_id}"
|
||||
)
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
|
|
@ -65,7 +69,10 @@ async def mcp_agent(options: Chat):
|
|||
key
|
||||
for key in options.installed_mcp["mcpServers"].keys()
|
||||
],
|
||||
})))
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
return ListenChatAgent(
|
||||
options.project_id,
|
||||
Agents.mcp_agent,
|
||||
|
|
@ -75,9 +82,11 @@ async def mcp_agent(options: Chat):
|
|||
model_type=options.model_type,
|
||||
api_key=options.api_key,
|
||||
url=options.api_url,
|
||||
model_config_dict=({
|
||||
"user": str(options.project_id),
|
||||
} if options.is_cloud() else None),
|
||||
model_config_dict=(
|
||||
{
|
||||
"user": str(options.project_id),
|
||||
} if options.is_cloud() else None
|
||||
),
|
||||
timeout=600, # 10 minutes
|
||||
**{
|
||||
k: v
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
import platform
|
||||
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import OpenAIAudioModels
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
from camel.types import ModelPlatformType
|
||||
|
||||
from app.agent.agent_model import agent_model
|
||||
from app.agent.listen_chat_agent import logger
|
||||
from app.agent.prompt import MULTI_MODAL_SYS_PROMPT
|
||||
|
|
@ -23,34 +28,37 @@ from app.utils.file_utils import get_working_directory
|
|||
from app.utils.toolkit.audio_analysis_toolkit import AudioAnalysisToolkit
|
||||
from app.utils.toolkit.human_toolkit import HumanToolkit
|
||||
from app.utils.toolkit.image_analysis_toolkit import ImageAnalysisToolkit
|
||||
|
||||
# TODO: Remove NoteTakingToolkit and use TerminalToolkit instead
|
||||
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
|
||||
from app.utils.toolkit.openai_image_toolkit import OpenAIImageToolkit
|
||||
from app.utils.toolkit.search_toolkit import SearchToolkit
|
||||
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
|
||||
from app.utils.toolkit.video_download_toolkit import VideoDownloaderToolkit
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import OpenAIAudioModels
|
||||
from camel.toolkits import ToolkitMessageIntegration
|
||||
from camel.types import ModelPlatformType
|
||||
|
||||
|
||||
def multi_modal_agent(options: Chat):
|
||||
working_directory = get_working_directory(options)
|
||||
logger.info(
|
||||
f"Creating multi-modal agent for project: {options.project_id} "
|
||||
f"in directory: {working_directory}")
|
||||
f"in directory: {working_directory}"
|
||||
)
|
||||
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(
|
||||
options.project_id, Agents.multi_modal_agent).send_message_to_user)
|
||||
options.project_id, Agents.multi_modal_agent
|
||||
).send_message_to_user
|
||||
)
|
||||
video_download_toolkit = VideoDownloaderToolkit(
|
||||
options.project_id, working_directory=working_directory)
|
||||
options.project_id, working_directory=working_directory
|
||||
)
|
||||
video_download_toolkit = message_integration.register_toolkits(
|
||||
video_download_toolkit)
|
||||
video_download_toolkit
|
||||
)
|
||||
image_analysis_toolkit = ImageAnalysisToolkit(options.project_id)
|
||||
image_analysis_toolkit = message_integration.register_toolkits(
|
||||
image_analysis_toolkit)
|
||||
image_analysis_toolkit
|
||||
)
|
||||
|
||||
terminal_toolkit = TerminalToolkit(
|
||||
options.project_id,
|
||||
|
|
@ -70,8 +78,8 @@ def multi_modal_agent(options: Chat):
|
|||
tools = [
|
||||
*video_download_toolkit.get_tools(),
|
||||
*image_analysis_toolkit.get_tools(),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id,
|
||||
Agents.multi_modal_agent),
|
||||
*HumanToolkit.
|
||||
get_can_use_tools(options.project_id, Agents.multi_modal_agent),
|
||||
*terminal_toolkit.get_tools(),
|
||||
*note_toolkit.get_tools(),
|
||||
]
|
||||
|
|
@ -88,7 +96,8 @@ def multi_modal_agent(options: Chat):
|
|||
url=options.api_url,
|
||||
)
|
||||
open_ai_image_toolkit = message_integration.register_toolkits(
|
||||
open_ai_image_toolkit)
|
||||
open_ai_image_toolkit
|
||||
)
|
||||
tools = [
|
||||
*tools,
|
||||
*open_ai_image_toolkit.get_tools(),
|
||||
|
|
@ -109,7 +118,8 @@ def multi_modal_agent(options: Chat):
|
|||
),
|
||||
)
|
||||
audio_analysis_toolkit = message_integration.register_toolkits(
|
||||
audio_analysis_toolkit)
|
||||
audio_analysis_toolkit
|
||||
)
|
||||
tools.extend(audio_analysis_toolkit.get_tools())
|
||||
|
||||
system_message = MULTI_MODAL_SYS_PROMPT.format(
|
||||
|
|
|
|||
|
|
@ -18,15 +18,12 @@ import logging
|
|||
from threading import Event
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
from app.service.task import (Action, ActionActivateAgentData,
|
||||
ActionActivateToolkitData, ActionBudgetNotEnough,
|
||||
ActionDeactivateAgentData,
|
||||
ActionDeactivateToolkitData, get_task_lock,
|
||||
set_process_task)
|
||||
from camel.agents import ChatAgent
|
||||
from camel.agents._types import ToolCallRequest
|
||||
from camel.agents.chat_agent import (AsyncStreamingChatAgentResponse,
|
||||
StreamingChatAgentResponse)
|
||||
from camel.agents.chat_agent import (
|
||||
AsyncStreamingChatAgentResponse,
|
||||
StreamingChatAgentResponse,
|
||||
)
|
||||
from camel.memories import AgentMemory
|
||||
from camel.messages import BaseMessage
|
||||
from camel.models import BaseModelBackend, ModelManager, ModelProcessingError
|
||||
|
|
@ -37,6 +34,17 @@ from camel.types import ModelPlatformType, ModelType
|
|||
from camel.types.agents import ToolCallingRecord
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.service.task import (
|
||||
Action,
|
||||
ActionActivateAgentData,
|
||||
ActionActivateToolkitData,
|
||||
ActionBudgetNotEnough,
|
||||
ActionDeactivateAgentData,
|
||||
ActionDeactivateToolkitData,
|
||||
get_task_lock,
|
||||
set_process_task,
|
||||
)
|
||||
|
||||
# Logger for agent tracking
|
||||
logger = logging.getLogger("agent")
|
||||
|
||||
|
|
@ -48,26 +56,30 @@ class ListenChatAgent(ChatAgent):
|
|||
api_task_id: str,
|
||||
agent_name: str,
|
||||
system_message: BaseMessage | str | None = None,
|
||||
model: (BaseModelBackend
|
||||
| ModelManager
|
||||
| Tuple[str, str]
|
||||
| str
|
||||
| ModelType
|
||||
| Tuple[ModelPlatformType, ModelType]
|
||||
| List[BaseModelBackend]
|
||||
| List[str]
|
||||
| List[ModelType]
|
||||
| List[Tuple[str, str]]
|
||||
| List[Tuple[ModelPlatformType, ModelType]]
|
||||
| None) = None,
|
||||
model: (
|
||||
BaseModelBackend
|
||||
| ModelManager
|
||||
| Tuple[str, str]
|
||||
| str
|
||||
| ModelType
|
||||
| Tuple[ModelPlatformType, ModelType]
|
||||
| List[BaseModelBackend]
|
||||
| List[str]
|
||||
| List[ModelType]
|
||||
| List[Tuple[str, str]]
|
||||
| List[Tuple[ModelPlatformType, ModelType]]
|
||||
| None
|
||||
) = None,
|
||||
memory: AgentMemory | None = None,
|
||||
message_window_size: int | None = None,
|
||||
token_limit: int | None = None,
|
||||
output_language: str | None = None,
|
||||
tools: List[FunctionTool | Callable[..., Any]] | None = None,
|
||||
toolkits_to_register_agent: List[RegisteredAgentToolkit] | None = None,
|
||||
external_tools: (List[FunctionTool | Callable[..., Any]
|
||||
| Dict[str, Any]] | None) = None,
|
||||
external_tools: (
|
||||
List[FunctionTool | Callable[..., Any]
|
||||
| Dict[str, Any]] | None
|
||||
) = None,
|
||||
response_terminators: List[ResponseTerminator] | None = None,
|
||||
scheduling_strategy: str = "round_robin",
|
||||
max_iteration: int | None = None,
|
||||
|
|
@ -117,23 +129,33 @@ class ListenChatAgent(ChatAgent):
|
|||
task_lock = get_task_lock(self.api_task_id)
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionActivateAgentData(data={
|
||||
"agent_name":
|
||||
self.agent_name,
|
||||
"process_task_id":
|
||||
self.process_task_id,
|
||||
"agent_id":
|
||||
self.agent_id,
|
||||
"message": (input_message.content if isinstance(
|
||||
input_message, BaseMessage) else input_message),
|
||||
}, )))
|
||||
ActionActivateAgentData(
|
||||
data={
|
||||
"agent_name":
|
||||
self.agent_name,
|
||||
"process_task_id":
|
||||
self.process_task_id,
|
||||
"agent_id":
|
||||
self.agent_id,
|
||||
"message": (
|
||||
input_message.content
|
||||
if isinstance(input_message, BaseMessage) else
|
||||
input_message
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
error_info = None
|
||||
message = None
|
||||
res = None
|
||||
msg = (input_message.content
|
||||
if isinstance(input_message, BaseMessage) else input_message)
|
||||
msg = (
|
||||
input_message.content
|
||||
if isinstance(input_message, BaseMessage) else input_message
|
||||
)
|
||||
logger.info(
|
||||
f"Agent {self.agent_name} starting step with message: {msg}")
|
||||
f"Agent {self.agent_name} starting step with message: {msg}"
|
||||
)
|
||||
try:
|
||||
res = super().step(input_message, response_format)
|
||||
except ModelProcessingError as e:
|
||||
|
|
@ -143,18 +165,21 @@ class ListenChatAgent(ChatAgent):
|
|||
message = "Budget has been exceeded"
|
||||
logger.warning(f"Agent {self.agent_name} budget exceeded")
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(ActionBudgetNotEnough()))
|
||||
task_lock.put_queue(ActionBudgetNotEnough())
|
||||
)
|
||||
else:
|
||||
message = str(e)
|
||||
logger.error(
|
||||
f"Agent {self.agent_name} model processing error: {e}")
|
||||
f"Agent {self.agent_name} model processing error: {e}"
|
||||
)
|
||||
total_tokens = 0
|
||||
except Exception as e:
|
||||
res = None
|
||||
error_info = e
|
||||
logger.error(
|
||||
f"Agent {self.agent_name} unexpected error in step: {e}",
|
||||
exc_info=True)
|
||||
exc_info=True
|
||||
)
|
||||
message = f"Error processing message: {e!s}"
|
||||
total_tokens = 0
|
||||
|
||||
|
|
@ -177,47 +202,55 @@ class ListenChatAgent(ChatAgent):
|
|||
total_tokens = 0
|
||||
if last_response:
|
||||
usage_info = last_response.info.get(
|
||||
"usage") or last_response.info.get(
|
||||
"token_usage") or {}
|
||||
"usage"
|
||||
) or last_response.info.get("token_usage") or {}
|
||||
if usage_info:
|
||||
total_tokens = usage_info.get(
|
||||
"total_tokens", 0)
|
||||
"total_tokens", 0
|
||||
)
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionDeactivateAgentData(data={
|
||||
"agent_name":
|
||||
self.agent_name,
|
||||
"process_task_id":
|
||||
self.process_task_id,
|
||||
"agent_id":
|
||||
self.agent_id,
|
||||
"message":
|
||||
accumulated_content,
|
||||
"tokens":
|
||||
total_tokens,
|
||||
}, )))
|
||||
ActionDeactivateAgentData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id":
|
||||
self.process_task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"message": accumulated_content,
|
||||
"tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return StreamingChatAgentResponse(_stream_with_deactivate())
|
||||
|
||||
message = res.msg.content if res.msg else ""
|
||||
usage_info = res.info.get("usage") or res.info.get(
|
||||
"token_usage") or {}
|
||||
total_tokens = usage_info.get("total_tokens",
|
||||
0) if usage_info else 0
|
||||
logger.info(f"Agent {self.agent_name} completed step, "
|
||||
f"tokens used: {total_tokens}")
|
||||
usage_info = res.info.get("usage") or res.info.get("token_usage"
|
||||
) or {}
|
||||
total_tokens = usage_info.get(
|
||||
"total_tokens", 0
|
||||
) if usage_info else 0
|
||||
logger.info(
|
||||
f"Agent {self.agent_name} completed step, "
|
||||
f"tokens used: {total_tokens}"
|
||||
)
|
||||
|
||||
assert message is not None
|
||||
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionDeactivateAgentData(data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"message": message,
|
||||
"tokens": total_tokens,
|
||||
}, )))
|
||||
ActionDeactivateAgentData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"message": message,
|
||||
"tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if error_info is not None:
|
||||
raise error_info
|
||||
|
|
@ -240,18 +273,26 @@ class ListenChatAgent(ChatAgent):
|
|||
self.process_task_id,
|
||||
"agent_id":
|
||||
self.agent_id,
|
||||
"message": (input_message.content if isinstance(
|
||||
input_message, BaseMessage) else input_message),
|
||||
"message": (
|
||||
input_message.content
|
||||
if isinstance(input_message, BaseMessage) else
|
||||
input_message
|
||||
),
|
||||
},
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
error_info = None
|
||||
message = None
|
||||
res = None
|
||||
msg = (input_message.content
|
||||
if isinstance(input_message, BaseMessage) else input_message)
|
||||
logger.debug(f"Agent {self.agent_name} starting async step "
|
||||
f"with message: {msg}")
|
||||
msg = (
|
||||
input_message.content
|
||||
if isinstance(input_message, BaseMessage) else input_message
|
||||
)
|
||||
logger.debug(
|
||||
f"Agent {self.agent_name} starting async step "
|
||||
f"with message: {msg}"
|
||||
)
|
||||
|
||||
try:
|
||||
res = await super().astep(input_message, response_format)
|
||||
|
|
@ -264,46 +305,56 @@ class ListenChatAgent(ChatAgent):
|
|||
message = "Budget has been exceeded"
|
||||
logger.warning(f"Agent {self.agent_name} budget exceeded")
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(ActionBudgetNotEnough()))
|
||||
task_lock.put_queue(ActionBudgetNotEnough())
|
||||
)
|
||||
else:
|
||||
message = str(e)
|
||||
logger.error(
|
||||
f"Agent {self.agent_name} model processing error: {e}")
|
||||
f"Agent {self.agent_name} model processing error: {e}"
|
||||
)
|
||||
total_tokens = 0
|
||||
except Exception as e:
|
||||
res = None
|
||||
error_info = e
|
||||
logger.error(
|
||||
f"Agent {self.agent_name} unexpected error in async step: {e}",
|
||||
exc_info=True)
|
||||
exc_info=True
|
||||
)
|
||||
message = f"Error processing message: {e!s}"
|
||||
total_tokens = 0
|
||||
|
||||
if res is not None:
|
||||
message = res.msg.content if res.msg else ""
|
||||
total_tokens = res.info["usage"]["total_tokens"]
|
||||
logger.info(f"Agent {self.agent_name} completed step, "
|
||||
f"tokens used: {total_tokens}")
|
||||
logger.info(
|
||||
f"Agent {self.agent_name} completed step, "
|
||||
f"tokens used: {total_tokens}"
|
||||
)
|
||||
|
||||
assert message is not None
|
||||
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionDeactivateAgentData(data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"message": message,
|
||||
"tokens": total_tokens,
|
||||
}, )))
|
||||
ActionDeactivateAgentData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"message": message,
|
||||
"tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if error_info is not None:
|
||||
raise error_info
|
||||
assert res is not None
|
||||
return res
|
||||
|
||||
def _execute_tool(self,
|
||||
tool_call_request: ToolCallRequest) -> ToolCallingRecord:
|
||||
def _execute_tool(
|
||||
self, tool_call_request: ToolCallRequest
|
||||
) -> ToolCallingRecord:
|
||||
func_name = tool_call_request.tool_name
|
||||
tool: FunctionTool = self._internal_tools[func_name]
|
||||
# Route async functions to async execution
|
||||
|
|
@ -327,28 +378,31 @@ class ListenChatAgent(ChatAgent):
|
|||
task_lock = get_task_lock(self.api_task_id)
|
||||
|
||||
toolkit_name = getattr(tool, "_toolkit_name") if hasattr(
|
||||
tool, "_toolkit_name") else "mcp_toolkit"
|
||||
logger.debug(f"Agent {self.agent_name} executing tool: "
|
||||
f"{func_name} from toolkit: {toolkit_name} "
|
||||
f"with args: {json.dumps(args, ensure_ascii=False)}")
|
||||
tool, "_toolkit_name"
|
||||
) else "mcp_toolkit"
|
||||
logger.debug(
|
||||
f"Agent {self.agent_name} executing tool: "
|
||||
f"{func_name} from toolkit: {toolkit_name} "
|
||||
f"with args: {json.dumps(args, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Only send activate event if tool is
|
||||
# NOT wrapped by @listen_toolkit
|
||||
if not has_listen_decorator:
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionActivateToolkitData(data={
|
||||
"agent_name":
|
||||
self.agent_name,
|
||||
"process_task_id":
|
||||
self.process_task_id,
|
||||
"toolkit_name":
|
||||
toolkit_name,
|
||||
"method_name":
|
||||
func_name,
|
||||
"message":
|
||||
json.dumps(args, ensure_ascii=False),
|
||||
}, )))
|
||||
ActionActivateToolkitData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message":
|
||||
json.dumps(args, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
# Set process_task context for all tool executions
|
||||
with set_process_task(self.process_task_id):
|
||||
raw_result = tool(**args)
|
||||
|
|
@ -357,7 +411,8 @@ class ListenChatAgent(ChatAgent):
|
|||
self._secure_result_store[tool_call_id] = raw_result
|
||||
result = (
|
||||
"[The tool has been executed successfully, but the output"
|
||||
" from the tool is masked. You can move forward]")
|
||||
" from the tool is masked. You can move forward]"
|
||||
)
|
||||
mask_flag = True
|
||||
else:
|
||||
result = raw_result
|
||||
|
|
@ -369,30 +424,39 @@ class ListenChatAgent(ChatAgent):
|
|||
result_str = repr(result)
|
||||
MAX_RESULT_LENGTH = 500
|
||||
if len(result_str) > MAX_RESULT_LENGTH:
|
||||
result_msg = (result_str[:MAX_RESULT_LENGTH] +
|
||||
(f"... (truncated, total length: "
|
||||
f"{len(result_str)} chars)"))
|
||||
result_msg = (
|
||||
result_str[:MAX_RESULT_LENGTH] + (
|
||||
f"... (truncated, total length: "
|
||||
f"{len(result_str)} chars)"
|
||||
)
|
||||
)
|
||||
else:
|
||||
result_msg = result_str
|
||||
|
||||
# Only send deactivate event if tool is NOT wrapped by @listen_toolkit
|
||||
# Only send deactivate event if tool is
|
||||
# NOT wrapped by @listen_toolkit
|
||||
if not has_listen_decorator:
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionDeactivateToolkitData(data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": result_msg,
|
||||
}, )))
|
||||
ActionDeactivateToolkitData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": result_msg,
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Capture the error message to prevent framework crash
|
||||
error_msg = f"Error executing tool '{func_name}': {e!s}"
|
||||
result = f"Tool execution failed: {error_msg}"
|
||||
mask_flag = False
|
||||
logger.error(f"Tool execution failed for {func_name}: {e}",
|
||||
exc_info=True)
|
||||
logger.error(
|
||||
f"Tool execution failed for {func_name}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
return self._record_tool_calling(
|
||||
func_name,
|
||||
|
|
@ -404,7 +468,8 @@ class ListenChatAgent(ChatAgent):
|
|||
)
|
||||
|
||||
async def _aexecute_tool(
|
||||
self, tool_call_request: ToolCallRequest) -> ToolCallingRecord:
|
||||
self, tool_call_request: ToolCallRequest
|
||||
) -> ToolCallingRecord:
|
||||
func_name = tool_call_request.tool_name
|
||||
tool: FunctionTool = self._internal_tools[func_name]
|
||||
|
||||
|
|
@ -420,21 +485,24 @@ class ListenChatAgent(ChatAgent):
|
|||
if hasattr(tool, "_toolkit_name"):
|
||||
toolkit_name = tool._toolkit_name
|
||||
|
||||
# Method 2: For MCP tools, check if func has __self__ (the toolkit instance)
|
||||
if not toolkit_name and hasattr(tool, "func") and hasattr(
|
||||
tool.func, "__self__"):
|
||||
# Method 2: For MCP tools, check if func
|
||||
# has __self__ (the toolkit instance)
|
||||
if not toolkit_name and hasattr(tool, "func"
|
||||
) and hasattr(tool.func, "__self__"):
|
||||
toolkit_instance = tool.func.__self__
|
||||
if hasattr(toolkit_instance, "toolkit_name") and callable(
|
||||
toolkit_instance.toolkit_name):
|
||||
toolkit_instance.toolkit_name
|
||||
):
|
||||
toolkit_name = toolkit_instance.toolkit_name()
|
||||
|
||||
# Method 3: Check if tool.func is a bound method with toolkit
|
||||
if not toolkit_name and hasattr(tool, "func"):
|
||||
if hasattr(tool.func, "func") and hasattr(tool.func.func,
|
||||
"__self__"):
|
||||
if hasattr(tool.func,
|
||||
"func") and hasattr(tool.func.func, "__self__"):
|
||||
toolkit_instance = tool.func.func.__self__
|
||||
if hasattr(toolkit_instance, "toolkit_name") and callable(
|
||||
toolkit_instance.toolkit_name):
|
||||
toolkit_instance.toolkit_name
|
||||
):
|
||||
toolkit_name = toolkit_instance.toolkit_name()
|
||||
|
||||
# Default fallback
|
||||
|
|
@ -442,7 +510,11 @@ class ListenChatAgent(ChatAgent):
|
|||
toolkit_name = "mcp_toolkit"
|
||||
|
||||
logger.info(
|
||||
f"Agent {self.agent_name} executing async tool: {func_name} from toolkit: {toolkit_name} with args: {json.dumps(args, ensure_ascii=False)}"
|
||||
f"Agent {self.agent_name} executing"
|
||||
f" async tool: {func_name}"
|
||||
f" from toolkit: {toolkit_name}"
|
||||
" with args:"
|
||||
f" {json.dumps(args, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Check if tool is wrapped by @listen_toolkit decorator
|
||||
|
|
@ -452,25 +524,24 @@ class ListenChatAgent(ChatAgent):
|
|||
# Only send activate event if tool is NOT wrapped by @listen_toolkit
|
||||
if not has_listen_decorator:
|
||||
await task_lock.put_queue(
|
||||
ActionActivateToolkitData(data={
|
||||
"agent_name":
|
||||
self.agent_name,
|
||||
"process_task_id":
|
||||
self.process_task_id,
|
||||
"toolkit_name":
|
||||
toolkit_name,
|
||||
"method_name":
|
||||
func_name,
|
||||
"message":
|
||||
json.dumps(args, ensure_ascii=False),
|
||||
}, ))
|
||||
ActionActivateToolkitData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": json.dumps(args, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
)
|
||||
try:
|
||||
# Set process_task context for all tool executions
|
||||
with set_process_task(self.process_task_id):
|
||||
# Try different invocation paths in order of preference
|
||||
if hasattr(tool, "func") and hasattr(tool.func, "async_call"):
|
||||
# Case: FunctionTool wrapping an MCP tool
|
||||
# Check if the wrapped tool is sync to avoid run_in_executor
|
||||
# Check if the wrapped tool is sync
|
||||
# to avoid run_in_executor
|
||||
if hasattr(tool, "is_async") and not tool.is_async:
|
||||
# Sync tool: call directly to preserve ContextVar
|
||||
result = tool(**args)
|
||||
|
|
@ -482,11 +553,14 @@ class ListenChatAgent(ChatAgent):
|
|||
|
||||
elif hasattr(tool, "async_call") and callable(tool.async_call):
|
||||
# Case: tool itself has async_call
|
||||
# Check if this is a sync tool to avoid run_in_executor (which breaks ContextVar)
|
||||
# Check if this is a sync tool to avoid
|
||||
# run_in_executor (breaks ContextVar)
|
||||
if hasattr(tool, "is_async") and not tool.is_async:
|
||||
# Sync tool: call directly to preserve ContextVar in same thread
|
||||
# Sync tool: call directly to preserve
|
||||
# ContextVar in same thread
|
||||
result = tool(**args)
|
||||
# Handle case where synchronous call returns a coroutine
|
||||
# Handle case where synchronous call
|
||||
# returns a coroutine
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
else:
|
||||
|
|
@ -494,7 +568,8 @@ class ListenChatAgent(ChatAgent):
|
|||
result = await tool.async_call(**args)
|
||||
|
||||
elif hasattr(tool, "func") and asyncio.iscoroutinefunction(
|
||||
tool.func):
|
||||
tool.func
|
||||
):
|
||||
# Case: tool wraps a direct async function
|
||||
result = await tool.func(**args)
|
||||
|
||||
|
|
@ -503,7 +578,8 @@ class ListenChatAgent(ChatAgent):
|
|||
result = await tool(**args)
|
||||
|
||||
else:
|
||||
# Fallback: synchronous call - call directly in current context
|
||||
# Fallback: synchronous call - call
|
||||
# directly in current context
|
||||
# DO NOT use run_in_executor to preserve ContextVar
|
||||
result = tool(**args)
|
||||
# Handle case where synchronous call returns a coroutine
|
||||
|
|
@ -514,8 +590,10 @@ class ListenChatAgent(ChatAgent):
|
|||
# Capture the error message to prevent framework crash
|
||||
error_msg = f"Error executing async tool '{func_name}': {e!s}"
|
||||
result = {"error": error_msg}
|
||||
logger.error(f"Async tool execution failed for {func_name}: {e}",
|
||||
exc_info=True)
|
||||
logger.error(
|
||||
f"Async tool execution failed for {func_name}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Prepare result message with truncation
|
||||
if isinstance(result, str):
|
||||
|
|
@ -524,20 +602,27 @@ class ListenChatAgent(ChatAgent):
|
|||
result_str = repr(result)
|
||||
MAX_RESULT_LENGTH = 500
|
||||
if len(result_str) > MAX_RESULT_LENGTH:
|
||||
result_msg = result_str[:MAX_RESULT_LENGTH] + f"... (truncated, total length: {len(result_str)} chars)"
|
||||
result_msg = (
|
||||
result_str[:MAX_RESULT_LENGTH] + "... (truncated, total"
|
||||
f" length: {len(result_str)}"
|
||||
" chars)"
|
||||
)
|
||||
else:
|
||||
result_msg = result_str
|
||||
|
||||
# Only send deactivate event if tool is NOT wrapped by @listen_toolkit
|
||||
if not has_listen_decorator:
|
||||
await task_lock.put_queue(
|
||||
ActionDeactivateToolkitData(data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": result_msg,
|
||||
}, ))
|
||||
ActionDeactivateToolkitData(
|
||||
data={
|
||||
"agent_name": self.agent_name,
|
||||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": result_msg,
|
||||
},
|
||||
)
|
||||
)
|
||||
return self._record_tool_calling(
|
||||
func_name,
|
||||
args,
|
||||
|
|
@ -560,8 +645,9 @@ class ListenChatAgent(ChatAgent):
|
|||
model=self.model_backend.models, # Pass the existing model_backend
|
||||
memory=None, # clone memory later
|
||||
message_window_size=getattr(self.memory, "window_size", None),
|
||||
token_limit=getattr(self.memory.get_context_creator(),
|
||||
"token_limit", None),
|
||||
token_limit=getattr(
|
||||
self.memory.get_context_creator(), "token_limit", None
|
||||
),
|
||||
output_language=self._output_language,
|
||||
tools=cloned_tools,
|
||||
toolkits_to_register_agent=toolkits_to_register,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# flake8: noqa
|
||||
|
||||
SOCIAL_MEDIA_SYS_PROMPT = """\
|
||||
You are a Social Media Management Assistant with comprehensive capabilities
|
||||
|
|
@ -595,3 +596,17 @@ Your approach depends on available search tools:
|
|||
- When encountering verification challenges (like login, CAPTCHAs or
|
||||
robot checks), you MUST request help using the human toolkit.
|
||||
</web_search_workflow>"""
|
||||
|
||||
DEFAULT_SUMMARY_PROMPT = (
|
||||
"After completing the task, please generate"
|
||||
" a summary of the entire task completion. "
|
||||
"The summary must be enclosed in"
|
||||
" <summary></summary> tags and include:\n"
|
||||
"1. A confirmation of task completion,"
|
||||
" referencing the original goal.\n"
|
||||
"2. A high-level overview of the work"
|
||||
" performed and the final outcome.\n"
|
||||
"3. A bulleted list of key results"
|
||||
" or accomplishments.\n"
|
||||
"Adopt a confident and professional tone."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
|
||||
from camel.toolkits import MCPToolkit
|
||||
|
||||
from app.component.environment import env
|
||||
from app.model.chat import McpServers
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
|
@ -41,14 +43,15 @@ from app.utils.toolkit.twitter_toolkit import TwitterToolkit
|
|||
from app.utils.toolkit.video_analysis_toolkit import VideoAnalysisToolkit
|
||||
from app.utils.toolkit.video_download_toolkit import VideoDownloaderToolkit
|
||||
from app.utils.toolkit.whatsapp_toolkit import WhatsAppToolkit
|
||||
from camel.toolkits import MCPToolkit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_toolkits(tools: list[str], agent_name: str, api_task_id: str):
|
||||
logger.info(f"Getting toolkits for agent: {agent_name}, "
|
||||
f"task: {api_task_id}, tools: {tools}")
|
||||
logger.info(
|
||||
f"Getting toolkits for agent: {agent_name}, "
|
||||
f"task: {api_task_id}, tools: {tools}"
|
||||
)
|
||||
toolkits = {
|
||||
"audio_analysis_toolkit": AudioAnalysisToolkit,
|
||||
"openai_image_toolkit": OpenAIImageToolkit,
|
||||
|
|
@ -80,7 +83,8 @@ async def get_toolkits(tools: list[str], agent_name: str, api_task_id: str):
|
|||
toolkit.agent_name = agent_name
|
||||
toolkit_tools = toolkit.get_can_use_tools(api_task_id)
|
||||
toolkit_tools = await toolkit_tools if asyncio.iscoroutine(
|
||||
toolkit_tools) else toolkit_tools
|
||||
toolkit_tools
|
||||
) else toolkit_tools
|
||||
res.extend(toolkit_tools)
|
||||
else:
|
||||
logger.warning(f"Toolkit {item} not found for agent {agent_name}")
|
||||
|
|
@ -89,7 +93,8 @@ async def get_toolkits(tools: list[str], agent_name: str, api_task_id: str):
|
|||
|
||||
async def get_mcp_tools(mcp_server: McpServers):
|
||||
logger.info(
|
||||
f"Getting MCP tools for {len(mcp_server['mcpServers'])} servers")
|
||||
f"Getting MCP tools for {len(mcp_server['mcpServers'])} servers"
|
||||
)
|
||||
if len(mcp_server["mcpServers"]) == 0:
|
||||
return []
|
||||
|
||||
|
|
@ -102,19 +107,26 @@ async def get_mcp_tools(mcp_server: McpServers):
|
|||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth"))
|
||||
"MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
mcp_toolkit = None
|
||||
try:
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=180)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
logger.info(f"Successfully connected to MCP toolkit with "
|
||||
f"{len(mcp_server['mcpServers'])} servers")
|
||||
logger.info(
|
||||
f"Successfully connected to MCP toolkit with "
|
||||
f"{len(mcp_server['mcpServers'])} servers"
|
||||
)
|
||||
tools = mcp_toolkit.get_tools()
|
||||
if tools:
|
||||
tool_names = [(tool.get_function_name() if hasattr(
|
||||
tool, "get_function_name") else str(tool)) for tool in tools]
|
||||
tool_names = [
|
||||
(
|
||||
tool.get_function_name()
|
||||
if hasattr(tool, "get_function_name") else str(tool)
|
||||
) for tool in tools
|
||||
]
|
||||
logging.debug(f"MCP tool names: {tool_names}")
|
||||
return tools
|
||||
except asyncio.CancelledError:
|
||||
|
|
|
|||
0
backend/app/component/__init__.py
Normal file
0
backend/app/component/__init__.py
Normal file
|
|
@ -1,29 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
def dump_class(obj, max_val_len=1000):
|
||||
cls = obj.__class__
|
||||
print(f"Class: {cls.__name__}")
|
||||
print("Attributes:")
|
||||
for name, val in vars(obj).items():
|
||||
val_str = repr(val)
|
||||
if len(val_str) > max_val_len:
|
||||
val_str = val_str[:max_val_len] + "... [truncated]"
|
||||
print(f" {name} = {val_str}")
|
||||
# print("Methods:")
|
||||
# for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
# print(f" {name}()")
|
||||
0
backend/app/component/pydantic/__init__.py
Normal file
0
backend/app/component/pydantic/__init__.py
Normal file
|
|
@ -11,4 +11,3 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
|
|
|
|||
|
|
@ -13,37 +13,44 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
|
||||
from app.component import code
|
||||
from app.component.environment import sanitize_env_path, set_user_env_path
|
||||
from app.exception.exception import UserException
|
||||
from app.model.chat import Chat, HumanReply, McpServers, Status, SupplementChat, AddTaskRequest, sse_json
|
||||
from app.model.chat import (
|
||||
AddTaskRequest,
|
||||
Chat,
|
||||
HumanReply,
|
||||
McpServers,
|
||||
Status,
|
||||
SupplementChat,
|
||||
sse_json,
|
||||
)
|
||||
from app.service.chat_service import step_solve
|
||||
from app.service.task import (
|
||||
Action,
|
||||
ActionAddTaskData,
|
||||
ActionImproveData,
|
||||
ActionInstallMcpData,
|
||||
ActionStopData,
|
||||
ActionSupplementData,
|
||||
ActionAddTaskData,
|
||||
ActionRemoveTaskData,
|
||||
ActionSkipTaskData,
|
||||
ActionStopData,
|
||||
ActionSupplementData,
|
||||
delete_task_lock,
|
||||
get_or_create_task_lock,
|
||||
get_task_lock,
|
||||
set_current_task_id,
|
||||
delete_task_lock,
|
||||
task_locks,
|
||||
)
|
||||
from app.component.environment import set_user_env_path, sanitize_env_path
|
||||
from app.utils.workforce import Workforce
|
||||
from camel.tasks.task import Task
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -69,23 +76,37 @@ async def _cleanup_task_lock_safe(task_lock, reason: str) -> bool:
|
|||
|
||||
# Check if task_lock still exists before attempting cleanup
|
||||
if task_lock.id not in task_locks:
|
||||
chat_logger.debug(f"[{reason}] Task lock already removed, skipping cleanup",
|
||||
extra={"task_id": task_lock.id})
|
||||
chat_logger.debug(
|
||||
f"[{reason}] Task lock already removed, skipping cleanup",
|
||||
extra={"task_id": task_lock.id}
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
task_lock.status = Status.done
|
||||
await delete_task_lock(task_lock.id)
|
||||
chat_logger.info(f"[{reason}] Task lock cleanup completed",
|
||||
extra={"task_id": task_lock.id})
|
||||
chat_logger.info(
|
||||
f"[{reason}] Task lock cleanup completed",
|
||||
extra={"task_id": task_lock.id}
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
chat_logger.error(f"[{reason}] Failed to cleanup task lock",
|
||||
extra={"task_id": task_lock.id, "error": str(e)}, exc_info=True)
|
||||
chat_logger.error(
|
||||
f"[{reason}] Failed to cleanup task lock",
|
||||
extra={
|
||||
"task_id": task_lock.id,
|
||||
"error": str(e)
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def timeout_stream_wrapper(stream_generator, timeout_seconds: int = SSE_TIMEOUT_SECONDS, task_lock=None):
|
||||
async def timeout_stream_wrapper(
|
||||
stream_generator,
|
||||
timeout_seconds: int = SSE_TIMEOUT_SECONDS,
|
||||
task_lock=None
|
||||
):
|
||||
"""Wraps a stream generator with timeout handling.
|
||||
|
||||
Closes the SSE connection if no data is received within the timeout period.
|
||||
|
|
@ -101,26 +122,45 @@ async def timeout_stream_wrapper(stream_generator, timeout_seconds: int = SSE_TI
|
|||
remaining_timeout = timeout_seconds - elapsed
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(generator.__anext__(), timeout=remaining_timeout)
|
||||
data = await asyncio.wait_for(
|
||||
generator.__anext__(), timeout=remaining_timeout
|
||||
)
|
||||
last_data_time = time.time()
|
||||
yield data
|
||||
except asyncio.TimeoutError:
|
||||
chat_logger.warning("SSE timeout: No data received, closing connection",
|
||||
extra={"timeout_seconds": timeout_seconds})
|
||||
yield sse_json("error", {"message": f"Connection timeout: No data received for {timeout_seconds // 60} minutes"})
|
||||
cleanup_triggered = await _cleanup_task_lock_safe(task_lock, "TIMEOUT")
|
||||
chat_logger.warning(
|
||||
"SSE timeout: No data received, closing connection",
|
||||
extra={"timeout_seconds": timeout_seconds}
|
||||
)
|
||||
timeout_min = timeout_seconds // 60
|
||||
yield sse_json(
|
||||
"error", {
|
||||
"message":
|
||||
"Connection timeout: No data"
|
||||
f" received for {timeout_min}"
|
||||
" minutes"
|
||||
}
|
||||
)
|
||||
cleanup_triggered = await _cleanup_task_lock_safe(
|
||||
task_lock, "TIMEOUT"
|
||||
)
|
||||
break
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
chat_logger.info("[STREAM-CANCELLED] Stream cancelled, triggering cleanup")
|
||||
chat_logger.info(
|
||||
"[STREAM-CANCELLED] Stream cancelled, triggering cleanup"
|
||||
)
|
||||
if not cleanup_triggered:
|
||||
await _cleanup_task_lock_safe(task_lock, "CANCELLED")
|
||||
raise
|
||||
except Exception as e:
|
||||
chat_logger.error("[STREAM-ERROR] Unexpected error in stream wrapper",
|
||||
extra={"error": str(e)}, exc_info=True)
|
||||
chat_logger.error(
|
||||
"[STREAM-ERROR] Unexpected error in stream wrapper",
|
||||
extra={"error": str(e)},
|
||||
exc_info=True
|
||||
)
|
||||
if not cleanup_triggered:
|
||||
await _cleanup_task_lock_safe(task_lock, "ERROR")
|
||||
raise
|
||||
|
|
@ -130,7 +170,11 @@ async def timeout_stream_wrapper(stream_generator, timeout_seconds: int = SSE_TI
|
|||
async def post(data: Chat, request: Request):
|
||||
chat_logger.info(
|
||||
"Starting new chat session",
|
||||
extra={"project_id": data.project_id, "task_id": data.task_id, "user": data.email}
|
||||
extra={
|
||||
"project_id": data.project_id,
|
||||
"task_id": data.task_id,
|
||||
"user": data.email
|
||||
}
|
||||
)
|
||||
|
||||
task_lock = get_or_create_task_lock(data.project_id)
|
||||
|
|
@ -145,7 +189,8 @@ async def post(data: Chat, request: Request):
|
|||
os.environ["file_save_path"] = data.file_save_path()
|
||||
os.environ["browser_port"] = str(data.browser_port)
|
||||
os.environ["OPENAI_API_KEY"] = data.api_key
|
||||
os.environ["OPENAI_API_BASE_URL"] = data.api_url or "https://api.openai.com/v1"
|
||||
os.environ["OPENAI_API_BASE_URL"
|
||||
] = data.api_url or "https://api.openai.com/v1"
|
||||
os.environ["CAMEL_MODEL_LOG_ENABLED"] = "true"
|
||||
|
||||
# Set user-specific search engine configuration if provided
|
||||
|
|
@ -153,16 +198,17 @@ async def post(data: Chat, request: Request):
|
|||
for key, value in data.search_config.items():
|
||||
if value:
|
||||
os.environ[key] = value
|
||||
chat_logger.debug(f"Set search config: {key}", extra={"project_id": data.project_id})
|
||||
chat_logger.debug(
|
||||
f"Set search config: {key}",
|
||||
extra={"project_id": data.project_id}
|
||||
)
|
||||
|
||||
email_sanitized = re.sub(r'[\\/*?:"<>|\s]', "_", data.email.split("@")[0]).strip(".")
|
||||
email_sanitized = re.sub(r'[\\/*?:"<>|\s]', "_",
|
||||
data.email.split("@")[0]).strip(".")
|
||||
camel_log = (
|
||||
Path.home()
|
||||
/ ".eigent"
|
||||
/ email_sanitized
|
||||
/ ("project_" + data.project_id)
|
||||
/ ("task_" + data.task_id)
|
||||
/ "camel_logs"
|
||||
Path.home() / ".eigent" / email_sanitized /
|
||||
("project_" + data.project_id) / ("task_" + data.task_id) /
|
||||
"camel_logs"
|
||||
)
|
||||
camel_log.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -175,20 +221,35 @@ async def post(data: Chat, request: Request):
|
|||
set_current_task_id(data.project_id, data.task_id)
|
||||
|
||||
# Put initial action in queue to start processing
|
||||
await task_lock.put_queue(ActionImproveData(data=data.question, new_task_id=data.task_id))
|
||||
await task_lock.put_queue(
|
||||
ActionImproveData(data=data.question, new_task_id=data.task_id)
|
||||
)
|
||||
|
||||
chat_logger.info(
|
||||
"Chat session initialized",
|
||||
extra={"project_id": data.project_id, "task_id": data.task_id, "log_dir": str(camel_log)},
|
||||
extra={
|
||||
"project_id": data.project_id,
|
||||
"task_id": data.task_id,
|
||||
"log_dir": str(camel_log)
|
||||
},
|
||||
)
|
||||
return StreamingResponse(
|
||||
timeout_stream_wrapper(step_solve(data, request, task_lock), task_lock=task_lock), media_type="text/event-stream"
|
||||
timeout_stream_wrapper(
|
||||
step_solve(data, request, task_lock), task_lock=task_lock
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/{id}", name="improve chat")
|
||||
def improve(id: str, data: SupplementChat):
|
||||
chat_logger.info("Chat improvement requested", extra={"task_id": id, "question_length": len(data.question)})
|
||||
chat_logger.info(
|
||||
"Chat improvement requested",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"question_length": len(data.question)
|
||||
}
|
||||
)
|
||||
task_lock = get_task_lock(id)
|
||||
|
||||
# Allow continuing conversation even after task is done
|
||||
|
|
@ -203,11 +264,21 @@ def improve(id: str, data: SupplementChat):
|
|||
|
||||
# Log context preservation
|
||||
if hasattr(task_lock, "conversation_history"):
|
||||
chat_logger.info(f"[CONTEXT] Preserved {len(task_lock.conversation_history)} conversation entries")
|
||||
hist_len = len(task_lock.conversation_history)
|
||||
chat_logger.info(
|
||||
"[CONTEXT] Preserved"
|
||||
f" {hist_len} conversation entries"
|
||||
)
|
||||
if hasattr(task_lock, "last_task_result"):
|
||||
chat_logger.info(f"[CONTEXT] Preserved task result: {len(task_lock.last_task_result)} chars")
|
||||
result_len = len(task_lock.last_task_result)
|
||||
chat_logger.info(
|
||||
"[CONTEXT] Preserved task"
|
||||
f" result: {result_len} chars"
|
||||
)
|
||||
|
||||
# If task_id is provided, optimistically update file_save_path (will be destroyed if task is not complex)
|
||||
# If task_id is provided, optimistically update
|
||||
# file_save_path (will be destroyed if task is
|
||||
# not complex)
|
||||
# this is because a NEW workforce instance may be created for this task
|
||||
new_folder_path = None
|
||||
if data.task_id:
|
||||
|
|
@ -224,24 +295,49 @@ def improve(id: str, data: SupplementChat):
|
|||
if eigent_index + 1 < len(path_parts):
|
||||
current_email = path_parts[eigent_index + 1]
|
||||
|
||||
# If we have the necessary information, update the file_save_path
|
||||
# If we have the necessary info, update
|
||||
# the file_save_path
|
||||
if current_email and id:
|
||||
# Create new path using the existing pattern: email/project_{project_id}/task_{task_id}
|
||||
new_folder_path = Path.home() / "eigent" / current_email / f"project_{id}" / f"task_{data.task_id}"
|
||||
# Create new path using the existing
|
||||
# pattern: email/project_{id}/task_{id}
|
||||
new_folder_path = (
|
||||
Path.home() / "eigent" / current_email / f"project_{id}" /
|
||||
f"task_{data.task_id}"
|
||||
)
|
||||
new_folder_path.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["file_save_path"] = str(new_folder_path)
|
||||
chat_logger.info(f"Updated file_save_path to: {new_folder_path}")
|
||||
chat_logger.info(
|
||||
f"Updated file_save_path to: {new_folder_path}"
|
||||
)
|
||||
|
||||
# Store the new folder path in task_lock for potential cleanup and persistence
|
||||
# Store the new folder path in task_lock
|
||||
# for potential cleanup and persistence
|
||||
task_lock.new_folder_path = new_folder_path
|
||||
else:
|
||||
chat_logger.warning(f"Could not update file_save_path - email: {current_email}, project_id: {id}")
|
||||
chat_logger.warning(
|
||||
"Could not update"
|
||||
" file_save_path -"
|
||||
f" email: {current_email},"
|
||||
f" project_id: {id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"Error updating file path for project_id: {id}, task_id: {data.task_id}: {e}")
|
||||
chat_logger.error(
|
||||
"Error updating file path for"
|
||||
f" project_id: {id},"
|
||||
f" task_id: {data.task_id}:"
|
||||
f" {e}"
|
||||
)
|
||||
|
||||
asyncio.run(task_lock.put_queue(ActionImproveData(data=data.question, new_task_id=data.task_id)))
|
||||
chat_logger.info("Improvement request queued with preserved context", extra={"project_id": id})
|
||||
asyncio.run(
|
||||
task_lock.put_queue(
|
||||
ActionImproveData(data=data.question, new_task_id=data.task_id)
|
||||
)
|
||||
)
|
||||
chat_logger.info(
|
||||
"Improvement request queued with preserved context",
|
||||
extra={"project_id": id}
|
||||
)
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
|
|
@ -260,24 +356,50 @@ def supplement(id: str, data: SupplementChat):
|
|||
def stop(id: str):
|
||||
"""stop the task"""
|
||||
chat_logger.info("=" * 80)
|
||||
chat_logger.info("🛑 [STOP-BUTTON] DELETE /chat/{id} request received from frontend")
|
||||
chat_logger.info(
|
||||
"🛑 [STOP-BUTTON] DELETE /chat/{id} request received from frontend"
|
||||
)
|
||||
chat_logger.info(f"[STOP-BUTTON] project_id/task_id: {id}")
|
||||
chat_logger.info("=" * 80)
|
||||
try:
|
||||
task_lock = get_task_lock(id)
|
||||
chat_logger.info(f"[STOP-BUTTON] Task lock retrieved, task_lock.id: {task_lock.id}, task_lock.status: {task_lock.status}")
|
||||
chat_logger.info(f"[STOP-BUTTON] Queueing ActionStopData(Action.stop) to task_lock queue")
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] Task lock retrieved,"
|
||||
f" task_lock.id: {task_lock.id},"
|
||||
f" task_lock.status: {task_lock.status}"
|
||||
)
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] Queueing"
|
||||
" ActionStopData(Action.stop)"
|
||||
" to task_lock queue"
|
||||
)
|
||||
asyncio.run(task_lock.put_queue(ActionStopData(action=Action.stop)))
|
||||
chat_logger.info(f"[STOP-BUTTON] ✅ ActionStopData queued successfully, this will trigger workforce.stop_gracefully()")
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] ActionStopData queued"
|
||||
" successfully, this will trigger"
|
||||
" workforce.stop_gracefully()"
|
||||
)
|
||||
except Exception as e:
|
||||
# Task lock may not exist if task is already finished or never started
|
||||
chat_logger.warning(f"[STOP-BUTTON] ⚠️ Task lock not found or already stopped, task_id: {id}, error: {str(e)}")
|
||||
# Task lock may not exist if task is already
|
||||
# finished or never started
|
||||
chat_logger.warning(
|
||||
"[STOP-BUTTON] Task lock not found"
|
||||
" or already stopped,"
|
||||
f" task_id: {id},"
|
||||
f" error: {str(e)}"
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/chat/{id}/human-reply")
|
||||
def human_reply(id: str, data: HumanReply):
|
||||
chat_logger.info("Human reply received", extra={"task_id": id, "reply_length": len(data.reply)})
|
||||
chat_logger.info(
|
||||
"Human reply received",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"reply_length": len(data.reply)
|
||||
}
|
||||
)
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_human_input(data.agent, data.reply))
|
||||
chat_logger.debug("Human reply processed", extra={"task_id": id})
|
||||
|
|
@ -286,9 +408,19 @@ def human_reply(id: str, data: HumanReply):
|
|||
|
||||
@router.post("/chat/{id}/install-mcp")
|
||||
def install_mcp(id: str, data: McpServers):
|
||||
chat_logger.info("Installing MCP servers", extra={"task_id": id, "servers_count": len(data.get("mcpServers", {}))})
|
||||
chat_logger.info(
|
||||
"Installing MCP servers",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"servers_count": len(data.get("mcpServers", {}))
|
||||
}
|
||||
)
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionInstallMcpData(action=Action.install_mcp, data=data)))
|
||||
asyncio.run(
|
||||
task_lock.put_queue(
|
||||
ActionInstallMcpData(action=Action.install_mcp, data=data)
|
||||
)
|
||||
)
|
||||
chat_logger.info("MCP installation queued", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
|
@ -296,7 +428,11 @@ def install_mcp(id: str, data: McpServers):
|
|||
@router.post("/chat/{id}/add-task", name="add task to workforce")
|
||||
def add_task(id: str, data: AddTaskRequest):
|
||||
"""Add a new task to the workforce"""
|
||||
chat_logger.info(f"Adding task to workforce for task_id: {id}, content: {data.content[:100]}...")
|
||||
chat_logger.info(
|
||||
"Adding task to workforce for"
|
||||
f" task_id: {id},"
|
||||
f" content: {data.content[:100]}..."
|
||||
)
|
||||
task_lock = get_task_lock(id)
|
||||
|
||||
try:
|
||||
|
|
@ -316,22 +452,35 @@ def add_task(id: str, data: AddTaskRequest):
|
|||
raise UserException(code.error, f"Failed to add task: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/chat/{project_id}/remove-task/{task_id}", name="remove task from workforce")
|
||||
@router.delete(
|
||||
"/chat/{project_id}/remove-task/{task_id}",
|
||||
name="remove task from workforce"
|
||||
)
|
||||
def remove_task(project_id: str, task_id: str):
|
||||
"""Remove a task from the workforce"""
|
||||
chat_logger.info(f"Removing task {task_id} from workforce for project_id: {project_id}")
|
||||
chat_logger.info(
|
||||
f"Removing task {task_id} from workforce for project_id: {project_id}"
|
||||
)
|
||||
task_lock = get_task_lock(project_id)
|
||||
|
||||
try:
|
||||
# Queue the remove task action
|
||||
remove_task_action = ActionRemoveTaskData(task_id=task_id, project_id=project_id)
|
||||
remove_task_action = ActionRemoveTaskData(
|
||||
task_id=task_id, project_id=project_id
|
||||
)
|
||||
asyncio.run(task_lock.put_queue(remove_task_action))
|
||||
|
||||
chat_logger.info(f"Task removal request queued for project_id: {project_id}, removing task: {task_id}")
|
||||
chat_logger.info(
|
||||
"Task removal request queued for"
|
||||
f" project_id: {project_id},"
|
||||
f" removing task: {task_id}"
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"Error removing task {task_id} for project_id: {project_id}: {e}")
|
||||
chat_logger.error(
|
||||
f"Error removing task {task_id} for project_id: {project_id}: {e}"
|
||||
)
|
||||
raise UserException(code.error, f"Failed to remove task: {str(e)}")
|
||||
|
||||
|
||||
|
|
@ -349,21 +498,45 @@ def skip_task(project_id: str):
|
|||
- Keeps SSE connection alive for multi-turn conversation
|
||||
"""
|
||||
chat_logger.info("=" * 80)
|
||||
chat_logger.info(f"🛑 [STOP-BUTTON] SKIP-TASK request received from frontend (User clicked Stop)")
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] SKIP-TASK request"
|
||||
" received from frontend"
|
||||
" (User clicked Stop)"
|
||||
)
|
||||
chat_logger.info(f"[STOP-BUTTON] project_id: {project_id}")
|
||||
chat_logger.info("=" * 80)
|
||||
task_lock = get_task_lock(project_id)
|
||||
chat_logger.info(f"[STOP-BUTTON] Task lock retrieved, task_lock.id: {task_lock.id}, task_lock.status: {task_lock.status}")
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] Task lock retrieved,"
|
||||
f" task_lock.id: {task_lock.id},"
|
||||
" task_lock.status:"
|
||||
f" {task_lock.status}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Queue the skip task action - this will preserve context for multi-turn
|
||||
# Queue the skip task action - this will
|
||||
# preserve context for multi-turn
|
||||
skip_task_action = ActionSkipTaskData(project_id=project_id)
|
||||
chat_logger.info(f"[STOP-BUTTON] Queueing ActionSkipTaskData (preserves context, marks as done)")
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] Queueing"
|
||||
" ActionSkipTaskData"
|
||||
" (preserves context,"
|
||||
" marks as done)"
|
||||
)
|
||||
asyncio.run(task_lock.put_queue(skip_task_action))
|
||||
|
||||
chat_logger.info(f"[STOP-BUTTON] ✅ Skip request queued - task will stop gracefully and preserve context")
|
||||
chat_logger.info(
|
||||
"[STOP-BUTTON] Skip request"
|
||||
" queued - task will stop"
|
||||
" gracefully and preserve context"
|
||||
)
|
||||
return Response(status_code=201)
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"[STOP-BUTTON] Error skipping task for project_id: {project_id}: {e}")
|
||||
chat_logger.error(
|
||||
"[STOP-BUTTON] Error skipping"
|
||||
" task for"
|
||||
f" project_id: {project_id}:"
|
||||
f" {e}"
|
||||
)
|
||||
raise UserException(code.error, f"Failed to skip task: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -12,9 +12,10 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("health_controller")
|
||||
|
||||
|
|
@ -28,9 +29,15 @@ class HealthResponse(BaseModel):
|
|||
|
||||
@router.get("/health", name="health check", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint for verifying backend is ready to accept requests."""
|
||||
"""Health check endpoint for verifying backend
|
||||
is ready to accept requests."""
|
||||
logger.debug("Health check requested")
|
||||
response = HealthResponse(status="ok", service="eigent")
|
||||
logger.debug("Health check completed", extra={"status": response.status, "service": response.service})
|
||||
logger.debug(
|
||||
"Health check completed",
|
||||
extra={
|
||||
"status": response.status,
|
||||
"service": response.service
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from app.component.model_validation import create_agent
|
||||
from app.model.chat import PLATFORM_MAPPING
|
||||
from camel.types import ModelType
|
||||
from app.component.error_format import normalize_error_to_openai_format
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("model_controller")
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.component.error_format import normalize_error_to_openai_format
|
||||
from app.component.model_validation import create_agent
|
||||
from app.model.chat import PLATFORM_MAPPING
|
||||
|
||||
logger = logging.getLogger("model_controller")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -31,8 +31,12 @@ class ValidateModelRequest(BaseModel):
|
|||
model_type: str = Field("GPT_4O_MINI", description="Model type")
|
||||
api_key: str | None = Field(None, description="API key")
|
||||
url: str | None = Field(None, description="Model URL")
|
||||
model_config_dict: dict | None = Field(None, description="Model config dict")
|
||||
extra_params: dict | None = Field(None, description="Extra model parameters")
|
||||
model_config_dict: dict | None = Field(
|
||||
None, description="Model config dict"
|
||||
)
|
||||
extra_params: dict | None = Field(
|
||||
None, description="Extra model parameters"
|
||||
)
|
||||
|
||||
@field_validator("model_platform")
|
||||
@classmethod
|
||||
|
|
@ -56,11 +60,25 @@ async def validate_model(request: ValidateModelRequest):
|
|||
has_custom_url = request.url is not None
|
||||
has_config = request.model_config_dict is not None
|
||||
|
||||
logger.info("Model validation started", extra={"platform": platform, "model_type": model_type, "has_url": has_custom_url, "has_config": has_config})
|
||||
logger.info(
|
||||
"Model validation started",
|
||||
extra={
|
||||
"platform": platform,
|
||||
"model_type": model_type,
|
||||
"has_url": has_custom_url,
|
||||
"has_config": has_config
|
||||
}
|
||||
)
|
||||
|
||||
# API key validation
|
||||
if request.api_key is not None and str(request.api_key).strip() == "":
|
||||
logger.warning("Model validation failed: empty API key", extra={"platform": platform, "model_type": model_type})
|
||||
logger.warning(
|
||||
"Model validation failed: empty API key",
|
||||
extra={
|
||||
"platform": platform,
|
||||
"model_type": model_type
|
||||
}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
|
@ -77,7 +95,13 @@ async def validate_model(request: ValidateModelRequest):
|
|||
try:
|
||||
extra = request.extra_params or {}
|
||||
|
||||
logger.debug("Creating agent for validation", extra={"platform": platform, "model_type": model_type})
|
||||
logger.debug(
|
||||
"Creating agent for validation",
|
||||
extra={
|
||||
"platform": platform,
|
||||
"model_type": model_type
|
||||
}
|
||||
)
|
||||
agent = create_agent(
|
||||
platform,
|
||||
model_type,
|
||||
|
|
@ -87,7 +111,13 @@ async def validate_model(request: ValidateModelRequest):
|
|||
**extra,
|
||||
)
|
||||
|
||||
logger.debug("Agent created, executing test step", extra={"platform": platform, "model_type": model_type})
|
||||
logger.debug(
|
||||
"Agent created, executing test step",
|
||||
extra={
|
||||
"platform": platform,
|
||||
"model_type": model_type
|
||||
}
|
||||
)
|
||||
response = agent.step(
|
||||
input_message="""
|
||||
Get the content of https://www.camel-ai.org,
|
||||
|
|
@ -97,10 +127,17 @@ async def validate_model(request: ValidateModelRequest):
|
|||
"""
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Normalize error to OpenAI-style error structure
|
||||
logger.error("Model validation failed", extra={"platform": platform, "model_type": model_type, "error": str(e)}, exc_info=True)
|
||||
logger.error(
|
||||
"Model validation failed",
|
||||
extra={
|
||||
"platform": platform,
|
||||
"model_type": model_type,
|
||||
"error": str(e)
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
message, error_code, error_obj = normalize_error_to_openai_format(e)
|
||||
|
||||
raise HTTPException(
|
||||
|
|
@ -111,7 +148,7 @@ async def validate_model(request: ValidateModelRequest):
|
|||
"error": error_obj,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Check validation results
|
||||
is_valid = bool(response)
|
||||
is_tool_calls = False
|
||||
|
|
@ -119,21 +156,35 @@ async def validate_model(request: ValidateModelRequest):
|
|||
if response and hasattr(response, "info") and response.info:
|
||||
tool_calls = response.info.get("tool_calls", [])
|
||||
if tool_calls and len(tool_calls) > 0:
|
||||
is_tool_calls = (
|
||||
tool_calls[0].result
|
||||
== "Tool execution completed successfully for https://www.camel-ai.org, Website Content: Welcome to CAMEL AI!"
|
||||
expected = (
|
||||
"Tool execution completed"
|
||||
" successfully for"
|
||||
" https://www.camel-ai.org,"
|
||||
" Website Content:"
|
||||
" Welcome to CAMEL AI!"
|
||||
)
|
||||
is_tool_calls = (tool_calls[0].result == expected)
|
||||
|
||||
no_tool_msg = (
|
||||
"This model doesn't support tool calls."
|
||||
" please try with another model."
|
||||
)
|
||||
result = ValidateModelResponse(
|
||||
is_valid=is_valid,
|
||||
is_tool_calls=is_tool_calls,
|
||||
message="Validation Success"
|
||||
if is_tool_calls
|
||||
else "This model doesn't support tool calls. please try with another model.",
|
||||
message="Validation Success" if is_tool_calls else no_tool_msg,
|
||||
error_code=None,
|
||||
error=None,
|
||||
)
|
||||
|
||||
logger.info("Model validation completed", extra={"platform": platform, "model_type": model_type, "is_valid": is_valid, "is_tool_calls": is_tool_calls})
|
||||
logger.info(
|
||||
"Model validation completed",
|
||||
extra={
|
||||
"platform": platform,
|
||||
"model_type": model_type,
|
||||
"is_valid": is_valid,
|
||||
"is_tool_calls": is_tool_calls
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -12,28 +12,29 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.component.environment import sanitize_env_path, set_user_env_path
|
||||
from app.model.chat import NewAgent, UpdateData
|
||||
from app.service.task import (
|
||||
Action,
|
||||
ActionNewAgent,
|
||||
ActionStartData,
|
||||
ActionStopData,
|
||||
ActionTakeControl,
|
||||
ActionStartData,
|
||||
ActionUpdateTaskData,
|
||||
get_task_lock,
|
||||
task_locks,
|
||||
)
|
||||
import asyncio
|
||||
from app.component.environment import set_user_env_path, sanitize_env_path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("task_controller")
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
|
@ -48,10 +49,26 @@ def start(id: str):
|
|||
|
||||
@router.put("/task/{id}", name="update task")
|
||||
def put(id: str, data: UpdateData):
|
||||
logger.info("Updating task", extra={"task_id": id, "task_items_count": len(data.task)})
|
||||
logger.debug("Update task data", extra={"task_id": id, "data": data.model_dump_json()})
|
||||
logger.info(
|
||||
"Updating task",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"task_items_count": len(data.task)
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"Update task data",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"data": data.model_dump_json()
|
||||
}
|
||||
)
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionUpdateTaskData(action=Action.update_task, data=data)))
|
||||
asyncio.run(
|
||||
task_lock.put_queue(
|
||||
ActionUpdateTaskData(action=Action.update_task, data=data)
|
||||
)
|
||||
)
|
||||
logger.info("Task updated successfully", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
|
@ -62,25 +79,55 @@ class TakeControl(BaseModel):
|
|||
|
||||
@router.put("/task/{id}/take-control", name="take control pause or resume")
|
||||
def take_control(id: str, data: TakeControl):
|
||||
logger.info("Task control action", extra={"task_id": id, "action": data.action})
|
||||
logger.info(
|
||||
"Task control action", extra={
|
||||
"task_id": id,
|
||||
"action": data.action
|
||||
}
|
||||
)
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionTakeControl(action=data.action)))
|
||||
logger.info("Task control action completed", extra={"task_id": id, "action": data.action})
|
||||
logger.info(
|
||||
"Task control action completed",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"action": data.action
|
||||
}
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/task/{id}/add-agent", name="add new agent")
|
||||
def add_agent(id: str, data: NewAgent):
|
||||
logger.info("Adding new agent to task", extra={"task_id": id, "agent_name": data.name})
|
||||
logger.debug("New agent data", extra={"task_id": id, "agent_data": data.model_dump_json()})
|
||||
logger.info(
|
||||
"Adding new agent to task",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"agent_name": data.name
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"New agent data",
|
||||
extra={
|
||||
"task_id": id,
|
||||
"agent_data": data.model_dump_json()
|
||||
}
|
||||
)
|
||||
# Set user-specific environment path for this thread
|
||||
set_user_env_path(data.env_path)
|
||||
# Load environment with validated path
|
||||
safe_env_path = sanitize_env_path(data.env_path)
|
||||
if safe_env_path:
|
||||
load_dotenv(dotenv_path=safe_env_path)
|
||||
asyncio.run(get_task_lock(id).put_queue(ActionNewAgent(**data.model_dump())))
|
||||
logger.info("Agent added to task", extra={"task_id": id, "agent_name": data.name})
|
||||
asyncio.run(
|
||||
get_task_lock(id).put_queue(ActionNewAgent(**data.model_dump()))
|
||||
)
|
||||
logger.info(
|
||||
"Agent added to task", extra={
|
||||
"task_id": id,
|
||||
"agent_name": data.name
|
||||
}
|
||||
)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,25 +12,19 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from app.utils.toolkit.notion_mcp_toolkit import NotionMCPToolkit
|
||||
from app.utils.toolkit.google_calendar_toolkit import GoogleCalendarToolkit
|
||||
from app.utils.toolkit.linkedin_toolkit import LinkedInToolkit
|
||||
from app.utils.oauth_state_manager import oauth_state_manager
|
||||
import logging
|
||||
|
||||
|
||||
from camel.toolkits.hybrid_browser_toolkit.hybrid_browser_toolkit_ts import (
|
||||
HybridBrowserToolkit as BaseHybridBrowserToolkit,
|
||||
)
|
||||
from app.utils.cookie_manager import CookieManager
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils.cookie_manager import CookieManager
|
||||
from app.utils.oauth_state_manager import oauth_state_manager
|
||||
from app.utils.toolkit.google_calendar_toolkit import GoogleCalendarToolkit
|
||||
from app.utils.toolkit.linkedin_toolkit import LinkedInToolkit
|
||||
from app.utils.toolkit.notion_mcp_toolkit import NotionMCPToolkit
|
||||
|
||||
|
||||
class LinkedInTokenRequest(BaseModel):
|
||||
|
|
@ -58,7 +52,8 @@ async def install_tool(tool: str):
|
|||
"""
|
||||
if tool == "notion":
|
||||
try:
|
||||
# Use a dummy task_id for installation, as this is just for pre-authentication
|
||||
# Use a dummy task_id for installation,
|
||||
# as this is just for pre-authentication
|
||||
toolkit = NotionMCPToolkit("install_auth")
|
||||
|
||||
try:
|
||||
|
|
@ -66,10 +61,15 @@ async def install_tool(tool: str):
|
|||
await toolkit.connect()
|
||||
|
||||
# Get available tools to verify connection
|
||||
tools = [tool_func.func.__name__ for tool_func in
|
||||
toolkit.get_tools()]
|
||||
tools = [
|
||||
tool_func.func.__name__
|
||||
for tool_func in toolkit.get_tools()
|
||||
]
|
||||
logger.info(
|
||||
f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
|
||||
"Successfully pre-instantiated"
|
||||
f" {tool} toolkit with"
|
||||
f" {len(tools)} tools"
|
||||
)
|
||||
|
||||
# Disconnect, authentication info is saved
|
||||
await toolkit.disconnect()
|
||||
|
|
@ -77,35 +77,54 @@ async def install_tool(tool: str):
|
|||
return {
|
||||
"success": True,
|
||||
"tools": tools,
|
||||
"message": f"Successfully installed and authenticated {tool} toolkit",
|
||||
"message":
|
||||
f"Successfully installed and authenticated {tool} toolkit",
|
||||
"count": len(tools),
|
||||
"toolkit_name": "NotionMCPToolkit"
|
||||
}
|
||||
except Exception as connect_error:
|
||||
logger.warning(
|
||||
f"Could not connect to {tool} MCP server: {connect_error}")
|
||||
# Even if connection fails, mark as installed so user can use it later
|
||||
f"Could not connect to {tool} MCP server: {connect_error}"
|
||||
)
|
||||
# Even if connection fails, mark as
|
||||
# installed so user can use it later
|
||||
return {
|
||||
"success": True,
|
||||
"success":
|
||||
True,
|
||||
"tools": [],
|
||||
"message": f"{tool} toolkit installed but not connected. Will connect when needed.",
|
||||
"count": 0,
|
||||
"toolkit_name": "NotionMCPToolkit",
|
||||
"warning": "Could not connect to Notion MCP server. You may need to authenticate when using the tool."
|
||||
"message":
|
||||
f"{tool} toolkit installed but"
|
||||
" not connected. Will connect"
|
||||
" when needed.",
|
||||
"count":
|
||||
0,
|
||||
"toolkit_name":
|
||||
"NotionMCPToolkit",
|
||||
"warning":
|
||||
"Could not connect to Notion"
|
||||
" MCP server. You may need to"
|
||||
" authenticate when using"
|
||||
" the tool."
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install {tool} toolkit: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to install {tool}: {str(e)}"
|
||||
status_code=500, detail=f"Failed to install {tool}: {str(e)}"
|
||||
)
|
||||
elif tool == "google_calendar":
|
||||
try:
|
||||
# Try to initialize toolkit - will succeed if credentials exist
|
||||
try:
|
||||
toolkit = GoogleCalendarToolkit("install_auth")
|
||||
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
|
||||
logger.info(f"Successfully initialized Google Calendar toolkit with {len(tools)} tools")
|
||||
tools = [
|
||||
tool_func.func.__name__
|
||||
for tool_func in toolkit.get_tools()
|
||||
]
|
||||
logger.info(
|
||||
"Successfully initialized Google"
|
||||
" Calendar toolkit with"
|
||||
f" {len(tools)} tools"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -116,24 +135,36 @@ async def install_tool(tool: str):
|
|||
}
|
||||
except ValueError as auth_error:
|
||||
# No credentials - need authorization
|
||||
logger.info(f"No credentials found, starting authorization: {auth_error}")
|
||||
logger.info(
|
||||
"No credentials found, starting"
|
||||
f" authorization: {auth_error}"
|
||||
)
|
||||
|
||||
# Start background authorization in a new thread
|
||||
logger.info("Starting background Google Calendar authorization")
|
||||
logger.info(
|
||||
"Starting background Google Calendar authorization"
|
||||
)
|
||||
GoogleCalendarToolkit.start_background_auth("install_auth")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"status": "authorizing",
|
||||
"message": "Authorization required. Browser should open automatically. Complete authorization and try installing again.",
|
||||
"toolkit_name": "GoogleCalendarToolkit",
|
||||
"requires_auth": True
|
||||
"success":
|
||||
False,
|
||||
"status":
|
||||
"authorizing",
|
||||
"message":
|
||||
"Authorization required. Browser"
|
||||
" should open automatically."
|
||||
" Complete authorization and"
|
||||
" try installing again.",
|
||||
"toolkit_name":
|
||||
"GoogleCalendarToolkit",
|
||||
"requires_auth":
|
||||
True
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install {tool} toolkit: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to install {tool}: {str(e)}"
|
||||
status_code=500, detail=f"Failed to install {tool}: {str(e)}"
|
||||
)
|
||||
elif tool == "linkedin":
|
||||
try:
|
||||
|
|
@ -143,17 +174,28 @@ async def install_tool(tool: str):
|
|||
if LinkedInToolkit.is_token_expired():
|
||||
logger.info("LinkedIn token has expired")
|
||||
return {
|
||||
"success": False,
|
||||
"status": "token_expired",
|
||||
"message": "LinkedIn token has expired. Please re-authenticate via OAuth.",
|
||||
"toolkit_name": "LinkedInToolkit",
|
||||
"requires_auth": True,
|
||||
"oauth_url": "/api/oauth/linkedin/login"
|
||||
"success":
|
||||
False,
|
||||
"status":
|
||||
"token_expired",
|
||||
"message":
|
||||
"LinkedIn token has expired."
|
||||
" Please re-authenticate"
|
||||
" via OAuth.",
|
||||
"toolkit_name":
|
||||
"LinkedInToolkit",
|
||||
"requires_auth":
|
||||
True,
|
||||
"oauth_url":
|
||||
"/api/oauth/linkedin/login"
|
||||
}
|
||||
|
||||
try:
|
||||
toolkit = LinkedInToolkit("install_auth")
|
||||
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
|
||||
tools = [
|
||||
tool_func.func.__name__
|
||||
for tool_func in toolkit.get_tools()
|
||||
]
|
||||
|
||||
# Try to get profile to verify token is valid
|
||||
profile = toolkit.get_profile_safe()
|
||||
|
|
@ -163,10 +205,22 @@ async def install_tool(tool: str):
|
|||
if LinkedInToolkit.is_token_expiring_soon():
|
||||
token_info = LinkedInToolkit.get_token_info()
|
||||
if token_info and token_info.get("expires_at"):
|
||||
days_remaining = (token_info["expires_at"] - int(time.time())) // (24 * 60 * 60)
|
||||
token_warning = f"Token expires in {days_remaining} days. Consider re-authenticating soon."
|
||||
days_remaining = (
|
||||
token_info["expires_at"] - int(time.time())
|
||||
) // (24 * 60 * 60)
|
||||
token_warning = (
|
||||
"Token expires in"
|
||||
f" {days_remaining}"
|
||||
" days. Consider"
|
||||
" re-authenticating"
|
||||
" soon."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully initialized LinkedIn toolkit with {len(tools)} tools")
|
||||
logger.info(
|
||||
"Successfully initialized"
|
||||
" LinkedIn toolkit with"
|
||||
f" {len(tools)} tools"
|
||||
)
|
||||
result = {
|
||||
"success": True,
|
||||
"tools": tools,
|
||||
|
|
@ -182,12 +236,20 @@ async def install_tool(tool: str):
|
|||
logger.warning(f"LinkedIn token may be invalid: {e}")
|
||||
# Token exists but may be expired/invalid
|
||||
return {
|
||||
"success": False,
|
||||
"status": "token_invalid",
|
||||
"message": "LinkedIn token may be expired or invalid. Please re-authenticate via OAuth.",
|
||||
"toolkit_name": "LinkedInToolkit",
|
||||
"requires_auth": True,
|
||||
"oauth_url": "/api/oauth/linkedin/login"
|
||||
"success":
|
||||
False,
|
||||
"status":
|
||||
"token_invalid",
|
||||
"message":
|
||||
"LinkedIn token may be expired"
|
||||
" or invalid. Please"
|
||||
" re-authenticate via OAuth.",
|
||||
"toolkit_name":
|
||||
"LinkedInToolkit",
|
||||
"requires_auth":
|
||||
True,
|
||||
"oauth_url":
|
||||
"/api/oauth/linkedin/login"
|
||||
}
|
||||
else:
|
||||
# No credentials - need OAuth authorization
|
||||
|
|
@ -195,7 +257,8 @@ async def install_tool(tool: str):
|
|||
return {
|
||||
"success": False,
|
||||
"status": "not_configured",
|
||||
"message": "LinkedIn OAuth required. Redirect user to OAuth login.",
|
||||
"message":
|
||||
"LinkedIn OAuth required. Redirect user to OAuth login.",
|
||||
"toolkit_name": "LinkedInToolkit",
|
||||
"requires_auth": True,
|
||||
"oauth_url": "/api/oauth/linkedin/login"
|
||||
|
|
@ -203,13 +266,18 @@ async def install_tool(tool: str):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to install {tool} toolkit: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to install {tool}: {str(e)}"
|
||||
status_code=500, detail=f"Failed to install {tool}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Tool '{tool}' not found. Available tools: ['notion', 'google_calendar', 'linkedin']"
|
||||
detail=(
|
||||
f"Tool '{tool}' not found."
|
||||
" Available tools:"
|
||||
" ['notion',"
|
||||
" 'google_calendar',"
|
||||
" 'linkedin']"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -224,26 +292,47 @@ async def list_available_tools():
|
|||
return {
|
||||
"tools": [
|
||||
{
|
||||
"name": "notion",
|
||||
"display_name": "Notion MCP",
|
||||
"description": "Notion workspace integration for reading and managing Notion pages",
|
||||
"toolkit_class": "NotionMCPToolkit",
|
||||
"requires_auth": True
|
||||
},
|
||||
{
|
||||
"name": "google_calendar",
|
||||
"display_name": "Google Calendar",
|
||||
"description": "Google Calendar integration for managing events and schedules",
|
||||
"toolkit_class": "GoogleCalendarToolkit",
|
||||
"requires_auth": True
|
||||
},
|
||||
{
|
||||
"name": "linkedin",
|
||||
"display_name": "LinkedIn",
|
||||
"description": "LinkedIn integration for creating posts, managing profile, and social media automation",
|
||||
"toolkit_class": "LinkedInToolkit",
|
||||
"requires_auth": True,
|
||||
"oauth_url": "/api/oauth/linkedin/login"
|
||||
"name":
|
||||
"notion",
|
||||
"display_name":
|
||||
"Notion MCP",
|
||||
"description":
|
||||
"Notion workspace integration"
|
||||
" for reading and managing"
|
||||
" Notion pages",
|
||||
"toolkit_class":
|
||||
"NotionMCPToolkit",
|
||||
"requires_auth":
|
||||
True
|
||||
}, {
|
||||
"name":
|
||||
"google_calendar",
|
||||
"display_name":
|
||||
"Google Calendar",
|
||||
"description":
|
||||
"Google Calendar integration"
|
||||
" for managing events"
|
||||
" and schedules",
|
||||
"toolkit_class":
|
||||
"GoogleCalendarToolkit",
|
||||
"requires_auth":
|
||||
True
|
||||
}, {
|
||||
"name":
|
||||
"linkedin",
|
||||
"display_name":
|
||||
"LinkedIn",
|
||||
"description":
|
||||
"LinkedIn integration for"
|
||||
" creating posts, managing"
|
||||
" profile, and social media"
|
||||
" automation",
|
||||
"toolkit_class":
|
||||
"LinkedInToolkit",
|
||||
"requires_auth":
|
||||
True,
|
||||
"oauth_url":
|
||||
"/api/oauth/linkedin/login"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -323,8 +412,8 @@ async def uninstall_tool(tool: str):
|
|||
|
||||
if tool == "notion":
|
||||
try:
|
||||
import hashlib
|
||||
import glob
|
||||
import hashlib
|
||||
|
||||
# Calculate the hash for Notion MCP URL
|
||||
# mcp-remote uses MD5 hash of the URL to generate file names
|
||||
|
|
@ -348,13 +437,21 @@ async def uninstall_tool(tool: str):
|
|||
try:
|
||||
os.remove(file_path)
|
||||
deleted_files.append(file_path)
|
||||
logger.info(f"Removed Notion auth file: {file_path}")
|
||||
logger.info(
|
||||
f"Removed Notion auth file: {file_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remove {file_path}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to remove {file_path}: {e}"
|
||||
)
|
||||
|
||||
message = f"Successfully uninstalled {tool}"
|
||||
if deleted_files:
|
||||
message += f" and cleaned up {len(deleted_files)} authentication file(s)"
|
||||
message += (
|
||||
" and cleaned up"
|
||||
f" {len(deleted_files)}"
|
||||
" authentication file(s)"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -373,16 +470,31 @@ async def uninstall_tool(tool: str):
|
|||
# Clean up Google Calendar token directories (user-scoped + legacy)
|
||||
token_dirs = set()
|
||||
try:
|
||||
token_dirs.add(os.path.dirname(GoogleCalendarToolkit._build_canonical_token_path()))
|
||||
token_dirs.add(
|
||||
os.path.dirname(
|
||||
GoogleCalendarToolkit._build_canonical_token_path()
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve canonical Google Calendar token path: {e}")
|
||||
logger.warning(
|
||||
"Failed to resolve canonical"
|
||||
" Google Calendar token"
|
||||
f" path: {e}"
|
||||
)
|
||||
|
||||
token_dirs.add(os.path.join(os.path.expanduser("~"), ".eigent", "tokens", "google_calendar"))
|
||||
token_dirs.add(
|
||||
os.path.join(
|
||||
os.path.expanduser("~"), ".eigent", "tokens",
|
||||
"google_calendar"
|
||||
)
|
||||
)
|
||||
|
||||
for token_dir in token_dirs:
|
||||
if os.path.exists(token_dir):
|
||||
shutil.rmtree(token_dir)
|
||||
logger.info(f"Removed Google Calendar token directory: {token_dir}")
|
||||
logger.info(
|
||||
f"Removed Google Calendar token directory: {token_dir}"
|
||||
)
|
||||
|
||||
# Clear OAuth state manager cache (this is the key fix!)
|
||||
# This removes the cached credentials from memory
|
||||
|
|
@ -390,14 +502,20 @@ async def uninstall_tool(tool: str):
|
|||
if state:
|
||||
if state.status in ["pending", "authorizing"]:
|
||||
state.cancel()
|
||||
logger.info("Cancelled ongoing Google Calendar authorization")
|
||||
logger.info(
|
||||
"Cancelled ongoing Google Calendar authorization"
|
||||
)
|
||||
# Clear the state completely to remove cached credentials
|
||||
oauth_state_manager._states.pop("google_calendar", None)
|
||||
logger.info("Cleared Google Calendar OAuth state cache")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully uninstalled {tool} and cleaned up authentication tokens"
|
||||
"success":
|
||||
True,
|
||||
"message":
|
||||
"Successfully uninstalled"
|
||||
f" {tool} and cleaned up"
|
||||
" authentication tokens"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to uninstall {tool}: {e}")
|
||||
|
|
@ -412,13 +530,18 @@ async def uninstall_tool(tool: str):
|
|||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully uninstalled {tool} and cleaned up authentication tokens"
|
||||
"success":
|
||||
True,
|
||||
"message":
|
||||
"Successfully uninstalled"
|
||||
f" {tool} and cleaned up"
|
||||
" authentication tokens"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Uninstalled {tool} (no tokens found to clean up)"
|
||||
"message":
|
||||
f"Uninstalled {tool} (no tokens found to clean up)"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to uninstall {tool}: {e}")
|
||||
|
|
@ -429,7 +552,13 @@ async def uninstall_tool(tool: str):
|
|||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Tool '{tool}' not found. Available tools: ['notion', 'google_calendar', 'linkedin']"
|
||||
detail=(
|
||||
f"Tool '{tool}' not found."
|
||||
" Available tools:"
|
||||
" ['notion',"
|
||||
" 'google_calendar',"
|
||||
" 'linkedin']"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -438,7 +567,8 @@ async def save_linkedin_token(token_request: LinkedInTokenRequest):
|
|||
r"""Save LinkedIn OAuth token after successful authorization.
|
||||
|
||||
Args:
|
||||
token_request: Token data containing access_token and optionally refresh_token
|
||||
token_request: Token data containing
|
||||
access_token and optionally refresh_token
|
||||
|
||||
Returns:
|
||||
Save result with tool information
|
||||
|
|
@ -453,7 +583,10 @@ async def save_linkedin_token(token_request: LinkedInTokenRequest):
|
|||
# Verify the token works by initializing toolkit
|
||||
try:
|
||||
toolkit = LinkedInToolkit("install_auth")
|
||||
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
|
||||
tools = [
|
||||
tool_func.func.__name__
|
||||
for tool_func in toolkit.get_tools()
|
||||
]
|
||||
profile = toolkit.get_profile_safe()
|
||||
|
||||
return {
|
||||
|
|
@ -472,16 +605,14 @@ async def save_linkedin_token(token_request: LinkedInTokenRequest):
|
|||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to save LinkedIn token"
|
||||
status_code=500, detail="Failed to save LinkedIn token"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save LinkedIn token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to save token: {str(e)}"
|
||||
status_code=500, detail=f"Failed to save token: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -508,8 +639,11 @@ async def get_linkedin_status():
|
|||
is_expiring_soon = LinkedInToolkit.is_token_expiring_soon()
|
||||
|
||||
result = {
|
||||
"authenticated": True,
|
||||
"status": "expired" if is_expired else ("expiring_soon" if is_expiring_soon else "valid"),
|
||||
"authenticated":
|
||||
True,
|
||||
"status":
|
||||
"expired" if is_expired else
|
||||
("expiring_soon" if is_expiring_soon else "valid"),
|
||||
}
|
||||
|
||||
if token_info:
|
||||
|
|
@ -528,7 +662,12 @@ async def get_linkedin_status():
|
|||
result["message"] = "Token has expired. Please re-authenticate."
|
||||
result["oauth_url"] = "/api/oauth/linkedin/login"
|
||||
elif is_expiring_soon:
|
||||
result["message"] = f"Token expires in {result.get('days_remaining', 'unknown')} days. Consider re-authenticating."
|
||||
days = result.get('days_remaining', 'unknown')
|
||||
result["message"] = (
|
||||
f"Token expires in {days}"
|
||||
" days. Consider"
|
||||
" re-authenticating."
|
||||
)
|
||||
result["oauth_url"] = "/api/oauth/linkedin/login"
|
||||
else:
|
||||
result["message"] = "LinkedIn is connected and token is valid."
|
||||
|
|
@ -537,45 +676,51 @@ async def get_linkedin_status():
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to get LinkedIn status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get status: {str(e)}"
|
||||
status_code=500, detail=f"Failed to get status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/browser/login", name="open browser for login")
|
||||
async def open_browser_login():
|
||||
"""
|
||||
Open an Electron-based Chrome browser for user login with a dedicated user data directory
|
||||
Open an Electron-based Chrome browser for
|
||||
user login with a dedicated user data directory
|
||||
|
||||
Returns:
|
||||
Browser session information
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import platform
|
||||
import socket
|
||||
import json
|
||||
|
||||
import subprocess
|
||||
|
||||
# Use fixed profile name for persistent logins (no port suffix)
|
||||
session_id = "user_login"
|
||||
cdp_port = 9223
|
||||
|
||||
# IMPORTANT: Use dedicated profile for tool_controller browser
|
||||
# This is the SOURCE OF TRUTH for login data
|
||||
# On Eigent startup, this data will be copied to WebView partition (one-way sync)
|
||||
browser_profiles_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(browser_profiles_base, "profile_user_login")
|
||||
# On Eigent startup, this data will be copied
|
||||
# to WebView partition (one-way sync)
|
||||
browser_profiles_base = os.path.expanduser(
|
||||
"~/.eigent/browser_profiles"
|
||||
)
|
||||
user_data_dir = os.path.join(
|
||||
browser_profiles_base, "profile_user_login"
|
||||
)
|
||||
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
|
||||
logger.info(
|
||||
f"Creating browser session {session_id} with profile at: {user_data_dir}")
|
||||
|
||||
"Creating browser session"
|
||||
f" {session_id} with profile"
|
||||
f" at: {user_data_dir}"
|
||||
)
|
||||
|
||||
# Check if browser is already running on this port
|
||||
def is_port_in_use(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
|
||||
if is_port_in_use(cdp_port):
|
||||
logger.info(f"Browser already running on port {cdp_port}")
|
||||
return {
|
||||
|
|
@ -583,31 +728,38 @@ async def open_browser_login():
|
|||
"session_id": session_id,
|
||||
"user_data_dir": user_data_dir,
|
||||
"cdp_port": cdp_port,
|
||||
"message": "Browser already running. Use existing window to log in.",
|
||||
"message":
|
||||
"Browser already running. Use existing window to log in.",
|
||||
"note": "Your login data will be saved in the profile."
|
||||
}
|
||||
|
||||
|
||||
# Use static Electron browser script
|
||||
electron_script_path = os.path.join(os.path.dirname(__file__), "electron_browser.cjs")
|
||||
electron_script_path = os.path.join(
|
||||
os.path.dirname(__file__), "electron_browser.cjs"
|
||||
)
|
||||
|
||||
# Verify script exists
|
||||
if not os.path.exists(electron_script_path):
|
||||
raise FileNotFoundError(f"Electron browser script not found: {electron_script_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Electron browser script not found: {electron_script_path}"
|
||||
)
|
||||
|
||||
electron_cmd = "npx"
|
||||
electron_args = [
|
||||
electron_cmd,
|
||||
"electron",
|
||||
electron_script_path,
|
||||
user_data_dir,
|
||||
str(cdp_port),
|
||||
"https://www.google.com"
|
||||
electron_cmd, "electron", electron_script_path, user_data_dir,
|
||||
str(cdp_port), "https://www.google.com"
|
||||
]
|
||||
|
||||
|
||||
# Get the app's directory to run npx in the right context
|
||||
app_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
logger.info(f"[PROFILE USER LOGIN] Launching Electron browser with CDP on port {cdp_port}")
|
||||
app_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[PROFILE USER LOGIN] Launching"
|
||||
" Electron browser with CDP"
|
||||
f" on port {cdp_port}"
|
||||
)
|
||||
logger.info(f"[PROFILE USER LOGIN] Working directory: {app_dir}")
|
||||
logger.info(f"[PROFILE USER LOGIN] userData path: {user_data_dir}")
|
||||
logger.info(f"[PROFILE USER LOGIN] Electron args: {electron_args}")
|
||||
|
|
@ -632,29 +784,43 @@ async def open_browser_login():
|
|||
|
||||
import asyncio
|
||||
asyncio.create_task(log_electron_output())
|
||||
|
||||
|
||||
# Wait a bit for Electron to start
|
||||
import asyncio
|
||||
await asyncio.sleep(3)
|
||||
|
||||
logger.info(f"[PROFILE USER LOGIN] Electron browser launched with PID {process.pid}")
|
||||
logger.info(
|
||||
"[PROFILE USER LOGIN] Electron"
|
||||
" browser launched with"
|
||||
f" PID {process.pid}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"user_data_dir": user_data_dir,
|
||||
"cdp_port": cdp_port,
|
||||
"pid": process.pid,
|
||||
"chrome_version": "130.0.6723.191", # Electron 33's Chrome version
|
||||
"message": "Electron browser opened successfully. Please log in to your accounts.",
|
||||
"note": "The browser will remain open for you to log in. Your login data will be saved in the profile."
|
||||
"success":
|
||||
True,
|
||||
"session_id":
|
||||
session_id,
|
||||
"user_data_dir":
|
||||
user_data_dir,
|
||||
"cdp_port":
|
||||
cdp_port,
|
||||
"pid":
|
||||
process.pid,
|
||||
"chrome_version":
|
||||
"130.0.6723.191", # Electron 33's Chrome version
|
||||
"message":
|
||||
"Electron browser opened successfully."
|
||||
" Please log in to your accounts.",
|
||||
"note":
|
||||
"The browser will remain open for"
|
||||
" you to log in. Your login data"
|
||||
" will be saved in the profile."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open Electron browser for login: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to open browser: {str(e)}"
|
||||
status_code=500, detail=f"Failed to open browser: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -674,21 +840,37 @@ async def list_cookie_domains(search: str = None):
|
|||
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(user_data_base, "profile_user_login")
|
||||
|
||||
logger.info(f"[COOKIES CHECK] Tool controller user_data_dir: {user_data_dir}")
|
||||
logger.info(f"[COOKIES CHECK] Tool controller user_data_dir exists: {os.path.exists(user_data_dir)}")
|
||||
logger.info(
|
||||
f"[COOKIES CHECK] Tool controller user_data_dir: {user_data_dir}"
|
||||
)
|
||||
logger.info(
|
||||
"[COOKIES CHECK] Tool controller"
|
||||
" user_data_dir exists:"
|
||||
f" {os.path.exists(user_data_dir)}"
|
||||
)
|
||||
|
||||
# Check partition path
|
||||
partition_path = os.path.join(user_data_dir, "Partitions", "user_login")
|
||||
partition_path = os.path.join(
|
||||
user_data_dir, "Partitions", "user_login"
|
||||
)
|
||||
logger.info(f"[COOKIES CHECK] partition path: {partition_path}")
|
||||
logger.info(f"[COOKIES CHECK] partition exists: {os.path.exists(partition_path)}")
|
||||
logger.info(
|
||||
"[COOKIES CHECK] partition"
|
||||
f" exists: {os.path.exists(partition_path)}"
|
||||
)
|
||||
|
||||
# Check cookies file
|
||||
cookies_file = os.path.join(partition_path, "Cookies")
|
||||
logger.info(f"[COOKIES CHECK] cookies file: {cookies_file}")
|
||||
logger.info(f"[COOKIES CHECK] cookies file exists: {os.path.exists(cookies_file)}")
|
||||
logger.info(
|
||||
"[COOKIES CHECK] cookies file"
|
||||
f" exists: {os.path.exists(cookies_file)}"
|
||||
)
|
||||
if os.path.exists(cookies_file):
|
||||
stat = os.stat(cookies_file)
|
||||
logger.info(f"[COOKIES CHECK] cookies file size: {stat.st_size} bytes")
|
||||
logger.info(
|
||||
f"[COOKIES CHECK] cookies file size: {stat.st_size} bytes"
|
||||
)
|
||||
|
||||
# Try to read actual cookie count
|
||||
try:
|
||||
|
|
@ -697,16 +879,24 @@ async def list_cookie_domains(search: str = None):
|
|||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM cookies")
|
||||
count = cursor.fetchone()[0]
|
||||
logger.info(f"[COOKIES CHECK] actual cookie count in database: {count}")
|
||||
logger.info(
|
||||
f"[COOKIES CHECK] actual cookie count in database: {count}"
|
||||
)
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[COOKIES CHECK] failed to read cookie count: {e}")
|
||||
logger.error(
|
||||
f"[COOKIES CHECK] failed to read cookie count: {e}"
|
||||
)
|
||||
|
||||
if not os.path.exists(user_data_dir):
|
||||
return {
|
||||
"success": True,
|
||||
"success":
|
||||
True,
|
||||
"domains": [],
|
||||
"message": "No browser profile found. Please login first using /browser/login."
|
||||
"message":
|
||||
"No browser profile found."
|
||||
" Please login first"
|
||||
" using /browser/login."
|
||||
}
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
|
|
@ -726,8 +916,7 @@ async def list_cookie_domains(search: str = None):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to list cookie domains: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to list cookies: {str(e)}"
|
||||
status_code=500, detail=f"Failed to list cookies: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -749,7 +938,11 @@ async def get_domain_cookies(domain: str):
|
|||
if not os.path.exists(user_data_dir):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No browser profile found. Please login first using /browser/login."
|
||||
detail=(
|
||||
"No browser profile found."
|
||||
" Please login first using"
|
||||
" /browser/login."
|
||||
)
|
||||
)
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
|
|
@ -767,8 +960,7 @@ async def get_domain_cookies(domain: str):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to get cookies for domain {domain}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get cookies: {str(e)}"
|
||||
status_code=500, detail=f"Failed to get cookies: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -790,7 +982,11 @@ async def delete_domain_cookies(domain: str):
|
|||
if not os.path.exists(user_data_dir):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No browser profile found. Please login first using /browser/login."
|
||||
detail=(
|
||||
"No browser profile found."
|
||||
" Please login first using"
|
||||
" /browser/login."
|
||||
)
|
||||
)
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
|
|
@ -812,8 +1008,7 @@ async def delete_domain_cookies(domain: str):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to delete cookies for domain {domain}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete cookies: {str(e)}"
|
||||
status_code=500, detail=f"Failed to delete cookies: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -831,8 +1026,7 @@ async def delete_all_cookies():
|
|||
|
||||
if not os.path.exists(user_data_dir):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No browser profile found."
|
||||
status_code=404, detail="No browser profile found."
|
||||
)
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
|
|
@ -845,8 +1039,7 @@ async def delete_all_cookies():
|
|||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete all cookies"
|
||||
status_code=500, detail="Failed to delete all cookies"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -854,6 +1047,5 @@ async def delete_all_cookies():
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to delete all cookies: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete cookies: {str(e)}"
|
||||
status_code=500, detail=f"Failed to delete cookies: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
0
backend/app/exception/__init__.py
Normal file
0
backend/app/exception/__init__.py
Normal file
|
|
@ -12,23 +12,28 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
|
||||
class UserException(Exception):
|
||||
|
||||
def __init__(self, code: int, description: str):
|
||||
self.code = code
|
||||
self.description = description
|
||||
|
||||
|
||||
class TokenException(Exception):
|
||||
|
||||
def __init__(self, code: int, text: str):
|
||||
self.code = code
|
||||
self.text = text
|
||||
|
||||
|
||||
class NoPermissionException(Exception):
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
|
||||
class ProgramException(Exception):
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
|
|
|||
|
|
@ -12,17 +12,22 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app import api
|
||||
from app.component import code
|
||||
from app.exception.exception import NoPermissionException, ProgramException, TokenException
|
||||
from app.component.pydantic.i18n import trans, get_language
|
||||
from app.exception.exception import UserException
|
||||
import logging
|
||||
from app.component.pydantic.i18n import get_language, trans
|
||||
from app.exception.exception import (
|
||||
NoPermissionException,
|
||||
ProgramException,
|
||||
TokenException,
|
||||
UserException,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("exception_handler")
|
||||
|
||||
|
|
@ -32,11 +37,13 @@ async def request_exception(request: Request, e: RequestValidationError):
|
|||
if (lang := get_language(request.headers.get("Accept-Language"))) is None:
|
||||
lang = "en_US"
|
||||
logger.warning(f"Validation error on {request.url.path}: {e.errors()}")
|
||||
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"code": code.form_error,
|
||||
"error": jsonable_encoder(trans.translate(list(e.errors()), locale=lang)),
|
||||
"code":
|
||||
code.form_error,
|
||||
"error":
|
||||
jsonable_encoder(trans.translate(list(e.errors()), locale=lang)),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -58,16 +65,27 @@ async def no_permission(request: Request, exception: NoPermissionException):
|
|||
logger.warning(f"No permission on {request.url.path}: {exception.text}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"code": code.no_permission_error, "text": exception.text},
|
||||
content={
|
||||
"code": code.no_permission_error,
|
||||
"text": exception.text
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@api.exception_handler(ProgramException)
|
||||
async def program_exception(request: Request, exception: NoPermissionException):
|
||||
logger.error(f"Program exception on {request.url.path}: {exception.text}", exc_info=True)
|
||||
async def program_exception(
|
||||
request: Request, exception: NoPermissionException
|
||||
):
|
||||
logger.error(
|
||||
f"Program exception on {request.url.path}: {exception.text}",
|
||||
exc_info=True
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"code": code.program_error, "text": exception.text},
|
||||
content={
|
||||
"code": code.program_error,
|
||||
"text": exception.text
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from app import api
|
||||
from app.component.babel import babel_configs
|
||||
from fastapi_babel import BabelMiddleware
|
||||
|
||||
from app import api
|
||||
from app.component.babel import babel_configs
|
||||
|
||||
api.add_middleware(BabelMiddleware, babel_configs=babel_configs)
|
||||
|
|
|
|||
0
backend/app/model/__init__.py
Normal file
0
backend/app/model/__init__.py
Normal file
|
|
@ -12,14 +12,17 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from enum import Enum
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from camel.types import ModelType, RoleType
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from camel.types import ModelType, RoleType
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.agent.prompt import DEFAULT_SUMMARY_PROMPT
|
||||
|
||||
logger = logging.getLogger("chat_model")
|
||||
|
||||
|
|
@ -42,7 +45,8 @@ class QuestionAnalysisResult(BaseModel):
|
|||
)
|
||||
answer: str | None = Field(
|
||||
default=None,
|
||||
description="Direct answer for simple questions. None for complex tasks."
|
||||
description="Direct answer for simple questions."
|
||||
" None for complex tasks."
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -53,6 +57,7 @@ PLATFORM_MAPPING = {
|
|||
"ModelArk": "openai-compatible-model",
|
||||
}
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
task_id: str
|
||||
project_id: str
|
||||
|
|
@ -62,7 +67,8 @@ class Chat(BaseModel):
|
|||
model_platform: str
|
||||
model_type: str
|
||||
api_key: str
|
||||
api_url: str | None = None # for cloud version, user don't need to set api_url
|
||||
# for cloud version, user don't need to set api_url
|
||||
api_url: str | None = None
|
||||
language: str = "en"
|
||||
browser_port: int = 9222
|
||||
max_retries: int = 3
|
||||
|
|
@ -71,17 +77,13 @@ class Chat(BaseModel):
|
|||
bun_mirror: str = ""
|
||||
uvx_mirror: str = ""
|
||||
env_path: str | None = None
|
||||
summary_prompt: str = (
|
||||
"After completing the task, please generate a summary of the entire task completion. "
|
||||
"The summary must be enclosed in <summary></summary> tags and include:\n"
|
||||
"1. A confirmation of task completion, referencing the original goal.\n"
|
||||
"2. A high-level overview of the work performed and the final outcome.\n"
|
||||
"3. A bulleted list of key results or accomplishments.\n"
|
||||
"Adopt a confident and professional tone."
|
||||
)
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT
|
||||
new_agents: list["NewAgent"] = []
|
||||
extra_params: dict | None = None # For provider-specific parameters like Azure
|
||||
search_config: dict[str, str] | None = None # User-specific search engine configurations (e.g., GOOGLE_API_KEY, SEARCH_ENGINE_ID)
|
||||
# For provider-specific parameters like Azure
|
||||
extra_params: dict | None = None
|
||||
# User-specific search engine configurations
|
||||
# (e.g., GOOGLE_API_KEY, SEARCH_ENGINE_ID)
|
||||
search_config: dict[str, str] | None = None
|
||||
|
||||
@field_validator("model_platform")
|
||||
@classmethod
|
||||
|
|
@ -99,18 +101,27 @@ class Chat(BaseModel):
|
|||
return model_type
|
||||
|
||||
def get_bun_env(self) -> dict[str, str]:
|
||||
return {"NPM_CONFIG_REGISTRY": self.bun_mirror} if self.bun_mirror else {}
|
||||
return {
|
||||
"NPM_CONFIG_REGISTRY": self.bun_mirror
|
||||
} if self.bun_mirror else {}
|
||||
|
||||
def get_uvx_env(self) -> dict[str, str]:
|
||||
return {"UV_DEFAULT_INDEX": self.uvx_mirror, "PIP_INDEX_URL": self.uvx_mirror} if self.uvx_mirror else {}
|
||||
return {
|
||||
"UV_DEFAULT_INDEX": self.uvx_mirror,
|
||||
"PIP_INDEX_URL": self.uvx_mirror
|
||||
} if self.uvx_mirror else {}
|
||||
|
||||
def is_cloud(self):
|
||||
return self.api_url is not None and "44.247.171.124" in self.api_url
|
||||
|
||||
def file_save_path(self, path: str | None = None):
|
||||
email = re.sub(r'[\\/*?:"<>|\s]', "_", self.email.split("@")[0]).strip(".")
|
||||
email = re.sub(r'[\\/*?:"<>|\s]', "_",
|
||||
self.email.split("@")[0]).strip(".")
|
||||
# Use project-based structure: project_{project_id}/task_{task_id}
|
||||
save_path = Path.home() / "eigent" / email / f"project_{self.project_id}" / f"task_{self.task_id}"
|
||||
save_path = (
|
||||
Path.home() / "eigent" / email / f"project_{self.project_id}" /
|
||||
f"task_{self.task_id}"
|
||||
)
|
||||
if path is not None:
|
||||
save_path = save_path / path
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -138,7 +149,8 @@ class UpdateData(BaseModel):
|
|||
|
||||
|
||||
class AgentModelConfig(BaseModel):
|
||||
"""Optional per-agent model configuration to override the default task model."""
|
||||
"""Optional per-agent model configuration
|
||||
to override the default task model."""
|
||||
model_platform: str | None = None
|
||||
model_type: str | None = None
|
||||
api_key: str | None = None
|
||||
|
|
@ -147,13 +159,15 @@ class AgentModelConfig(BaseModel):
|
|||
|
||||
def has_custom_config(self) -> bool:
|
||||
"""Check if any custom model configuration is set."""
|
||||
return any([
|
||||
self.model_platform is not None,
|
||||
self.model_type is not None,
|
||||
self.api_key is not None,
|
||||
self.api_url is not None,
|
||||
self.extra_params is not None,
|
||||
])
|
||||
return any(
|
||||
[
|
||||
self.model_platform is not None,
|
||||
self.model_type is not None,
|
||||
self.api_key is not None,
|
||||
self.api_url is not None,
|
||||
self.extra_params is not None,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class NewAgent(BaseModel):
|
||||
|
|
@ -177,6 +191,7 @@ class AddTaskRequest(BaseModel):
|
|||
class RemoveTaskRequest(BaseModel):
|
||||
task_id: str
|
||||
|
||||
|
||||
def sse_json(step: str, data):
|
||||
res_format = {"step": step, "data": data}
|
||||
return f"data: {json.dumps(res_format, ensure_ascii=False)}\n\n"
|
||||
|
|
|
|||
|
|
@ -11,27 +11,35 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""
|
||||
Centralized router registration for the Eigent API.
|
||||
All routers are explicitly registered here for better visibility and maintainability.
|
||||
All routers are explicitly registered here
|
||||
for better visibility and maintainability.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
from app.controller import chat_controller, model_controller, task_controller, tool_controller, health_controller
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.controller import (
|
||||
chat_controller,
|
||||
health_controller,
|
||||
model_controller,
|
||||
task_controller,
|
||||
tool_controller,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("router")
|
||||
|
||||
|
||||
def register_routers(app: FastAPI, prefix: str = "") -> None:
|
||||
"""
|
||||
Register all API routers with their respective prefixes and tags.
|
||||
|
||||
|
||||
This replaces the auto-discovery mechanism for better:
|
||||
- Visibility: See all routes in one place
|
||||
- Maintainability: Easy to add/remove routes
|
||||
- Debugging: Clear registration order and configuration
|
||||
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
prefix: Optional global prefix for all routes (e.g., "/api")
|
||||
|
|
@ -43,9 +51,11 @@ def register_routers(app: FastAPI, prefix: str = "") -> None:
|
|||
"description": "Health check endpoint for service readiness"
|
||||
},
|
||||
{
|
||||
"router": chat_controller.router,
|
||||
"router":
|
||||
chat_controller.router,
|
||||
"tags": ["chat"],
|
||||
"description": "Chat session management, improvements, and human interactions"
|
||||
"description":
|
||||
"Chat session management, improvements, and human interactions"
|
||||
},
|
||||
{
|
||||
"router": model_controller.router,
|
||||
|
|
@ -53,26 +63,28 @@ def register_routers(app: FastAPI, prefix: str = "") -> None:
|
|||
"description": "Model validation and configuration"
|
||||
},
|
||||
{
|
||||
"router": task_controller.router,
|
||||
"router":
|
||||
task_controller.router,
|
||||
"tags": ["task"],
|
||||
"description": "Task lifecycle management (start, stop, update, control)"
|
||||
"description":
|
||||
"Task lifecycle management (start, stop, update, control)"
|
||||
},
|
||||
{
|
||||
"router": tool_controller.router,
|
||||
"tags": ["tool"],
|
||||
"tags": ["tool"],
|
||||
"description": "Tool installation and management"
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for config in routers_config:
|
||||
app.include_router(
|
||||
config["router"],
|
||||
prefix=prefix,
|
||||
tags=config["tags"]
|
||||
config["router"], prefix=prefix, tags=config["tags"]
|
||||
)
|
||||
route_count = len(config["router"].routes)
|
||||
logger.info(
|
||||
f"Registered {config['tags'][0]} router: {route_count} routes - {config['description']}"
|
||||
f"Registered {config['tags'][0]} router:"
|
||||
f" {route_count} routes -"
|
||||
f" {config['description']}"
|
||||
)
|
||||
|
||||
logger.info(f"Total routers registered: {len(routers_config)}")
|
||||
|
||||
logger.info(f"Total routers registered: {len(routers_config)}")
|
||||
|
|
|
|||
0
backend/app/service/__init__.py
Normal file
0
backend/app/service/__init__.py
Normal file
0
backend/app/utils/listen/__init__.py
Normal file
0
backend/app/utils/listen/__init__.py
Normal file
0
backend/app/utils/server/__init__.py
Normal file
0
backend/app/utils/server/__init__.py
Normal file
0
backend/app/utils/toolkit/__init__.py
Normal file
0
backend/app/utils/toolkit/__init__.py
Normal file
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import browser_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -31,14 +31,15 @@ def test_browser_agent_creation(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.browser.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.browser'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'), \
|
||||
patch('app.agent.factory.browser.HumanToolkit') as mock_human_toolkit, \
|
||||
patch('app.agent.factory.browser.HybridBrowserToolkit') as mock_browser_toolkit, \
|
||||
patch('app.agent.factory.browser.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch('app.agent.factory.browser.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch('app.agent.factory.browser.SearchToolkit') as mock_search_toolkit, \
|
||||
patch('app.agent.factory.browser.ToolkitMessageIntegration'), \
|
||||
patch(f'{_mod}.HumanToolkit') as mock_human_toolkit, \
|
||||
patch(f'{_mod}.HybridBrowserToolkit') as mock_browser_toolkit, \
|
||||
patch(f'{_mod}.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch(f'{_mod}.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch(f'{_mod}.SearchToolkit') as mock_search_toolkit, \
|
||||
patch(f'{_mod}.ToolkitMessageIntegration'), \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
|
||||
# Mock all toolkit instances
|
||||
|
|
@ -66,10 +67,13 @@ def test_browser_agent_creation(sample_chat_data):
|
|||
|
||||
# Check that it was called with browser agent configuration
|
||||
call_args = mock_agent_model.call_args
|
||||
assert "browser_agent" in str(call_args[0][0]) # agent_name (enum contains this value)
|
||||
assert "browser_agent" in str(
|
||||
call_args[0][0]
|
||||
) # agent_name (enum contains this value)
|
||||
# The system_prompt is a BaseMessage, so check its content attribute
|
||||
system_message = call_args[0][1]
|
||||
if hasattr(system_message, 'content'):
|
||||
assert "search" in system_message.content.lower()
|
||||
else:
|
||||
assert "search" in str(system_message).lower() # system_prompt contains search
|
||||
assert "search" in str(system_message).lower(
|
||||
) # system_prompt contains search
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import developer_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -32,14 +32,15 @@ async def test_developer_agent_creation(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.developer.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.developer'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'), \
|
||||
patch('app.agent.factory.developer.HumanToolkit') as mock_human_toolkit, \
|
||||
patch('app.agent.factory.developer.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch('app.agent.factory.developer.WebDeployToolkit') as mock_web_toolkit, \
|
||||
patch('app.agent.factory.developer.ScreenshotToolkit') as mock_screenshot_toolkit, \
|
||||
patch('app.agent.factory.developer.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch('app.agent.factory.developer.ToolkitMessageIntegration'):
|
||||
patch(f'{_mod}.HumanToolkit') as mock_human_toolkit, \
|
||||
patch(f'{_mod}.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch(f'{_mod}.WebDeployToolkit') as mock_web_toolkit, \
|
||||
patch(f'{_mod}.ScreenshotToolkit') as mock_screenshot_toolkit, \
|
||||
patch(f'{_mod}.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch(f'{_mod}.ToolkitMessageIntegration'):
|
||||
|
||||
# Mock all toolkit instances
|
||||
mock_human_toolkit.get_can_use_tools.return_value = []
|
||||
|
|
@ -58,7 +59,9 @@ async def test_developer_agent_creation(sample_chat_data):
|
|||
|
||||
# Should have called with development-related tools
|
||||
call_args = mock_agent_model.call_args
|
||||
assert "developer_agent" in str(call_args[0][0]) # agent_name (enum contains this value)
|
||||
assert "developer_agent" in str(
|
||||
call_args[0][0]
|
||||
) # agent_name (enum contains this value)
|
||||
tools_arg = call_args[0][3] # tools argument
|
||||
assert isinstance(tools_arg, list)
|
||||
|
||||
|
|
@ -73,14 +76,15 @@ async def test_developer_agent_with_multiple_toolkits(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.developer.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.developer'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'), \
|
||||
patch('app.agent.factory.developer.HumanToolkit') as mock_human_toolkit, \
|
||||
patch('app.agent.factory.developer.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch('app.agent.factory.developer.WebDeployToolkit') as mock_web_toolkit, \
|
||||
patch('app.agent.factory.developer.ScreenshotToolkit') as mock_screenshot_toolkit, \
|
||||
patch('app.agent.factory.developer.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch('app.agent.factory.developer.ToolkitMessageIntegration'):
|
||||
patch(f'{_mod}.HumanToolkit') as mock_human_toolkit, \
|
||||
patch(f'{_mod}.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch(f'{_mod}.WebDeployToolkit') as mock_web_toolkit, \
|
||||
patch(f'{_mod}.ScreenshotToolkit') as mock_screenshot_toolkit, \
|
||||
patch(f'{_mod}.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch(f'{_mod}.ToolkitMessageIntegration'):
|
||||
|
||||
# Mock all toolkit instances
|
||||
mock_human_toolkit.get_can_use_tools.return_value = []
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import document_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -32,17 +32,18 @@ async def test_document_agent_creation(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.document.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.document'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'), \
|
||||
patch('app.agent.factory.document.HumanToolkit') as mock_human_toolkit, \
|
||||
patch('app.agent.factory.document.FileToolkit') as mock_file_toolkit, \
|
||||
patch('app.agent.factory.document.PPTXToolkit') as mock_pptx_toolkit, \
|
||||
patch('app.agent.factory.document.MarkItDownToolkit') as mock_markdown_toolkit, \
|
||||
patch('app.agent.factory.document.ExcelToolkit') as mock_excel_toolkit, \
|
||||
patch('app.agent.factory.document.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch('app.agent.factory.document.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch('app.agent.factory.document.GoogleDriveMCPToolkit') as mock_gdrive_toolkit, \
|
||||
patch('app.agent.factory.document.ToolkitMessageIntegration'):
|
||||
patch(f'{_mod}.HumanToolkit') as mock_human_toolkit, \
|
||||
patch(f'{_mod}.FileToolkit') as mock_file_toolkit, \
|
||||
patch(f'{_mod}.PPTXToolkit') as mock_pptx_toolkit, \
|
||||
patch(f'{_mod}.MarkItDownToolkit') as mock_markdown_toolkit, \
|
||||
patch(f'{_mod}.ExcelToolkit') as mock_excel_toolkit, \
|
||||
patch(f'{_mod}.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch(f'{_mod}.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch(f'{_mod}.GoogleDriveMCPToolkit') as mock_gdrive_toolkit, \
|
||||
patch(f'{_mod}.ToolkitMessageIntegration'):
|
||||
|
||||
# Mock all toolkit instances
|
||||
mock_human_toolkit.get_can_use_tools.return_value = []
|
||||
|
|
@ -64,4 +65,6 @@ async def test_document_agent_creation(sample_chat_data):
|
|||
|
||||
# Should have called with document-related tools
|
||||
call_args = mock_agent_model.call_args
|
||||
assert "document_agent" in str(call_args[0][0]) # agent_name (enum contains this value)
|
||||
assert "document_agent" in str(
|
||||
call_args[0][0]
|
||||
) # agent_name (enum contains this value)
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import mcp_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -27,12 +27,13 @@ async def test_mcp_agent_creation(sample_chat_data):
|
|||
"""Test mcp_agent creates agent with MCP tools."""
|
||||
options = Chat(**sample_chat_data)
|
||||
|
||||
with patch('app.agent.factory.mcp.ListenChatAgent') as mock_listen_agent, \
|
||||
patch('app.agent.factory.mcp.ModelFactory.create') as mock_model_factory, \
|
||||
_mod = 'app.agent.factory.mcp'
|
||||
with patch(f'{_mod}.ListenChatAgent') as mock_listen_agent, \
|
||||
patch(f'{_mod}.ModelFactory.create') as mock_model_factory, \
|
||||
patch('asyncio.create_task'), \
|
||||
patch('app.agent.factory.mcp.McpSearchToolkit') as mock_mcp_search_toolkit, \
|
||||
patch('app.agent.factory.mcp.get_mcp_tools') as mock_get_mcp_tools, \
|
||||
patch('app.agent.factory.mcp.get_task_lock') as mock_get_task_lock:
|
||||
patch(f'{_mod}.McpSearchToolkit') as mock_mcp_search_toolkit, \
|
||||
patch(f'{_mod}.get_mcp_tools') as mock_get_mcp_tools, \
|
||||
patch(f'{_mod}.get_task_lock'):
|
||||
|
||||
# Mock toolkit instances
|
||||
mock_mcp_search_toolkit.return_value.get_tools.return_value = []
|
||||
|
|
@ -49,4 +50,6 @@ async def test_mcp_agent_creation(sample_chat_data):
|
|||
|
||||
# Check that it was called with MCP agent configuration
|
||||
call_args = mock_listen_agent.call_args
|
||||
assert "mcp_agent" in str(call_args[0][1]) # agent_name (enum contains this value)
|
||||
assert "mcp_agent" in str(
|
||||
call_args[0][1]
|
||||
) # agent_name (enum contains this value)
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import multi_modal_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -31,17 +31,18 @@ def test_multi_modal_agent_creation(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.multi_modal.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.multi_modal'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'), \
|
||||
patch('app.agent.factory.multi_modal.HumanToolkit') as mock_human_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.VideoDownloaderToolkit') as mock_video_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.ImageAnalysisToolkit') as mock_image_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.OpenAIImageToolkit') as mock_openai_image_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.AudioAnalysisToolkit') as mock_audio_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.SearchToolkit') as mock_search_toolkit, \
|
||||
patch('app.agent.factory.multi_modal.ToolkitMessageIntegration'):
|
||||
patch(f'{_mod}.HumanToolkit') as mock_human_toolkit, \
|
||||
patch(f'{_mod}.VideoDownloaderToolkit') as mock_video_toolkit, \
|
||||
patch(f'{_mod}.ImageAnalysisToolkit') as mock_image_toolkit, \
|
||||
patch(f'{_mod}.OpenAIImageToolkit') as mock_openai_image_toolkit, \
|
||||
patch(f'{_mod}.AudioAnalysisToolkit') as mock_audio_toolkit, \
|
||||
patch(f'{_mod}.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch(f'{_mod}.NoteTakingToolkit') as mock_note_toolkit, \
|
||||
patch(f'{_mod}.SearchToolkit') as mock_search_toolkit, \
|
||||
patch(f'{_mod}.ToolkitMessageIntegration'):
|
||||
|
||||
# Mock all toolkit instances
|
||||
mock_human_toolkit.get_can_use_tools.return_value = []
|
||||
|
|
@ -63,4 +64,6 @@ def test_multi_modal_agent_creation(sample_chat_data):
|
|||
|
||||
# Check that it was called with multi-modal agent configuration
|
||||
call_args = mock_agent_model.call_args
|
||||
assert "multi_modal_agent" in str(call_args[0][0]) # agent_name (enum contains this value)
|
||||
assert "multi_modal_agent" in str(
|
||||
call_args[0][0]
|
||||
) # agent_name (enum contains this value)
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import question_confirm_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -31,7 +31,8 @@ def test_question_confirm_agent_creation(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.question_confirm.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.question_confirm'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'):
|
||||
mock_agent = MagicMock()
|
||||
mock_agent_model.return_value = mock_agent
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.factory import task_summary_agent
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -31,7 +31,8 @@ def test_task_summary_agent_creation(sample_chat_data):
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
with patch('app.agent.factory.task_summary.agent_model') as mock_agent_model, \
|
||||
_mod = 'app.agent.factory.task_summary'
|
||||
with patch(f'{_mod}.agent_model') as mock_agent_model, \
|
||||
patch('asyncio.create_task'):
|
||||
mock_agent = MagicMock()
|
||||
mock_agent_model.return_value = mock_agent
|
||||
|
|
|
|||
|
|
@ -14,12 +14,12 @@
|
|||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.agent_model import agent_model
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -37,11 +37,11 @@ class TestAgentFactoryFunctions:
|
|||
mock_task_lock = MagicMock()
|
||||
task_locks[options.task_id] = mock_task_lock
|
||||
|
||||
agent_model_mod = sys.modules['app.agent.agent_model']
|
||||
with patch.object(agent_model_mod, 'ListenChatAgent') as mock_listen_agent, \
|
||||
patch.object(agent_model_mod, 'ModelFactory') as mock_model_factory, \
|
||||
patch.object(agent_model_mod, 'get_task_lock', return_value=mock_task_lock), \
|
||||
patch('asyncio.create_task') as mock_create_task:
|
||||
_m = sys.modules['app.agent.agent_model']
|
||||
with patch.object(_m, 'ListenChatAgent') as mock_listen_agent, \
|
||||
patch.object(_m, 'ModelFactory') as mock_model_factory, \
|
||||
patch.object(_m, 'get_task_lock', return_value=mock_task_lock), \
|
||||
patch('asyncio.create_task'):
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_listen_agent.return_value = mock_agent
|
||||
|
|
@ -84,11 +84,11 @@ class TestAgentIntegration:
|
|||
task_locks[api_task_id] = mock_task_lock
|
||||
|
||||
# Create agent
|
||||
agent_model_mod = sys.modules['app.agent.agent_model']
|
||||
with patch.object(agent_model_mod, 'ModelFactory') as mock_model_factory, \
|
||||
patch.object(agent_model_mod, '_schedule_async_task'), \
|
||||
patch.object(agent_model_mod, 'ListenChatAgent') as mock_listen_agent, \
|
||||
patch.object(agent_model_mod, 'get_task_lock', return_value=mock_task_lock):
|
||||
_m = sys.modules['app.agent.agent_model']
|
||||
with patch.object(_m, 'ModelFactory') as mock_model_factory, \
|
||||
patch.object(_m, '_schedule_async_task'), \
|
||||
patch.object(_m, 'ListenChatAgent') as mock_listen_agent, \
|
||||
patch.object(_m, 'get_task_lock', return_value=mock_task_lock):
|
||||
mock_model = MagicMock()
|
||||
mock_model_factory.return_value = mock_model
|
||||
|
||||
|
|
@ -96,7 +96,9 @@ class TestAgentIntegration:
|
|||
mock_agent_instance.api_task_id = api_task_id
|
||||
mock_listen_agent.return_value = mock_agent_instance
|
||||
|
||||
agent = agent_model("IntegrationAgent", "Test system prompt", options, [])
|
||||
agent = agent_model(
|
||||
"IntegrationAgent", "Test system prompt", options, []
|
||||
)
|
||||
|
||||
assert agent is mock_agent_instance
|
||||
assert agent.api_task_id == api_task_id
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
from camel.agents import ChatAgent
|
||||
from camel.agents._types import ToolCallRequest
|
||||
from camel.messages import BaseMessage
|
||||
|
|
@ -26,6 +26,7 @@ from camel.types.agents import ToolCallingRecord
|
|||
from app.agent.listen_chat_agent import ListenChatAgent
|
||||
from app.model.chat import Chat
|
||||
|
||||
_LCA = 'app.agent.listen_chat_agent'
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
|
@ -38,7 +39,7 @@ class TestListenChatAgent:
|
|||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock') as mock_get_lock, \
|
||||
with patch(f'{_LCA}.get_task_lock') as mock_get_lock, \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model:
|
||||
mock_task_lock = MagicMock()
|
||||
mock_get_lock.return_value = mock_task_lock
|
||||
|
|
@ -68,9 +69,9 @@ class TestListenChatAgent:
|
|||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model, \
|
||||
patch('asyncio.create_task') as mock_create_task:
|
||||
patch('asyncio.create_task'):
|
||||
|
||||
# Mock the model backend creation
|
||||
mock_backend = MagicMock()
|
||||
|
|
@ -80,9 +81,7 @@ class TestListenChatAgent:
|
|||
mock_create_model.return_value = mock_backend
|
||||
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
agent.process_task_id = "test_process_task"
|
||||
|
||||
|
|
@ -92,25 +91,31 @@ class TestListenChatAgent:
|
|||
mock_response.msg.content = "Test response content"
|
||||
mock_response.info = {"usage": {"total_tokens": 100}}
|
||||
|
||||
with patch.object(ChatAgent, 'step', return_value=mock_response) as mock_parent_step:
|
||||
with patch.object(
|
||||
ChatAgent, 'step', return_value=mock_response
|
||||
) as mock_parent_step:
|
||||
result = agent.step("Test input message")
|
||||
|
||||
assert result is mock_response
|
||||
# Check that step was called with the input message (don't assert on response_format param)
|
||||
# Check that step was called with
|
||||
# the input message (don't assert
|
||||
# on response_format param)
|
||||
mock_parent_step.assert_called_once()
|
||||
args, kwargs = mock_parent_step.call_args
|
||||
assert args[0] == "Test input message"
|
||||
# Should queue activation notification
|
||||
mock_task_lock.put_queue.assert_called()
|
||||
|
||||
def test_listen_chat_agent_step_with_base_message_input(self, mock_task_lock):
|
||||
def test_listen_chat_agent_step_with_base_message_input(
|
||||
self, mock_task_lock
|
||||
):
|
||||
"""Test ListenChatAgent step method with BaseMessage input."""
|
||||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model, \
|
||||
patch('asyncio.create_task') as mock_create_task:
|
||||
patch('asyncio.create_task'):
|
||||
|
||||
# Mock the model backend creation
|
||||
mock_backend = MagicMock()
|
||||
|
|
@ -120,9 +125,7 @@ class TestListenChatAgent:
|
|||
mock_create_model.return_value = mock_backend
|
||||
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
agent.agent_id = "test_agent_456"
|
||||
agent.process_task_id = "test_process_task"
|
||||
|
|
@ -137,18 +140,24 @@ class TestListenChatAgent:
|
|||
mock_response.msg.content = "Test response content"
|
||||
mock_response.info = {"usage": {"total_tokens": 100}}
|
||||
|
||||
with patch.object(ChatAgent, 'step', return_value=mock_response) as mock_parent_step:
|
||||
with patch.object(
|
||||
ChatAgent, 'step', return_value=mock_response
|
||||
) as mock_parent_step:
|
||||
result = agent.step(mock_message)
|
||||
|
||||
assert result is mock_response
|
||||
# Check that step was called with the mock message (don't assert on response_format param)
|
||||
# Check that step was called with
|
||||
# the mock message (don't assert
|
||||
# on response_format param)
|
||||
mock_parent_step.assert_called_once()
|
||||
args, kwargs = mock_parent_step.call_args
|
||||
assert args[0] is mock_message
|
||||
|
||||
# Should queue activation with message content
|
||||
mock_task_lock.put_queue.assert_called()
|
||||
# Just verify put_queue was called - don't check internal data structure details
|
||||
# Just verify put_queue was called -
|
||||
# don't check internal data
|
||||
# structure details
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_chat_agent_astep(self, mock_task_lock):
|
||||
|
|
@ -156,9 +165,9 @@ class TestListenChatAgent:
|
|||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model, \
|
||||
patch('asyncio.create_task') as mock_create_task:
|
||||
patch('asyncio.create_task'):
|
||||
|
||||
# Mock the model backend creation
|
||||
mock_backend = MagicMock()
|
||||
|
|
@ -168,9 +177,7 @@ class TestListenChatAgent:
|
|||
mock_create_model.return_value = mock_backend
|
||||
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
agent.process_task_id = "test_process_task"
|
||||
|
||||
|
|
@ -180,11 +187,15 @@ class TestListenChatAgent:
|
|||
mock_response.msg.content = "Test response message"
|
||||
mock_response.info = {"usage": {"total_tokens": 100}}
|
||||
|
||||
with patch.object(ChatAgent, 'astep', return_value=mock_response) as mock_parent_astep:
|
||||
with patch.object(
|
||||
ChatAgent, 'astep', return_value=mock_response
|
||||
) as mock_parent_astep:
|
||||
result = await agent.astep("Test async input")
|
||||
|
||||
assert result is mock_response
|
||||
# Check that astep was called with the input message (don't assert on response_format param)
|
||||
# Check that astep was called with
|
||||
# the input message (don't assert
|
||||
# on response_format param)
|
||||
mock_parent_astep.assert_called_once()
|
||||
args, kwargs = mock_parent_astep.call_args
|
||||
assert args[0] == "Test async input"
|
||||
|
|
@ -197,9 +208,9 @@ class TestListenChatAgent:
|
|||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model, \
|
||||
patch('asyncio.create_task') as mock_create_task:
|
||||
patch('asyncio.create_task'):
|
||||
|
||||
# Mock the model backend creation
|
||||
mock_backend = MagicMock()
|
||||
|
|
@ -209,9 +220,7 @@ class TestListenChatAgent:
|
|||
mock_create_model.return_value = mock_backend
|
||||
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
|
||||
# Create a mock tool and add it to _internal_tools
|
||||
|
|
@ -231,13 +240,16 @@ class TestListenChatAgent:
|
|||
# Mock tool calling record
|
||||
mock_record = MagicMock(spec=ToolCallingRecord)
|
||||
|
||||
with patch.object(agent, '_record_tool_calling', return_value=mock_record) as mock_record_func:
|
||||
with patch.object(
|
||||
agent, '_record_tool_calling', return_value=mock_record
|
||||
) as mock_record_func:
|
||||
result = agent._execute_tool(tool_call_request)
|
||||
|
||||
assert result is mock_record
|
||||
mock_record_func.assert_called_once()
|
||||
|
||||
# Should queue toolkit activation and deactivation notifications
|
||||
# Should queue toolkit activation
|
||||
# and deactivation notifications
|
||||
assert mock_task_lock.put_queue.call_count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -246,7 +258,7 @@ class TestListenChatAgent:
|
|||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model:
|
||||
|
||||
# Mock the model backend creation
|
||||
|
|
@ -257,9 +269,7 @@ class TestListenChatAgent:
|
|||
mock_create_model.return_value = mock_backend
|
||||
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
|
||||
# Create a mock tool and add it to _internal_tools
|
||||
|
|
@ -277,13 +287,16 @@ class TestListenChatAgent:
|
|||
|
||||
mock_record = MagicMock(spec=ToolCallingRecord)
|
||||
|
||||
with patch.object(agent, '_record_tool_calling', return_value=mock_record) as mock_record_func:
|
||||
with patch.object(
|
||||
agent, '_record_tool_calling', return_value=mock_record
|
||||
) as mock_record_func:
|
||||
result = await agent._aexecute_tool(tool_call_request)
|
||||
|
||||
assert result is mock_record
|
||||
mock_record_func.assert_called_once()
|
||||
|
||||
# Should queue toolkit activation and deactivation notifications
|
||||
# Should queue toolkit activation
|
||||
# and deactivation notifications
|
||||
assert mock_task_lock.put_queue.call_count >= 2
|
||||
|
||||
def test_listen_chat_agent_clone(self, mock_task_lock):
|
||||
|
|
@ -291,7 +304,7 @@ class TestListenChatAgent:
|
|||
api_task_id = "test_api_task_123"
|
||||
agent_name = "TestAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model:
|
||||
|
||||
# Mock the model backend creation
|
||||
|
|
@ -299,20 +312,21 @@ class TestListenChatAgent:
|
|||
mock_backend.model_type = "gpt-4"
|
||||
mock_backend.current_model = MagicMock()
|
||||
mock_backend.current_model.model_type = "gpt-4"
|
||||
mock_backend.models = "gpt-4" # String instead of list to avoid list processing
|
||||
# String instead of list to avoid
|
||||
# list processing
|
||||
mock_backend.models = "gpt-4"
|
||||
mock_backend.scheduling_strategy = MagicMock()
|
||||
mock_backend.scheduling_strategy.__name__ = "round_robin"
|
||||
mock_create_model.return_value = mock_backend
|
||||
|
||||
# Mock the clone process by patching ListenChatAgent constructor for clone
|
||||
# Mock the clone process by patching
|
||||
# ListenChatAgent constructor for clone
|
||||
cloned_agent = MagicMock()
|
||||
cloned_agent.process_task_id = "test_process_task"
|
||||
|
||||
# First create the initial agent
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
|
||||
# Set up necessary attributes for cloning
|
||||
|
|
@ -333,8 +347,13 @@ class TestListenChatAgent:
|
|||
agent.prune_tool_calls_from_memory = False
|
||||
|
||||
# Now mock the constructor for the clone call
|
||||
with patch('app.agent.listen_chat_agent.ListenChatAgent', return_value=cloned_agent) as mock_clone_constructor, \
|
||||
patch.object(agent, '_clone_tools', return_value=([], [])):
|
||||
with patch(
|
||||
f'{_LCA}.ListenChatAgent',
|
||||
return_value=cloned_agent
|
||||
) as mock_clone_constructor, \
|
||||
patch.object(
|
||||
agent, '_clone_tools',
|
||||
return_value=([], [])):
|
||||
|
||||
result = agent.clone(with_memory=True)
|
||||
|
||||
|
|
@ -350,7 +369,7 @@ class TestListenChatAgent:
|
|||
mock_tool = MagicMock(spec=FunctionTool)
|
||||
tools = [mock_tool]
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model:
|
||||
|
||||
# Mock the model backend creation
|
||||
|
|
@ -372,7 +391,8 @@ class TestListenChatAgent:
|
|||
|
||||
assert len(agent.function_list) == 1 # Should have the tool
|
||||
# Check that tools were passed to parent class
|
||||
mock_task_lock.put_queue.assert_not_called() # No immediate action for tool setup
|
||||
mock_task_lock.put_queue.assert_not_called(
|
||||
) # No immediate action for tool setup
|
||||
|
||||
def test_listen_chat_agent_with_pause_event(self, mock_task_lock):
|
||||
"""Test ListenChatAgent with pause event."""
|
||||
|
|
@ -381,7 +401,7 @@ class TestListenChatAgent:
|
|||
|
||||
pause_event = asyncio.Event()
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', return_value=mock_task_lock), \
|
||||
with patch(f'{_LCA}.get_task_lock', return_value=mock_task_lock), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model:
|
||||
|
||||
# Mock the model backend creation
|
||||
|
|
@ -405,12 +425,19 @@ class TestListenChatAgent:
|
|||
api_task_id = "error_test_123"
|
||||
agent_name = "ErrorAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock') as mock_get_lock, \
|
||||
patch('camel.models.ModelFactory.create', side_effect=ValueError("Invalid model")):
|
||||
with patch(f'{_LCA}.get_task_lock') as mock_get_lock, \
|
||||
patch(
|
||||
'camel.models.ModelFactory.create',
|
||||
side_effect=ValueError(
|
||||
"Invalid model"
|
||||
)
|
||||
):
|
||||
mock_task_lock = MagicMock()
|
||||
mock_get_lock.return_value = mock_task_lock
|
||||
|
||||
# Try to create agent with invalid model which should raise an error through ModelFactory
|
||||
# Try to create agent with invalid
|
||||
# model which should raise an error
|
||||
# through ModelFactory
|
||||
with pytest.raises(ValueError):
|
||||
ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
|
|
@ -423,8 +450,14 @@ class TestListenChatAgent:
|
|||
api_task_id = "error_test_123"
|
||||
agent_name = "ErrorAgent"
|
||||
|
||||
with patch('app.agent.listen_chat_agent.get_task_lock', side_effect=Exception("Task lock not found")), \
|
||||
patch('camel.models.ModelFactory.create') as mock_create_model:
|
||||
with patch(
|
||||
f'{_LCA}.get_task_lock',
|
||||
side_effect=Exception(
|
||||
"Task lock not found"
|
||||
)
|
||||
), \
|
||||
patch(
|
||||
'camel.models.ModelFactory.create') as mock_create_model:
|
||||
|
||||
# Mock the model backend creation
|
||||
mock_backend = MagicMock()
|
||||
|
|
@ -434,9 +467,7 @@ class TestListenChatAgent:
|
|||
mock_create_model.return_value = mock_backend
|
||||
|
||||
agent = ListenChatAgent(
|
||||
api_task_id=api_task_id,
|
||||
agent_name=agent_name,
|
||||
model="gpt-4"
|
||||
api_task_id=api_task_id, agent_name=agent_name, model="gpt-4"
|
||||
)
|
||||
|
||||
# Should handle task lock errors gracefully
|
||||
|
|
@ -451,7 +482,7 @@ class TestAgentWithLLM:
|
|||
@pytest.mark.asyncio
|
||||
async def test_agent_with_real_model(self, sample_chat_data):
|
||||
"""Test agent creation with real LLM model."""
|
||||
options = Chat(**sample_chat_data)
|
||||
Chat(**sample_chat_data)
|
||||
|
||||
# This test would use real model backends
|
||||
# Marked as model_backend test for selective execution
|
||||
|
|
@ -460,7 +491,7 @@ class TestAgentWithLLM:
|
|||
@pytest.mark.very_slow
|
||||
async def test_full_agent_conversation_workflow(self, sample_chat_data):
|
||||
"""Test complete agent conversation workflow (very slow test)."""
|
||||
options = Chat(**sample_chat_data)
|
||||
Chat(**sample_chat_data)
|
||||
|
||||
# This test would run complete conversation workflow
|
||||
# Marked as very_slow for execution only in full test mode
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.tools import get_mcp_tools, get_toolkits
|
||||
from app.model.chat import McpServers
|
||||
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
|
|
@ -32,21 +32,28 @@ class TestToolkitFunctions:
|
|||
agent_name = "TestAgent"
|
||||
api_task_id = "test_task_123"
|
||||
|
||||
with patch('app.agent.tools.SearchToolkit') as mock_search_toolkit, \
|
||||
patch('app.agent.tools.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch('app.agent.tools.FileToolkit') as mock_file_toolkit:
|
||||
_mod = 'app.agent.tools'
|
||||
with patch(f'{_mod}.SearchToolkit') as mock_search_toolkit, \
|
||||
patch(f'{_mod}.TerminalToolkit') as mock_terminal_toolkit, \
|
||||
patch(f'{_mod}.FileToolkit') as mock_file_toolkit:
|
||||
|
||||
# Mock toolkit instances - these should return tools directly from get_can_use_tools
|
||||
# Mock toolkit instances - these should
|
||||
# return tools directly
|
||||
# from get_can_use_tools
|
||||
mock_search_instance = MagicMock()
|
||||
mock_search_instance.agent_name = agent_name
|
||||
mock_search_tools = [MagicMock(), MagicMock()]
|
||||
mock_search_instance.get_can_use_tools.return_value = mock_search_tools
|
||||
mock_search_instance\
|
||||
.get_can_use_tools\
|
||||
.return_value = mock_search_tools
|
||||
mock_search_toolkit.return_value = mock_search_instance
|
||||
|
||||
mock_terminal_instance = MagicMock()
|
||||
mock_terminal_instance.agent_name = agent_name
|
||||
mock_terminal_tools = [MagicMock()]
|
||||
mock_terminal_instance.get_can_use_tools.return_value = mock_terminal_tools
|
||||
mock_terminal_instance\
|
||||
.get_can_use_tools\
|
||||
.return_value = mock_terminal_tools
|
||||
mock_terminal_toolkit.return_value = mock_terminal_instance
|
||||
|
||||
mock_file_instance = MagicMock()
|
||||
|
|
@ -55,16 +62,26 @@ class TestToolkitFunctions:
|
|||
mock_file_instance.get_can_use_tools.return_value = mock_file_tools
|
||||
mock_file_toolkit.return_value = mock_file_instance
|
||||
|
||||
# Mock the toolkit classes to have get_can_use_tools class method that returns the mock tools
|
||||
mock_search_toolkit.get_can_use_tools = MagicMock(return_value=mock_search_tools)
|
||||
mock_terminal_toolkit.get_can_use_tools = MagicMock(return_value=mock_terminal_tools)
|
||||
mock_file_toolkit.get_can_use_tools = MagicMock(return_value=mock_file_tools)
|
||||
# Mock the toolkit classes to have
|
||||
# get_can_use_tools class method
|
||||
# that returns the mock tools
|
||||
mock_search_toolkit.get_can_use_tools = MagicMock(
|
||||
return_value=mock_search_tools
|
||||
)
|
||||
mock_terminal_toolkit.get_can_use_tools = MagicMock(
|
||||
return_value=mock_terminal_tools
|
||||
)
|
||||
mock_file_toolkit.get_can_use_tools = MagicMock(
|
||||
return_value=mock_file_tools
|
||||
)
|
||||
|
||||
result = await get_toolkits(tools, agent_name, api_task_id)
|
||||
|
||||
# The result should contain tools from the toolkits that match
|
||||
assert isinstance(result, list)
|
||||
# Since get_toolkits filters by known toolkit names, only matching ones should be included
|
||||
# Since get_toolkits filters by known
|
||||
# toolkit names, only matching ones
|
||||
# should be included
|
||||
assert len(result) >= 0 # Should have some tools if any match
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -97,7 +114,10 @@ class TestToolkitFunctions:
|
|||
agent_name = "ErrorAgent"
|
||||
api_task_id = "error_test_123"
|
||||
|
||||
with patch('app.agent.tools.SearchToolkit', side_effect=Exception("Toolkit init failed")):
|
||||
with patch(
|
||||
'app.agent.tools.SearchToolkit',
|
||||
side_effect=Exception("Toolkit init failed")
|
||||
):
|
||||
# Should handle toolkit initialization errors
|
||||
result = await get_toolkits(tools, agent_name, api_task_id)
|
||||
# Should return what it can or empty list
|
||||
|
|
@ -154,7 +174,10 @@ class TestMcpTools:
|
|||
}
|
||||
}
|
||||
|
||||
with patch('app.agent.tools.MCPToolkit', side_effect=Exception("Connection failed")):
|
||||
with patch(
|
||||
'app.agent.tools.MCPToolkit',
|
||||
side_effect=Exception("Connection failed")
|
||||
):
|
||||
result = await get_mcp_tools(mcp_servers)
|
||||
assert result == []
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,14 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from app.utils.listen.toolkit_listen import (MAX_LENGTH, _format_args,
|
||||
_format_result, _truncate,
|
||||
listen_toolkit)
|
||||
|
||||
from app.utils.listen.toolkit_listen import (
|
||||
MAX_LENGTH,
|
||||
_format_args,
|
||||
_format_result,
|
||||
_truncate,
|
||||
listen_toolkit,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
|
@ -197,8 +202,10 @@ def test_listen_toolkit_sync_returns_result():
|
|||
mock_task_lock = MagicMock()
|
||||
mock_task_lock.put_queue = AsyncMock()
|
||||
|
||||
with patch("app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock):
|
||||
with patch(
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
):
|
||||
|
||||
@listen_toolkit()
|
||||
def test_method(self, arg1):
|
||||
|
|
@ -215,8 +222,10 @@ def test_listen_toolkit_sync_raises_exception():
|
|||
mock_task_lock = MagicMock()
|
||||
mock_task_lock.put_queue = AsyncMock()
|
||||
|
||||
with patch("app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock):
|
||||
with patch(
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
):
|
||||
|
||||
@listen_toolkit()
|
||||
def test_method(self):
|
||||
|
|
@ -265,8 +274,10 @@ async def test_listen_toolkit_async_returns_result():
|
|||
mock_task_lock = MagicMock()
|
||||
mock_task_lock.put_queue = AsyncMock()
|
||||
|
||||
with patch("app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock):
|
||||
with patch(
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
):
|
||||
|
||||
@listen_toolkit()
|
||||
async def test_method(self, arg1):
|
||||
|
|
@ -284,8 +295,10 @@ async def test_listen_toolkit_async_raises_exception():
|
|||
mock_task_lock = MagicMock()
|
||||
mock_task_lock.put_queue = AsyncMock()
|
||||
|
||||
with patch("app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock):
|
||||
with patch(
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
):
|
||||
|
||||
@listen_toolkit()
|
||||
async def test_method(self):
|
||||
|
|
@ -322,8 +335,10 @@ def test_listen_toolkit_with_custom_inputs_formatter():
|
|||
custom_formatter_called.append((arg1, arg2))
|
||||
return f"custom: {arg1}, {arg2}"
|
||||
|
||||
with patch("app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock):
|
||||
with patch(
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
):
|
||||
|
||||
@listen_toolkit(inputs=custom_inputs)
|
||||
def test_method(self, arg1, arg2):
|
||||
|
|
@ -346,8 +361,8 @@ def test_listen_toolkit_with_custom_return_msg_formatter():
|
|||
return f"formatted: {res}"
|
||||
|
||||
with patch(
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
"app.utils.listen.toolkit_listen.get_task_lock",
|
||||
return_value=mock_task_lock
|
||||
), patch("app.utils.listen.toolkit_listen._format_result") as mock_format:
|
||||
mock_format.return_value = "formatted: test_result"
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ def pytest_addoption(parser: pytest.Parser) -> None:
|
|||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config, items: list[pytest.Item]
|
||||
) -> None:
|
||||
if config.getoption("--llm-test-only"):
|
||||
skip_fast = pytest.mark.skip(reason="Skipped for llm test only")
|
||||
for item in items:
|
||||
|
|
@ -118,13 +120,13 @@ def mock_openai_api():
|
|||
with patch("openai.OpenAI") as mock_openai:
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
|
||||
# Mock chat completion
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.usage.total_tokens = 100
|
||||
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
yield mock_client
|
||||
|
||||
|
|
@ -148,19 +150,21 @@ def mock_camel_agent():
|
|||
agent = MagicMock() # Use MagicMock instead of AsyncMock
|
||||
agent.role_name = "test_agent"
|
||||
agent.agent_id = "test_agent_123"
|
||||
|
||||
|
||||
# Make step method return proper structure with both .msg and .msgs[0]
|
||||
mock_response = MagicMock()
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Test agent response"
|
||||
mock_message.parsed = None
|
||||
|
||||
|
||||
mock_response.msg = mock_message
|
||||
mock_response.msgs = [mock_message] # msgs[0] should point to the same content
|
||||
mock_response.msgs = [
|
||||
mock_message
|
||||
] # msgs[0] should point to the same content
|
||||
mock_response.info = {"usage": {"total_tokens": 50}}
|
||||
|
||||
|
||||
agent.step.return_value = mock_response
|
||||
|
||||
|
||||
agent.astep = AsyncMock()
|
||||
agent.astep.return_value.msg.content = "Test async agent response"
|
||||
agent.astep.return_value.msg.parsed = None
|
||||
|
|
@ -196,17 +200,18 @@ def mock_request():
|
|||
def app() -> FastAPI:
|
||||
"""Create FastAPI test application."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.controller.chat_controller import router as chat_router
|
||||
from app.controller.model_controller import router as model_router
|
||||
from app.controller.task_controller import router as task_router
|
||||
from app.controller.tool_controller import router as tool_router
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(chat_router)
|
||||
app.include_router(model_router)
|
||||
app.include_router(task_router)
|
||||
app.include_router(tool_router)
|
||||
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -253,14 +258,14 @@ def mock_worker_with_agent():
|
|||
worker.agent_id = "test_agent_123"
|
||||
worker.astep = AsyncMock()
|
||||
worker.step = MagicMock()
|
||||
|
||||
|
||||
# Mock response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.msg = MagicMock()
|
||||
mock_response.msg.content = "Test worker response"
|
||||
mock_response.msg.parsed = {"result": "test"}
|
||||
mock_response.info = {"usage": {"total_tokens": 50}}
|
||||
|
||||
|
||||
worker.astep.return_value = mock_response
|
||||
worker.step.return_value = mock_response
|
||||
return worker
|
||||
|
|
@ -285,7 +290,7 @@ def mock_environment_variables():
|
|||
"file_save_path": "/tmp/test_files",
|
||||
"browser_port": "8080"
|
||||
}
|
||||
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
yield env_vars
|
||||
|
||||
|
|
@ -327,15 +332,15 @@ async def async_mock_agent() -> AsyncGenerator[AsyncMock, None]:
|
|||
agent = AsyncMock()
|
||||
agent.role_name = "async_test_agent"
|
||||
agent.agent_id = "async_test_agent_456"
|
||||
|
||||
|
||||
# Mock async step method
|
||||
mock_response = MagicMock()
|
||||
mock_response.msg.content = "Async test response"
|
||||
mock_response.msg.parsed = {"test": "data"}
|
||||
mock_response.info = {"usage": {"total_tokens": 75}}
|
||||
|
||||
|
||||
agent.astep.return_value = mock_response
|
||||
|
||||
|
||||
yield agent
|
||||
|
||||
|
||||
|
|
@ -349,7 +354,8 @@ def pytest_configure(config):
|
|||
"markers", "model_backend: mark test as requiring model backend"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "very_slow: mark test as very slow (requires full test mode)"
|
||||
"markers",
|
||||
"very_slow: mark test as very slow (requires full test mode)"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "optional: mark test as optional (skipped in fast mode)"
|
||||
|
|
@ -357,6 +363,4 @@ def pytest_configure(config):
|
|||
config.addinivalue_line(
|
||||
"markers", "integration: mark test as integration test"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "unit: mark test as unit test"
|
||||
)
|
||||
config.addinivalue_line("markers", "unit: mark test as unit test")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue