Merge branch 'main' into fix/markdown-ordered-list-numbering

This commit is contained in:
Wendong-Fan 2025-11-10 17:50:42 +08:00 committed by GitHub
commit cd30c7d840
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
277 changed files with 27748 additions and 8413 deletions

View file

@ -5,4 +5,26 @@ VITE_PROXY_URL=https://dev.eigent.ai
VITE_USE_LOCAL_PROXY=false
# VITE_PROXY_URL=http://localhost:3001
# VITE_USE_LOCAL_PROXY=true
# VITE_USE_LOCAL_PROXY=true
TRACEROOT_TOKEN=your_traceroot_token_here
TRACEROOT_SERVICE_NAME=eigent
TRACEROOT_GITHUB_OWNER=eigent
TRACEROOT_GITHUB_REPO_NAME=eigent-ai
TRACEROOT_GITHUB_COMMIT_HASH=main
TRACEROOT_ENABLE_SPAN_CLOUD_EXPORT=true
TRACEROOT_ENABLE_LOG_CLOUD_EXPORT=true
TRACEROOT_ENABLE_SPAN_CONSOLE_EXPORT=false
TRACEROOT_ENABLE_LOG_CONSOLE_EXPORT=true
TRACEROOT_TRACER_VERBOSE=false
TRACEROOT_LOGGER_VERBOSE=false

8
.gitignore vendored
View file

@ -46,3 +46,11 @@ public/
# Testing
coverage/
.traceroot-config.yaml
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python

View file

@ -2,6 +2,8 @@
// See http://go.microsoft.com/fwlink/?LinkId=827846
// for the documentation about the extensions.json format
"recommendations": [
"mrmlnc.vscode-json5"
"mrmlnc.vscode-json5",
"ms-python.python",
"ms-python.debugpy"
]
}

16
.vscode/launch.json vendored
View file

@ -50,5 +50,21 @@
"http://127.0.0.1:7777/**"
]
},
{
"name": "Debug Python Backend (Attach)",
"type": "debugpy",
"request": "attach",
"connect": {
"host": "localhost",
"port": 5678
},
"pathMappings": [
{
"localRoot": "${workspaceFolder}/backend",
"remoteRoot": "."
}
],
"justMyCode": false
}
]
}

View file

@ -1,4 +1,4 @@
from app.utils import traceroot_wrapper as traceroot
from utils import traceroot_wrapper as traceroot
import importlib.util
import os
from pathlib import Path

View file

@ -3,13 +3,12 @@ import os
import re
from pathlib import Path
from dotenv import load_dotenv
from fastapi import APIRouter, Request, Response
from fastapi import APIRouter, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from loguru import logger
from app.utils import traceroot_wrapper as traceroot
from utils import traceroot_wrapper as traceroot
from app.component import code
from app.exception.exception import UserException
from app.model.chat import Chat, HumanReply, McpServers, Status, SupplementChat
from app.model.chat import Chat, HumanReply, McpServers, Status, SupplementChat, AddTaskRequest
from app.service.chat_service import step_solve
from app.service.task import (
Action,
@ -17,13 +16,18 @@ from app.service.task import (
ActionInstallMcpData,
ActionStopData,
ActionSupplementData,
create_task_lock,
ActionAddTaskData,
ActionRemoveTaskData,
ActionSkipTaskData,
get_or_create_task_lock,
get_task_lock,
)
from app.component.environment import set_user_env_path
from app.utils.workforce import Workforce
from camel.tasks.task import Task
router = APIRouter(tags=["chat"])
router = APIRouter()
# Create traceroot logger for chat controller
chat_logger = traceroot.get_logger('chat_controller')
@ -32,55 +36,110 @@ chat_logger = traceroot.get_logger('chat_controller')
@router.post("/chat", name="start chat")
@traceroot.trace()
async def post(data: Chat, request: Request):
chat_logger.info(f"Starting new chat session for task_id: {data.task_id}, user: {data.email}")
task_lock = create_task_lock(data.task_id)
chat_logger.info("Starting new chat session", extra={"project_id": data.project_id, "task_id": data.task_id, "user": data.email})
task_lock = get_or_create_task_lock(data.project_id)
# Set user-specific environment path for this thread
set_user_env_path(data.env_path)
load_dotenv(dotenv_path=data.env_path)
# logger.debug(f"start chat: {data.model_dump_json()}")
os.environ["file_save_path"] = data.file_save_path()
os.environ["browser_port"] = str(data.browser_port)
os.environ["OPENAI_API_KEY"] = data.api_key
os.environ["OPENAI_API_BASE_URL"] = data.api_url or "https://api.openai.com/v1"
os.environ["CAMEL_MODEL_LOG_ENABLED"] = "true"
email = re.sub(r'[\\/*?:"<>|\s]', "_", data.email.split("@")[0]).strip(".")
camel_log = Path.home() / ".eigent" / email / ("task_" + data.task_id) / "camel_logs"
# Set user-specific search engine configuration if provided
if data.search_config:
for key, value in data.search_config.items():
if value: # Only set non-empty values
os.environ[key] = value
chat_logger.info(f"Set search config: {key}", extra={"project_id": data.project_id})
email_sanitized = re.sub(r'[\\/*?:"<>|\s]', "_", data.email.split("@")[0]).strip(".")
camel_log = Path.home() / ".eigent" / email_sanitized / ("project_" + data.project_id) / ("task_" + data.task_id) / "camel_logs"
camel_log.mkdir(parents=True, exist_ok=True)
os.environ["CAMEL_LOG_DIR"] = str(camel_log)
if data.is_cloud():
os.environ["cloud_api_key"] = data.api_key
chat_logger.info(f"Chat session initialized, starting streaming response for task_id: {data.task_id}")
# Put initial action in queue to start processing
await task_lock.put_queue(ActionImproveData(data=data.question))
chat_logger.info("Chat session initialized, starting streaming response", extra={"project_id": data.project_id, "task_id": data.task_id, "log_dir": str(camel_log)})
return StreamingResponse(step_solve(data, request, task_lock), media_type="text/event-stream")
@router.post("/chat/{id}", name="improve chat")
@traceroot.trace()
def improve(id: str, data: SupplementChat):
chat_logger.info(f"Improving chat for task_id: {id} with question: {data.question}")
chat_logger.info("Chat improvement requested", extra={"task_id": id, "question_length": len(data.question)})
task_lock = get_task_lock(id)
# Allow continuing conversation even after task is done
# This supports multi-turn conversation after complex task completion
if task_lock.status == Status.done:
raise UserException(code.error, "Task was done")
# Reset status to allow processing new messages
task_lock.status = Status.confirming
# Clear any existing background tasks since workforce was stopped
if hasattr(task_lock, 'background_tasks'):
task_lock.background_tasks.clear()
# Note: conversation_history and last_task_result are preserved
# Log context preservation
if hasattr(task_lock, 'conversation_history'):
chat_logger.info(f"[CONTEXT] Preserved {len(task_lock.conversation_history)} conversation entries")
if hasattr(task_lock, 'last_task_result'):
chat_logger.info(f"[CONTEXT] Preserved task result: {len(task_lock.last_task_result)} chars")
# Update file save path if task_id is provided
new_folder_path = None
if data.task_id:
try:
# Get current environment values needed to construct new path
current_email = None
# Extract email from current file_save_path if available
current_file_save_path = os.environ.get("file_save_path", "")
if current_file_save_path:
path_parts = Path(current_file_save_path).parts
if len(path_parts) >= 3 and "eigent" in path_parts:
eigent_index = path_parts.index("eigent")
if eigent_index + 1 < len(path_parts):
current_email = path_parts[eigent_index + 1]
# If we have the necessary information, update the file_save_path
if current_email and id:
# Create new path using the existing pattern: email/project_{project_id}/task_{task_id}
new_folder_path = Path.home() / "eigent" / current_email / f"project_{id}" / f"task_{data.task_id}"
new_folder_path.mkdir(parents=True, exist_ok=True)
os.environ["file_save_path"] = str(new_folder_path)
chat_logger.info(f"Updated file_save_path to: {new_folder_path}")
# Store the new folder path in task_lock for potential cleanup and persistence
task_lock.new_folder_path = new_folder_path
else:
chat_logger.warning(f"Could not update file_save_path - email: {current_email}, project_id: {id}")
except Exception as e:
chat_logger.error(f"Error updating file path for project_id: {id}, task_id: {data.task_id}: {e}")
asyncio.run(task_lock.put_queue(ActionImproveData(data=data.question)))
chat_logger.info(f"Improvement request queued for task_id: {id}")
chat_logger.info("Improvement request queued with preserved context", extra={"project_id": id})
return Response(status_code=201)
@router.put("/chat/{id}", name="supplement task")
@traceroot.trace()
def supplement(id: str, data: SupplementChat):
chat_logger.info(f"Supplementing task_id: {id} with additional data")
chat_logger.info("Chat supplement requested", extra={"task_id": id})
task_lock = get_task_lock(id)
if task_lock.status != Status.done:
raise UserException(code.error, "Please wait task done")
asyncio.run(task_lock.put_queue(ActionSupplementData(data=data)))
chat_logger.info(f"Supplement data queued for task_id: {id}")
chat_logger.debug("Supplement data queued", extra={"task_id": id})
return Response(status_code=201)
@ -88,28 +147,92 @@ def supplement(id: str, data: SupplementChat):
@traceroot.trace()
def stop(id: str):
"""stop the task"""
chat_logger.warning(f"Stopping chat session for task_id: {id}")
chat_logger.warning("Stopping chat session", extra={"task_id": id})
task_lock = get_task_lock(id)
asyncio.run(task_lock.put_queue(ActionStopData(action=Action.stop)))
chat_logger.info(f"Stop signal sent for task_id: {id}")
chat_logger.info("Chat stop signal sent", extra={"task_id": id})
return Response(status_code=204)
@router.post("/chat/{id}/human-reply")
@traceroot.trace()
def human_reply(id: str, data: HumanReply):
chat_logger.info(f"Human reply received for task_id: {id}, agent: {data.agent}")
chat_logger.info("Human reply received", extra={"task_id": id, "reply_length": len(data.reply)})
task_lock = get_task_lock(id)
asyncio.run(task_lock.put_human_input(data.agent, data.reply))
chat_logger.info(f"Human reply processed for task_id: {id}")
chat_logger.debug("Human reply processed", extra={"task_id": id})
return Response(status_code=201)
@router.post("/chat/{id}/install-mcp")
@traceroot.trace()
def install_mcp(id: str, data: McpServers):
chat_logger.info(f"Installing MCP servers for task_id: {id}, servers count: {len(data.get('mcpServers', {}))}")
chat_logger.info("Installing MCP servers", extra={"task_id": id, "servers_count": len(data.get('mcpServers', {}))})
task_lock = get_task_lock(id)
asyncio.run(task_lock.put_queue(ActionInstallMcpData(action=Action.install_mcp, data=data)))
chat_logger.info(f"MCP installation queued for task_id: {id}")
chat_logger.info("MCP installation queued", extra={"task_id": id})
return Response(status_code=201)
@router.post("/chat/{id}/add-task", name="add task to workforce")
@traceroot.trace()
def add_task(id: str, data: AddTaskRequest):
"""Add a new task to the workforce"""
chat_logger.info(f"Adding task to workforce for task_id: {id}, content: {data.content[:100]}...")
task_lock = get_task_lock(id)
try:
# Queue the add task action
add_task_action = ActionAddTaskData(
content=data.content,
project_id=data.project_id,
task_id=data.task_id,
additional_info=data.additional_info,
insert_position=data.insert_position
)
asyncio.run(task_lock.put_queue(add_task_action))
return Response(status_code=201)
except Exception as e:
chat_logger.error(f"Error adding task for task_id: {id}: {e}")
raise UserException(code.error, f"Failed to add task: {str(e)}")
@router.delete("/chat/{project_id}/remove-task/{task_id}", name="remove task from workforce")
@traceroot.trace()
def remove_task(project_id: str, task_id: str):
"""Remove a task from the workforce"""
chat_logger.info(f"Removing task {task_id} from workforce for project_id: {project_id}")
task_lock = get_task_lock(project_id)
try:
# Queue the remove task action
remove_task_action = ActionRemoveTaskData(task_id=task_id, project_id=project_id)
asyncio.run(task_lock.put_queue(remove_task_action))
chat_logger.info(f"Task removal request queued for project_id: {project_id}, removing task: {task_id}")
return Response(status_code=204)
except Exception as e:
chat_logger.error(f"Error removing task {task_id} for project_id: {project_id}: {e}")
raise UserException(code.error, f"Failed to remove task: {str(e)}")
@router.post("/chat/{project_id}/skip-task", name="skip task in workforce")
@traceroot.trace()
def skip_task(project_id: str):
"""Skip a task in the workforce"""
chat_logger.info(f"Skipping task in workforce for project_id: {project_id}")
task_lock = get_task_lock(project_id)
try:
# Queue the skip task action
skip_task_action = ActionSkipTaskData(project_id=project_id)
asyncio.run(task_lock.put_queue(skip_task_action))
chat_logger.info(f"Task skip request queued for project_id: {project_id}")
return Response(status_code=201)
except Exception as e:
chat_logger.error(f"Error skipping task for project_id: {project_id}: {e}")
raise UserException(code.error, f"Failed to skip task: {str(e)}")

View file

@ -0,0 +1,16 @@
from fastapi import APIRouter
from pydantic import BaseModel
router = APIRouter(tags=["Health"])
class HealthResponse(BaseModel):
status: str
service: str
@router.get("/health", name="health check", response_model=HealthResponse)
async def health_check():
"""Health check endpoint for verifying backend is ready to accept requests."""
return HealthResponse(status="ok", service="eigent")

View file

@ -1,11 +1,14 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.component.model_validation import create_agent
from camel.types import ModelType
from app.component.error_format import normalize_error_to_openai_format
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("model_controller")
router = APIRouter(tags=["model"])
router = APIRouter()
class ValidateModelRequest(BaseModel):
@ -26,33 +29,46 @@ class ValidateModelResponse(BaseModel):
@router.post("/model/validate")
@traceroot.trace()
async def validate_model(request: ValidateModelRequest):
try:
# API key validation
if request.api_key is not None and str(request.api_key).strip() == "":
return ValidateModelResponse(
is_valid=False,
is_tool_calls=False,
message="Invalid key. Validation failed.",
error_code="invalid_api_key",
error={
"message": "Invalid key. Validation failed.",
"""Validate model configuration and tool call support."""
platform = request.model_platform
model_type = request.model_type
has_custom_url = request.url is not None
has_config = request.model_config_dict is not None
logger.info("Model validation started", extra={"platform": platform, "model_type": model_type, "has_url": has_custom_url, "has_config": has_config})
# API key validation
if request.api_key is not None and str(request.api_key).strip() == "":
logger.warning("Model validation failed: empty API key", extra={"platform": platform, "model_type": model_type})
raise HTTPException(
status_code=400,
detail={
"message": "Invalid key. Validation failed.",
"error_code": "invalid_api_key",
"error": {
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
},
)
}
)
try:
extra = request.extra_params or {}
logger.debug("Creating agent for validation", extra={"platform": platform, "model_type": model_type})
agent = create_agent(
request.model_platform,
request.model_type,
platform,
model_type,
api_key=request.api_key,
url=request.url,
model_config_dict=request.model_config_dict,
**extra,
)
logger.debug("Agent created, executing test step", extra={"platform": platform, "model_type": model_type})
response = agent.step(
input_message="""
Get the content of https://www.camel-ai.org,
@ -61,17 +77,23 @@ async def validate_model(request: ValidateModelRequest):
you must call the get_website_content tool only once.
"""
)
except Exception as e:
# Normalize error to OpenAI-style error structure
logger.error("Model validation failed", extra={"platform": platform, "model_type": model_type, "error": str(e)}, exc_info=True)
message, error_code, error_obj = normalize_error_to_openai_format(e)
return ValidateModelResponse(
is_valid=False,
is_tool_calls=False,
message=message,
error_code=error_code,
error=error_obj,
raise HTTPException(
status_code=400,
detail={
"message": message,
"error_code": error_code,
"error": error_obj,
}
)
# Check validation results
is_valid = bool(response)
is_tool_calls = False
@ -83,7 +105,7 @@ async def validate_model(request: ValidateModelRequest):
== "Tool execution completed successfully for https://www.camel-ai.org, Website Content: Welcome to CAMEL AI!"
)
return ValidateModelResponse(
result = ValidateModelResponse(
is_valid=is_valid,
is_tool_calls=is_tool_calls,
message="Validation Success"
@ -92,3 +114,7 @@ async def validate_model(request: ValidateModelRequest):
error_code=None,
error=None,
)
logger.info("Model validation completed", extra={"platform": platform, "model_type": model_type, "is_valid": is_valid, "is_tool_calls": is_tool_calls})
return result

View file

@ -1,7 +1,6 @@
from typing import Literal
from dotenv import load_dotenv
from fastapi import APIRouter, Response
from loguru import logger
from pydantic import BaseModel
from app.model.chat import NewAgent, UpdateData
from app.service.task import (
@ -16,24 +15,32 @@ from app.service.task import (
)
import asyncio
from app.component.environment import set_user_env_path
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("task_controller")
router = APIRouter(tags=["task"])
router = APIRouter()
@router.post("/task/{id}/start", name="start task")
@traceroot.trace()
def start(id: str):
task_lock = get_task_lock(id)
logger.debug(f"start task {id}")
logger.info("Starting task", extra={"task_id": id})
asyncio.run(task_lock.put_queue(ActionStartData(action=Action.start)))
logger.debug(f"start task {id} success")
logger.info("Task started successfully", extra={"task_id": id})
return Response(status_code=201)
@router.put("/task/{id}", name="update task")
@traceroot.trace()
def put(id: str, data: UpdateData):
logger.info("Updating task", extra={"task_id": id, "task_items_count": len(data.task)})
logger.debug("Update task data", extra={"task_id": id, "data": data.model_dump_json()})
task_lock = get_task_lock(id)
asyncio.run(task_lock.put_queue(ActionUpdateTaskData(action=Action.update_task, data=data)))
logger.info("Task updated successfully", extra={"task_id": id})
return Response(status_code=201)
@ -42,23 +49,33 @@ class TakeControl(BaseModel):
@router.put("/task/{id}/take-control", name="take control pause or resume")
@traceroot.trace()
def take_control(id: str, data: TakeControl):
logger.info("Task control action", extra={"task_id": id, "action": data.action})
task_lock = get_task_lock(id)
asyncio.run(task_lock.put_queue(ActionTakeControl(action=data.action)))
logger.info("Task control action completed", extra={"task_id": id, "action": data.action})
return Response(status_code=204)
@router.post("/task/{id}/add-agent", name="add new agent")
@traceroot.trace()
def add_agent(id: str, data: NewAgent):
logger.info("Adding new agent to task", extra={"task_id": id, "agent_name": data.name})
logger.debug("New agent data", extra={"task_id": id, "agent_data": data.model_dump_json()})
# Set user-specific environment path for this thread
set_user_env_path(data.env_path)
load_dotenv(dotenv_path=data.env_path)
asyncio.run(get_task_lock(id).put_queue(ActionNewAgent(**data.model_dump())))
logger.info("Agent added to task", extra={"task_id": id, "agent_name": data.name})
return Response(status_code=204)
@router.delete("/task/stop-all", name="stop all tasks")
@traceroot.trace()
def stop_all():
logger.warning("Stopping all tasks", extra={"task_count": len(task_locks)})
for task_lock in task_locks.values():
asyncio.run(task_lock.put_queue(ActionStopData()))
logger.info("All tasks stopped", extra={"task_count": len(task_locks)})
return Response(status_code=204)

View file

@ -1,10 +1,17 @@
from fastapi import APIRouter, HTTPException
from loguru import logger
from app.utils.toolkit.notion_mcp_toolkit import NotionMCPToolkit
from app.utils.toolkit.google_calendar_toolkit import GoogleCalendarToolkit
from app.utils.oauth_state_manager import oauth_state_manager
from utils import traceroot_wrapper as traceroot
from camel.toolkits.hybrid_browser_toolkit.hybrid_browser_toolkit_ts import (
HybridBrowserToolkit as BaseHybridBrowserToolkit,
)
from app.utils.cookie_manager import CookieManager
import os
import uuid
router = APIRouter(tags=["task"])
logger = traceroot.get_logger("tool_controller")
router = APIRouter()
@router.post("/install/tool/{tool}", name="install tool")
@ -28,8 +35,10 @@ async def install_tool(tool: str):
await toolkit.connect()
# Get available tools to verify connection
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
logger.info(f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
tools = [tool_func.func.__name__ for tool_func in
toolkit.get_tools()]
logger.info(
f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
# Disconnect, authentication info is saved
await toolkit.disconnect()
@ -42,7 +51,8 @@ async def install_tool(tool: str):
"toolkit_name": "NotionMCPToolkit"
}
except Exception as connect_error:
logger.warning(f"Could not connect to {tool} MCP server: {connect_error}")
logger.warning(
f"Could not connect to {tool} MCP server: {connect_error}")
# Even if connection fails, mark as installed so user can use it later
return {
"success": True,
@ -60,20 +70,34 @@ async def install_tool(tool: str):
)
elif tool == "google_calendar":
try:
# Use a dummy task_id for installation, as this is just for pre-authentication
toolkit = GoogleCalendarToolkit("install_auth")
# Try to initialize toolkit - will succeed if credentials exist
try:
toolkit = GoogleCalendarToolkit("install_auth")
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
logger.info(f"Successfully initialized Google Calendar toolkit with {len(tools)} tools")
# Get available tools to verify connection
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
logger.info(f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
return {
"success": True,
"tools": tools,
"message": f"Successfully installed {tool} toolkit",
"count": len(tools),
"toolkit_name": "GoogleCalendarToolkit"
}
except ValueError as auth_error:
# No credentials - need authorization
logger.info(f"No credentials found, starting authorization: {auth_error}")
return {
"success": True,
"tools": tools,
"message": f"Successfully installed {tool} toolkit",
"count": len(tools),
"toolkit_name": "GoogleCalendarToolkit"
}
# Start background authorization in a new thread
logger.info("Starting background Google Calendar authorization")
GoogleCalendarToolkit.start_background_auth("install_auth")
return {
"success": False,
"status": "authorizing",
"message": "Authorization required. Browser should open automatically. Complete authorization and try installing again.",
"toolkit_name": "GoogleCalendarToolkit",
"requires_auth": True
}
except Exception as e:
logger.error(f"Failed to install {tool} toolkit: {e}")
raise HTTPException(
@ -113,3 +137,880 @@ async def list_available_tools():
}
]
}
@router.get("/oauth/status/{provider}", name="get oauth status")
async def get_oauth_status(provider: str):
"""
Get the current OAuth authorization status for a provider
Args:
provider: OAuth provider name (e.g., 'google_calendar')
Returns:
Current authorization status
"""
state = oauth_state_manager.get_state(provider)
if not state:
return {
"provider": provider,
"status": "not_started",
"message": "No authorization in progress"
}
return state.to_dict()
@router.post("/oauth/cancel/{provider}", name="cancel oauth")
async def cancel_oauth(provider: str):
"""
Cancel an ongoing OAuth authorization flow
Args:
provider: OAuth provider name (e.g., 'google_calendar')
Returns:
Cancellation result
"""
state = oauth_state_manager.get_state(provider)
if not state:
raise HTTPException(
status_code=404,
detail=f"No authorization found for provider '{provider}'"
)
if state.status not in ["pending", "authorizing"]:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel authorization with status '{state.status}'"
)
state.cancel()
logger.info(f"Cancelled OAuth authorization for {provider}")
return {
"success": True,
"provider": provider,
"message": "Authorization cancelled successfully"
}
@router.delete("/uninstall/tool/{tool}", name="uninstall tool")
async def uninstall_tool(tool: str):
"""
Uninstall a tool and clean up its authentication data
Args:
tool: Tool name to uninstall (notion, google_calendar)
Returns:
Uninstallation result
"""
import os
import shutil
if tool == "notion":
try:
import hashlib
import glob
# Calculate the hash for Notion MCP URL
# mcp-remote uses MD5 hash of the URL to generate file names
notion_url = "https://mcp.notion.com/mcp"
url_hash = hashlib.md5(notion_url.encode()).hexdigest()
# Find and remove Notion-specific auth files
mcp_auth_dir = os.path.join(os.path.expanduser("~"), ".mcp-auth")
deleted_files = []
if os.path.exists(mcp_auth_dir):
# Look for all files with the Notion hash prefix
for version_dir in os.listdir(mcp_auth_dir):
version_path = os.path.join(mcp_auth_dir, version_dir)
if os.path.isdir(version_path):
# Find all files matching the hash pattern
pattern = os.path.join(version_path, f"{url_hash}_*")
notion_files = glob.glob(pattern)
for file_path in notion_files:
try:
os.remove(file_path)
deleted_files.append(file_path)
logger.info(f"Removed Notion auth file: {file_path}")
except Exception as e:
logger.warning(f"Failed to remove {file_path}: {e}")
message = f"Successfully uninstalled {tool}"
if deleted_files:
message += f" and cleaned up {len(deleted_files)} authentication file(s)"
return {
"success": True,
"message": message,
"deleted_files": deleted_files
}
except Exception as e:
logger.error(f"Failed to uninstall {tool}: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to uninstall {tool}: {str(e)}"
)
elif tool == "google_calendar":
try:
# Clean up Google Calendar token directory
token_dir = os.path.join(os.path.expanduser("~"), ".eigent", "tokens", "google_calendar")
if os.path.exists(token_dir):
shutil.rmtree(token_dir)
logger.info(f"Removed Google Calendar token directory: {token_dir}")
# Clear OAuth state manager cache (this is the key fix!)
# This removes the cached credentials from memory
state = oauth_state_manager.get_state("google_calendar")
if state:
if state.status in ["pending", "authorizing"]:
state.cancel()
logger.info("Cancelled ongoing Google Calendar authorization")
# Clear the state completely to remove cached credentials
oauth_state_manager._states.pop("google_calendar", None)
logger.info("Cleared Google Calendar OAuth state cache")
return {
"success": True,
"message": f"Successfully uninstalled {tool} and cleaned up authentication tokens"
}
except Exception as e:
logger.error(f"Failed to uninstall {tool}: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to uninstall {tool}: {str(e)}"
)
else:
raise HTTPException(
status_code=404,
detail=f"Tool '{tool}' not found. Available tools: ['notion', 'google_calendar']"
)
@router.post("/browser/login", name="open browser for login")
async def open_browser_login():
"""
Open an Electron-based Chrome browser for user login with a dedicated user data directory
Returns:
Browser session information
"""
try:
import subprocess
import platform
import socket
import json
# Use fixed profile name for persistent logins (no port suffix)
session_id = "user_login"
cdp_port = 9223
# IMPORTANT: Use dedicated profile for tool_controller browser
# This is the SOURCE OF TRUTH for login data
# On Eigent startup, this data will be copied to WebView partition (one-way sync)
browser_profiles_base = os.path.expanduser("~/.eigent/browser_profiles")
user_data_dir = os.path.join(browser_profiles_base, "profile_user_login")
os.makedirs(user_data_dir, exist_ok=True)
logger.info(
f"Creating browser session {session_id} with profile at: {user_data_dir}")
# Check if browser is already running on this port
def is_port_in_use(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if is_port_in_use(cdp_port):
logger.info(f"Browser already running on port {cdp_port}")
return {
"success": True,
"session_id": session_id,
"user_data_dir": user_data_dir,
"cdp_port": cdp_port,
"message": "Browser already running. Use existing window to log in.",
"note": "Your login data will be saved in the profile."
}
# Create Electron browser script with .cjs extension for CommonJS
electron_script_path = os.path.join(os.path.dirname(__file__), "electron_browser.cjs")
electron_script_content = '''
const { app, BrowserWindow, ipcMain } = require('electron');
const path = require('path');
// Parse command line arguments
const args = process.argv.slice(2);
const userDataDir = args[0];
const cdpPort = args[1];
const startUrl = args[2] || 'https://www.google.com';
// This must be called before app.ready
app.commandLine.appendSwitch('remote-debugging-port', cdpPort);
console.log('[ELECTRON BROWSER] Starting with:');
console.log(' Chrome version:', process.versions.chrome);
console.log(' User data dir (requested):', userDataDir);
console.log(' CDP port:', cdpPort);
console.log(' Start URL:', startUrl);
// Set app paths - must be done before app.ready
// Do NOT use commandLine.appendSwitch('user-data-dir') as it conflicts with setPath
app.setPath('userData', userDataDir);
app.setPath('sessionData', userDataDir);
app.whenReady().then(async () => {
const { session } = require('electron');
const fs = require('fs');
const path = require('path');
// Log actual paths being used
console.log('[ELECTRON BROWSER] Actual paths:');
console.log(' app.getPath("userData"):', app.getPath('userData'));
console.log(' app.getPath("sessionData"):', app.getPath('sessionData'));
console.log(' app.getPath("cache"):', app.getPath('cache'));
console.log(' app.getPath("temp"):', app.getPath('temp'));
console.log(' process.argv:', process.argv);
// Check command line switches
console.log('[ELECTRON BROWSER] Command line switches:');
console.log(' user-data-dir:', app.commandLine.getSwitchValue('user-data-dir'));
console.log(' remote-debugging-port:', app.commandLine.getSwitchValue('remote-debugging-port'));
// Log partition session info
const userLoginSession = session.fromPartition('persist:user_login');
console.log('[ELECTRON BROWSER] Session info:');
console.log(' Partition: persist:user_login');
console.log(' Session storage path:', userLoginSession.getStoragePath());
// Check if Cookies file exists
const cookiesPath = path.join(app.getPath('userData'), 'Partitions', 'user_login', 'Cookies');
console.log('[ELECTRON BROWSER] Cookies path:', cookiesPath);
console.log('[ELECTRON BROWSER] Cookies exists:', fs.existsSync(cookiesPath));
if (fs.existsSync(cookiesPath)) {
const stats = fs.statSync(cookiesPath);
console.log('[ELECTRON BROWSER] Cookies file size:', stats.size, 'bytes');
}
const win = new BrowserWindow({
width: 1400,
height: 900,
title: 'Eigent Browser - Login',
webPreferences: {
nodeIntegration: true,
contextIsolation: false,
webviewTag: true
}
});
// Create navigation bar and webview
const html = `
<!DOCTYPE html>
<html>
<head>
<style>
body {
margin: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
display: flex;
flex-direction: column;
height: 100vh;
overflow: hidden;
}
#nav-bar {
display: flex;
align-items: center;
padding: 8px;
background: #f5f5f5;
border-bottom: 1px solid #ddd;
gap: 8px;
}
button {
padding: 6px 12px;
border: 1px solid #ccc;
background: white;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
display: flex;
align-items: center;
gap: 4px;
}
button:hover:not(:disabled) {
background: #f0f0f0;
}
button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
#url-input {
flex: 1;
padding: 8px 12px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 14px;
}
#url-input:focus {
outline: none;
border-color: #4285f4;
}
#webview {
flex: 1;
width: 100%;
border: none;
}
.nav-icon {
font-size: 16px;
}
#loading-indicator {
width: 20px;
height: 20px;
border: 2px solid #f3f3f3;
border-top: 2px solid #4285f4;
border-radius: 50%;
animation: spin 1s linear infinite;
display: none;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.loading #loading-indicator {
display: block;
}
.loading #reload-btn .nav-icon {
display: none;
}
</style>
</head>
<body>
<div id="nav-bar">
<button id="back-btn" title="Back">
<span class="nav-icon"></span>
</button>
<button id="forward-btn" title="Forward">
<span class="nav-icon"></span>
</button>
<button id="reload-btn" title="Reload">
<span class="nav-icon"></span>
<div id="loading-indicator"></div>
</button>
<button id="home-btn" title="Home">
<span class="nav-icon">🏠</span>
</button>
<input type="text" id="url-input" placeholder="Enter URL..." />
<button id="go-btn">Go</button>
<button id="linkedin-btn" style="background: #0077B5; color: white; border-color: #0077B5;">
LinkedIn
</button>
<button id="info-btn" title="Show Info"></button>
</div>
<webview id="webview" src="${startUrl}" partition="persist:user_login"></webview>
<div id="info-panel" style="display: none; position: absolute; top: 50px; right: 10px; background: white; border: 1px solid #ccc; padding: 15px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); z-index: 1000; max-width: 400px; font-size: 12px;">
<h4 style="margin: 0 0 10px 0;">Browser Info</h4>
<div id="info-content"></div>
<button onclick="document.getElementById('info-panel').style.display='none'" style="margin-top: 10px;">Close</button>
</div>
<script>
const webview = document.getElementById('webview');
const backBtn = document.getElementById('back-btn');
const forwardBtn = document.getElementById('forward-btn');
const reloadBtn = document.getElementById('reload-btn');
const homeBtn = document.getElementById('home-btn');
const urlInput = document.getElementById('url-input');
const goBtn = document.getElementById('go-btn');
const linkedinBtn = document.getElementById('linkedin-btn');
const navBar = document.getElementById('nav-bar');
const infoBtn = document.getElementById('info-btn');
const infoPanel = document.getElementById('info-panel');
const infoContent = document.getElementById('info-content');
// Show info panel
infoBtn.addEventListener('click', () => {
const { ipcRenderer } = require('electron');
// Get browser info
const info = {
'Chrome Version': process.versions.chrome,
'Electron Version': process.versions.electron,
'Node Version': process.versions.node,
'User Data Dir (requested)': '${userDataDir}',
'CDP Port': '${cdpPort}',
'Platform': process.platform,
'Architecture': process.arch
};
// Also check webview partition info
const partition = webview.partition || 'default';
info['WebView Partition'] = partition;
// Format info as HTML
let html = '<table style="width: 100%; border-collapse: collapse;">';
for (const [key, value] of Object.entries(info)) {
html += '<tr><td style="padding: 4px; border-bottom: 1px solid #eee;"><strong>' + key + ':</strong></td><td style="padding: 4px; border-bottom: 1px solid #eee; word-break: break-all;">' + value + '</td></tr>';
}
html += '</table>';
infoContent.innerHTML = html;
infoPanel.style.display = 'block';
});
// Update navigation buttons
function updateNavButtons() {
backBtn.disabled = !webview.canGoBack();
forwardBtn.disabled = !webview.canGoForward();
}
// Navigate to URL
function navigateToUrl() {
let url = urlInput.value.trim();
if (!url) return;
if (!url.startsWith('http://') && !url.startsWith('https://')) {
url = 'https://' + url;
}
webview.loadURL(url);
}
// Event listeners
backBtn.addEventListener('click', () => webview.goBack());
forwardBtn.addEventListener('click', () => webview.goForward());
reloadBtn.addEventListener('click', () => webview.reload());
homeBtn.addEventListener('click', () => webview.loadURL('${startUrl}'));
goBtn.addEventListener('click', navigateToUrl);
linkedinBtn.addEventListener('click', () => webview.loadURL('https://www.linkedin.com'));
urlInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') {
navigateToUrl();
}
});
// WebView events
webview.addEventListener('did-start-loading', () => {
navBar.classList.add('loading');
});
webview.addEventListener('did-stop-loading', () => {
navBar.classList.remove('loading');
updateNavButtons();
});
webview.addEventListener('did-navigate', (e) => {
urlInput.value = e.url;
updateNavButtons();
});
webview.addEventListener('did-navigate-in-page', (e) => {
urlInput.value = e.url;
updateNavButtons();
});
webview.addEventListener('new-window', (e) => {
// Open new windows in the same webview
e.preventDefault();
webview.loadURL(e.url);
});
// Initialize
updateNavButtons();
</script>
</body>
</html>`;
win.loadURL('data:text/html;charset=UTF-8,' + encodeURIComponent(html));
// Show window when ready
win.once('ready-to-show', () => {
win.show();
// Log cookies periodically to track changes
setInterval(async () => {
try {
const cookies = await userLoginSession.cookies.get({});
console.log('[ELECTRON BROWSER] Current cookies count:', cookies.length);
if (cookies.length > 0) {
console.log('[ELECTRON BROWSER] Cookie domains:', [...new Set(cookies.map(c => c.domain))]);
}
} catch (error) {
console.error('[ELECTRON BROWSER] Failed to get cookies:', error);
}
}, 5000); // Check every 5 seconds
});
win.on('closed', async () => {
console.log('[ELECTRON BROWSER] Window closed, preparing to quit...');
// Flush storage data before quitting to ensure cookies are saved
try {
const { session } = require('electron');
const fs = require('fs');
const path = require('path');
const userLoginSession = session.fromPartition('persist:user_login');
// Log cookies before flush
const cookiesBeforeFlush = await userLoginSession.cookies.get({});
console.log('[ELECTRON BROWSER] Cookies count before flush:', cookiesBeforeFlush.length);
// Flush storage
console.log('[ELECTRON BROWSER] Flushing storage data...');
await userLoginSession.flushStorageData();
console.log('[ELECTRON BROWSER] Storage data flushed successfully');
// Check cookies file after flush
const cookiesPath = path.join(app.getPath('userData'), 'Partitions', 'user_login', 'Cookies');
if (fs.existsSync(cookiesPath)) {
const stats = fs.statSync(cookiesPath);
console.log('[ELECTRON BROWSER] Cookies file size after flush:', stats.size, 'bytes');
} else {
console.log('[ELECTRON BROWSER] WARNING: Cookies file does not exist after flush!');
}
} catch (error) {
console.error('[ELECTRON BROWSER] Failed to flush storage data:', error);
}
app.quit();
});
});
let isQuitting = false;
app.on('before-quit', async (event) => {
if (isQuitting) return;
// Prevent immediate quit to allow storage flush and cookie sync
event.preventDefault();
isQuitting = true;
console.log('[ELECTRON BROWSER] before-quit event triggered');
try {
const { session } = require('electron');
const fs = require('fs');
const path = require('path');
const userLoginSession = session.fromPartition('persist:user_login');
// Log cookies before flush
const cookiesBeforeQuit = await userLoginSession.cookies.get({});
console.log('[ELECTRON BROWSER] Cookies count before quit:', cookiesBeforeQuit.length);
if (cookiesBeforeQuit.length > 0) {
console.log('[ELECTRON BROWSER] Cookie domains before quit:', [...new Set(cookiesBeforeQuit.map(c => c.domain))]);
}
// Flush storage
console.log('[ELECTRON BROWSER] Flushing storage on quit...');
await userLoginSession.flushStorageData();
console.log('[ELECTRON BROWSER] Storage data flushed on quit');
} catch (error) {
console.error('[ELECTRON BROWSER] Failed to sync cookies:', error);
} finally {
console.log('[ELECTRON BROWSER] Exiting now...');
// Force quit after sync
app.exit(0);
}
});
app.on('window-all-closed', () => {
if (!isQuitting) {
app.quit();
}
});
'''
# Write the Electron script
with open(electron_script_path, 'w') as f:
f.write(electron_script_content)
# Find Electron executable
# Try to use the same Electron version as the main app
electron_cmd = "npx"
electron_args = [
electron_cmd,
"electron",
electron_script_path,
user_data_dir,
str(cdp_port),
"https://www.google.com"
]
# Get the app's directory to run npx in the right context
app_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
logger.info(f"[PROFILE USER LOGIN] Launching Electron browser with CDP on port {cdp_port}")
logger.info(f"[PROFILE USER LOGIN] Working directory: {app_dir}")
logger.info(f"[PROFILE USER LOGIN] userData path: {user_data_dir}")
logger.info(f"[PROFILE USER LOGIN] Electron args: {electron_args}")
# Start process and capture output in real-time
process = subprocess.Popen(
electron_args,
cwd=app_dir,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout
universal_newlines=True,
bufsize=1 # Line buffered
)
# Create async task to log Electron output
async def log_electron_output():
for line in iter(process.stdout.readline, ''):
if line:
logger.info(f"[ELECTRON OUTPUT] {line.strip()}")
import asyncio
asyncio.create_task(log_electron_output())
# Wait a bit for Electron to start
import asyncio
await asyncio.sleep(3)
# Clean up the script file after a delay
async def cleanup_script():
await asyncio.sleep(10)
try:
os.remove(electron_script_path)
except:
pass
asyncio.create_task(cleanup_script())
logger.info(f"[PROFILE USER LOGIN] Electron browser launched with PID {process.pid}")
return {
"success": True,
"session_id": session_id,
"user_data_dir": user_data_dir,
"cdp_port": cdp_port,
"pid": process.pid,
"chrome_version": "130.0.6723.191", # Electron 33's Chrome version
"message": "Electron browser opened successfully. Please log in to your accounts.",
"note": "The browser will remain open for you to log in. Your login data will be saved in the profile."
}
except Exception as e:
logger.error(f"Failed to open Electron browser for login: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to open browser: {str(e)}"
)
@router.get("/browser/cookies", name="list cookie domains")
async def list_cookie_domains(search: str = None):
"""
list cookie domains
Args:
search: url
Returns:
list of cookie domains
"""
try:
# Use tool_controller browser's user data directory (source of truth)
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
user_data_dir = os.path.join(user_data_base, "profile_user_login")
logger.info(f"[COOKIES CHECK] Tool controller user_data_dir: {user_data_dir}")
logger.info(f"[COOKIES CHECK] Tool controller user_data_dir exists: {os.path.exists(user_data_dir)}")
# Check partition path
partition_path = os.path.join(user_data_dir, "Partitions", "user_login")
logger.info(f"[COOKIES CHECK] partition path: {partition_path}")
logger.info(f"[COOKIES CHECK] partition exists: {os.path.exists(partition_path)}")
# Check cookies file
cookies_file = os.path.join(partition_path, "Cookies")
logger.info(f"[COOKIES CHECK] cookies file: {cookies_file}")
logger.info(f"[COOKIES CHECK] cookies file exists: {os.path.exists(cookies_file)}")
if os.path.exists(cookies_file):
stat = os.stat(cookies_file)
logger.info(f"[COOKIES CHECK] cookies file size: {stat.st_size} bytes")
# Try to read actual cookie count
try:
import sqlite3
conn = sqlite3.connect(cookies_file)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM cookies")
count = cursor.fetchone()[0]
logger.info(f"[COOKIES CHECK] actual cookie count in database: {count}")
conn.close()
except Exception as e:
logger.error(f"[COOKIES CHECK] failed to read cookie count: {e}")
if not os.path.exists(user_data_dir):
return {
"success": True,
"domains": [],
"message": "No browser profile found. Please login first using /browser/login."
}
cookie_manager = CookieManager(user_data_dir)
if search:
domains = cookie_manager.search_cookies(search)
else:
domains = cookie_manager.get_cookie_domains()
return {
"success": True,
"domains": domains,
"total": len(domains),
"user_data_dir": user_data_dir
}
except Exception as e:
logger.error(f"Failed to list cookie domains: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to list cookies: {str(e)}"
)
@router.get("/browser/cookies/{domain}", name="get domain cookies")
async def get_domain_cookies(domain: str):
"""
get domain cookies
Args:
domain
Returns:
cookies
"""
try:
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
user_data_dir = os.path.join(user_data_base, "profile_user_login")
if not os.path.exists(user_data_dir):
raise HTTPException(
status_code=404,
detail="No browser profile found. Please login first using /browser/login."
)
cookie_manager = CookieManager(user_data_dir)
cookies = cookie_manager.get_cookies_for_domain(domain)
return {
"success": True,
"domain": domain,
"cookies": cookies,
"count": len(cookies)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get cookies for domain {domain}: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get cookies: {str(e)}"
)
@router.delete("/browser/cookies/{domain}", name="delete domain cookies")
async def delete_domain_cookies(domain: str):
"""
Delete cookies
Args:
domain
Returns:
deleted cookies
"""
try:
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
user_data_dir = os.path.join(user_data_base, "profile_user_login")
if not os.path.exists(user_data_dir):
raise HTTPException(
status_code=404,
detail="No browser profile found. Please login first using /browser/login."
)
cookie_manager = CookieManager(user_data_dir)
success = cookie_manager.delete_cookies_for_domain(domain)
if success:
return {
"success": True,
"message": f"Successfully deleted cookies for domain: {domain}"
}
else:
raise HTTPException(
status_code=500,
detail=f"Failed to delete cookies for domain: {domain}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete cookies for domain {domain}: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to delete cookies: {str(e)}"
)
@router.delete("/browser/cookies", name="delete all cookies")
async def delete_all_cookies():
"""
delete all cookies
Returns:
deleted cookies
"""
try:
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
user_data_dir = os.path.join(user_data_base, "profile_user_login")
if not os.path.exists(user_data_dir):
raise HTTPException(
status_code=404,
detail="No browser profile found."
)
cookie_manager = CookieManager(user_data_dir)
success = cookie_manager.delete_all_cookies()
if success:
return {
"success": True,
"message": "Successfully deleted all cookies"
}
else:
raise HTTPException(
status_code=500,
detail="Failed to delete all cookies"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete all cookies: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to delete cookies: {str(e)}"
)

View file

@ -3,18 +3,22 @@ from fastapi import Request
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
from app import api
from app.component import code
from app.exception.exception import NoPermissionException, ProgramException, TokenException
from app.component.pydantic.i18n import trans, get_language
from app.exception.exception import UserException
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("exception_handler")
@api.exception_handler(RequestValidationError)
async def request_exception(request: Request, e: RequestValidationError):
if (lang := get_language(request.headers.get("Accept-Language"))) is None:
lang = "en_US"
logger.warning(f"Validation error on {request.url.path}: {e.errors()}")
return JSONResponse(
content={
"code": code.form_error,
@ -25,16 +29,19 @@ async def request_exception(request: Request, e: RequestValidationError):
@api.exception_handler(TokenException)
async def token_exception(request: Request, e: TokenException):
logger.warning(f"Token exception on {request.url.path}: {e.text}")
return JSONResponse(content={"code": e.code, "text": e.text})
@api.exception_handler(UserException)
async def user_exception(request: Request, e: UserException):
logger.info(f"User exception on {request.url.path}: {e.description}")
return JSONResponse(content={"code": e.code, "text": e.description})
@api.exception_handler(NoPermissionException)
async def no_permission(request: Request, exception: NoPermissionException):
logger.warning(f"No permission on {request.url.path}: {exception.text}")
return JSONResponse(
status_code=200,
content={"code": code.no_permission_error, "text": exception.text},
@ -43,6 +50,7 @@ async def no_permission(request: Request, exception: NoPermissionException):
@api.exception_handler(ProgramException)
async def program_exception(request: Request, exception: NoPermissionException):
logger.error(f"Program exception on {request.url.path}: {exception.text}", exc_info=True)
return JSONResponse(
status_code=200,
content={"code": code.program_error, "text": exception.text},
@ -51,8 +59,16 @@ async def program_exception(request: Request, exception: NoPermissionException):
@api.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled error: {exc}")
traceback.print_exc() # output to electron log
logger.error(
f"Unhandled exception on {request.method} {request.url.path}: {exc}",
exc_info=True,
extra={
"request_method": request.method,
"request_path": str(request.url.path),
"request_query": str(request.url.query),
"client_host": request.client.host if request.client else None,
}
)
return JSONResponse(
status_code=500,

View file

@ -3,9 +3,11 @@ import json
from pathlib import Path
import re
from typing import Literal
from loguru import logger
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from camel.types import ModelType, RoleType
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("chat_model")
class Status(str, Enum):
@ -20,11 +22,22 @@ class ChatHistory(BaseModel):
content: str
class QuestionAnalysisResult(BaseModel):
type: Literal["simple", "complex"] = Field(
description="Whether this is a simple question or complex task"
)
answer: str | None = Field(
default=None,
description="Direct answer for simple questions. None for complex tasks."
)
McpServers = dict[Literal["mcpServers"], dict[str, dict]]
class Chat(BaseModel):
task_id: str
project_id: str
question: str
email: str
attaches: list[str] = []
@ -50,6 +63,7 @@ class Chat(BaseModel):
)
new_agents: list["NewAgent"] = []
extra_params: dict | None = None # For provider-specific parameters like Azure
search_config: dict[str, str] | None = None # User-specific search engine configurations (e.g., GOOGLE_API_KEY, SEARCH_ENGINE_ID)
@field_validator("model_type")
@classmethod
@ -72,7 +86,8 @@ class Chat(BaseModel):
def file_save_path(self, path: str | None = None):
email = re.sub(r'[\\/*?:"<>|\s]', "_", self.email.split("@")[0]).strip(".")
save_path = Path.home() / "eigent" / email / ("task_" + self.task_id)
# Use project-based structure: project_{project_id}/task_{task_id}
save_path = Path.home() / "eigent" / email / f"project_{self.project_id}" / f"task_{self.task_id}"
if path is not None:
save_path = save_path / path
save_path.mkdir(parents=True, exist_ok=True)
@ -82,6 +97,7 @@ class Chat(BaseModel):
class SupplementChat(BaseModel):
question: str
task_id: str | None = None
class HumanReply(BaseModel):
@ -106,6 +122,18 @@ class NewAgent(BaseModel):
env_path: str | None = None
class AddTaskRequest(BaseModel):
content: str
project_id: str | None = None
task_id: str | None = None
additional_info: dict | None = None
insert_position: int = -1
is_independent: bool = False
class RemoveTaskRequest(BaseModel):
task_id: str
def sse_json(step: str, data):
res_format = {"step": step, "data": data}
return f"data: {json.dumps(res_format, ensure_ascii=False)}\n\n"

64
backend/app/router.py Normal file
View file

@ -0,0 +1,64 @@
"""
Centralized router registration for the Eigent API.
All routers are explicitly registered here for better visibility and maintainability.
"""
from fastapi import FastAPI
from app.controller import chat_controller, model_controller, task_controller, tool_controller, health_controller
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("router")
def register_routers(app: FastAPI, prefix: str = "") -> None:
"""
Register all API routers with their respective prefixes and tags.
This replaces the auto-discovery mechanism for better:
- Visibility: See all routes in one place
- Maintainability: Easy to add/remove routes
- Debugging: Clear registration order and configuration
Args:
app: FastAPI application instance
prefix: Optional global prefix for all routes (e.g., "/api")
"""
routers_config = [
{
"router": health_controller.router,
"tags": ["Health"],
"description": "Health check endpoint for service readiness"
},
{
"router": chat_controller.router,
"tags": ["chat"],
"description": "Chat session management, improvements, and human interactions"
},
{
"router": model_controller.router,
"tags": ["model"],
"description": "Model validation and configuration"
},
{
"router": task_controller.router,
"tags": ["task"],
"description": "Task lifecycle management (start, stop, update, control)"
},
{
"router": tool_controller.router,
"tags": ["tool"],
"description": "Tool installation and management"
},
]
for config in routers_config:
app.include_router(
config["router"],
prefix=prefix,
tags=config["tags"]
)
route_count = len(config["router"].routes)
logger.info(
f"Registered {config['tags'][0]} router: {route_count} routes - {config['description']}"
)
logger.info(f"Total routers registered: {len(routers_config)}")

View file

@ -1,5 +1,6 @@
import asyncio
import datetime
import json
from pathlib import Path
import platform
from typing import Literal
@ -8,6 +9,7 @@ from inflection import titleize
from pydash import chain
from app.component.debug import dump_class
from app.component.environment import env
from app.utils.file_utils import get_working_directory
from app.service.task import (
ActionImproveData,
ActionInstallMcpData,
@ -19,7 +21,6 @@ from camel.toolkits import AgentCommunicationToolkit, ToolkitMessageIntegration
from app.utils.toolkit.human_toolkit import HumanToolkit
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
from app.utils.workforce import Workforce
from loguru import logger
from app.model.chat import Chat, NewAgent, Status, sse_json, TaskContent
from camel.tasks import Task
from app.utils.agent import (
@ -40,9 +41,197 @@ from app.service.task import Action, Agents
from app.utils.server.sync_step import sync_step
from camel.types import ModelPlatformType
from camel.models import ModelProcessingError
from utils import traceroot_wrapper as traceroot
import os
logger = traceroot.get_logger("chat_service")
def format_task_context(task_data: dict, seen_files: set | None = None, skip_files: bool = False) -> str:
"""Format structured task data into a readable context string.
Args:
task_data: Dictionary containing task content, result, and working directory
seen_files: Optional set to track already-listed files and avoid duplicates (deprecated, use skip_files instead)
skip_files: If True, skip the file listing entirely
"""
context_parts = []
if task_data.get('task_content'):
context_parts.append(f"Previous Task: {task_data['task_content']}")
if task_data.get('task_result'):
context_parts.append(f"Previous Task Result: {task_data['task_result']}")
# Skip file listing if requested
if not skip_files:
working_directory = task_data.get('working_directory')
if working_directory:
try:
if os.path.exists(working_directory):
generated_files = []
for root, dirs, files in os.walk(working_directory):
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']]
for file in files:
if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')):
file_path = os.path.join(root, file)
absolute_path = os.path.abspath(file_path)
# Only add if not seen before (or if we're not tracking seen files)
if seen_files is None or absolute_path not in seen_files:
generated_files.append(absolute_path)
if seen_files is not None:
seen_files.add(absolute_path)
if generated_files:
context_parts.append("Generated Files from Previous Task:")
for file_path in sorted(generated_files):
context_parts.append(f" - {file_path}")
except Exception as e:
logger.warning(f"Failed to collect generated files: {e}")
return "\n".join(context_parts)
def collect_previous_task_context(working_directory: str, previous_task_content: str, previous_task_result: str, previous_summary: str = "") -> str:
"""
Collect context from previous task including content, result, summary, and generated files.
Args:
working_directory: The working directory to scan for generated files
previous_task_content: The content of the previous task
previous_task_result: The result/output of the previous task
previous_summary: The summary of the previous task
Returns:
Formatted context string to prepend to new task
"""
context_parts = []
# Add previous task information
context_parts.append("=== CONTEXT FROM PREVIOUS TASK ===\n")
# Add previous task content
if previous_task_content:
context_parts.append(f"Previous Task:\n{previous_task_content}\n")
# Add previous task summary
if previous_summary:
context_parts.append(f"Previous Task Summary:\n{previous_summary}\n")
# Add previous task result
if previous_task_result:
context_parts.append(f"Previous Task Result:\n{previous_task_result}\n")
# Collect generated files from working directory
try:
if os.path.exists(working_directory):
generated_files = []
for root, dirs, files in os.walk(working_directory):
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']]
for file in files:
if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')):
file_path = os.path.join(root, file)
absolute_path = os.path.abspath(file_path)
generated_files.append(absolute_path)
if generated_files:
context_parts.append("Generated Files from Previous Task:")
for file_path in sorted(generated_files):
context_parts.append(f" - {file_path}")
context_parts.append("")
except Exception as e:
logger.warning(f"Failed to collect generated files: {e}")
context_parts.append("=== END OF PREVIOUS TASK CONTEXT ===\n")
return "\n".join(context_parts)
def check_conversation_history_length(task_lock: TaskLock, max_length: int = 100000) -> tuple[bool, int]:
"""
Check if conversation history exceeds maximum length
Returns:
tuple: (is_exceeded, total_length)
"""
if not hasattr(task_lock, 'conversation_history') or not task_lock.conversation_history:
return False, 0
total_length = 0
for entry in task_lock.conversation_history:
total_length += len(entry.get('content', ''))
is_exceeded = total_length > max_length
if is_exceeded:
logger.warning(f"Conversation history length {total_length} exceeds maximum {max_length}")
return is_exceeded, total_length
def build_conversation_context(task_lock: TaskLock, header: str = "=== CONVERSATION HISTORY ===") -> str:
"""Build conversation context from task_lock history with files listed only once at the end.
Args:
task_lock: TaskLock containing conversation history
header: Header text for the context section
Returns:
Formatted context string with task history and files listed once at the end
"""
context = ""
working_directories = set() # Collect all unique working directories
if task_lock.conversation_history:
context = f"{header}\n"
for entry in task_lock.conversation_history:
if entry['role'] == 'task_result':
if isinstance(entry['content'], dict):
formatted_context = format_task_context(entry['content'], skip_files=True)
context += formatted_context + "\n\n"
if entry['content'].get('working_directory'):
working_directories.add(entry['content']['working_directory'])
else:
context += entry['content'] + "\n"
elif entry['role'] == 'assistant':
context += f"Assistant: {entry['content']}\n\n"
if working_directories:
all_generated_files = set() # Use set to avoid duplicates
for working_directory in working_directories:
try:
if os.path.exists(working_directory):
for root, dirs, files in os.walk(working_directory):
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']]
for file in files:
if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')):
file_path = os.path.join(root, file)
absolute_path = os.path.abspath(file_path)
all_generated_files.add(absolute_path)
except Exception as e:
logger.warning(f"Failed to collect generated files from {working_directory}: {e}")
if all_generated_files:
context += "Generated Files from Previous Tasks:\n"
for file_path in sorted(all_generated_files):
context += f" - {file_path}\n"
context += "\n"
context += "\n"
return context
def build_context_for_workforce(task_lock: TaskLock, options: Chat) -> str:
"""Build context information for workforce."""
return build_conversation_context(task_lock, header="=== CONVERSATION HISTORY ===")
@sync_step
@traceroot.trace()
async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
# if True:
# import faulthandler
@ -52,12 +241,40 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
# faulthandler.dump_traceback_later(second)
start_event_loop = True
question_agent = question_confirm_agent(options)
if not hasattr(task_lock, 'conversation_history'):
task_lock.conversation_history = []
if not hasattr(task_lock, 'last_task_result'):
task_lock.last_task_result = ""
if not hasattr(task_lock, 'question_agent'):
task_lock.question_agent = None
if not hasattr(task_lock, 'summary_generated'):
task_lock.summary_generated = False
# Create or reuse persistent question_agent
if task_lock.question_agent is None:
task_lock.question_agent = question_confirm_agent(options)
logger.info(f"Created new persistent question_agent for project {options.project_id}")
else:
logger.info(f"Reusing existing question_agent with {len(task_lock.conversation_history)} history entries")
question_agent = task_lock.question_agent
# Other variables
camel_task = None
workforce = None
last_completed_task_result = "" # Track the last completed task result
summary_task_content = "" # Track task summary
loop_iteration = 0
logger.info("Starting step_solve", extra={"project_id": options.project_id, "task_id": options.task_id})
logger.debug("Step solve options", extra={"task_id": options.task_id, "model_platform": options.model_platform})
while True:
loop_iteration += 1
if await request.is_disconnected():
logger.warning(f"Client disconnected for task {options.task_id}")
logger.warning(f"Client disconnected for project {options.project_id}")
if workforce is not None:
if workforce._running:
workforce.stop()
@ -70,10 +287,10 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
break
try:
item = await task_lock.get_queue()
# logger.info(f"item: {dump_class(item)}")
except Exception as e:
logger.error(f"Error getting item from queue: {e}")
break
logger.error("Error getting item from queue", extra={"project_id": options.project_id, "task_id": options.task_id, "error": str(e)}, exc_info=True)
# Continue waiting instead of breaking on queue error
continue
try:
if item.action == Action.improve or start_event_loop:
@ -87,33 +304,116 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
else:
assert isinstance(item, ActionImproveData)
question = item.data
if len(question) < 12 and len(options.attaches) == 0:
confirm = await question_confirm(question_agent, question)
else:
confirm = True
if confirm is not True:
yield confirm
is_exceeded, total_length = check_conversation_history_length(task_lock)
if is_exceeded:
logger.error("Conversation history too long", extra={"project_id": options.project_id, "current_length": total_length, "max_length": 100000})
yield sse_json("context_too_long", {
"message": "The conversation history is too long. Please create a new project to continue.",
"current_length": total_length,
"max_length": 100000
})
continue
# Simplified logic: attachments mean workforce, otherwise let agent decide
is_complex_task: bool
if len(options.attaches) > 0:
# Questions with attachments always need workforce
is_complex_task = True
else:
yield sse_json("confirmed", "")
is_complex_task = await question_confirm(question_agent, question, task_lock)
if not is_complex_task:
simple_answer_prompt = f"{build_conversation_context(task_lock, header='=== Previous Conversation ===')}User Query: {question}\n\nProvide a direct, helpful answer to this simple question."
try:
simple_resp = question_agent.step(simple_answer_prompt)
answer_content = simple_resp.msgs[0].content if simple_resp and simple_resp.msgs else "I understand your question, but I'm having trouble generating a response right now."
task_lock.add_conversation('assistant', answer_content)
yield sse_json("wait_confirm", {"content": answer_content, "question": question})
except Exception as e:
logger.error(f"Error generating simple answer: {e}")
yield sse_json("wait_confirm", {"content": "I encountered an error while processing your question.", "question": question})
# Clean up empty folder if it was created for this task
if hasattr(task_lock, 'new_folder_path') and task_lock.new_folder_path:
try:
folder_path = Path(task_lock.new_folder_path)
if folder_path.exists() and folder_path.is_dir():
# Check if folder is empty
if not any(folder_path.iterdir()):
folder_path.rmdir()
logger.info(f"Cleaned up empty folder: {folder_path}")
# Also clean up parent project folder if it becomes empty
project_folder = folder_path.parent
if project_folder.exists() and not any(project_folder.iterdir()):
project_folder.rmdir()
logger.info(f"Cleaned up empty project folder: {project_folder}")
else:
logger.info(f"Folder not empty, keeping: {folder_path}")
# Reset the folder path
task_lock.new_folder_path = None
except Exception as e:
logger.error(f"Error cleaning up folder: {e}")
else:
yield sse_json("confirmed", {"question": question})
context_for_coordinator = build_context_for_workforce(task_lock, options)
(workforce, mcp) = await construct_workforce(options)
for new_agent in options.new_agents:
workforce.add_single_agent_worker(
format_agent_description(new_agent), await new_agent_model(new_agent, options)
)
summary_task_agent = task_summary_agent(options)
task_lock.status = Status.confirmed
question = question + options.summary_prompt
camel_task = Task(content=question, id=options.task_id)
clean_task_content = question + options.summary_prompt
camel_task = Task(content=clean_task_content, id=options.task_id)
if len(options.attaches) > 0:
camel_task.additional_info = {Path(file_path).name: file_path for file_path in options.attaches}
sub_tasks = await asyncio.to_thread(workforce.eigent_make_sub_tasks, camel_task)
summary_task_content = await summary_task(summary_task_agent, camel_task)
sub_tasks = await asyncio.to_thread(
workforce.eigent_make_sub_tasks,
camel_task,
context_for_coordinator
)
if not task_lock.summary_generated:
summary_task_agent = task_summary_agent(options)
try:
summary_task_content = await asyncio.wait_for(
summary_task(summary_task_agent, camel_task), timeout=10
)
task_lock.summary_generated = True
logger.info("Generated summary for first task", extra={"project_id": options.project_id})
except asyncio.TimeoutError:
logger.warning("summary_task timeout", extra={"project_id": options.project_id, "task_id": options.task_id})
# Fallback to a minimal summary to unblock UI
fallback_name = "Task"
content_preview = camel_task.content if hasattr(camel_task, "content") else ""
if content_preview is None:
content_preview = ""
fallback_summary = (
(content_preview[:80] + "...") if len(content_preview) > 80 else content_preview
)
summary_task_content = f"{fallback_name}|{fallback_summary}"
task_lock.summary_generated = True
else:
if len(question) > 100:
summary_task_content = f"Task|{question[:97]}..."
else:
summary_task_content = f"Task|{question}"
logger.info("Skipped summary generation for subsequent task", extra={"project_id": options.project_id})
yield to_sub_tasks(camel_task, summary_task_content)
# tracer.stop()
# tracer.save("trace.json")
# Only auto-start in debug mode
if env("debug") == "on":
logger.info(f"[DEBUG] Auto-starting workforce in debug mode")
task_lock.status = Status.processing
task = asyncio.create_task(workforce.eigent_start(sub_tasks))
task_lock.add_background_task(task)
@ -124,12 +424,185 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
sub_tasks = update_sub_tasks(sub_tasks, update_tasks)
add_sub_tasks(camel_task, item.data.task)
yield to_sub_tasks(camel_task, summary_task_content)
elif item.action == Action.add_task:
# Check if this might be a misrouted second question
if camel_task is None and workforce is None:
logger.error(f"Cannot add task: both camel_task and workforce are None for project {options.project_id}")
yield sse_json("error", {"message": "Cannot add task: task not initialized. Please start a task first."})
continue
assert camel_task is not None
if workforce is None:
logger.error(f"Cannot add task: workforce not initialized for project {options.project_id}")
yield sse_json("error", {"message": "Workforce not initialized. Please start the task first."})
continue
# Add task to the workforce queue
workforce.add_task(
item.content,
item.task_id,
item.additional_info
)
returnData = {
"project_id": item.project_id,
"task_id": item.task_id or (len(camel_task.subtasks) + 1)
}
yield sse_json("add_task", returnData)
elif item.action == Action.remove_task:
if workforce is None:
logger.error(f"Cannot remove task: workforce not initialized for project {options.project_id}")
yield sse_json("error", {"message": "Workforce not initialized. Please start the task first."})
continue
workforce.remove_task(item.task_id)
returnData = {
"project_id": item.project_id,
"task_id": item.task_id
}
yield sse_json("remove_task", returnData)
elif item.action == Action.skip_task:
if workforce is not None and item.project_id == options.project_id:
if workforce._state.name == 'PAUSED':
# Resume paused workforce to skip the task
workforce.resume()
workforce.skip_gracefully()
elif item.action == Action.start:
# Check conversation history length before starting task
is_exceeded, total_length = check_conversation_history_length(task_lock)
if is_exceeded:
logger.error(f"Cannot start task: conversation history too long ({total_length} chars) for project {options.project_id}")
yield sse_json("context_too_long", {
"message": "The conversation history is too long. Please create a new project to continue.",
"current_length": total_length,
"max_length": 100000
})
continue
if workforce is not None:
if workforce._state.name == 'PAUSED':
# Resume paused workforce - subtasks should already be loaded
workforce.resume()
continue
else:
continue
task_lock.status = Status.processing
task = asyncio.create_task(workforce.eigent_start(sub_tasks))
task_lock.add_background_task(task)
elif item.action == Action.task_state:
# Track completed task results for the end event
task_id = item.data.get('task_id', 'unknown')
task_state = item.data.get('state', 'unknown')
task_result = item.data.get('result', '')
if task_state == 'DONE' and task_result:
last_completed_task_result = task_result
yield sse_json("task_state", item.data)
elif item.action == Action.new_task_state:
# Log new task state details
new_task_id = item.data.get('task_id', 'unknown')
new_task_state = item.data.get('state', 'unknown')
new_task_result = item.data.get('result', '')
assert camel_task is not None
old_task_content: str = camel_task.content
old_task_result: str = await get_task_result_with_optional_summary(camel_task, options)
old_task_content_clean: str = old_task_content
if "=== CURRENT TASK ===" in old_task_content_clean:
old_task_content_clean = old_task_content_clean.split("=== CURRENT TASK ===")[-1].strip()
task_lock.add_conversation('task_result', {
'task_content': old_task_content_clean,
'task_result': old_task_result,
'working_directory': get_working_directory(options, task_lock)
})
new_task_content = item.data.get('content', '')
if new_task_content:
import time
task_id = item.data.get('task_id', f"{int(time.time() * 1000)}-multi")
new_camel_task = Task(content=new_task_content, id=task_id)
if hasattr(camel_task, 'additional_info') and camel_task.additional_info:
new_camel_task.additional_info = camel_task.additional_info
camel_task = new_camel_task
# Now trigger end of previous task using stored result
yield sse_json("end", old_task_result)
# Always yield new_task_state first - this is not optional
yield sse_json("new_task_state", item.data)
# Trigger Queue Removal
yield sse_json("remove_task", {"task_id": item.data.get("task_id")})
# Then handle multi-turn processing
if workforce is not None and new_task_content:
task_lock.status = Status.confirming
workforce.pause()
try:
is_multi_turn_complex = await question_confirm(question_agent, new_task_content, task_lock)
if not is_multi_turn_complex:
simple_answer_prompt = f"{build_conversation_context(task_lock, header='=== Previous Conversation ===')}User Query: {new_task_content}\n\nProvide a direct, helpful answer to this simple question."
try:
simple_resp = question_agent.step(simple_answer_prompt)
answer_content = simple_resp.msgs[0].content if simple_resp and simple_resp.msgs else "I understand your question, but I'm having trouble generating a response right now."
task_lock.add_conversation('assistant', answer_content)
# Send response to user (don't send confirmed if simple response)
yield sse_json("wait_confirm", {"content": answer_content, "question": new_task_content})
except Exception as e:
logger.error(f"Error generating simple answer in multi-turn: {e}")
yield sse_json("wait_confirm", {"content": "I encountered an error while processing your question.", "question": new_task_content})
workforce.resume()
continue # This continues the main while loop, waiting for next action
yield sse_json("confirmed", {"question": new_task_content})
task_lock.status = Status.confirmed
context_for_multi_turn = build_context_for_workforce(task_lock, options)
new_sub_tasks = await workforce.handle_decompose_append_task(
camel_task,
reset=False,
coordinator_context=context_for_multi_turn
)
task_content_for_summary = new_task_content
if len(task_content_for_summary) > 100:
new_summary_content = f"Follow-up Task|{task_content_for_summary[:97]}..."
else:
new_summary_content = f"Follow-up Task|{task_content_for_summary}"
# Send the extracted events
yield to_sub_tasks(camel_task, new_summary_content)
# Update the context with new task data
sub_tasks = new_sub_tasks
summary_task_content = new_summary_content
except Exception as e:
import traceback
logger.error(f"[TRACE] Traceback: {traceback.format_exc()}")
# Continue with existing context if decomposition fails
yield sse_json("error", {"message": f"Failed to process task: {str(e)}"})
else:
if workforce is None:
logger.warning(f"[TRACE] Workforce is None - this might be the issue")
if not new_task_content:
logger.warning(f"[TRACE] No new task content provided")
elif item.action == Action.create_agent:
yield sse_json("create_agent", item.data)
elif item.action == Action.activate_agent:
@ -167,9 +640,15 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
elif item.action == Action.pause:
if workforce is not None:
workforce.pause()
logger.info(f"Workforce paused for project {options.project_id}")
else:
logger.warning(f"Cannot pause: workforce is None for project {options.project_id}")
elif item.action == Action.resume:
if workforce is not None:
workforce.resume()
logger.info(f"Workforce resumed for project {options.project_id}")
else:
logger.warning(f"Cannot resume: workforce is None for project {options.project_id}")
elif item.action == Action.new_agent:
if workforce is not None:
workforce.pause()
@ -180,21 +659,52 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
elif item.action == Action.end:
assert camel_task is not None
task_lock.status = Status.done
yield sse_json("end", str(camel_task.result))
final_result: str = await get_task_result_with_optional_summary(camel_task, options)
task_lock.last_task_result = final_result
task_content: str = camel_task.content
if "=== CURRENT TASK ===" in task_content:
task_content = task_content.split("=== CURRENT TASK ===")[-1].strip()
task_lock.add_conversation('task_result', {
'task_content': task_content,
'task_result': final_result,
'working_directory': get_working_directory(options, task_lock)
})
yield sse_json("end", final_result)
if workforce is not None:
workforce.stop_gracefully()
break
logger.info(f"Workforce stopped gracefully for project {options.project_id}")
workforce = None
else:
logger.warning(f"Workforce already None at end action for project {options.project_id}")
camel_task = None
if question_agent is not None:
question_agent.reset()
logger.info(f"Reset question_agent for project {options.project_id}")
elif item.action == Action.supplement:
assert camel_task is not None
task_lock.status = Status.processing
camel_task.add_subtask(
Task(
content=item.data.question,
id=f"{camel_task.id}.{len(camel_task.subtasks)}",
# Check if this might be a misrouted second question
if camel_task is None:
logger.warning(f"SUPPLEMENT action received but camel_task is None for project {options.project_id}")
else:
assert camel_task is not None
task_lock.status = Status.processing
camel_task.add_subtask(
Task(
content=item.data.question,
id=f"{camel_task.id}.{len(camel_task.subtasks)}",
)
)
)
task = asyncio.create_task(workforce.eigent_start(camel_task.subtasks))
task_lock.add_background_task(task)
if workforce is not None:
task = asyncio.create_task(workforce.eigent_start(camel_task.subtasks))
task_lock.add_background_task(task)
elif item.action == Action.budget_not_enough:
if workforce is not None:
workforce.pause()
@ -204,32 +714,43 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
if workforce._running:
workforce.stop()
workforce.stop_gracefully()
logger.info(f"Workforce stopped for project {options.project_id}")
else:
logger.warning(f"Workforce is None at stop action for project {options.project_id}")
await delete_task_lock(task_lock.id)
break
else:
logger.warning(f"Unknown action: {item.action}")
except ModelProcessingError as e:
if "Budget has been exceeded" in str(e):
logger.warning(f"Budget exceeded for task {options.task_id}, action: {item.action}")
# workforce decompose task don't use ListenAgent, this need return sse
if "workforce" in locals() and workforce is not None:
workforce.pause()
yield sse_json(Action.budget_not_enough, {"message": "budget not enouth"})
else:
logger.error(f"Error processing action {item.action}: {e}")
logger.error(f"ModelProcessingError for task {options.task_id}, action {item.action}: {e}", exc_info=True)
yield sse_json("error", {"message": str(e)})
if "workforce" in locals() and workforce is not None and workforce._running:
workforce.stop()
except Exception as e:
logger.error(f"Error processing action {item.action}: {e}")
logger.error(f"Unhandled exception for task {options.task_id}, action {item.action}: {e}", exc_info=True)
yield sse_json("error", {"message": str(e)})
# Continue processing other items instead of breaking
@traceroot.trace()
async def install_mcp(
mcp: ListenChatAgent,
install_mcp: ActionInstallMcpData,
):
mcp.add_tools(await get_mcp_tools(install_mcp.data))
logger.info(f"Installing MCP tools: {list(install_mcp.data.get('mcpServers', {}).keys())}")
try:
mcp.add_tools(await get_mcp_tools(install_mcp.data))
logger.info("MCP tools installed successfully")
except Exception as e:
logger.error(f"Error installing MCP tools: {e}", exc_info=True)
raise
def to_sub_tasks(task: Task, summary_task_content: str):
@ -287,30 +808,53 @@ def add_sub_tasks(camel_task: Task, update_tasks: list[TaskContent]):
)
async def question_confirm(agent: ListenChatAgent, prompt: str) -> str | Literal[True]:
prompt = f"""
> **Your Role:** You are a highly capable agent. Your primary function is to analyze a user's request and determine the appropriate course of action.
>
> **Your Process:**
>
> 1. **Analyze the User's Query:** Carefully examine the user's request: `{prompt}`.
>
> 2. **Categorize the Query:**
> * **Simple Query:** Is this a simple greeting, a question that can be answered directly, or a conversational interaction (e.g., "hello", "thank you")?
> * **Complex Task:** Is this a request that requires a series of steps, code execution, or interaction with tools to complete?
>
> 3. **Execute Your Decision:**
> * **For a Simple Query:** Provide a direct and helpful response.
> * **For a Complex Task:** Your *only* response should be "yes". This will trigger a specialized workforce to handle the task. Do not include any other text, punctuation, or pleasantries.
"""
resp = agent.step(prompt)
logger.info(f"resp: {agent.chat_history}")
if resp.msgs[0].content.lower() != "yes":
return sse_json("wait_confirm", {"content": resp.msgs[0].content})
else:
async def question_confirm(agent: ListenChatAgent, prompt: str, task_lock: TaskLock | None = None) -> bool:
"""Simple question confirmation - returns True for complex tasks, False for simple questions."""
context_prompt = ""
if task_lock:
context_prompt = build_conversation_context(task_lock, header="=== Previous Conversation ===")
full_prompt = f"""{context_prompt}User Query: {prompt}
Determine if this user query is a complex task or a simple question.
**Complex task** (answer "yes"): Requires tools, code execution, file operations, multi-step planning, or creating/modifying content
- Examples: "create a file", "search for X", "implement feature Y", "write code", "analyze data", "build something"
**Simple question** (answer "no"): Can be answered directly with knowledge or conversation history, no action needed
- Examples: greetings ("hello", "hi"), fact queries ("what is X?"), clarifications ("what did you mean?"), status checks ("how are you?")
Answer only "yes" or "no". Do not provide any explanation.
Is this a complex task? (yes/no):"""
try:
resp = agent.step(full_prompt)
if not resp or not resp.msgs or len(resp.msgs) == 0:
logger.warning("No response from agent, defaulting to complex task")
return True
content = resp.msgs[0].content
if not content:
logger.warning("Empty content from agent, defaulting to complex task")
return True
normalized = content.strip().lower()
is_complex = "yes" in normalized
logger.info(f"Question confirm result: {'complex task' if is_complex else 'simple question'}",
extra={"response": content, "is_complex": is_complex})
return is_complex
except Exception as e:
logger.error(f"Error in question_confirm: {e}")
return True
@traceroot.trace()
async def summary_task(agent: ListenChatAgent, task: Task) -> str:
prompt = f"""The user's task is:
---
@ -324,13 +868,100 @@ Your instructions are:
Example format: "Task Name|This is the summary of the task."
Do not include any other text or formatting.
"""
logger.debug("Generating task summary", extra={"task_id": task.id})
try:
res = agent.step(prompt)
summary = res.msgs[0].content
logger.info("Task summary generated", extra={"summary": summary})
return summary
except Exception as e:
logger.error("Error generating task summary", extra={"error": str(e)}, exc_info=True)
raise
async def summary_subtasks_result(agent: ListenChatAgent, task: Task) -> str:
"""
Summarize the aggregated results from all subtasks into a concise summary.
Args:
agent: The summary agent to use
task: The main task containing subtasks and their aggregated results
Returns:
A concise summary of all subtask results
"""
subtasks_info = ""
for i, subtask in enumerate(task.subtasks, 1):
subtasks_info += f"\n**Subtask {i}**\n"
subtasks_info += f"Description: {subtask.content}\n"
subtasks_info += f"Result: {subtask.result or 'No result'}\n"
subtasks_info += "---\n"
prompt = f"""You are a professional summarizer. Summarize the results of the following subtasks.
Main Task: {task.content}
Subtasks (with descriptions and results):
---
{subtasks_info}
---
Instructions:
1. Provide a concise summary of what was accomplished
2. Highlight key findings or outputs from each subtask
3. Mention any important files created or actions taken
4. Use bullet points or sections for clarity
5. DO NOT repeat the task name in your summary - go straight to the results
6. Keep it professional but conversational
Summary:
"""
res = agent.step(prompt)
logger.info(f"summary_task: {res.msgs[0].content}")
return res.msgs[0].content
summary = res.msgs[0].content
logger.info(f"Generated subtasks summary for task {task.id} with {len(task.subtasks)} subtasks")
return summary
async def get_task_result_with_optional_summary(task: Task, options: Chat) -> str:
"""
Get the task result, with LLM summary if there are multiple subtasks.
Args:
task: The task to get result from
options: Chat options for creating summary agent
Returns:
The task result (summarized if multiple subtasks, raw otherwise)
"""
result = str(task.result or "")
if task.subtasks and len(task.subtasks) > 1:
logger.info(f"Task {task.id} has {len(task.subtasks)} subtasks, generating summary")
try:
summary_agent = task_summary_agent(options)
summarized_result = await summary_subtasks_result(summary_agent, task)
result = summarized_result
logger.info(f"Successfully generated summary for task {task.id}")
except Exception as e:
logger.error(f"Failed to generate summary for task {task.id}: {e}")
elif task.subtasks and len(task.subtasks) == 1:
logger.info(f"Task {task.id} has only 1 subtask, skipping LLM summary")
if result and "--- Subtask" in result and "Result ---" in result:
parts = result.split("Result ---", 1)
if len(parts) > 1:
result = parts[1].strip()
return result
@traceroot.trace()
async def construct_workforce(options: Chat) -> tuple[Workforce, ListenChatAgent]:
working_directory = options.file_save_path()
logger.info("Constructing workforce", extra={"project_id": options.project_id, "task_id": options.task_id})
working_directory = get_working_directory(options)
logger.debug("Working directory set", extra={"working_directory": working_directory})
[coordinator_agent, task_agent] = [
agent_model(
key,
@ -339,8 +970,8 @@ async def construct_workforce(options: Chat) -> tuple[Workforce, ListenChatAgent
[
*(
ToolkitMessageIntegration(
message_handler=HumanToolkit(options.task_id, key).send_message_to_user
).register_toolkits(NoteTakingToolkit(options.task_id, working_directory=working_directory))
message_handler=HumanToolkit(options.project_id, key).send_message_to_user
).register_toolkits(NoteTakingToolkit(options.project_id, working_directory=working_directory))
).get_tools()
],
)
@ -373,11 +1004,11 @@ The current date is {datetime.date.today()}. For any date-related tasks, you MUS
""",
options,
[
*HumanToolkit.get_can_use_tools(options.task_id, Agents.new_worker_agent),
*HumanToolkit.get_can_use_tools(options.project_id, Agents.new_worker_agent),
*(
ToolkitMessageIntegration(
message_handler=HumanToolkit(options.task_id, Agents.new_worker_agent).send_message_to_user
).register_toolkits(NoteTakingToolkit(options.task_id, working_directory=working_directory))
message_handler=HumanToolkit(options.project_id, Agents.new_worker_agent).send_message_to_user
).register_toolkits(NoteTakingToolkit(options.project_id, working_directory=working_directory))
).get_tools(),
],
)
@ -402,7 +1033,7 @@ The current date is {datetime.date.today()}. For any date-related tasks, you MUS
model_platform_enum = None
workforce = Workforce(
options.task_id,
options.project_id,
"A workforce",
graceful_shutdown_timeout=3, # 30 seconds for debugging
share_memory=False,
@ -481,10 +1112,13 @@ def format_agent_description(agent_data: NewAgent | ActionNewAgent) -> str:
return " ".join(description_parts)
@traceroot.trace()
async def new_agent_model(data: NewAgent | ActionNewAgent, options: Chat):
working_directory = options.file_save_path()
logger.info("Creating new agent", extra={"agent_name": data.name, "project_id": options.project_id, "task_id": options.task_id})
logger.debug("New agent data", extra={"agent_data": data.model_dump_json()})
working_directory = get_working_directory(options)
tool_names = []
tools = [*await get_toolkits(data.tools, data.name, options.task_id)]
tools = [*await get_toolkits(data.tools, data.name, options.project_id)]
for item in data.tools:
tool_names.append(titleize(item))
if data.mcp_tools is not None:
@ -492,7 +1126,8 @@ async def new_agent_model(data: NewAgent | ActionNewAgent, options: Chat):
for item in data.mcp_tools["mcpServers"].keys():
tool_names.append(titleize(item))
for item in tools:
logger.debug(f"new agent function tool ====== {item.func.__name__}")
logger.debug(f"Agent {data.name} tool: {item.func.__name__}")
logger.info(f"Agent {data.name} created with {len(tools)} tools: {tool_names}")
# Enhanced system message with platform information
enhanced_description = f"""{data.description}
- You are now working in system {platform.system()} with architecture

View file

@ -1,4 +1,5 @@
from typing_extensions import Any, Literal, TypedDict
from typing import List, Dict, Optional
from pydantic import BaseModel
from app.exception.exception import ProgramException
from app.model.chat import McpServers, Status, SupplementChat, Chat, UpdateData
@ -9,13 +10,16 @@ from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta
import weakref
from loguru import logger
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("task_service")
class Action(str, Enum):
improve = "improve" # user -> backend
update_task = "update_task" # user -> backend
task_state = "task_state" # backend -> user
new_task_state = "new_task_state" # backend -> user
start = "start" # user -> backend
create_agent = "create_agent" # backend -> user
activate_agent = "activate_agent" # backend -> user
@ -36,6 +40,9 @@ class Action(str, Enum):
resume = "resume" # user -> backend user take control
new_agent = "new_agent" # user -> backend
budget_not_enough = "budget_not_enough" # backend -> user
add_task = "add_task" # user -> backend
remove_task = "remove_task" # user -> backend
skip_task = "skip_task" # user -> backend
class ActionImproveData(BaseModel):
@ -56,6 +63,10 @@ class ActionTaskStateData(BaseModel):
action: Literal[Action.task_state] = Action.task_state
data: dict[Literal["task_id", "content", "state", "result", "failure_count"], str | int]
class ActionNewTaskStateData(BaseModel):
action: Literal[Action.new_task_state] = Action.new_task_state
data: dict[Literal["task_id", "content", "state", "result", "failure_count"], str | int]
class ActionAskData(BaseModel):
action: Literal[Action.ask] = Action.ask
@ -169,6 +180,26 @@ class ActionBudgetNotEnough(BaseModel):
action: Literal[Action.budget_not_enough] = Action.budget_not_enough
class ActionAddTaskData(BaseModel):
action: Literal[Action.add_task] = Action.add_task
content: str
project_id: str | None = None
task_id: str | None = None
additional_info: dict | None = None
insert_position: int = -1
class ActionRemoveTaskData(BaseModel):
action: Literal[Action.remove_task] = Action.remove_task
task_id: str
project_id: str
class ActionSkipTaskData(BaseModel):
action: Literal[Action.skip_task] = Action.skip_task
project_id: str
ActionData = (
ActionImproveData
| ActionStartData
@ -192,6 +223,9 @@ ActionData = (
| ActionTakeControl
| ActionNewAgent
| ActionBudgetNotEnough
| ActionAddTaskData
| ActionRemoveTaskData
| ActionSkipTaskData
)
@ -221,6 +255,16 @@ class TaskLock:
background_tasks: set[asyncio.Task]
"""Track all background tasks for cleanup"""
# Context management fields
conversation_history: List[Dict[str, Any]]
"""Store conversation history for context"""
last_task_result: str
"""Store the last task execution result"""
question_agent: Optional[Any]
"""Persistent question confirmation agent"""
summary_generated: bool
"""Track if summary has been generated for this project"""
def __init__(self, id: str, queue: asyncio.Queue, human_input: dict) -> None:
self.id = id
self.queue = queue
@ -229,6 +273,12 @@ class TaskLock:
self.last_accessed = datetime.now()
self.background_tasks = set()
# Initialize context management fields
self.conversation_history = []
self.last_task_result = ""
self.last_task_summary = ""
self.question_agent = None
async def put_queue(self, data: ActionData):
self.last_accessed = datetime.now()
await self.queue.put(data)
@ -262,6 +312,25 @@ class TaskLock:
pass
self.background_tasks.clear()
def add_conversation(self, role: str, content: str | dict):
"""Add a conversation entry to history"""
self.conversation_history.append({
'role': role,
'content': content,
'timestamp': datetime.now().isoformat()
})
def get_recent_context(self, max_entries: int = None) -> str:
"""Get recent conversation context as a formatted string"""
if not self.conversation_history:
return ""
context = "=== Recent Conversation ===\n"
history_to_use = self.conversation_history if max_entries is None else self.conversation_history[-max_entries:]
for entry in history_to_use:
context += f"{entry['role']}: {entry['content']}\n"
return context
task_locks = dict[str, TaskLock]()
# Cleanup task for removing stale task locks
@ -275,6 +344,11 @@ def get_task_lock(id: str) -> TaskLock:
return task_locks[id]
def get_task_lock_if_exists(id: str) -> TaskLock | None:
"""Get task lock if it exists, otherwise return None"""
return task_locks.get(id)
def create_task_lock(id: str) -> TaskLock:
if id in task_locks:
raise ProgramException("Task already exists")
@ -288,6 +362,13 @@ def create_task_lock(id: str) -> TaskLock:
return task_locks[id]
def get_or_create_task_lock(id: str) -> TaskLock:
"""Get existing task lock or create a new one if it doesn't exist"""
if id in task_locks:
return task_locks[id]
return create_task_lock(id)
async def delete_task_lock(id: str):
if id not in task_locks:
raise ProgramException("Task not found")

View file

View file

@ -6,7 +6,7 @@ from threading import Event
import traceback
from typing import Any, Callable, Dict, List, Tuple
import uuid
from app.utils import traceroot_wrapper as traceroot
from utils import traceroot_wrapper as traceroot
from camel.agents import ChatAgent
from camel.agents.chat_agent import StreamingChatAgentResponse, AsyncStreamingChatAgentResponse
from camel.agents._types import ToolCallRequest
@ -18,6 +18,7 @@ from camel.terminators import ResponseTerminator
from camel.toolkits import FunctionTool, RegisteredAgentToolkit
from camel.types.agents import ToolCallingRecord
from app.component.environment import env
from app.utils.file_utils import get_working_directory
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from app.utils.toolkit.hybrid_browser_toolkit import HybridBrowserToolkit
from app.utils.toolkit.excel_toolkit import ExcelToolkit
@ -50,7 +51,6 @@ from camel.types import ModelPlatformType, ModelType
from camel.toolkits import MCPToolkit, ToolkitMessageIntegration
import datetime
from pydantic import BaseModel
from loguru import logger
from app.model.chat import Chat, McpServers
# Create traceroot logger for agent tracking
@ -68,6 +68,8 @@ from app.service.task import (
)
from app.service.task import set_process_task
NOW_STR = datetime.datetime.now().strftime("%Y-%m-%d %H:00:00")
class ListenChatAgent(ChatAgent):
@traceroot.trace()
@ -171,7 +173,6 @@ class ListenChatAgent(ChatAgent):
except Exception as e:
res = None
error_info = e
logger.exception(e)
traceroot_logger.error(f"Agent {self.agent_name} unexpected error in step: {e}", exc_info=True)
message = f"Error processing message: {e!s}"
total_tokens = 0
@ -246,8 +247,7 @@ class ListenChatAgent(ChatAgent):
except Exception as e:
res = None
error_info = e
logger.exception(e)
traceroot_logger.error(f"Agent {self.agent_name} unexpected error in step: {e}", exc_info=True)
traceroot_logger.error(f"Agent {self.agent_name} unexpected error in async step: {e}", exc_info=True)
message = f"Error processing message: {e!s}"
total_tokens = 0
@ -323,6 +323,17 @@ class ListenChatAgent(ChatAgent):
else:
result = raw_result
mask_flag = False
# Prepare result message with truncation
if isinstance(result, str):
result_msg = result
else:
result_str = repr(result)
MAX_RESULT_LENGTH = 500
if len(result_str) > MAX_RESULT_LENGTH:
result_msg = result_str[:MAX_RESULT_LENGTH] + f"... (truncated, total length: {len(result_str)} chars)"
else:
result_msg = result_str
asyncio.create_task(
task_lock.put_queue(
ActionDeactivateToolkitData(
@ -331,7 +342,7 @@ class ListenChatAgent(ChatAgent):
"process_task_id": self.process_task_id,
"toolkit_name": toolkit_name,
"method_name": func_name,
"message": result if isinstance(result, str) else repr(result),
"message": result_msg,
},
)
)
@ -341,9 +352,7 @@ class ListenChatAgent(ChatAgent):
error_msg = f"Error executing tool '{func_name}': {e!s}"
result = f"Tool execution failed: {error_msg}"
mask_flag = False
logger.debug(error_msg)
traceroot_logger.error(f"Tool execution failed for {func_name}: {e}")
traceback.print_exc()
traceroot_logger.error(f"Tool execution failed for {func_name}: {e}", exc_info=True)
return self._record_tool_calling(func_name, args, result, tool_call_id, mask_output=mask_flag)
@ -403,9 +412,18 @@ class ListenChatAgent(ChatAgent):
# Capture the error message to prevent framework crash
error_msg = f"Error executing async tool '{func_name}': {e!s}"
result = {"error": error_msg}
logger.warning(error_msg)
traceroot_logger.error(f"Async tool execution failed for {func_name}: {e}")
traceback.print_exc()
traceroot_logger.error(f"Async tool execution failed for {func_name}: {e}", exc_info=True)
# Prepare result message with truncation
if isinstance(result, str):
result_msg = result
else:
result_str = repr(result)
MAX_RESULT_LENGTH = 500
if len(result_str) > MAX_RESULT_LENGTH:
result_msg = result_str[:MAX_RESULT_LENGTH] + f"... (truncated, total length: {len(result_str)} chars)"
else:
result_msg = result_str
await task_lock.put_queue(
ActionDeactivateToolkitData(
@ -414,7 +432,7 @@ class ListenChatAgent(ChatAgent):
"process_task_id": self.process_task_id,
"toolkit_name": toolkit_name,
"method_name": func_name,
"message": result if isinstance(result, str) else repr(result),
"message": result_msg,
},
)
)
@ -427,7 +445,7 @@ class ListenChatAgent(ChatAgent):
# Clone tools and collect toolkits that need registration
cloned_tools, toolkits_to_register = self._clone_tools()
new_agent = ListenChatAgent(
api_task_id=self.api_task_id,
agent_name=self.agent_name,
@ -443,7 +461,6 @@ class ListenChatAgent(ChatAgent):
response_terminators=self.response_terminators,
scheduling_strategy=self.model_backend.scheduling_strategy.__name__,
max_iteration=self.max_iteration,
agent_id=self.agent_id,
stop_event=self.stop_event,
tool_execution_timeout=self.tool_execution_timeout,
mask_tool_output=self.mask_tool_output,
@ -474,9 +491,9 @@ def agent_model(
tool_names: list[str] | None = None,
toolkits_to_register_agent: list[RegisteredAgentToolkit] | None = None,
):
task_lock = get_task_lock(options.task_id)
task_lock = get_task_lock(options.project_id)
agent_id = str(uuid.uuid4())
traceroot_logger.info(f"Creating agent: {agent_name} with id: {agent_id} for task: {options.task_id}")
traceroot_logger.info(f"Creating agent: {agent_name} with id: {agent_id} for project: {options.project_id}")
asyncio.create_task(
task_lock.put_queue(
ActionCreateAgentData(data={"agent_name": agent_name, "agent_id": agent_id, "tools": tool_names or []})
@ -484,7 +501,7 @@ def agent_model(
)
return ListenChatAgent(
options.task_id,
options.project_id,
agent_name,
system_message,
model=ModelFactory.create(
@ -493,7 +510,7 @@ def agent_model(
api_key=options.api_key,
url=options.api_url,
model_config_dict={
"user": str(options.task_id),
"user": str(options.project_id),
}
if options.is_cloud()
else None,
@ -515,7 +532,7 @@ def agent_model(
def question_confirm_agent(options: Chat):
return agent_model(
"question_confirm_agent",
f"You are a highly capable agent. Your primary function is to analyze a user's request and determine the appropriate course of action. The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.",
f"You are a highly capable agent. Your primary function is to analyze a user's request and determine the appropriate course of action. The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.",
options,
)
@ -531,24 +548,24 @@ def task_summary_agent(options: Chat):
@traceroot.trace()
async def developer_agent(options: Chat):
working_directory = options.file_save_path()
traceroot_logger.info(f"Creating developer agent for task: {options.task_id} in directory: {working_directory}")
working_directory = get_working_directory(options)
traceroot_logger.info(f"Creating developer agent for project: {options.project_id} in directory: {working_directory}")
message_integration = ToolkitMessageIntegration(
message_handler=HumanToolkit(options.task_id, Agents.developer_agent).send_message_to_user
message_handler=HumanToolkit(options.project_id, Agents.developer_agent).send_message_to_user
)
note_toolkit = NoteTakingToolkit(
api_task_id=options.task_id, agent_name=Agents.developer_agent, working_directory=working_directory
api_task_id=options.project_id, agent_name=Agents.developer_agent, working_directory=working_directory
)
note_toolkit = message_integration.register_toolkits(note_toolkit)
web_deploy_toolkit = WebDeployToolkit(api_task_id=options.task_id)
web_deploy_toolkit = WebDeployToolkit(api_task_id=options.project_id)
web_deploy_toolkit = message_integration.register_toolkits(web_deploy_toolkit)
screenshot_toolkit = ScreenshotToolkit(options.task_id, working_directory=working_directory)
screenshot_toolkit = ScreenshotToolkit(options.project_id, working_directory=working_directory)
screenshot_toolkit = message_integration.register_toolkits(screenshot_toolkit)
terminal_toolkit = TerminalToolkit(options.task_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
terminal_toolkit = TerminalToolkit(options.project_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
tools = [
*HumanToolkit.get_can_use_tools(options.task_id, Agents.developer_agent),
*HumanToolkit.get_can_use_tools(options.project_id, Agents.developer_agent),
*note_toolkit.get_tools(),
*web_deploy_toolkit.get_tools(),
*terminal_toolkit.get_tools(),
@ -577,7 +594,7 @@ and generation.
- **System**: {platform.system()} ({platform.machine()})
- **Working Directory**: `{working_directory}`. All local file operations must
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
</operating_environment>
<mandatory_instructions>
@ -702,14 +719,14 @@ these tips to maximize your effectiveness:
@traceroot.trace()
def search_agent(options: Chat):
working_directory = options.file_save_path()
traceroot_logger.info(f"Creating search agent for task: {options.task_id} in directory: {working_directory}")
working_directory = get_working_directory(options)
traceroot_logger.info(f"Creating search agent for project: {options.project_id} in directory: {working_directory}")
message_integration = ToolkitMessageIntegration(
message_handler=HumanToolkit(options.task_id, Agents.search_agent).send_message_to_user
message_handler=HumanToolkit(options.project_id, Agents.search_agent).send_message_to_user
)
web_toolkit_custom = HybridBrowserToolkit(
options.task_id,
options.project_id,
headless=False,
browser_log_to_file=True,
stealth=True,
@ -729,12 +746,14 @@ def search_agent(options: Chat):
],
)
# Save reference before registering for toolkits_to_register_agent
web_toolkit_for_agent_registration = web_toolkit_custom
web_toolkit_custom = message_integration.register_toolkits(web_toolkit_custom)
terminal_toolkit = TerminalToolkit(options.task_id, Agents.search_agent, safe_mode=True, clone_current_env=False)
terminal_toolkit = TerminalToolkit(options.project_id, Agents.search_agent, safe_mode=True, clone_current_env=False)
terminal_toolkit = message_integration.register_functions([terminal_toolkit.shell_exec])
note_toolkit = NoteTakingToolkit(options.task_id, Agents.search_agent, working_directory=working_directory)
note_toolkit = NoteTakingToolkit(options.project_id, Agents.search_agent, working_directory=working_directory)
note_toolkit = message_integration.register_toolkits(note_toolkit)
search_tools = SearchToolkit.get_can_use_tools(options.task_id)
search_tools = SearchToolkit.get_can_use_tools(options.project_id)
# Only register search tools if any are available
if search_tools:
search_tools = message_integration.register_functions(search_tools)
@ -742,7 +761,7 @@ def search_agent(options: Chat):
search_tools = []
tools = [
*HumanToolkit.get_can_use_tools(options.task_id, Agents.search_agent),
*HumanToolkit.get_can_use_tools(options.project_id, Agents.search_agent),
*web_toolkit_custom.get_tools(),
*terminal_toolkit,
*note_toolkit.get_tools(),
@ -772,7 +791,7 @@ comprehensive and well-documented information.
- **System**: {platform.system()} ({platform.machine()})
- **Working Directory**: `{working_directory}`. All local file operations must
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
</operating_environment>
<mandatory_instructions>
@ -793,7 +812,7 @@ The current date is {datetime.date.today()}. For any date-related tasks, you MUS
- **CRITICAL URL POLICY**: You are STRICTLY FORBIDDEN from inventing,
guessing, or constructing URLs yourself. You MUST only use URLs from
trusted sources:
1. URLs returned by search tools (like `search_google` or `search_exa`)
1. URLs returned by search tools (`search_google`)
2. URLs found on webpages you have visited through browser tools
3. URLs provided by the user in their request
Fabricating or guessing URLs is considered a critical error and must
@ -839,8 +858,6 @@ Your approach depends on available search tools:
sites using `browser_type` and submit with `browser_enter`
- **Extract URLs from results**: Only use URLs that appear in the search
results on these websites
- **Alternative Search**: If available, use `search_exa` for additional
results
**Common Browser Operations (both scenarios):**
- **Navigation and Exploration**: Use `browser_visit_page` to open URLs.
@ -877,41 +894,42 @@ Your approach depends on available search tools:
NoteTakingToolkit.toolkit_name(),
TerminalToolkit.toolkit_name(),
],
toolkits_to_register_agent=[web_toolkit_for_agent_registration],
)
@traceroot.trace()
async def document_agent(options: Chat):
working_directory = options.file_save_path()
traceroot_logger.info(f"Creating document agent for task: {options.task_id} in directory: {working_directory}")
working_directory = get_working_directory(options)
traceroot_logger.info(f"Creating document agent for project: {options.project_id} in directory: {working_directory}")
message_integration = ToolkitMessageIntegration(
message_handler=HumanToolkit(options.task_id, Agents.task_agent).send_message_to_user
message_handler=HumanToolkit(options.project_id, Agents.task_agent).send_message_to_user
)
file_write_toolkit = FileToolkit(options.task_id, working_directory=working_directory)
pptx_toolkit = PPTXToolkit(options.task_id, working_directory=working_directory)
file_write_toolkit = FileToolkit(options.project_id, working_directory=working_directory)
pptx_toolkit = PPTXToolkit(options.project_id, working_directory=working_directory)
pptx_toolkit = message_integration.register_toolkits(pptx_toolkit)
mark_it_down_toolkit = MarkItDownToolkit(options.task_id)
mark_it_down_toolkit = MarkItDownToolkit(options.project_id)
mark_it_down_toolkit = message_integration.register_toolkits(mark_it_down_toolkit)
excel_toolkit = ExcelToolkit(options.task_id, working_directory=working_directory)
excel_toolkit = ExcelToolkit(options.project_id, working_directory=working_directory)
excel_toolkit = message_integration.register_toolkits(excel_toolkit)
note_toolkit = NoteTakingToolkit(options.task_id, Agents.document_agent, working_directory=working_directory)
note_toolkit = NoteTakingToolkit(options.project_id, Agents.document_agent, working_directory=working_directory)
note_toolkit = message_integration.register_toolkits(note_toolkit)
terminal_toolkit = TerminalToolkit(options.task_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
terminal_toolkit = TerminalToolkit(options.project_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
tools = [
*file_write_toolkit.get_tools(),
*pptx_toolkit.get_tools(),
*HumanToolkit.get_can_use_tools(options.task_id, Agents.document_agent),
*HumanToolkit.get_can_use_tools(options.project_id, Agents.document_agent),
*mark_it_down_toolkit.get_tools(),
*excel_toolkit.get_tools(),
*note_toolkit.get_tools(),
*terminal_toolkit.get_tools(),
*await GoogleDriveMCPToolkit.get_can_use_tools(options.task_id, options.get_bun_env()),
*await GoogleDriveMCPToolkit.get_can_use_tools(options.project_id, options.get_bun_env()),
]
if env("EXA_API_KEY") or options.is_cloud():
search_toolkit = SearchToolkit(options.task_id, Agents.document_agent).search_exa
search_toolkit = message_integration.register_functions([search_toolkit])
tools.extend(search_toolkit)
# if env("EXA_API_KEY") or options.is_cloud():
# search_toolkit = SearchToolkit(options.project_id, Agents.document_agent).search_exa
# search_toolkit = message_integration.register_functions([search_toolkit])
# tools.extend(search_toolkit)
system_message = f"""
<role>
You are a Documentation Specialist, responsible for creating, modifying, and
@ -935,7 +953,7 @@ to be embedded in your work.
- **System**: {platform.system()} ({platform.machine()})
- **Working Directory**: `{working_directory}`. All local file operations must
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
</operating_environment>
<mandatory_instructions>
@ -1083,32 +1101,32 @@ supported formats including advanced spreadsheet functionality.
@traceroot.trace()
def multi_modal_agent(options: Chat):
working_directory = options.file_save_path()
traceroot_logger.info(f"Creating multi-modal agent for task: {options.task_id} in directory: {working_directory}")
working_directory = get_working_directory(options)
traceroot_logger.info(f"Creating multi-modal agent for project: {options.project_id} in directory: {working_directory}")
message_integration = ToolkitMessageIntegration(
message_handler=HumanToolkit(options.task_id, Agents.multi_modal_agent).send_message_to_user
message_handler=HumanToolkit(options.project_id, Agents.multi_modal_agent).send_message_to_user
)
video_download_toolkit = VideoDownloaderToolkit(options.task_id, working_directory=working_directory)
video_download_toolkit = VideoDownloaderToolkit(options.project_id, working_directory=working_directory)
video_download_toolkit = message_integration.register_toolkits(video_download_toolkit)
image_analysis_toolkit = ImageAnalysisToolkit(options.task_id)
image_analysis_toolkit = ImageAnalysisToolkit(options.project_id)
image_analysis_toolkit = message_integration.register_toolkits(image_analysis_toolkit)
terminal_toolkit = TerminalToolkit(
options.task_id, agent_name=Agents.multi_modal_agent, safe_mode=True, clone_current_env=False
options.project_id, agent_name=Agents.multi_modal_agent, safe_mode=True, clone_current_env=False
)
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
note_toolkit = NoteTakingToolkit(options.task_id, Agents.multi_modal_agent, working_directory=working_directory)
note_toolkit = NoteTakingToolkit(options.project_id, Agents.multi_modal_agent, working_directory=working_directory)
note_toolkit = message_integration.register_toolkits(note_toolkit)
tools = [
*video_download_toolkit.get_tools(),
*image_analysis_toolkit.get_tools(),
*HumanToolkit.get_can_use_tools(options.task_id, Agents.multi_modal_agent),
*HumanToolkit.get_can_use_tools(options.project_id, Agents.multi_modal_agent),
*terminal_toolkit.get_tools(),
*note_toolkit.get_tools(),
]
if options.is_cloud():
open_ai_image_toolkit = OpenAIImageToolkit( # todo check llm has this model
options.task_id,
options.project_id,
model="dall-e-3",
response_format="b64_json",
size="1024x1024",
@ -1130,7 +1148,7 @@ def multi_modal_agent(options: Chat):
if model_platform_enum == ModelPlatformType.OPENAI:
audio_analysis_toolkit = AudioAnalysisToolkit(
options.task_id,
options.project_id,
working_directory,
OpenAIAudioModels(
api_key=options.api_key,
@ -1140,10 +1158,10 @@ def multi_modal_agent(options: Chat):
audio_analysis_toolkit = message_integration.register_toolkits(audio_analysis_toolkit)
tools.extend(audio_analysis_toolkit.get_tools())
if env("EXA_API_KEY") or options.is_cloud():
search_toolkit = SearchToolkit(options.task_id, Agents.multi_modal_agent).search_exa
search_toolkit = message_integration.register_functions([search_toolkit])
tools.extend(search_toolkit)
# if env("EXA_API_KEY") or options.is_cloud():
# search_toolkit = SearchToolkit(options.project_id, Agents.multi_modal_agent).search_exa
# search_toolkit = message_integration.register_functions([search_toolkit])
# tools.extend(search_toolkit)
system_message = f"""
<role>
@ -1167,7 +1185,7 @@ presentations, and other documents.
- **System**: {platform.system()} ({platform.machine()})
- **Working Directory**: `{working_directory}`. All local file operations must
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
</operating_environment>
<mandatory_instructions>
@ -1253,27 +1271,27 @@ async def social_medium_agent(options: Chat):
Agent to handling tasks related to social media:
include toolkits: WhatsApp, Twitter, LinkedIn, Reddit, Notion, Slack, Discord and Google Suite.
"""
working_directory = options.file_save_path()
traceroot_logger.info(f"Creating social medium agent for task: {options.task_id} in directory: {working_directory}")
working_directory = get_working_directory(options)
traceroot_logger.info(f"Creating social medium agent for project: {options.project_id} in directory: {working_directory}")
tools = [
*WhatsAppToolkit.get_can_use_tools(options.task_id),
*TwitterToolkit.get_can_use_tools(options.task_id),
*LinkedInToolkit.get_can_use_tools(options.task_id),
*RedditToolkit.get_can_use_tools(options.task_id),
*await NotionMCPToolkit.get_can_use_tools(options.task_id),
# *SlackToolkit.get_can_use_tools(options.task_id),
*await GoogleGmailMCPToolkit.get_can_use_tools(options.task_id, options.get_bun_env()),
*GoogleCalendarToolkit.get_can_use_tools(options.task_id),
*HumanToolkit.get_can_use_tools(options.task_id, Agents.social_medium_agent),
*TerminalToolkit(options.task_id, agent_name=Agents.social_medium_agent, clone_current_env=False).get_tools(),
*WhatsAppToolkit.get_can_use_tools(options.project_id),
*TwitterToolkit.get_can_use_tools(options.project_id),
*LinkedInToolkit.get_can_use_tools(options.project_id),
*RedditToolkit.get_can_use_tools(options.project_id),
*await NotionMCPToolkit.get_can_use_tools(options.project_id),
# *SlackToolkit.get_can_use_tools(options.project_id),
*await GoogleGmailMCPToolkit.get_can_use_tools(options.project_id, options.get_bun_env()),
*GoogleCalendarToolkit.get_can_use_tools(options.project_id),
*HumanToolkit.get_can_use_tools(options.project_id, Agents.social_medium_agent),
*TerminalToolkit(options.project_id, agent_name=Agents.social_medium_agent, clone_current_env=False).get_tools(),
*NoteTakingToolkit(
options.task_id, Agents.social_medium_agent, working_directory=working_directory
options.project_id, Agents.social_medium_agent, working_directory=working_directory
).get_tools(),
# *DiscordToolkit(options.task_id).get_tools(), # Not supported temporarily
# *GoogleSuiteToolkit(options.task_id).get_tools(), # Not supported temporarily
# *DiscordToolkit(options.project_id).get_tools(), # Not supported temporarily
# *GoogleSuiteToolkit(options.project_id).get_tools(), # Not supported temporarily
]
if env("EXA_API_KEY") or options.is_cloud():
tools.append(FunctionTool(SearchToolkit(options.task_id, Agents.social_medium_agent).search_exa))
# if env("EXA_API_KEY") or options.is_cloud():
# tools.append(FunctionTool(SearchToolkit(options.project_id, Agents.social_medium_agent).search_exa))
return agent_model(
Agents.social_medium_agent,
BaseMessage.make_assistant_message(
@ -1290,7 +1308,7 @@ use plain text formatting instead.
- **Working Directory**: `{working_directory}`. All local file operations must
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
Your integrated toolkits enable you to:
@ -1369,16 +1387,16 @@ operations.
@traceroot.trace()
async def mcp_agent(options: Chat):
traceroot_logger.info(
f"Creating MCP agent for task: {options.task_id} with {len(options.installed_mcp['mcpServers'])} MCP servers"
f"Creating MCP agent for project: {options.project_id} with {len(options.installed_mcp['mcpServers'])} MCP servers"
)
tools = [
# *HumanToolkit.get_can_use_tools(options.task_id, Agents.mcp_agent),
*McpSearchToolkit(options.task_id).get_tools(),
# *HumanToolkit.get_can_use_tools(options.project_id, Agents.mcp_agent),
*McpSearchToolkit(options.project_id).get_tools(),
]
if len(options.installed_mcp["mcpServers"]) > 0:
try:
mcp_tools = await get_mcp_tools(options.installed_mcp)
traceroot_logger.info(f"Retrieved {len(mcp_tools)} MCP tools for task {options.task_id}")
traceroot_logger.info(f"Retrieved {len(mcp_tools)} MCP tools for task {options.project_id}")
if mcp_tools:
tool_names = [tool.get_function_name() if hasattr(tool, 'get_function_name') else str(tool) for tool in mcp_tools]
traceroot_logger.debug(f"MCP tools: {tool_names}")
@ -1386,9 +1404,9 @@ async def mcp_agent(options: Chat):
except Exception as e:
traceroot_logger.debug(repr(e))
task_lock = get_task_lock(options.task_id)
task_lock = get_task_lock(options.project_id)
agent_id = str(uuid.uuid4())
traceroot_logger.info(f"Creating MCP agent: {Agents.mcp_agent} with id: {agent_id} for task: {options.task_id}")
traceroot_logger.info(f"Creating MCP agent: {Agents.mcp_agent} with id: {agent_id} for task: {options.project_id}")
asyncio.create_task(
task_lock.put_queue(
ActionCreateAgentData(
@ -1401,7 +1419,7 @@ async def mcp_agent(options: Chat):
)
)
return ListenChatAgent(
options.task_id,
options.project_id,
Agents.mcp_agent,
system_message="You are a helpful assistant that can help users search mcp servers. The found mcp services will be returned to the user, and you will ask the user via ask_human_via_gui whether they want to install these mcp services.",
model=ModelFactory.create(
@ -1410,7 +1428,7 @@ async def mcp_agent(options: Chat):
api_key=options.api_key,
url=options.api_url,
model_config_dict={
"user": str(options.task_id),
"user": str(options.project_id),
}
if options.is_cloud()
else None,

View file

@ -0,0 +1,236 @@
import sqlite3
import os
from typing import List, Dict, Optional
from utils import traceroot_wrapper as traceroot
import shutil
from datetime import datetime
logger = traceroot.get_logger("cookie_manager")
class CookieManager:
"""Manager for reading and managing browser cookies
from Electron/Chrome SQLite database"""
def __init__(self, user_data_dir: str):
self.user_data_dir = user_data_dir
# Check for cookies in partition directory first (for persist:user_login)
partition_cookies_path = os.path.join(user_data_dir, "Partitions", "user_login", "Cookies")
if os.path.exists(partition_cookies_path):
self.cookies_db_path = partition_cookies_path
logger.info(f"Using partition cookies at: {partition_cookies_path}")
else:
# Fallback to default location
self.cookies_db_path = os.path.join(user_data_dir, "Cookies")
if not os.path.exists(self.cookies_db_path):
alt_path = os.path.join(user_data_dir, "Network", "Cookies")
if os.path.exists(alt_path):
self.cookies_db_path = alt_path
else:
logger.warning(f"Cookies database not found at {self.cookies_db_path} or {partition_cookies_path}")
def _get_cookies_connection(self) -> Optional[sqlite3.Connection]:
"""Get database connection using a temporary copy to avoid locks"""
if not os.path.exists(self.cookies_db_path):
logger.warning(f"Cookies database not found: {self.cookies_db_path}")
return None
try:
temp_db_path = self.cookies_db_path + ".tmp"
shutil.copy2(self.cookies_db_path, temp_db_path)
conn = sqlite3.connect(temp_db_path)
conn.row_factory = sqlite3.Row
return conn
except Exception as e:
logger.error(f"Error connecting to cookies database: {e}")
return None
def _cleanup_temp_db(self):
"""Clean up temporary database file"""
temp_db_path = self.cookies_db_path + ".tmp"
try:
if os.path.exists(temp_db_path):
os.remove(temp_db_path)
except Exception as e:
logger.debug(f"Error cleaning up temp database: {e}")
def get_cookie_domains(self) -> List[Dict[str, any]]:
"""Get list of all domains with cookies"""
conn = self._get_cookies_connection()
if not conn:
return []
try:
cursor = conn.cursor()
query = """
SELECT
host_key as domain,
COUNT(*) as cookie_count,
MAX(last_access_utc) as last_access
FROM cookies
GROUP BY host_key
ORDER BY last_access DESC
"""
cursor.execute(query)
rows = cursor.fetchall()
domains = []
for row in rows:
try:
chrome_timestamp = row['last_access']
if chrome_timestamp:
seconds_since_epoch = (chrome_timestamp / 1000000.0) - 11644473600
last_access = datetime.fromtimestamp(seconds_since_epoch).strftime('%Y-%m-%d %H:%M:%S')
else:
last_access = "Never"
except Exception as e:
logger.debug(f"Error converting timestamp: {e}")
last_access = "Unknown"
domains.append({
'domain': row['domain'],
'cookie_count': row['cookie_count'],
'last_access': last_access
})
logger.info(f"Found {len(domains)} domains with cookies")
return domains
except Exception as e:
logger.error(f"Error reading cookies: {e}")
return []
finally:
conn.close()
self._cleanup_temp_db()
def get_cookies_for_domain(self, domain: str) -> List[Dict[str, str]]:
"""Get all cookies for a specific domain"""
conn = self._get_cookies_connection()
if not conn:
return []
try:
cursor = conn.cursor()
query = """
SELECT
host_key,
name,
value,
path,
expires_utc,
is_secure,
is_httponly
FROM cookies
WHERE host_key = ? OR host_key LIKE ?
ORDER BY name
"""
cursor.execute(query, (domain, f'%.{domain}'))
rows = cursor.fetchall()
cookies = []
for row in rows:
cookies.append({
'domain': row['host_key'],
'name': row['name'],
'value': row['value'][:50] + '...' if len(row['value']) > 50 else row['value'],
'path': row['path'],
'secure': bool(row['is_secure']),
'httponly': bool(row['is_httponly'])
})
return cookies
except Exception as e:
logger.error(f"Error reading cookies for domain {domain}: {e}")
return []
finally:
conn.close()
self._cleanup_temp_db()
def delete_cookies_for_domain(self, domain: str) -> bool:
"""Delete all cookies for a specific domain"""
if not os.path.exists(self.cookies_db_path):
logger.warning(f"Cookies database not found: {self.cookies_db_path}")
return False
try:
conn = sqlite3.connect(self.cookies_db_path)
cursor = conn.cursor()
delete_query = """
DELETE FROM cookies
WHERE host_key = ? OR host_key LIKE ?
"""
cursor.execute(delete_query, (domain, f'%.{domain}'))
deleted_count = cursor.rowcount
conn.commit()
# IMPORTANT: Execute VACUUM to remove deleted data and compact database
# This prevents recovery from WAL files
cursor.execute("VACUUM")
conn.commit()
conn.close()
# Also remove WAL and SHM files to ensure clean state
self._cleanup_wal_files()
logger.info(f"Deleted {deleted_count} cookies for domain {domain}")
return True
except Exception as e:
logger.error(f"Error deleting cookies for domain {domain}: {e}")
return False
def _cleanup_wal_files(self):
"""Remove SQLite WAL and SHM files"""
try:
wal_path = self.cookies_db_path + '-wal'
shm_path = self.cookies_db_path + '-shm'
journal_path = self.cookies_db_path + '-journal'
for path in [wal_path, shm_path, journal_path]:
if os.path.exists(path):
os.remove(path)
logger.info(f"Removed temporary file: {path}")
except Exception as e:
logger.warning(f"Error cleaning up WAL files: {e}")
def delete_all_cookies(self) -> bool:
"""Delete all cookies"""
if not os.path.exists(self.cookies_db_path):
logger.warning(f"Cookies database not found: {self.cookies_db_path}")
return False
try:
conn = sqlite3.connect(self.cookies_db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM cookies")
deleted_count = cursor.rowcount
conn.commit()
# IMPORTANT: Execute VACUUM to remove deleted data and compact database
# This prevents recovery from WAL files
cursor.execute("VACUUM")
conn.commit()
conn.close()
# Also remove WAL and SHM files to ensure clean state
self._cleanup_wal_files()
logger.info(f"Deleted all {deleted_count} cookies")
return True
except Exception as e:
logger.error(f"Error deleting all cookies: {e}")
return False
def search_cookies(self, keyword: str) -> List[Dict[str, any]]:
"""Search cookies by domain keyword"""
domains = self.get_cookie_domains()
keyword_lower = keyword.lower()
return [
domain for domain in domains
if keyword_lower in domain['domain'].lower()
]

View file

@ -0,0 +1,20 @@
"""File system utilities."""
from app.component.environment import env
from app.model.chat import Chat
def get_working_directory(options: Chat, task_lock=None) -> str:
"""
Get the correct working directory for file operations.
First checks if there's an updated path from improve API call,
then falls back to environment variable or default path.
"""
if not task_lock:
from app.service.task import get_task_lock_if_exists
task_lock = get_task_lock_if_exists(options.project_id)
if task_lock and hasattr(task_lock, 'new_folder_path') and task_lock.new_folder_path:
return str(task_lock.new_folder_path)
else:
return env("file_save_path", options.file_save_path())

View file

@ -1,10 +1,11 @@
import asyncio
from functools import wraps
from inspect import iscoroutinefunction
from inspect import iscoroutinefunction, getmembers, ismethod, signature
import json
from typing import Any, Callable
from typing import Any, Callable, Type, TypeVar
import threading
from concurrent.futures import ThreadPoolExecutor
from loguru import logger
from app.service.task import (
ActionActivateToolkitData,
ActionDeactivateToolkitData,
@ -12,6 +13,41 @@ from app.service.task import (
)
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from app.service.task import process_task
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("toolkit_listen")
def _safe_put_queue(task_lock, data):
"""Safely put data to the queue, handling both sync and async contexts"""
try:
# Try to get current event loop
loop = asyncio.get_running_loop()
# We're in an async context, create a task
task = asyncio.create_task(task_lock.put_queue(data))
if hasattr(task_lock, "add_background_task"):
task_lock.add_background_task(task)
except RuntimeError:
# No running event loop, we need to handle this differently
try:
# Create a new event loop in a separate thread to avoid conflicts
def run_in_thread():
try:
# Create a new event loop for this thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
new_loop.run_until_complete(task_lock.put_queue(data))
finally:
new_loop.close()
except Exception as e:
logger.error(f"[listen_toolkit] Failed to send data in thread: {e}")
# Run in a separate thread to avoid blocking
thread = threading.Thread(target=run_in_thread, daemon=True)
thread.start()
except Exception as e:
logger.error(f"[listen_toolkit] Failed to send data to queue: {e}")
def listen_toolkit(
@ -27,6 +63,11 @@ def listen_toolkit(
@wraps(wrap)
async def async_wrapper(*args, **kwargs):
toolkit: AbstractToolkit = args[0]
# Check if api_task_id exists
if not hasattr(toolkit, 'api_task_id'):
logger.warning(f"[listen_toolkit] {toolkit.__class__.__name__} missing api_task_id, calling method directly")
return await func(*args, **kwargs)
task_lock = get_task_lock(toolkit.api_task_id)
if inputs is not None:
@ -40,19 +81,23 @@ def listen_toolkit(
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
args_str = f"{args_str}, {kwargs_str}" if args_str else kwargs_str
# Truncate args_str if too long
MAX_ARGS_LENGTH = 500
if len(args_str) > MAX_ARGS_LENGTH:
args_str = args_str[:MAX_ARGS_LENGTH] + f"... (truncated, total length: {len(args_str)} chars)"
toolkit_name = toolkit.toolkit_name()
method_name = func.__name__.replace("_", " ")
await task_lock.put_queue(
ActionActivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": args_str,
},
)
activate_data = ActionActivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": args_str,
},
)
await task_lock.put_queue(activate_data)
error = None
res = None
try:
@ -70,21 +115,26 @@ def listen_toolkit(
res_msg = json.dumps(res, ensure_ascii=False)
except TypeError:
# Handle cases where res contains non-serializable objects (like coroutines)
res_msg = str(res)
res_str = str(res)
# Truncate very long outputs to avoid flooding logs
MAX_LENGTH = 500
if len(res_str) > MAX_LENGTH:
res_msg = res_str[:MAX_LENGTH] + f"... (truncated, total length: {len(res_str)} chars)"
else:
res_msg = res_str
else:
res_msg = str(error)
await task_lock.put_queue(
ActionDeactivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": res_msg,
},
)
deactivate_data = ActionDeactivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": res_msg,
},
)
await task_lock.put_queue(deactivate_data)
if error is not None:
raise error
return res
@ -96,6 +146,11 @@ def listen_toolkit(
@wraps(wrap)
def sync_wrapper(*args, **kwargs):
toolkit: AbstractToolkit = args[0]
# Check if api_task_id exists
if not hasattr(toolkit, 'api_task_id'):
logger.warning(f"[listen_toolkit] {toolkit.__class__.__name__} missing api_task_id, calling method directly")
return func(*args, **kwargs)
task_lock = get_task_lock(toolkit.api_task_id)
if inputs is not None:
@ -109,34 +164,34 @@ def listen_toolkit(
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
args_str = f"{args_str}, {kwargs_str}" if args_str else kwargs_str
# Truncate args_str if too long
MAX_ARGS_LENGTH = 500
if len(args_str) > MAX_ARGS_LENGTH:
args_str = args_str[:MAX_ARGS_LENGTH] + f"... (truncated, total length: {len(args_str)} chars)"
toolkit_name = toolkit.toolkit_name()
method_name = func.__name__.replace("_", " ")
task = asyncio.create_task(
task_lock.put_queue(
ActionActivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": args_str,
},
)
)
activate_data = ActionActivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": args_str,
},
)
if hasattr(task_lock, "add_background_task"):
task_lock.add_background_task(task)
_safe_put_queue(task_lock, activate_data)
error = None
res = None
try:
logger.debug(f"Executing toolkit method: {toolkit_name}.{method_name} for agent '{toolkit.agent_name}'")
res = func(*args, **kwargs)
# Safety check: if the result is a coroutine, we need to await it
# Safety check: if the result is a coroutine, this is a programming error
if asyncio.iscoroutine(res):
import warnings
warnings.warn(f"Async function {func.__name__} was incorrectly called synchronously")
res = asyncio.run(res)
error_msg = f"Async function {func.__name__} was incorrectly called in sync context. This is a bug - the function should be marked as async or should not return a coroutine."
logger.error(f"[listen_toolkit] {error_msg}")
# Cannot safely await in sync context - close the coroutine to prevent warnings
res.close()
raise TypeError(error_msg)
except Exception as e:
error = e
@ -150,25 +205,26 @@ def listen_toolkit(
res_msg = json.dumps(res, ensure_ascii=False)
except TypeError:
# Handle cases where res contains non-serializable objects (like coroutines)
res_msg = str(res)
res_str = str(res)
# Truncate very long outputs to avoid flooding logs
MAX_LENGTH = 500
if len(res_str) > MAX_LENGTH:
res_msg = res_str[:MAX_LENGTH] + f"... (truncated, total length: {len(res_str)} chars)"
else:
res_msg = res_str
else:
res_msg = str(error)
task = asyncio.create_task(
task_lock.put_queue(
ActionDeactivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": res_msg,
},
)
)
deactivate_data = ActionDeactivateToolkitData(
data={
"agent_name": toolkit.agent_name,
"process_task_id": process_task.get(""),
"toolkit_name": toolkit_name,
"method_name": method_name,
"message": res_msg,
},
)
if hasattr(task_lock, "add_background_task"):
task_lock.add_background_task(task)
_safe_put_queue(task_lock, deactivate_data)
if error is not None:
raise error
return res
@ -176,3 +232,87 @@ def listen_toolkit(
return sync_wrapper
return decorator
T = TypeVar('T')
# Methods that should not be wrapped by auto_listen_toolkit
# These are utility/helper methods that don't perform actual tool operations
EXCLUDED_METHODS = {
'get_tools', # Tool enumeration
'get_can_use_tools', # Tool filtering
'toolkit_name', # Metadata getter
'run_mcp_server', # MCP server initialization
'model_dump', # Pydantic model serialization
'model_dump_json', # Pydantic model serialization
'dict', # Pydantic legacy dict method
'json', # Pydantic legacy json method
'copy', # Object copying
'update', # Object update
}
def auto_listen_toolkit(base_toolkit_class: Type[T]) -> Callable[[Type[T]], Type[T]]:
"""
Class decorator that automatically wraps all public methods from the base toolkit
with the @listen_toolkit decorator.
Excluded methods (not wrapped):
- get_tools, get_can_use_tools: Tool enumeration/filtering
- toolkit_name: Metadata getter
- run_mcp_server: MCP server initialization
- Pydantic serialization methods: model_dump, model_dump_json, dict, json
- Object utility methods: copy, update
These methods are typically called during initialization or for metadata,
and should not trigger activate/deactivate events.
Usage:
@auto_listen_toolkit(BaseNoteTakingToolkit)
class NoteTakingToolkit(BaseNoteTakingToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent
"""
def class_decorator(cls: Type[T]) -> Type[T]:
base_methods = {}
for name in dir(base_toolkit_class):
# Skip private methods and excluded helper methods
if not name.startswith('_') and name not in EXCLUDED_METHODS:
attr = getattr(base_toolkit_class, name)
if callable(attr):
base_methods[name] = attr
for method_name, base_method in base_methods.items():
if method_name in cls.__dict__:
continue
sig = signature(base_method)
def create_wrapper(method_name: str, base_method: Callable) -> Callable:
# Unwrap decorators to check the actual function
unwrapped_method = base_method
while hasattr(unwrapped_method, '__wrapped__'):
unwrapped_method = unwrapped_method.__wrapped__
# Check if the unwrapped method is a coroutine function
if iscoroutinefunction(unwrapped_method):
async def async_method_wrapper(self, *args, **kwargs):
return await getattr(super(cls, self), method_name)(*args, **kwargs)
async_method_wrapper.__name__ = method_name
async_method_wrapper.__signature__ = sig
return async_method_wrapper
else:
def sync_method_wrapper(self, *args, **kwargs):
return getattr(super(cls, self), method_name)(*args, **kwargs)
sync_method_wrapper.__name__ = method_name
sync_method_wrapper.__signature__ = sig
return sync_method_wrapper
wrapper = create_wrapper(method_name, base_method)
decorated_method = listen_toolkit(base_method)(wrapper)
setattr(cls, method_name, decorated_method)
return cls
return class_decorator

View file

@ -0,0 +1,94 @@
"""
OAuth authorization state manager for background authorization flows
"""
import threading
from typing import Dict, Optional, Literal, Any
from datetime import datetime
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("main")
AuthStatus = Literal["pending", "authorizing", "success", "failed", "cancelled"]
class OAuthState:
"""Represents the state of an OAuth authorization flow"""
def __init__(self, provider: str):
self.provider = provider
self.status: AuthStatus = "pending"
self.error: Optional[str] = None
self.thread: Optional[threading.Thread] = None
self.result: Optional[Any] = None
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self._cancel_event = threading.Event()
self.server = None # Store the local server instance for forced shutdown
def is_cancelled(self) -> bool:
"""Check if cancellation has been requested"""
return self._cancel_event.is_set()
def cancel(self):
"""Request cancellation of the authorization flow"""
self._cancel_event.set()
self.status = "cancelled"
self.completed_at = datetime.now()
def to_dict(self) -> Dict:
"""Convert state to dictionary for API response"""
return {
"provider": self.provider,
"status": self.status,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class OAuthStateManager:
"""Manager for tracking OAuth authorization flows"""
def __init__(self):
self._states: Dict[str, OAuthState] = {}
self._lock = threading.Lock()
def create_state(self, provider: str) -> OAuthState:
"""Create a new OAuth state for a provider"""
with self._lock:
# Cancel any existing authorization for this provider
if provider in self._states:
old_state = self._states[provider]
if old_state.status in ["pending", "authorizing"]:
old_state.cancel()
logger.info(f"Cancelled previous {provider} authorization")
state = OAuthState(provider)
self._states[provider] = state
return state
def get_state(self, provider: str) -> Optional[OAuthState]:
"""Get the current state for a provider"""
with self._lock:
return self._states.get(provider)
def update_status(
self,
provider: str,
status: AuthStatus,
error: Optional[str] = None,
result: Optional[Any] = None
):
"""Update the status of an authorization flow"""
with self._lock:
if provider in self._states:
state = self._states[provider]
state.status = status
state.error = error
state.result = result
if status in ["success", "failed", "cancelled"]:
state.completed_at = datetime.now()
logger.info(f"Updated {provider} OAuth status to {status}")
# Global instance
oauth_state_manager = OAuthStateManager()

View file

@ -3,9 +3,11 @@ import httpx
import asyncio
import os
import json
from loguru import logger
from app.service.chat_service import Chat
from app.component.environment import env
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("sync_step")
def sync_step(func):
@ -28,7 +30,9 @@ def sync_step(func):
send_to_api(
sync_url,
{
"task_id": chat.task_id,
# TODO: revert to task_id to support multi-task project replay
# "task_id": chat.task_id,
"task_id": chat.project_id,
"step": json_data["step"],
"data": json_data["data"],
},

View file

@ -2,11 +2,15 @@ import datetime
from camel.agents.chat_agent import AsyncStreamingChatAgentResponse
from camel.societies.workforce.single_agent_worker import SingleAgentWorker as BaseSingleAgentWorker
from camel.tasks.task import Task, TaskState, is_task_result_insufficient
from utils import traceroot_wrapper as traceroot
from app.utils.agent import ListenChatAgent
from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT
from colorama import Fore
from camel.societies.workforce.utils import TaskResult
from camel.utils.context_utils import ContextUtility
logger = traceroot.get_logger("single_agent_worker")
class SingleAgentWorker(BaseSingleAgentWorker):
@ -19,6 +23,8 @@ class SingleAgentWorker(BaseSingleAgentWorker):
pool_max_size: int = 10,
auto_scale_pool: bool = True,
use_structured_output_handler: bool = True,
context_utility: ContextUtility | None = None,
enable_workflow_memory: bool = False,
) -> None:
super().__init__(
description=description,
@ -28,6 +34,8 @@ class SingleAgentWorker(BaseSingleAgentWorker):
pool_max_size=pool_max_size,
auto_scale_pool=auto_scale_pool,
use_structured_output_handler=use_structured_output_handler,
context_utility=context_utility,
enable_workflow_memory=enable_workflow_memory,
)
self.worker = worker # change type hint
@ -54,6 +62,7 @@ class SingleAgentWorker(BaseSingleAgentWorker):
worker_agent.process_task_id = task.id # type: ignore rewrite line
response_content = ""
final_response = None
try:
dependency_tasks_info = self._get_dep_tasks_info(dependencies)
prompt = PROCESS_TASK_PROMPT.format(
@ -130,8 +139,28 @@ class SingleAgentWorker(BaseSingleAgentWorker):
usage_info = response.info.get("usage") or response.info.get("token_usage")
total_tokens = usage_info.get("total_tokens", 0) if usage_info else 0
# collect conversation from working agent to
# accumulator for workflow memory
# Only transfer memory if workflow memory is enabled
if self.enable_workflow_memory:
accumulator = self._get_conversation_accumulator()
# transfer all memory records from working agent to accumulator
try:
# retrieve all context records from the working agent
work_records = worker_agent.memory.retrieve()
# write these records to the accumulator's memory
memory_records = [record.memory_record for record in work_records]
accumulator.memory.write_records(memory_records)
logger.debug(f"Transferred {len(memory_records)} memory records to accumulator")
except Exception as e:
logger.warning(f"Failed to transfer conversation to accumulator: {e}")
except Exception as e:
print(f"{Fore.RED}Error processing task {task.id}: {type(e).__name__}: {e}{Fore.RESET}")
logger.error(f"Error processing task {task.id}: {type(e).__name__}: {e}")
# Store error information in task result
task.result = f"{type(e).__name__}: {e!s}"
return TaskState.FAILED
@ -144,6 +173,8 @@ class SingleAgentWorker(BaseSingleAgentWorker):
task.additional_info = {}
# Create worker attempt details with descriptive keys
# Use final_response if available (streaming), otherwise use response
response_for_info = final_response if final_response is not None else response
worker_attempt_details = {
"agent_id": getattr(worker_agent, "agent_id", worker_agent.role_name),
"original_worker_id": getattr(self.worker, "agent_id", self.worker.role_name),
@ -154,11 +185,7 @@ class SingleAgentWorker(BaseSingleAgentWorker):
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
f"to process task: {task.content}",
"response_content": response_content[:50],
"tool_calls": str(
final_response.info.get("tool_calls")
if isinstance(response, AsyncStreamingChatAgentResponse)
else response.info.get("tool_calls")
)[:50],
"tool_calls": str(response_for_info.info.get("tool_calls", []) if response_for_info and hasattr(response_for_info, 'info') else [])[:50],
"total_tokens": total_tokens,
}
@ -172,9 +199,12 @@ class SingleAgentWorker(BaseSingleAgentWorker):
print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}")
logger.info(f"Response from {self}:")
if not self.use_structured_output_handler:
# Handle native structured output parsing
if task_result is None:
logger.error("Error in worker step execution: Invalid task result")
print(f"{Fore.RED}Error in worker step execution: Invalid task result{Fore.RESET}")
task_result = TaskResult(
content="Failed to generate valid task result.",
@ -186,12 +216,17 @@ class SingleAgentWorker(BaseSingleAgentWorker):
f"\n{color}{task_result.content}{Fore.RESET}\n======", # type: ignore[union-attr]
)
if task_result.failed: # type: ignore[union-attr]
logger.error(f"{task_result.content}") # type: ignore[union-attr]
else:
logger.info(f"{task_result.content}") # type: ignore[union-attr]
task.result = task_result.content # type: ignore[union-attr]
if task_result.failed: # type: ignore[union-attr]
return TaskState.FAILED
if is_task_result_insufficient(task):
print(f"{Fore.RED}Task {task.id}: Content validation failed - task marked as failed{Fore.RESET}")
logger.warning(f"Task {task.id}: Content validation failed - task marked as failed")
return TaskState.FAILED
return TaskState.DONE

View file

@ -4,10 +4,11 @@ from camel.toolkits import AudioAnalysisToolkit as BaseAudioAnalysisToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseAudioAnalysisToolkit)
class AudioAnalysisToolkit(BaseAudioAnalysisToolkit, AbstractToolkit):
agent_name: str = Agents.multi_modal_agent
@ -23,14 +24,3 @@ class AudioAnalysisToolkit(BaseAudioAnalysisToolkit, AbstractToolkit):
cache_dir = env("file_save_path", os.path.expanduser("~/.eigent/tmp/"))
super().__init__(cache_dir, transcribe_model, audio_reasoning_model, timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseAudioAnalysisToolkit.audio2text,
lambda _, audio_path, question: f"transcribe audio from {audio_path} and ask question: {question}",
)
def ask_question_about_audio(self, audio_path: str, question: str) -> str:
return super().ask_question_about_audio(audio_path, question)
@listen_toolkit(BaseAudioAnalysisToolkit.audio2text)
def audio2text(self, audio_path: str) -> str:
return super().audio2text(audio_path)

View file

@ -1,10 +1,11 @@
from typing import List, Literal
from camel.toolkits import CodeExecutionToolkit as BaseCodeExecutionToolkit, FunctionTool
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseCodeExecutionToolkit)
class CodeExecutionToolkit(BaseCodeExecutionToolkit, AbstractToolkit):
agent_name: str = Agents.developer_agent
@ -21,18 +22,6 @@ class CodeExecutionToolkit(BaseCodeExecutionToolkit, AbstractToolkit):
self.api_task_id = api_task_id
super().__init__(sandbox, verbose, unsafe_mode, import_white_list, require_confirm, timeout)
@listen_toolkit(
BaseCodeExecutionToolkit.execute_code,
)
def execute_code(self, code: str, code_type: str = "python") -> str:
return super().execute_code(code, code_type)
@listen_toolkit(
BaseCodeExecutionToolkit.execute_command,
)
def execute_command(self, command: str) -> str | tuple[str, str]:
return super().execute_command(command)
def get_tools(self) -> List[FunctionTool]:
return [
FunctionTool(self.execute_code),

View file

@ -1,10 +1,11 @@
from camel.toolkits import Crawl4AIToolkit as BaseCrawl4AIToolkit
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseCrawl4AIToolkit)
class Crawl4AIToolkit(BaseCrawl4AIToolkit, AbstractToolkit):
agent_name: str = Agents.search_agent
@ -12,18 +13,5 @@ class Crawl4AIToolkit(BaseCrawl4AIToolkit, AbstractToolkit):
self.api_task_id = api_task_id
super().__init__(timeout)
# async def _get_client(self):
# r"""Get or create the AsyncWebCrawler client."""
# if self._client is None:
# from crawl4ai import AsyncWebCrawler
# self._client = AsyncWebCrawler(use_managed_browser=True)
# await self._client.__aenter__()
# return self._client
@listen_toolkit(BaseCrawl4AIToolkit.scrape)
async def scrape(self, url: str) -> str:
return await super().scrape(url)
def toolkit_name(self) -> str:
return "Crawl Toolkit"

View file

@ -3,10 +3,11 @@ from camel.toolkits import ExcelToolkit as BaseExcelToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseExcelToolkit)
class ExcelToolkit(BaseExcelToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent
@ -20,7 +21,3 @@ class ExcelToolkit(BaseExcelToolkit, AbstractToolkit):
if working_directory is None:
working_directory = env("file_save_path", os.path.expanduser("~/Downloads"))
super().__init__(timeout=timeout, working_directory=working_directory)
@listen_toolkit(BaseExcelToolkit.extract_excel_content)
def extract_excel_content(self, document_path: str) -> str:
return super().extract_excel_content(document_path)

View file

@ -5,10 +5,11 @@ from camel.toolkits import FileToolkit as BaseFileToolkit
from app.component.environment import env
from app.service.task import process_task
from app.service.task import ActionWriteFileData, Agents, get_task_lock
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseFileToolkit)
class FileToolkit(BaseFileToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent
@ -54,15 +55,3 @@ class FileToolkit(BaseFileToolkit, AbstractToolkit):
)
)
return res
@listen_toolkit(
BaseFileToolkit.read_file,
)
def read_file(self, file_paths: str | list[str]) -> str | dict[str, str]:
return super().read_file(file_paths)
@listen_toolkit(
BaseFileToolkit.edit_file,
)
def edit_file(self, file_path: str, old_content: str, new_content: str) -> str:
return super().edit_file(file_path, old_content, new_content)

View file

@ -3,10 +3,11 @@ from camel.toolkits import GithubToolkit as BaseGithubToolkit
from camel.toolkits.function_tool import FunctionTool
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseGithubToolkit)
class GithubToolkit(BaseGithubToolkit, AbstractToolkit):
agent_name: str = Agents.developer_agent
@ -19,86 +20,6 @@ class GithubToolkit(BaseGithubToolkit, AbstractToolkit):
super().__init__(access_token, timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseGithubToolkit.create_pull_request,
lambda _,
repo_name,
file_path,
new_content,
pr_title,
body,
branch_name: f"Create PR in {repo_name} for {file_path} with title '{pr_title}', branch '{branch_name}', content '{new_content}'",
)
def create_pull_request(
self,
repo_name: str,
file_path: str,
new_content: str,
pr_title: str,
body: str,
branch_name: str,
) -> str:
return super().create_pull_request(repo_name, file_path, new_content, pr_title, body, branch_name)
@listen_toolkit(
BaseGithubToolkit.get_issue_list,
lambda _, repo_name, state="all": f"Get issue list from {repo_name} with state '{state}'",
lambda issues: f"Retrieved {len(issues)} issues",
)
def get_issue_list(
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
) -> list[dict[str, object]]:
return super().get_issue_list(repo_name, state)
@listen_toolkit(
BaseGithubToolkit.get_issue_content,
lambda _, repo_name, issue_number: f"Get content of issue {issue_number} from {repo_name}",
)
def get_issue_content(self, repo_name: str, issue_number: int) -> str:
return super().get_issue_content(repo_name, issue_number)
@listen_toolkit(
BaseGithubToolkit.get_pull_request_list,
lambda _, repo_name, state="all": f"Get pull request list from {repo_name} with state '{state}'",
lambda prs: f"Retrieved {len(prs)} pull requests",
)
def get_pull_request_list(
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
) -> list[dict[str, object]]:
return super().get_pull_request_list(repo_name, state)
@listen_toolkit(
BaseGithubToolkit.get_pull_request_code,
lambda _, repo_name, pr_number: f"Get code for pull request {pr_number} in {repo_name}",
lambda code: f"Retrieved {len(code)} code files",
)
def get_pull_request_code(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
return super().get_pull_request_code(repo_name, pr_number)
@listen_toolkit(
BaseGithubToolkit.get_pull_request_comments,
lambda _, repo_name, pr_number: f"Get comments for pull request {pr_number} in {repo_name}",
lambda comments: f"Retrieved {len(comments)} comments",
)
def get_pull_request_comments(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
return super().get_pull_request_comments(repo_name, pr_number)
@listen_toolkit(
BaseGithubToolkit.get_all_file_paths,
lambda _, repo_name, path="": f"Get all file paths from {repo_name}, path '{path}'",
lambda paths: f"Retrieved {len(paths)} file paths",
)
def get_all_file_paths(self, repo_name: str, path: str = "") -> list[str]:
return super().get_all_file_paths(repo_name, path)
@listen_toolkit(
BaseGithubToolkit.retrieve_file_content,
lambda _, repo_name, file_path: f"Retrieve content of file {file_path} from {repo_name}",
lambda content: f"Retrieved content of length {len(content)}",
)
def retrieve_file_content(self, repo_name: str, file_path: str) -> str:
return super().retrieve_file_content(repo_name, file_path)
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
if env("GITHUB_ACCESS_TOKEN"):

View file

@ -1,59 +1,272 @@
from typing import Any, Dict, List
import os
import threading
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from app.utils.oauth_state_manager import oauth_state_manager
from utils import traceroot_wrapper as traceroot
from camel.toolkits import GoogleCalendarToolkit as BaseGoogleCalendarToolkit
logger = traceroot.get_logger("main")
SCOPES = ['https://www.googleapis.com/auth/calendar']
@auto_listen_toolkit(BaseGoogleCalendarToolkit)
class GoogleCalendarToolkit(BaseGoogleCalendarToolkit, AbstractToolkit):
agent_name: str = Agents.social_medium_agent
def __init__(self, api_task_id: str, timeout: float | None = None):
self.api_task_id = api_task_id
self._token_path = (
env("GOOGLE_CALENDAR_TOKEN_PATH")
or os.path.join(
os.path.expanduser("~"),
".eigent",
"tokens",
"google_calendar",
f"google_calendar_token_{api_task_id}.json",
)
)
super().__init__(timeout)
@listen_toolkit(BaseGoogleCalendarToolkit.create_event)
def create_event(
self,
event_title: str,
start_time: str,
end_time: str,
description: str = "",
location: str = "",
attendees_email: List[str] | None = None,
timezone: str = "UTC",
) -> Dict[str, Any]:
return super().create_event(event_title, start_time, end_time, description, location, attendees_email, timezone)
@listen_toolkit(BaseGoogleCalendarToolkit.get_events)
def get_events(self, max_results: int = 10, time_min: str | None = None) -> List[Dict[str, Any]] | Dict[str, Any]:
return super().get_events(max_results, time_min)
@listen_toolkit(BaseGoogleCalendarToolkit.update_event)
def update_event(
self,
event_id: str,
event_title: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
description: str | None = None,
location: str | None = None,
attendees_email: List[str] | None = None,
) -> Dict[str, Any]:
return super().update_event(event_id, event_title, start_time, end_time, description, location, attendees_email)
@listen_toolkit(BaseGoogleCalendarToolkit.delete_event)
def delete_event(self, event_id: str) -> str:
return super().delete_event(event_id)
@listen_toolkit(BaseGoogleCalendarToolkit.get_calendar_details)
def get_calendar_details(self) -> Dict[str, Any]:
return super().get_calendar_details()
@classmethod
def get_can_use_tools(cls, api_task_id: str):
if env("GOOGLE_CLIENT_ID") and env("GOOGLE_CLIENT_SECRET"):
from dotenv import load_dotenv
# Force reload environment variables
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
if os.path.exists(default_env_path):
load_dotenv(dotenv_path=default_env_path, override=True)
if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
return cls(api_task_id).get_tools()
else:
return []
def _get_calendar_service(self):
from googleapiclient.discovery import build
from google.auth.transport.requests import Request
creds = self._authenticate()
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
try:
os.makedirs(os.path.dirname(self._token_path), exist_ok=True)
with open(self._token_path, "w") as f:
f.write(creds.to_json())
except Exception:
pass
return build("calendar", "v3", credentials=creds)
def _authenticate(self):
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
from dotenv import load_dotenv
# Force reload environment variables from default .env file
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
if os.path.exists(default_env_path):
load_dotenv(dotenv_path=default_env_path, override=True)
creds = None
# First, try to load from token file
try:
if os.path.exists(self._token_path):
logger.info(f"Loading credentials from token file: {self._token_path}")
creds = Credentials.from_authorized_user_file(self._token_path, SCOPES)
logger.info("Successfully loaded credentials from token file")
except Exception as e:
logger.warning(f"Could not load from token file: {e}")
creds = None
# If no token file, try environment variables
if not creds:
client_id = os.environ.get("GOOGLE_CLIENT_ID")
client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
refresh_token = os.environ.get("GOOGLE_REFRESH_TOKEN")
token_uri = os.environ.get("GOOGLE_TOKEN_URI") or "https://oauth2.googleapis.com/token"
if refresh_token and client_id and client_secret:
logger.info("Creating credentials from environment variables")
creds = Credentials(
None,
refresh_token=refresh_token,
token_uri=token_uri,
client_id=client_id,
client_secret=client_secret,
scopes=SCOPES,
)
# If still no creds, check background authorization
if not creds:
state = oauth_state_manager.get_state("google_calendar")
if state and state.status == "success" and state.result:
logger.info("Using credentials from background authorization")
creds = state.result
else:
# No credentials available
raise ValueError("No credentials available. Please run authorization first via /api/install/tool/google_calendar")
# Refresh if expired
if creds and creds.expired and creds.refresh_token:
try:
logger.info("Token expired, refreshing...")
creds.refresh(Request())
logger.info("Token refreshed successfully")
except Exception as e:
logger.error(f"Failed to refresh token: {e}")
raise ValueError("Failed to refresh expired token. Please re-authorize.")
# Save credentials
try:
os.makedirs(os.path.dirname(self._token_path), exist_ok=True)
with open(self._token_path, "w") as f:
f.write(creds.to_json())
except Exception as e:
logger.warning(f"Could not save credentials: {e}")
return creds
@staticmethod
def start_background_auth(api_task_id: str = "install_auth") -> str:
"""
Start background OAuth authorization flow with timeout
Returns the status of the authorization
"""
from google_auth_oauthlib.flow import InstalledAppFlow
from dotenv import load_dotenv
# Force reload environment variables from default .env file
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
if os.path.exists(default_env_path):
logger.info(f"Reloading environment variables from {default_env_path}")
load_dotenv(dotenv_path=default_env_path, override=True)
# Check if there's an existing authorization and force stop it
old_state = oauth_state_manager.get_state("google_calendar")
if old_state and old_state.status in ["pending", "authorizing"]:
logger.info("Found existing authorization, forcing shutdown...")
old_state.cancel()
# Try to shutdown the old server if it exists
if hasattr(old_state, 'server') and old_state.server:
try:
old_state.server.shutdown()
logger.info("Old server shutdown successfully")
except Exception as e:
logger.warning(f"Could not shutdown old server: {e}")
# Create new state for this authorization
state = oauth_state_manager.create_state("google_calendar")
def auth_flow():
try:
state.status = "authorizing"
oauth_state_manager.update_status("google_calendar", "authorizing")
# Reload environment variables in this thread
from dotenv import load_dotenv
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
if os.path.exists(default_env_path):
load_dotenv(dotenv_path=default_env_path, override=True)
client_id = os.environ.get("GOOGLE_CLIENT_ID")
client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
token_uri = os.environ.get("GOOGLE_TOKEN_URI") or "https://oauth2.googleapis.com/token"
logger.info(f"Google Calendar auth - client_id present: {bool(client_id)}, client_secret present: {bool(client_secret)}")
if not client_id or not client_secret:
error_msg = "GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET must be set in environment variables"
logger.error(error_msg)
raise ValueError(error_msg)
client_config = {
"installed": {
"client_id": client_id,
"client_secret": client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": token_uri,
"redirect_uris": ["http://localhost"],
}
}
logger.debug(f"calendar client_config initialized with client_id: {client_id[:10]}...")
flow = InstalledAppFlow.from_client_config(client_config, SCOPES)
# Check for cancellation before starting
if state.is_cancelled():
logger.info("Authorization cancelled before starting")
return
# This will automatically open browser and wait for user authorization
logger.info("=" * 80)
logger.info(f"[Thread {threading.current_thread().name}] Starting local server for Google Calendar authorization")
logger.info("Browser should open automatically in a moment...")
logger.info("=" * 80)
# Run local server - this will block until authorization completes
# Note: Each call uses a random port (port=0), so multiple concurrent attempts won't conflict
try:
creds = flow.run_local_server(
port=0,
authorization_prompt_message="",
success_message="<h1>Authorization successful!</h1><p>You can close this window and return to Eigent.</p>",
open_browser=True
)
logger.info("Authorization flow completed successfully!")
except Exception as server_error:
logger.error(f"Error during run_local_server: {server_error}")
raise
# Check for cancellation after auth
if state.is_cancelled():
logger.info("Authorization cancelled after completion")
return
# Save credentials to token file
token_path = os.path.join(
os.path.expanduser("~"),
".eigent",
"tokens",
"google_calendar",
f"google_calendar_token_{api_task_id}.json",
)
try:
os.makedirs(os.path.dirname(token_path), exist_ok=True)
with open(token_path, "w") as f:
f.write(creds.to_json())
logger.info(f"Saved Google Calendar credentials to {token_path}")
except Exception as e:
logger.warning(f"Could not save credentials: {e}")
# Update state with success
oauth_state_manager.update_status("google_calendar", "success", result=creds)
logger.info("Google Calendar authorization successful!")
except Exception as e:
if state.is_cancelled():
logger.info("Authorization was cancelled")
oauth_state_manager.update_status("google_calendar", "cancelled")
else:
error_msg = str(e)
logger.error(f"Google Calendar authorization failed: {error_msg}")
oauth_state_manager.update_status("google_calendar", "failed", error=error_msg)
finally:
# Clean up server reference
state.server = None
# Start authorization in background thread
thread = threading.Thread(target=auth_flow, daemon=True, name=f"GoogleCalendar-OAuth-{state.started_at.timestamp()}")
state.thread = thread
thread.start()
logger.info("Started background Google Calendar authorization")
return "authorizing"

View file

@ -1,14 +1,16 @@
import asyncio
from camel.toolkits.base import BaseToolkit
from loguru import logger
from camel.toolkits.function_tool import FunctionTool
from app.service.task import Action, ActionAskData, ActionNoticeData, get_task_lock
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from app.service.task import process_task
# Rewrite HumanToolkit because the system's user interaction was using console, but in electron we cannot use console. Changed to use SSE response to let frontend show dialog for user interaction
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("human_toolkit")
@auto_listen_toolkit(BaseToolkit)
class HumanToolkit(BaseToolkit, AbstractToolkit):
r"""A class representing a toolkit for human interaction.
Note:

View file

@ -12,12 +12,14 @@ from camel.toolkits.hybrid_browser_toolkit_py.actions import ActionExecutor
from camel.toolkits.hybrid_browser_toolkit_py.snapshot import PageSnapshot
from camel.toolkits.hybrid_browser_toolkit_py.agent import PlaywrightLLMAgent
from camel.toolkits.function_tool import FunctionTool
from loguru import logger
from app.component.environment import env
from app.exception.exception import ProgramException
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("hybrid_browser_python_toolkit")
class BrowserSession(BaseHybridBrowserSession):
@ -124,6 +126,7 @@ class BrowserSession(BaseHybridBrowserSession):
break
@auto_listen_toolkit(BaseHybridBrowserToolkit)
class HybridBrowserPythonToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
agent_name: str = Agents.search_agent
@ -224,14 +227,6 @@ class HybridBrowserPythonToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
self._agent: PlaywrightLLMAgent | None = None
self._unified_script = self._load_unified_analyzer()
@listen_toolkit(BaseHybridBrowserToolkit.browser_open)
async def browser_open(self) -> Dict[str, str]:
return await super().browser_open()
@listen_toolkit(BaseHybridBrowserToolkit.browser_close)
async def browser_close(self) -> str:
return await super().browser_close()
@listen_toolkit(BaseHybridBrowserToolkit.browser_visit_page, lambda _, url: url)
async def browser_visit_page(self, url: str) -> Dict[str, Any]:
r"""Navigates to a URL.
@ -282,66 +277,6 @@ class HybridBrowserPythonToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
return {"result": nav_result, "snapshot": snapshot, **tab_info}
@listen_toolkit(BaseHybridBrowserToolkit.browser_back)
async def browser_back(self) -> Dict[str, Any]:
return await super().browser_back()
@listen_toolkit(BaseHybridBrowserToolkit.browser_forward)
async def browser_forward(self) -> Dict[str, Any]:
return await super().browser_forward()
@listen_toolkit(BaseHybridBrowserToolkit.browser_click)
async def browser_click(self, *, ref: str) -> Dict[str, Any]:
return await super().browser_click(ref=ref)
@listen_toolkit(BaseHybridBrowserToolkit.browser_type)
async def browser_type(self, *, ref: str, text: str) -> Dict[str, Any]:
return await super().browser_type(ref=ref, text=text)
@listen_toolkit(BaseHybridBrowserToolkit.browser_switch_tab)
async def browser_switch_tab(self, *, tab_id: str) -> Dict[str, Any]:
return await super().browser_switch_tab(tab_id=tab_id)
@listen_toolkit(BaseHybridBrowserToolkit.browser_select)
async def browser_select(self, *, ref: str, value: str) -> Dict[str, str]:
return await super().browser_select(ref=ref, value=value)
@listen_toolkit(BaseHybridBrowserToolkit.browser_scroll)
async def browser_scroll(self, *, direction: str, amount: int) -> Dict[str, str]:
return await super().browser_scroll(direction=direction, amount=amount)
@listen_toolkit(BaseHybridBrowserToolkit.browser_wait_user)
async def browser_wait_user(self, timeout_sec: float | None = None) -> Dict[str, str]:
return await super().browser_wait_user(timeout_sec)
@listen_toolkit(BaseHybridBrowserToolkit.browser_enter)
async def browser_enter(self) -> Dict[str, str]:
return await super().browser_enter()
@listen_toolkit(BaseHybridBrowserToolkit.browser_solve_task)
async def browser_solve_task(self, task_prompt: str, start_url: str, max_steps: int = 15) -> str:
return await super().browser_solve_task(task_prompt, start_url, max_steps)
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_page_snapshot)
async def browser_get_page_snapshot(self) -> str:
return await super().browser_get_page_snapshot()
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_som_screenshot)
async def browser_get_som_screenshot(self):
return await super().browser_get_som_screenshot()
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_page_links)
async def browser_get_page_links(self, *, ref: List[str]) -> Dict[str, Any]:
return await super().browser_get_page_links(ref=ref)
@listen_toolkit(BaseHybridBrowserToolkit.browser_close_tab)
async def browser_close_tab(self, *, tab_id: str) -> Dict[str, Any]:
return await super().browser_close_tab(tab_id=tab_id)
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_tab_info)
async def browser_get_tab_info(self) -> Dict[str, Any]:
return await super().browser_get_tab_info()
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
browser = HybridBrowserPythonToolkit(

View file

@ -4,7 +4,6 @@ import time
import asyncio
import json
from typing import Any, Dict, List, Optional
from loguru import logger
import websockets
import websockets.exceptions
@ -16,8 +15,11 @@ from camel.toolkits.hybrid_browser_toolkit.ws_wrapper import WebSocketBrowserWra
from app.component.command import bun, uv
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("hybrid_browser_toolkit")
class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper):
@ -45,8 +47,13 @@ class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper):
future.set_result(response)
logger.debug(f"Processed response for message {message_id}")
else:
# Log unexpected messages
logger.warning(f"Received unexpected message: {response}")
message_summary = {
"id": response.get("id"),
"success": response.get("success"),
"has_result": "result" in response,
"result_type": type(response.get("result")).__name__ if "result" in response else None
}
logger.debug(f"Received unexpected message: {message_summary}")
except asyncio.CancelledError:
disconnect_reason = "Receive loop cancelled"
@ -210,6 +217,7 @@ class WebSocketConnectionPool:
websocket_connection_pool = WebSocketConnectionPool()
@auto_listen_toolkit(BaseHybridBrowserToolkit)
class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
agent_name: str = Agents.search_agent
@ -240,7 +248,22 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
cdp_keep_current_page: bool = False,
full_visual_mode: bool = False,
) -> None:
logger.info(f"[HybridBrowserToolkit] Initializing with api_task_id: {api_task_id}")
self.api_task_id = api_task_id
logger.debug(f"[HybridBrowserToolkit] api_task_id set to: {self.api_task_id}")
# Set default user_data_dir if not provided
if user_data_dir is None:
# Use browser port to determine profile directory
browser_port = env('browser_port', '9222')
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
user_data_dir = os.path.join(user_data_base, f"profile_{browser_port}")
os.makedirs(user_data_dir, exist_ok=True)
logger.info(f"[HybridBrowserToolkit] Using port-based user_data_dir: {user_data_dir} (port: {browser_port})")
else:
logger.info(f"[HybridBrowserToolkit] Using provided user_data_dir: {user_data_dir}")
logger.debug(f"[HybridBrowserToolkit] Calling super().__init__ with session_id: {session_id}")
super().__init__(
headless=headless,
user_data_dir=user_data_dir,
@ -264,16 +287,24 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
cdp_keep_current_page=cdp_keep_current_page,
full_visual_mode=full_visual_mode,
)
logger.info(f"[HybridBrowserToolkit] Initialization complete for api_task_id: {self.api_task_id}")
async def _ensure_ws_wrapper(self):
"""Ensure WebSocket wrapper is initialized using connection pool."""
logger.debug(f"[HybridBrowserToolkit] _ensure_ws_wrapper called for api_task_id: {getattr(self, 'api_task_id', 'NOT SET')}")
global websocket_connection_pool
# Get session ID from config or use default
session_id = self._ws_config.get("session_id", "default")
logger.debug(f"[HybridBrowserToolkit] Using session_id: {session_id}")
# Log when connecting to browser
cdp_url = self._ws_config.get("cdp_url", f"http://localhost:{env('browser_port', '9222')}")
logger.info(f"[PROJECT BROWSER] Connecting to browser via CDP at {cdp_url}")
# Get or create connection from pool
self._ws_wrapper = await websocket_connection_pool.get_connection(session_id, self._ws_config)
logger.info(f"[HybridBrowserToolkit] WebSocket wrapper initialized for session: {session_id}")
# Additional health check
if self._ws_wrapper.websocket is None:
@ -287,10 +318,16 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
if new_session_id is None:
new_session_id = str(uuid.uuid4())[:8]
# For cloned sessions, use the same user_data_dir to share login state
# This allows multiple agents to use the same browser profile without conflicts
logger.info(f"Cloning session {new_session_id} with shared user_data_dir: {self._user_data_dir}")
# Use the same session_id to share the same browser instance
# This ensures all clones use the same WebSocket connection and browser
return HybridBrowserToolkit(
self.api_task_id,
headless=self._headless,
user_data_dir=self._user_data_dir,
user_data_dir=self._user_data_dir, # Use the same user_data_dir
stealth=self._stealth,
web_agent_model=self._web_agent_model,
cache_dir=f"{self._cache_dir.rstrip('/')}/_clone_{new_session_id}/",
@ -336,74 +373,3 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
if hasattr(self, "_ws_wrapper") and self._ws_wrapper:
session_id = self._ws_config.get("session_id", "default")
logger.debug(f"HybridBrowserToolkit for session {session_id} is being garbage collected")
@listen_toolkit(BaseHybridBrowserToolkit.browser_open)
async def browser_open(self) -> Dict[str, Any]:
return await super().browser_open()
@listen_toolkit(BaseHybridBrowserToolkit.browser_close)
async def browser_close(self) -> str:
return await super().browser_close()
@listen_toolkit(BaseHybridBrowserToolkit.browser_visit_page)
async def browser_visit_page(self, url: str) -> Dict[str, Any]:
logger.debug(f"browser_visit_page called with URL: {url}")
try:
result = await super().browser_visit_page(url)
logger.debug(f"browser_visit_page succeeded for URL: {url}")
return result
except Exception as e:
logger.error(f"browser_visit_page failed for URL {url}: {type(e).__name__}: {e}")
raise
@listen_toolkit(BaseHybridBrowserToolkit.browser_back)
async def browser_back(self) -> Dict[str, Any]:
return await super().browser_back()
@listen_toolkit(BaseHybridBrowserToolkit.browser_forward)
async def browser_forward(self) -> Dict[str, Any]:
return await super().browser_forward()
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_page_snapshot)
async def browser_get_page_snapshot(self) -> str:
return await super().browser_get_page_snapshot()
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_som_screenshot)
async def browser_get_som_screenshot(self, read_image: bool = False, instruction: str | None = None) -> str:
return await super().browser_get_som_screenshot(read_image, instruction)
@listen_toolkit(BaseHybridBrowserToolkit.browser_click)
async def browser_click(self, *, ref: str) -> Dict[str, Any]:
return await super().browser_click(ref=ref)
@listen_toolkit(BaseHybridBrowserToolkit.browser_type)
async def browser_type(self, *, ref: str, text: str) -> Dict[str, Any]:
return await super().browser_type(ref=ref, text=text)
@listen_toolkit(BaseHybridBrowserToolkit.browser_select)
async def browser_select(self, *, ref: str, value: str) -> Dict[str, Any]:
return await super().browser_select(ref=ref, value=value)
@listen_toolkit(BaseHybridBrowserToolkit.browser_scroll)
async def browser_scroll(self, *, direction: str, amount: int = 500) -> Dict[str, Any]:
return await super().browser_scroll(direction=direction, amount=amount)
@listen_toolkit(BaseHybridBrowserToolkit.browser_enter)
async def browser_enter(self) -> Dict[str, Any]:
return await super().browser_enter()
@listen_toolkit(BaseHybridBrowserToolkit.browser_wait_user)
async def browser_wait_user(self, timeout_sec: float | None = None) -> Dict[str, Any]:
return await super().browser_wait_user(timeout_sec)
@listen_toolkit(BaseHybridBrowserToolkit.browser_switch_tab)
async def browser_switch_tab(self, *, tab_id: str) -> Dict[str, Any]:
return await super().browser_switch_tab(tab_id=tab_id)
@listen_toolkit(BaseHybridBrowserToolkit.browser_close_tab)
async def browser_close_tab(self, *, tab_id: str) -> Dict[str, Any]:
return await super().browser_close_tab(tab_id=tab_id)
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_tab_info)
async def browser_get_tab_info(self) -> Dict[str, Any]:
return await super().browser_get_tab_info()

View file

@ -2,10 +2,11 @@ from camel.models import BaseModelBackend
from camel.toolkits import ImageAnalysisToolkit as BaseImageAnalysisToolkit
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseImageAnalysisToolkit)
class ImageAnalysisToolkit(BaseImageAnalysisToolkit, AbstractToolkit):
agent_name: str = Agents.multi_modal_agent
@ -17,24 +18,3 @@ class ImageAnalysisToolkit(BaseImageAnalysisToolkit, AbstractToolkit):
):
super().__init__(model, timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseImageAnalysisToolkit.image_to_text,
lambda _,
image_path,
sys_prompt: f"transcribe image from {image_path} and ask sys_prompt: {sys_prompt}",
)
def image_to_text(self, image_path: str, sys_prompt: str | None = None) -> str:
return super().image_to_text(image_path, sys_prompt)
@listen_toolkit(
BaseImageAnalysisToolkit.ask_question_about_image,
lambda _,
image_path,
question,
sys_prompt: f"transcribe image from {image_path} and ask question: {question} with sys_prompt: {sys_prompt}",
)
def ask_question_about_image(
self, image_path: str, question: str, sys_prompt: str | None = None
) -> str:
return super().ask_question_about_image(image_path, question, sys_prompt)

View file

@ -2,10 +2,11 @@ from camel.toolkits import LinkedInToolkit as BaseLinkedInToolkit
from camel.toolkits.function_tool import FunctionTool
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseLinkedInToolkit)
class LinkedInToolkit(BaseLinkedInToolkit, AbstractToolkit):
agent_name: str = Agents.social_medium_agent
@ -13,27 +14,6 @@ class LinkedInToolkit(BaseLinkedInToolkit, AbstractToolkit):
super().__init__(timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseLinkedInToolkit.create_post,
lambda _, text: f"create a LinkedIn post with text: {text}",
)
def create_post(self, text: str) -> dict:
return super().create_post(text)
@listen_toolkit(
BaseLinkedInToolkit.delete_post,
lambda _, post_id: f"delete LinkedIn post with id: {post_id}",
)
def delete_post(self, post_id: str) -> str:
return super().delete_post(post_id)
@listen_toolkit(
BaseLinkedInToolkit.get_profile,
lambda _, include_id: f"get LinkedIn profile with include_id: {include_id}",
)
def get_profile(self, include_id: bool = False) -> dict:
return super().get_profile(include_id)
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
if env("LINKEDIN_ACCESS_TOKEN"):

View file

@ -2,17 +2,14 @@ from typing import Dict, List
from camel.toolkits import MarkItDownToolkit as BaseMarkItDownToolkit
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseMarkItDownToolkit)
class MarkItDownToolkit(BaseMarkItDownToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent
def __init__(self, api_task_id: str, timeout: float | None = None):
self.api_task_id = api_task_id
super().__init__(timeout)
@listen_toolkit(BaseMarkItDownToolkit.read_files)
def read_files(self, file_paths: List[str]) -> Dict[str, str]:
return super().read_files(file_paths)

View file

@ -5,10 +5,11 @@ from typing import Optional
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseNoteTakingToolkit)
class NoteTakingToolkit(BaseNoteTakingToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent
@ -25,19 +26,3 @@ class NoteTakingToolkit(BaseNoteTakingToolkit, AbstractToolkit):
if working_directory is None:
working_directory = env("file_save_path", os.path.expanduser("~/.eigent/notes")) + "/note.md"
super().__init__(working_directory=working_directory, timeout=timeout)
@listen_toolkit(BaseNoteTakingToolkit.append_note)
def append_note(self, note_name: str, content: str) -> str:
return super().append_note(note_name=note_name, content=content)
@listen_toolkit(BaseNoteTakingToolkit.read_note)
def read_note(self, note_name: Optional[str] = "all_notes") -> str:
return super().read_note(note_name=note_name)
@listen_toolkit(BaseNoteTakingToolkit.create_note)
def create_note(self, note_name: str, content: str, overwrite: bool = False) -> str:
return super().create_note(note_name=note_name, content=content, overwrite=overwrite)
@listen_toolkit(BaseNoteTakingToolkit.list_note)
def list_note(self) -> str:
return super().list_note()

View file

@ -1,11 +1,36 @@
import os
import json
import asyncio
from textwrap import indent
from typing import Any, Dict, List
from loguru import logger
from camel.toolkits import FunctionTool
from app.component.environment import env
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from camel.toolkits.mcp_toolkit import MCPToolkit
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("notion_mcp_toolkit")
def _customize_function_parameters(schema: Dict[str, Any]) -> None:
r"""Customize function parameters for specific functions.
This method allows modifying parameter descriptions or other schema
attributes for specific functions.
"""
function_info = schema.get("function", {})
function_name = function_info.get("name", "")
parameters = function_info.get("parameters", {})
properties = parameters.get("properties", {})
required = parameters.get("required", [])
help_description = "If you need use parent, you can use `notion-search` for the information"
# Modify the notion-create-pages function to make parent optional
if function_name == "notion-create-pages" or function_name == "notion-create-database":
required.remove("parent")
parameters["required"] = required
if "parent" in properties:
# Update the parent parameter description
properties["parent"]["description"] = "Optional. " + properties["parent"]["description"] + help_description
class NotionMCPToolkit(MCPToolkit, AbstractToolkit):
@ -33,80 +58,57 @@ class NotionMCPToolkit(MCPToolkit, AbstractToolkit):
}
}
}
super().__init__(config_dict=config_dict, timeout=timeout)
def get_tools(self) -> List[FunctionTool]:
r"""Returns a list of tools provided by the NotionMCPToolkit.
Returns:
List[FunctionTool]: List of available tools.
"""
all_tools = []
for client in self.clients:
try:
original_build_schema = client._build_tool_schema
def create_wrapper(orig_func):
def wrapper(mcp_tool):
return self._build_custom_tool_schema(
mcp_tool, orig_func
)
return wrapper
client._build_tool_schema = create_wrapper( # type: ignore[method-assign]
original_build_schema
)
client_tools = client.get_tools()
all_tools.extend(client_tools)
client._build_tool_schema = original_build_schema # type: ignore[method-assign]
except Exception as e:
logger.error(f"Failed to get tools from client: {e}")
return all_tools
def _build_custom_tool_schema(self, mcp_tool, original_build_schema):
r"""Build tool schema with custom modifications."""
schema = original_build_schema(mcp_tool)
self._customize_function_parameters(schema)
return schema
def _customize_function_parameters(self, schema: Dict[str, Any]) -> None:
r"""Customize function parameters for specific functions.
This method allows modifying parameter descriptions or other schema
attributes for specific functions.
"""
function_info = schema.get("function", {})
function_name = function_info.get("name", "")
parameters = function_info.get("parameters", {})
properties = parameters.get("properties", {})
# Modify the notion-create-pages function to make parent optional
if function_name == "notion-create-pages":
if "parent" in properties:
# Update the parent parameter description
properties["parent"]["description"] = (
"Optional. The parent under which the new pages will be created. "
"This can be a page (page_id), a database page (database_id), or "
"a data source/collection under a database (data_source_id). "
"If omitted, the new pages will be created as private pages at the workspace level. "
"Use data_source_id when you have a collection:// URL from the fetch tool."
)
super().__init__(config_dict=config_dict, timeout=timeout)
@classmethod
async def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
tools = []
toolkit = cls(api_task_id)
try:
await toolkit.connect()
# Use subclass implementation that inlines upstream processing
all_tools = toolkit.get_tools()
for item in all_tools:
setattr(item, "_toolkit_name", cls.__name__)
tools.append(item)
except Exception as e:
print(f"Warning: Could not connect to Notion MCP server: {e}")
return tools
# Retry mechanism for remote MCP connection
max_retries = 3
retry_delay = 2 # seconds
for attempt in range(max_retries):
tools = []
toolkit = None
try:
# Create a fresh toolkit instance for each retry
toolkit = cls(api_task_id)
logger.info(f"Attempting to connect to Notion MCP server (attempt {attempt + 1}/{max_retries})")
await toolkit.connect()
# Get tools from the connected toolkit
all_tools = toolkit.get_tools()
tool_schema = [
item.get_openai_tool_schema() for item in all_tools
]
# Adjust tool schema
for item in tool_schema:
_customize_function_parameters(item)
for item in all_tools:
setattr(item, "_toolkit_name", cls.__name__)
tools.append(item)
# Check if we actually got tools
if len(tools) == 0:
logger.warning(f"Connected to Notion MCP server but got 0 tools (attempt {attempt + 1}/{max_retries})")
raise Exception("No tools retrieved from Notion MCP server")
# Success! Got tools
logger.info(f"Successfully connected to Notion MCP server and loaded {len(tools)} tools")
return tools
except Exception as e:
logger.warning(f"Failed to connect to Notion MCP server (attempt {attempt + 1}/{max_retries}): {e}")
# If not the last attempt, wait and retry
if attempt < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
# Last attempt failed
logger.error(f"All {max_retries} connection attempts to Notion MCP server failed. Notion tools will not be available for this task.")
return []

View file

@ -3,10 +3,11 @@ from camel.toolkits import NotionToolkit as BaseNotionToolkit
from camel.toolkits.function_tool import FunctionTool
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseNotionToolkit)
class NotionToolkit(BaseNotionToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent
@ -19,29 +20,6 @@ class NotionToolkit(BaseNotionToolkit, AbstractToolkit):
super().__init__(notion_token, timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseNotionToolkit.list_all_pages,
lambda _: "list all pages in Notion workspace",
lambda result: f"{len(result)} pages found",
)
def list_all_pages(self) -> List[dict]:
return super().list_all_pages()
@listen_toolkit(
BaseNotionToolkit.list_all_users,
lambda _: "list all users in Notion workspace",
lambda result: f"{len(result)} users found",
)
def list_all_users(self) -> List[dict]:
return super().list_all_users()
@listen_toolkit(
BaseNotionToolkit.get_notion_block_text_content,
lambda _, page_id: f"get text content of page with id: {page_id}",
)
def get_notion_block_text_content(self, block_id: str) -> str:
return super().get_notion_block_text_content(block_id)
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> List[FunctionTool]:
if env("NOTION_TOKEN"):

View file

@ -3,11 +3,12 @@ from camel.toolkits import OpenAIImageToolkit as BaseOpenAIImageToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from typing import Literal, Optional, Union, List
@auto_listen_toolkit(BaseOpenAIImageToolkit)
class OpenAIImageToolkit(BaseOpenAIImageToolkit, AbstractToolkit):
agent_name: str = Agents.multi_modal_agent

View file

@ -4,11 +4,12 @@ from camel.toolkits import PPTXToolkit as BasePPTXToolkit
from app.component.environment import env
from app.service.task import ActionWriteFileData, Agents, get_task_lock
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from app.service.task import process_task
@auto_listen_toolkit(BasePPTXToolkit)
class PPTXToolkit(BasePPTXToolkit, AbstractToolkit):
agent_name: str = Agents.document_agent

View file

@ -4,10 +4,11 @@ from camel.toolkits import PyAutoGUIToolkit as BasePyAutoGUIToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BasePyAutoGUIToolkit)
class PyAutoGUIToolkit(BasePyAutoGUIToolkit, AbstractToolkit):
agent_name: str = Agents.search_agent
@ -21,69 +22,3 @@ class PyAutoGUIToolkit(BasePyAutoGUIToolkit, AbstractToolkit):
screenshots_dir = env("file_save_path", os.path.expanduser("~/Downloads"))
super().__init__(timeout, screenshots_dir)
self.api_task_id = api_task_id
@listen_toolkit(BasePyAutoGUIToolkit.mouse_move, lambda _, x, y: f"mouse move to {x}, {y}")
def mouse_move(self, x: int, y: int) -> str:
return super().mouse_move(x, y)
@listen_toolkit(
BasePyAutoGUIToolkit.mouse_click,
lambda _, button="left", clicks=1, x=None, y=None: f"mouse click {button} {clicks} times at {x}, {y}",
)
def mouse_click(
self,
button: Literal["left", "middle", "right"] = "left",
clicks: int = 1,
x: int | None = None,
y: int | None = None,
) -> str:
return super().mouse_click(button, clicks, x, y)
@listen_toolkit(
BasePyAutoGUIToolkit.keyboard_type,
lambda _, text, interval=0: f"keyboard type {text}, interval {interval}",
)
def keyboard_type(self, text: str, interval: float = 0) -> str:
return super().keyboard_type(text, interval)
@listen_toolkit(BasePyAutoGUIToolkit.take_screenshot)
def take_screenshot(self) -> str:
return super().take_screenshot()
@listen_toolkit(BasePyAutoGUIToolkit.get_mouse_position)
def get_mouse_position(self) -> str:
return super().get_mouse_position()
@listen_toolkit(BasePyAutoGUIToolkit.press_key, lambda _, key: f"press key {key}")
def press_key(self, key: str | list[str]) -> str:
return super().press_key(key)
@listen_toolkit(BasePyAutoGUIToolkit.hotkey, lambda _, keys: f"hotkey {keys}")
def hotkey(self, keys: List[str]) -> str:
return super().hotkey(keys)
@listen_toolkit(
BasePyAutoGUIToolkit.mouse_drag,
lambda _,
start_x,
start_y,
end_x,
end_y,
button="left": f"mouse drag from {start_x}, {start_y} to {end_x}, {end_y} with {button} button",
)
def mouse_drag(
self,
start_x: int,
start_y: int,
end_x: int,
end_y: int,
button: Literal["left", "middle", "right"] = "left",
) -> str:
return super().mouse_drag(start_x, start_y, end_x, end_y, button)
@listen_toolkit(
BasePyAutoGUIToolkit.scroll,
lambda _, scroll_amount, x=None, y=None: f"scroll {scroll_amount} at {x}, {y}",
)
def scroll(self, scroll_amount: int, x: int | None = None, y: int | None = None) -> str:
return super().scroll(scroll_amount, x, y)

View file

@ -3,10 +3,11 @@ from camel.toolkits import RedditToolkit as BaseRedditToolkit
from camel.toolkits.function_tool import FunctionTool
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseRedditToolkit)
class RedditToolkit(BaseRedditToolkit, AbstractToolkit):
agent_name: str = Agents.social_medium_agent
@ -20,47 +21,6 @@ class RedditToolkit(BaseRedditToolkit, AbstractToolkit):
super().__init__(retries, delay, timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseRedditToolkit.collect_top_posts,
lambda _,
subreddit_name,
post_limit=5,
comment_limit=5: f"collect top posts from subreddit: {subreddit_name} with post limit: {post_limit} and comment limit: {comment_limit}",
lambda result: f"top posts collected: {result}",
)
def collect_top_posts(
self, subreddit_name: str, post_limit: int = 5, comment_limit: int = 5
) -> List[Dict[str, Any]] | str:
return super().collect_top_posts(subreddit_name, post_limit, comment_limit)
@listen_toolkit(
BaseRedditToolkit.perform_sentiment_analysis,
lambda _, data: f"perform sentiment analysis on data number: {len(data)}",
lambda result: f"perform analysis result: {result}",
)
def perform_sentiment_analysis(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return super().perform_sentiment_analysis(data)
@listen_toolkit(
BaseRedditToolkit.track_keyword_discussions,
lambda _,
subreddits,
keywords,
post_limit=10,
comment_limit=10,
sentiment_analysis=False: f"track keyword discussions for subreddits: {subreddits}, keywords: {keywords}",
lambda result: f"track keyword discussions result: {result}",
)
def track_keyword_discussions(
self,
subreddits: List[str],
keywords: List[str],
post_limit: int = 10,
comment_limit: int = 10,
sentiment_analysis: bool = False,
) -> List[Dict[str, Any]] | str:
return super().track_keyword_discussions(subreddits, keywords, post_limit, comment_limit, sentiment_analysis)
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
if env("REDDIT_CLIENT_ID") and env("REDDIT_CLIENT_SECRET") and env("REDDIT_USER_AGENT"):

View file

@ -3,10 +3,11 @@ from camel.toolkits import ScreenshotToolkit as BaseScreenshotToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseScreenshotToolkit)
class ScreenshotToolkit(BaseScreenshotToolkit, AbstractToolkit):
agent_name: str = Agents.developer_agent
@ -15,13 +16,3 @@ class ScreenshotToolkit(BaseScreenshotToolkit, AbstractToolkit):
if working_directory is None:
working_directory = env("file_save_path", os.path.expanduser("~/Downloads"))
super().__init__(working_directory, timeout)
@listen_toolkit(BaseScreenshotToolkit.take_screenshot_and_read_image)
def take_screenshot_and_read_image(
self, filename: str, save_to_file: bool = True, read_image: bool = True, instruction: str | None = None
) -> str:
return super().take_screenshot_and_read_image(filename, save_to_file, read_image, instruction)
@listen_toolkit(BaseScreenshotToolkit.read_image)
def read_image(self, image_path: str, instruction: str = "") -> str:
return super().read_image(image_path, instruction)

View file

@ -2,13 +2,17 @@ from typing import Any, Dict, List, Literal
from camel.toolkits import SearchToolkit as BaseSearchToolkit
from camel.toolkits.function_tool import FunctionTool
import httpx
from loguru import logger
import os
from app.component.environment import env, env_not_empty
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("search_toolkit")
@auto_listen_toolkit(BaseSearchToolkit)
class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
agent_name: str = Agents.search_agent
@ -25,6 +29,32 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
super().__init__(
timeout=timeout, exclude_domains=exclude_domains
)
# Cache for user-specific search configurations
self._user_google_api_key = None
self._user_search_engine_id = None
self._config_loaded = False
def _load_user_search_config(self):
"""
Load user-specific Google Search configuration from user's .env file.
This is called lazily when search_google is invoked.
"""
if self._config_loaded:
return
self._config_loaded = True
# Try to get user-specific configuration from thread-local environment
# which is set by the middleware based on the user's project settings
google_api_key = env("GOOGLE_API_KEY")
search_engine_id = env("SEARCH_ENGINE_ID")
if google_api_key and search_engine_id:
self._user_google_api_key = google_api_key
self._user_search_engine_id = search_engine_id
logger.info("Loaded user-specific Google Search configuration")
else:
logger.debug("No user-specific Google Search configuration found, will use cloud search")
# @listen_toolkit(BaseSearchToolkit.search_wiki)
# def search_wiki(self, entity: str) -> str:
@ -50,19 +80,61 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
@listen_toolkit(
BaseSearchToolkit.search_google,
lambda _, query, search_type="web": f"with query '{query}' and {search_type} result pages",
lambda _, query, search_type="web", number_of_result_pages=10, start_page=1: f"with query '{query}', {search_type} type, {number_of_result_pages} result pages starting from page {start_page}",
)
def search_google(self, query: str, search_type: str = "web") -> list[dict[str, Any]]:
if env("GOOGLE_API_KEY") and env("SEARCH_ENGINE_ID"):
return super().search_google(query, search_type)
else:
return self.cloud_search_google(query, search_type)
def search_google(
self,
query: str,
search_type: str = "web",
number_of_result_pages: int = 10,
start_page: int = 1
) -> list[dict[str, Any]]:
# Load user-specific configuration
self._load_user_search_config()
def cloud_search_google(self, query: str, search_type):
# If user has configured their own Google API keys, use them
if self._user_google_api_key and self._user_search_engine_id:
logger.info("Using user-configured Google Search API")
# Temporarily set environment variables for this search
old_google_key = os.environ.get("GOOGLE_API_KEY")
old_search_id = os.environ.get("SEARCH_ENGINE_ID")
try:
os.environ["GOOGLE_API_KEY"] = self._user_google_api_key
os.environ["SEARCH_ENGINE_ID"] = self._user_search_engine_id
return super().search_google(query, search_type, number_of_result_pages, start_page)
finally:
# Restore original environment variables
if old_google_key is not None:
os.environ["GOOGLE_API_KEY"] = old_google_key
elif "GOOGLE_API_KEY" in os.environ:
del os.environ["GOOGLE_API_KEY"]
if old_search_id is not None:
os.environ["SEARCH_ENGINE_ID"] = old_search_id
elif "SEARCH_ENGINE_ID" in os.environ:
del os.environ["SEARCH_ENGINE_ID"]
else:
# Fallback to cloud search
logger.info("Using cloud Google Search (no user configuration found)")
return self.cloud_search_google(query, search_type, number_of_result_pages, start_page)
def cloud_search_google(
self,
query: str,
search_type: str = "web",
number_of_result_pages: int = 10,
start_page: int = 1
):
url = env_not_empty("SERVER_URL")
res = httpx.get(
url + "/proxy/google",
params={"query": query, "search_type": search_type},
params={
"query": query,
"search_type": search_type,
"number_of_result_pages": number_of_result_pages,
"start_page": start_page
},
headers={"api-key": env_not_empty("cloud_api_key")},
)
return res.json()
@ -163,73 +235,73 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
# def search_bing(self, query: str) -> dict[str, Any]:
# return super().search_bing(query)
@listen_toolkit(BaseSearchToolkit.search_exa, lambda _, query, *args, **kwargs: f"{query}, {args}, {kwargs}")
def search_exa(
self,
query: str,
search_type: Literal["auto", "neural", "keyword"] = "auto",
category: None
| Literal[
"company",
"research paper",
"news",
"pdf",
"github",
"tweet",
"personal site",
"linkedin profile",
"financial report",
] = None,
include_text: List[str] | None = None,
exclude_text: List[str] | None = None,
use_autoprompt: bool = True,
text: bool = False,
) -> Dict[str, Any]:
if env("EXA_API_KEY"):
res = super().search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
return res
else:
return self.cloud_search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
def cloud_search_exa(
self,
query: str,
search_type: Literal["auto", "neural", "keyword"] = "auto",
category: None
| Literal[
"company",
"research paper",
"news",
"pdf",
"github",
"tweet",
"personal site",
"linkedin profile",
"financial report",
] = None,
include_text: List[str] | None = None,
exclude_text: List[str] | None = None,
use_autoprompt: bool = True,
text: bool = False,
):
url = env_not_empty("SERVER_URL")
logger.debug(f">>>>>>>>>>>>>>>>{url}<<<<")
res = httpx.post(
url + "/proxy/exa",
json={
"query": query,
"search_type": search_type,
"category": category,
"include_text": include_text,
"exclude_text": exclude_text,
"use_autoprompt": use_autoprompt,
"text": text,
},
headers={"api-key": env_not_empty("cloud_api_key")},
)
logger.debug(">>>>>>>>>>>>>>>>>")
logger.debug(res)
return res.json()
# @listen_toolkit(BaseSearchToolkit.search_exa, lambda _, query, *args, **kwargs: f"{query}, {args}, {kwargs}")
# def search_exa(
# self,
# query: str,
# search_type: Literal["auto", "neural", "keyword"] = "auto",
# category: None
# | Literal[
# "company",
# "research paper",
# "news",
# "pdf",
# "github",
# "tweet",
# "personal site",
# "linkedin profile",
# "financial report",
# ] = None,
# include_text: List[str] | None = None,
# exclude_text: List[str] | None = None,
# use_autoprompt: bool = True,
# text: bool = False,
# ) -> Dict[str, Any]:
# if env("EXA_API_KEY"):
# res = super().search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
# return res
# else:
# return self.cloud_search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
#
# def cloud_search_exa(
# self,
# query: str,
# search_type: Literal["auto", "neural", "keyword"] = "auto",
# category: None
# | Literal[
# "company",
# "research paper",
# "news",
# "pdf",
# "github",
# "tweet",
# "personal site",
# "linkedin profile",
# "financial report",
# ] = None,
# include_text: List[str] | None = None,
# exclude_text: List[str] | None = None,
# use_autoprompt: bool = True,
# text: bool = False,
# ):
# url = env_not_empty("SERVER_URL")
# logger.debug(f">>>>>>>>>>>>>>>>{url}<<<<")
# res = httpx.post(
# url + "/proxy/exa",
# json={
# "query": query,
# "search_type": search_type,
# "category": category,
# "include_text": include_text,
# "exclude_text": exclude_text,
# "use_autoprompt": use_autoprompt,
# "text": text,
# },
# headers={"api-key": env_not_empty("cloud_api_key")},
# )
# logger.debug(">>>>>>>>>>>>>>>>>")
# logger.debug(res)
# return res.json()
# @listen_toolkit(
# BaseSearchToolkit.search_alibaba_tongxiao,
@ -289,12 +361,12 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
# if env("BOCHA_API_KEY"):
# tools.append(FunctionTool(search_toolkit.search_bocha))
if env("EXA_API_KEY") or env("cloud_api_key"):
tools.append(FunctionTool(search_toolkit.search_exa))
# if env("EXA_API_KEY") or env("cloud_api_key"):
# tools.append(FunctionTool(search_toolkit.search_exa))
# if env("TONGXIAO_API_KEY"):
# tools.append(FunctionTool(search_toolkit.search_alibaba_tongxiao))
return tools
def get_tools(self) -> List[FunctionTool]:
return [FunctionTool(self.search_exa)]
# def get_tools(self) -> List[FunctionTool]:
# return [FunctionTool(self.search_exa)]

View file

@ -1,12 +1,15 @@
from camel.toolkits import SlackToolkit as BaseSlackToolkit
from camel.toolkits.function_tool import FunctionTool
from loguru import logger
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("slack_toolkit")
@auto_listen_toolkit(BaseSlackToolkit)
class SlackToolkit(BaseSlackToolkit, AbstractToolkit):
agent_name: str = Agents.social_medium_agent
@ -14,71 +17,6 @@ class SlackToolkit(BaseSlackToolkit, AbstractToolkit):
super().__init__(timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseSlackToolkit.create_slack_channel,
lambda _, name, is_private=True: f"create a Slack channel with name: {name} and is_private: {is_private}",
)
def create_slack_channel(self, name: str, is_private: bool | None = True) -> str:
return super().create_slack_channel(name, is_private)
@listen_toolkit(
BaseSlackToolkit.join_slack_channel,
lambda _, channel_id: f"join Slack channel with id: {channel_id}",
)
def join_slack_channel(self, channel_id: str) -> str:
return super().join_slack_channel(channel_id)
@listen_toolkit(
BaseSlackToolkit.leave_slack_channel,
lambda _, channel_id: f"leave Slack channel with id: {channel_id}",
)
def leave_slack_channel(self, channel_id: str) -> str:
return super().leave_slack_channel(channel_id)
@listen_toolkit(
BaseSlackToolkit.get_slack_channel_information,
lambda _: "get Slack channel information",
)
def get_slack_channel_information(self) -> str:
return super().get_slack_channel_information()
@listen_toolkit(
BaseSlackToolkit.get_slack_channel_message,
lambda _, channel_id: f"get Slack channel message for channel id: {channel_id}",
)
def get_slack_channel_message(self, channel_id: str) -> str:
return super().get_slack_channel_message(channel_id)
@listen_toolkit(
BaseSlackToolkit.send_slack_message,
lambda _, message, channel_id, file_path=None, user=None: f"send Slack message: {message} to channel id: {channel_id}, file: {file_path}, user: {user}",
)
def send_slack_message(self, message: str, channel_id: str, file_path: str | None = None, user: str | None = None) -> str:
return super().send_slack_message(message, channel_id, file_path, user)
@listen_toolkit(
BaseSlackToolkit.delete_slack_message,
lambda _,
time_stamp,
channel_id: f"delete Slack message with timestamp: {time_stamp} in channel id: {channel_id}",
)
def delete_slack_message(self, time_stamp: str, channel_id: str) -> str:
return super().delete_slack_message(time_stamp, channel_id)
@listen_toolkit(
BaseSlackToolkit.get_slack_user_list,
lambda _: "get Slack user list",
)
def get_slack_user_list(self) -> str:
return super().get_slack_user_list()
@listen_toolkit(
BaseSlackToolkit.get_slack_user_info,
lambda _, user_id: f"get Slack user info with user id: {user_id}",
)
def get_slack_user_info(self, user_id: str) -> str:
return super().get_slack_user_info(user_id)
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
logger.debug(f"slack===={env('SLACK_BOT_TOKEN')}")

View file

@ -1,16 +1,23 @@
import asyncio
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from camel.toolkits.terminal_toolkit import TerminalToolkit as BaseTerminalToolkit
from camel.toolkits.terminal_toolkit.terminal_toolkit import _to_plain
from app.component.environment import env
from app.service.task import Action, ActionTerminalData, Agents, get_task_lock
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
from app.service.task import process_task
@auto_listen_toolkit(BaseTerminalToolkit)
class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
agent_name: str = Agents.developer_agent
_thread_pool: Optional[ThreadPoolExecutor] = None
_thread_local = threading.local()
def __init__(
self,
@ -30,6 +37,11 @@ class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
self.agent_name = agent_name
if working_directory is None:
working_directory = env("file_save_path", os.path.expanduser("~/.eigent/terminal/"))
if TerminalToolkit._thread_pool is None:
TerminalToolkit._thread_pool = ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="terminal_toolkit"
)
super().__init__(
timeout=timeout,
working_directory=working_directory,
@ -54,58 +66,57 @@ class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
def _update_terminal_output(self, output: str):
task_lock = get_task_lock(self.api_task_id)
# This method will be called during init. At that time, the process_task_id parameter does not exist, so it is set to be empty default
process_task_id = process_task.get("")
task = asyncio.create_task(
task_lock.put_queue(
ActionTerminalData(
action=Action.terminal,
process_task_id=process_task_id,
data=output,
)
# Create the coroutine
coro = task_lock.put_queue(
ActionTerminalData(
action=Action.terminal,
process_task_id=process_task_id,
data=output,
)
)
if hasattr(task_lock, "add_background_task"):
task_lock.add_background_task(task)
@listen_toolkit(
BaseTerminalToolkit.shell_exec,
lambda _, id, command, block=True: f"id: {id}, command: {command}, block: {block}",
)
def shell_exec(self, id: str, command: str, block: bool = True) -> str:
return super().shell_exec(id=id, command=command, block=block)
# Try to get the current event loop, if none exists, create a new one in a thread
try:
loop = asyncio.get_running_loop()
# If we're in an async context, schedule the coroutine
task = loop.create_task(coro)
if hasattr(task_lock, "add_background_task"):
task_lock.add_background_task(task)
except RuntimeError:
self._thread_pool.submit(self._run_coro_in_thread, coro,task_lock)
@listen_toolkit(
BaseTerminalToolkit.shell_view,
lambda _, id: f"id: {id}",
)
def shell_view(self, id: str) -> str:
return super().shell_view(id)
@staticmethod
def _run_coro_in_thread(coro,task_lock):
"""
Execute coro in the thread pool, with each thread bound to a long-term event loop
"""
if not hasattr(TerminalToolkit._thread_local, "loop"):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
TerminalToolkit._thread_local.loop = loop
else:
loop = TerminalToolkit._thread_local.loop
@listen_toolkit(
BaseTerminalToolkit.shell_wait,
lambda _, id, wait_seconds=None: f"id: {id}, wait_seconds: {wait_seconds}",
)
def shell_wait(self, id: str, wait_seconds: float = 5.0) -> str:
return super().shell_wait(id=id, wait_seconds=wait_seconds)
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
TerminalToolkit._thread_local.loop = loop
@listen_toolkit(
BaseTerminalToolkit.shell_write_to_process,
lambda _, id, command: f"id: {id}, command: {command}",
)
def shell_write_to_process(self, id: str, command: str) -> str:
return super().shell_write_to_process(id=id, command=command)
try:
task = loop.create_task(coro)
if hasattr(task_lock, "add_background_task"):
task_lock.add_background_task(task)
loop.run_until_complete(task)
except Exception as e:
logging.error(
f"Failed to execute coroutine in thread pool: {str(e)}",
exc_info=True
)
@listen_toolkit(
BaseTerminalToolkit.shell_kill_process,
lambda _, id: f"id: {id}",
)
def shell_kill_process(self, id: str) -> str:
return super().shell_kill_process(id=id)
@listen_toolkit(
BaseTerminalToolkit.shell_ask_user_for_help,
lambda _, id, prompt: f"id: {id}, prompt: {prompt}",
)
def shell_ask_user_for_help(self, id: str, prompt: str) -> str:
return super().shell_ask_user_for_help(id=id, prompt=prompt)
@classmethod
def shutdown(cls):
if cls._thread_pool:
cls._thread_pool.shutdown(wait=True)
cls._thread_pool = None

View file

@ -1,40 +1,13 @@
from camel.toolkits import ThinkingToolkit as BaseThinkingToolkit
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseThinkingToolkit)
class ThinkingToolkit(BaseThinkingToolkit, AbstractToolkit):
def __init__(self, api_task_id: str, agent_name: str, timeout: float | None = None):
super().__init__(timeout)
self.api_task_id = api_task_id
self.agent_name = agent_name
@listen_toolkit(BaseThinkingToolkit.plan)
def plan(self, plan: str) -> str:
return super().plan(plan)
@listen_toolkit(BaseThinkingToolkit.hypothesize)
def hypothesize(self, hypothesis: str) -> str:
return super().hypothesize(hypothesis)
@listen_toolkit(BaseThinkingToolkit.think)
def think(self, thought: str) -> str:
return super().think(thought)
@listen_toolkit(BaseThinkingToolkit.contemplate)
def contemplate(self, contemplation: str) -> str:
return super().contemplate(contemplation)
@listen_toolkit(BaseThinkingToolkit.critique)
def critique(self, critique: str) -> str:
return super().critique(critique)
@listen_toolkit(BaseThinkingToolkit.synthesize)
def synthesize(self, synthesis: str) -> str:
return super().synthesize(synthesis)
@listen_toolkit(BaseThinkingToolkit.reflect)
def reflect(self, reflection: str) -> str:
return super().reflect(reflection)

View file

@ -9,10 +9,11 @@ from camel.toolkits.twitter_toolkit import (
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseTwitterToolkit)
class TwitterToolkit(BaseTwitterToolkit, AbstractToolkit):
agent_name: str = Agents.social_medium_agent

View file

@ -4,10 +4,11 @@ from camel.toolkits import VideoAnalysisToolkit as BaseVideoAnalysisToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseVideoAnalysisToolkit)
class VideoAnalysisToolkit(BaseVideoAnalysisToolkit, AbstractToolkit):
agent_name: str = Agents.multi_modal_agent
@ -36,10 +37,3 @@ class VideoAnalysisToolkit(BaseVideoAnalysisToolkit, AbstractToolkit):
cookies_path,
timeout,
)
@listen_toolkit(
BaseVideoAnalysisToolkit.ask_question_about_video,
lambda _, video_path, question: f"transcribe video from {video_path} and ask question: {question}",
)
def ask_question_about_video(self, video_path: str, question: str) -> str:
return super().ask_question_about_video(video_path, question)

View file

@ -5,10 +5,11 @@ from camel.toolkits import VideoDownloaderToolkit as BaseVideoDownloaderToolkit
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseVideoDownloaderToolkit)
class VideoDownloaderToolkit(BaseVideoDownloaderToolkit, AbstractToolkit):
agent_name: str = Agents.multi_modal_agent
@ -23,23 +24,3 @@ class VideoDownloaderToolkit(BaseVideoDownloaderToolkit, AbstractToolkit):
working_directory = env("file_save_path", os.path.expanduser("~/Downloads"))
super().__init__(working_directory, cookies_path, timeout)
self.api_task_id = api_task_id
@listen_toolkit(BaseVideoDownloaderToolkit.download_video)
def download_video(self, url: str) -> str:
return super().download_video(url)
@listen_toolkit(
BaseVideoDownloaderToolkit.get_video_bytes,
lambda _, video_path: f"get video bytes from {video_path}",
lambda _: "get video bytes",
)
def get_video_bytes(self, video_path: str) -> bytes:
return super().get_video_bytes(video_path)
@listen_toolkit(
BaseVideoDownloaderToolkit.get_video_screenshots,
lambda _, video_path, amount: f"get video screenshots from {video_path}, amount: {amount}",
lambda results: f"get video screenshots {len(results)}",
)
def get_video_screenshots(self, video_path: str, amount: int) -> List[Image]:
return super().get_video_screenshots(video_path, amount)

View file

@ -3,10 +3,11 @@ from typing import Any, Dict
from camel.toolkits import WebDeployToolkit as BaseWebDeployToolkit
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseWebDeployToolkit)
class WebDeployToolkit(BaseWebDeployToolkit, AbstractToolkit):
agent_name: str = Agents.developer_agent
@ -43,11 +44,3 @@ class WebDeployToolkit(BaseWebDeployToolkit, AbstractToolkit):
) -> Dict[str, Any]:
subdirectory = str(uuid.uuid4())
return super().deploy_folder(folder_path, port, domain, subdirectory)
@listen_toolkit(BaseWebDeployToolkit.stop_server)
def stop_server(self, port: int) -> Dict[str, Any]:
return super().stop_server(port)
@listen_toolkit(BaseWebDeployToolkit.list_running_servers)
def list_running_servers(self) -> Dict[str, Any]:
return super().list_running_servers()

View file

@ -3,10 +3,11 @@ from camel.toolkits import WhatsAppToolkit as BaseWhatsAppToolkit
from camel.toolkits.function_tool import FunctionTool
from app.component.environment import env
from app.service.task import Agents
from app.utils.listen.toolkit_listen import listen_toolkit
from app.utils.listen.toolkit_listen import auto_listen_toolkit
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
@auto_listen_toolkit(BaseWhatsAppToolkit)
class WhatsAppToolkit(BaseWhatsAppToolkit, AbstractToolkit):
agent_name: str = Agents.social_medium_agent
@ -14,30 +15,6 @@ class WhatsAppToolkit(BaseWhatsAppToolkit, AbstractToolkit):
super().__init__(timeout)
self.api_task_id = api_task_id
@listen_toolkit(
BaseWhatsAppToolkit.send_message,
lambda _, to, message: f"send message to {to}: {message}",
lambda result: f"message sent result: {result}",
)
def send_message(self, to: str, message: str) -> Dict[str, Any] | str:
return super().send_message(to, message)
@listen_toolkit(
BaseWhatsAppToolkit.get_message_templates,
lambda _: "get message templates",
lambda result: f"message templates: {result}",
)
def get_message_templates(self) -> List[Dict[str, Any]] | str:
return super().get_message_templates()
@listen_toolkit(
BaseWhatsAppToolkit.get_business_profile,
lambda _: "get business profile",
lambda result: f"business profile: {result}",
)
def get_business_profile(self) -> Dict[str, Any] | str:
return super().get_business_profile()
@classmethod
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
if env("WHATSAPP_ACCESS_TOKEN") and env("WHATSAPP_PHONE_NUMBER_ID"):

View file

@ -1,35 +0,0 @@
"""Conditional traceroot wrapper - only loads if .traceroot-config.yaml exists."""
from pathlib import Path
from typing import Callable
def _find_config() -> bool:
"""Check if .traceroot-config.yaml exists in current or parent directories."""
path = Path.cwd()
for _ in range(5):
if (path / ".traceroot-config.yaml").exists():
return True
if path == path.parent:
break
path = path.parent
return False
# Load traceroot only if config exists
if _find_config():
import traceroot
trace = traceroot.trace
get_logger = traceroot.get_logger
else:
# No-op implementations
def trace():
def decorator(func: Callable) -> Callable:
return func
return decorator
class _NoOpLogger:
def __getattr__(self, name):
return lambda *args, **kwargs: None
def get_logger(name: str):
return _NoOpLogger()

View file

@ -9,7 +9,6 @@ from camel.societies.workforce.workforce import (
from camel.societies.workforce.task_channel import TaskChannel
from camel.societies.workforce.base import BaseNode
from camel.societies.workforce.utils import TaskAssignResult
from loguru import logger
from camel.tasks.task import Task, TaskState, validate_task_content
from app.component import code
from app.exception.exception import UserException
@ -18,30 +17,16 @@ from app.service.task import (
Action,
ActionAssignTaskData,
ActionEndData,
ActionNewTaskStateData,
ActionTaskStateData,
get_camel_task,
get_task_lock,
)
from app.utils.single_agent_worker import SingleAgentWorker
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("workforce")
# === Debug sink === Write detailed dependency debug logs to file (logs/workforce_debug.log)
# Create a new file every day, keep the logs for the last 7 days, and write asynchronously without blocking the main process
logger.add(
"logs/workforce_debug_{time:YYYY-MM-DD}.log",
rotation="00:00",
retention="7 days",
enqueue=True,
level="DEBUG",
)
# Independent sink: only collect the "[WF]" debug lines we insert to quickly view the dependency chain
logger.add(
"logs/wf_trace_{time:YYYY-MM-DD-HH}.log",
rotation="00:00",
retention="7 days",
enqueue=True,
level="DEBUG",
filter=lambda record: record["message"].startswith("[WF]"),
)
class Workforce(BaseWorkforce):
@ -69,8 +54,15 @@ class Workforce(BaseWorkforce):
use_structured_output_handler=use_structured_output_handler,
)
def eigent_make_sub_tasks(self, task: Task):
"""split process_task method to eigent_make_sub_tasks and eigent_start method"""
def eigent_make_sub_tasks(self, task: Task, coordinator_context: str = ""):
"""
Split process_task method to eigent_make_sub_tasks and eigent_start method.
Args:
task: The main task to decompose
coordinator_context: Optional context ONLY for coordinator agent during decomposition.
This context will NOT be passed to subtasks or worker agents.
"""
if not validate_task_content(task.content, task.id):
task.state = TaskState.FAILED
@ -85,10 +77,20 @@ class Workforce(BaseWorkforce):
self.set_channel(TaskChannel())
self._state = WorkforceState.RUNNING
task.state = TaskState.OPEN
self._pending_tasks.append(task)
# Decompose the task into subtasks first
subtasks_result = self._decompose_task(task)
if coordinator_context:
original_content = task.content
task_with_context = coordinator_context
if coordinator_context:
task_with_context += "\n=== CURRENT TASK ===\n"
task_with_context += original_content
task.content = task_with_context
subtasks_result = self._decompose_task(task)
task.content = original_content
else:
subtasks_result = self._decompose_task(task)
# Handle both streaming and non-streaming results
if isinstance(subtasks_result, Generator):
@ -119,6 +121,64 @@ class Workforce(BaseWorkforce):
if self._state != WorkforceState.STOPPED:
self._state = WorkforceState.IDLE
async def handle_decompose_append_task(
self, task: Task, reset: bool = True, coordinator_context: str = ""
) -> List[Task]:
"""
Override to support coordinator_context parameter.
Handle task decomposition and validation, then append to pending tasks.
Args:
task: The task to be processed
reset: Should trigger workforce reset (Workforce must not be running)
coordinator_context: Optional context ONLY for coordinator during decomposition
Returns:
List[Task]: The decomposed subtasks or the original task
"""
if not validate_task_content(task.content, task.id):
task.state = TaskState.FAILED
task.result = "Task failed: Invalid or empty content provided"
logger.warning(
f"Task {task.id} rejected: Invalid or empty content. "
f"Content preview: '{task.content}'"
)
return [task]
if reset and self._state != WorkforceState.RUNNING:
self.reset()
logger.info("Workforce reset before handling task.")
self._task = task
task.state = TaskState.FAILED
if coordinator_context:
original_content = task.content
task_with_context = coordinator_context
if coordinator_context:
task_with_context += "\n=== CURRENT TASK ===\n"
task_with_context += original_content
task.content = task_with_context
subtasks_result = self._decompose_task(task)
task.content = original_content
else:
subtasks_result = self._decompose_task(task)
if isinstance(subtasks_result, Generator):
subtasks = []
for new_tasks in subtasks_result:
subtasks.extend(new_tasks)
else:
subtasks = subtasks_result
if subtasks:
self._pending_tasks.extendleft(reversed(subtasks))
logger.info(f"Appended {len(subtasks)} subtasks to pending tasks")
return subtasks if subtasks else [task]
async def _find_assignee(self, tasks: List[Task]) -> TaskAssignResult:
# Task assignment phase: send "waiting for execution" notification to the frontend, and send "start execution" notification when the task actually begins execution
assigned = await super()._find_assignee(tasks)
@ -133,7 +193,9 @@ class Workforce(BaseWorkforce):
# Find task content
task_obj = get_camel_task(item.task_id, tasks)
if task_obj is None:
logger.warning(f"[WF] WARN: Task {item.task_id} not found in tasks list during ASSIGN phase. This may indicate a task tree inconsistency.")
logger.warning(
f"[WF] WARN: Task {item.task_id} not found in tasks list during ASSIGN phase. This may indicate a task tree inconsistency."
)
content = ""
else:
content = task_obj.content
@ -179,7 +241,11 @@ class Workforce(BaseWorkforce):
await super()._post_task(task, assignee_id)
def add_single_agent_worker(
self, description: str, worker: ListenChatAgent, pool_max_size: int = DEFAULT_WORKER_POOL_SIZE
self,
description: str,
worker: ListenChatAgent,
pool_max_size: int = DEFAULT_WORKER_POOL_SIZE,
enable_workflow_memory: bool = False,
) -> BaseWorkforce:
if self._state == WorkforceState.RUNNING:
raise RuntimeError("Cannot add workers while workforce is running. Pause the workforce first.")
@ -195,6 +261,8 @@ class Workforce(BaseWorkforce):
worker=worker,
pool_max_size=pool_max_size,
use_structured_output_handler=self.use_structured_output_handler,
context_utility=None, # Will be set during save/load operations
enable_workflow_memory=enable_workflow_memory,
)
self._children.append(worker_node)
@ -218,17 +286,33 @@ class Workforce(BaseWorkforce):
logger.debug(f"[WF] DONE {task.id}")
task_lock = get_task_lock(self.api_task_id)
await task_lock.put_queue(
ActionTaskStateData(
data={
"task_id": task.id,
"content": task.content,
"state": task.state,
"result": task.result or "",
"failure_count": task.failure_count,
},
# Log task completion with result details
is_main_task = self._task and task.id == self._task.id
task_type = "MAIN TASK" if is_main_task else "SUB-TASK"
logger.info(f"[TASK-RESULT] {task_type} COMPLETED: {task.id}")
logger.info(f"[TASK-RESULT] Content: {task.content[:200]}..." if len(task.content) > 200 else f"[TASK-RESULT] Content: {task.content}")
logger.info(f"[TASK-RESULT] Result: {task.result[:500]}..." if task.result and len(str(task.result)) > 500 else f"[TASK-RESULT] Result: {task.result}")
task_data = {
"task_id": task.id,
"content": task.content,
"state": task.state,
"result": task.result or "",
"failure_count": task.failure_count,
}
if self._task_is_new(task_data):
await task_lock.put_queue(
ActionNewTaskStateData(
data=task_data
)
)
else:
await task_lock.put_queue(
ActionTaskStateData(
data=task_data
)
)
)
return await super()._handle_completed_task(task)
@ -260,6 +344,36 @@ class Workforce(BaseWorkforce):
return result
def _task_is_new(self, item:dict) -> bool:
# Validate the task state data object first
assert isinstance(item, dict)
task_id = item.get("task_id", "")
state = item.get("state", "")
result = item.get("result", "")
failure_count = item.get("failure_count", 0)
# Validate required fields
if not task_id:
logger.error("Missing task_id in task_state data")
return False
elif not state:
logger.error(f"Missing state in task_state data for task {task_id}")
return False
# Ensure failure_count is an integer
try:
failure_count = int(failure_count)
except (ValueError, TypeError):
logger.error(f"Invalid failure_count in task_state data for task {task_id}: {failure_count}")
failure_count = 0 # Default to 0 if invalid
should_send_new_task_state = (
state == "FAILED" or
(failure_count == 0 and result.strip() == "")
)
return should_send_new_task_state
def stop(self) -> None:
super().stop()
task_lock = get_task_lock(self.api_task_id)

View file

@ -1,37 +1,58 @@
import os
import sys
import pathlib
import signal
import asyncio
import atexit
# Add project root to Python path to import shared utils
_project_root = pathlib.Path(__file__).parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
# 1) Load env and init traceroot BEFORE importing modules that get a logger
from utils import traceroot_wrapper as traceroot
from app import api
from loguru import logger
from app.component.environment import auto_include_routers, env
# Only initialize traceroot if enabled
if traceroot.is_enabled():
from traceroot.integrations.fastapi import connect_fastapi
connect_fastapi(api)
# 2) Now safe to import modules that use traceroot.get_logger() at import-time
from app.component.environment import env
from app.router import register_routers
os.environ["PYTHONIOENCODING"] = "utf-8"
app_logger = traceroot.get_logger("main")
# Log application startup
logger.info("Starting Eigent Multi-Agent System API")
logger.info(f"Python encoding: {os.environ.get('PYTHONIOENCODING')}")
logger.info(f"Environment: {os.environ.get('ENVIRONMENT', 'development')}")
app_logger.info("Starting Eigent Multi-Agent System API")
app_logger.info(f"Python encoding: {os.environ.get('PYTHONIOENCODING')}")
app_logger.info(f"Environment: {os.environ.get('ENVIRONMENT', 'development')}")
prefix = env("url_prefix", "")
logger.info(f"Loading routers with prefix: '{prefix}'")
auto_include_routers(api, prefix, "app/controller")
logger.info("All routers loaded successfully")
app_logger.info(f"Loading routers with prefix: '{prefix}'")
register_routers(api, prefix)
app_logger.info("All routers loaded successfully")
# Check if debug mode is enabled via environment variable
if os.environ.get('ENABLE_PYTHON_DEBUG') == 'true':
try:
import debugpy
DEBUG_PORT = int(os.environ.get('DEBUG_PORT', '5678'))
app_logger.info(f"Debug mode enabled - Starting debugpy server on port {DEBUG_PORT}")
debugpy.listen(("localhost", DEBUG_PORT))
app_logger.info(f"Debugger ready for attachment on localhost:{DEBUG_PORT}")
#📝 In VS Code: Run 'Debug Python Backend (Attach)' configuration
# Don't wait for client automatically - let it attach when ready
except ImportError:
app_logger.warning("debugpy not available, install with: uv add debugpy")
except Exception as e:
app_logger.error(f"Failed to start debugpy: {e}")
# Configure Loguru
log_path = os.path.expanduser("~/.eigent/runtime/log/app.log")
os.makedirs(os.path.dirname(log_path), exist_ok=True)
logger.add(
log_path, # Log file
rotation="10 MB", # Log rotation: 10MB per file
retention="10 days", # Retain logs for the last 10 days
level="DEBUG", # Log level
encoding="utf-8",
)
logger.info(f"Loguru configured with log file: {log_path}")
dir = pathlib.Path(__file__).parent / "runtime"
dir.mkdir(parents=True, exist_ok=True)
@ -44,12 +65,12 @@ async def write_pid_file():
async with aiofiles.open(dir / "run.pid", "w") as f:
await f.write(str(os.getpid()))
logger.info(f"PID file written: {os.getpid()}")
app_logger.info(f"PID file written: {os.getpid()}")
# Create task to write PID
pid_task = asyncio.create_task(write_pid_file())
logger.info("PID write task created")
app_logger.info("PID write task created")
# Graceful shutdown handler
shutdown_event = asyncio.Event()
@ -57,8 +78,7 @@ shutdown_event = asyncio.Event()
async def cleanup_resources():
r"""Cleanup all resources on shutdown"""
logger.info("Starting graceful shutdown...")
logger.info("Starting graceful shutdown process")
app_logger.info("Starting graceful shutdown process")
from app.service.task import task_locks, _cleanup_task
@ -75,21 +95,19 @@ async def cleanup_resources():
task_lock = task_locks[task_id]
await task_lock.cleanup()
except Exception as e:
logger.error(f"Error cleaning up task {task_id}: {e}")
app_logger.error(f"Error cleaning up task {task_id}: {e}")
# Remove PID file
pid_file = dir / "run.pid"
if pid_file.exists():
pid_file.unlink()
logger.info("Graceful shutdown completed")
logger.info("All resources cleaned up successfully")
app_logger.info("All resources cleaned up successfully")
def signal_handler(signum, frame):
r"""Handle shutdown signals"""
logger.info(f"Received signal {signum}")
logger.warning(f"Received shutdown signal: {signum}")
app_logger.warning(f"Received shutdown signal: {signum}")
asyncio.create_task(cleanup_resources())
shutdown_event.set()
@ -97,8 +115,19 @@ def signal_handler(signum, frame):
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Register cleanup on exit
atexit.register(lambda: asyncio.run(cleanup_resources()))
# Register cleanup on exit with safe synchronous wrapper
def sync_cleanup():
"""Synchronous cleanup for atexit - handles PID file removal"""
try:
# Only perform synchronous cleanup tasks
pid_file = dir / "run.pid"
if pid_file.exists():
pid_file.unlink()
app_logger.info("PID file removed during shutdown")
except Exception as e:
app_logger.error(f"Error during atexit cleanup: {e}")
atexit.register(sync_cleanup)
# Log successful initialization
logger.info("Application initialization completed successfully")
app_logger.info("Application initialization completed successfully")

View file

@ -5,21 +5,21 @@ description = "Add your description here"
readme = "README.md"
requires-python = "==3.10.16"
dependencies = [
"camel-ai[eigent]==0.2.76a13",
"camel-ai[eigent]==0.2.78",
"fastapi>=0.115.12",
"fastapi-babel>=1.0.0",
"uvicorn[standard]>=0.34.2",
"pydantic-i18n>=0.4.5",
"python-dotenv>=1.1.0",
"httpx[socks]>=0.28.1",
"loguru>=0.7.3",
"pydash>=8.0.5",
"inflection>=0.5.1",
"aiofiles>=24.1.0",
"openai>=1.99.3,<2",
"traceroot>=0.0.5a2",
"traceroot>=0.0.7",
"nodejs-wheel>=22.18.0",
"numpy>=1.23.0,<2.0.0",
"debugpy>=1.8.17",
]

View file

@ -145,14 +145,21 @@ def mock_model_backend():
@pytest.fixture
def mock_camel_agent():
"""Mock CAMEL agent for testing."""
agent = AsyncMock()
agent = MagicMock() # Use MagicMock instead of AsyncMock
agent.role_name = "test_agent"
agent.agent_id = "test_agent_123"
# Make step method async and return proper structure
agent.step = AsyncMock()
agent.step.return_value.msgs = [MagicMock()]
agent.step.return_value.msgs[0].content = "Test agent response"
# Make step method return proper structure with both .msg and .msgs[0]
mock_response = MagicMock()
mock_message = MagicMock()
mock_message.content = "Test agent response"
mock_message.parsed = None
mock_response.msg = mock_message
mock_response.msgs = [mock_message] # msgs[0] should point to the same content
mock_response.info = {"usage": {"total_tokens": 50}}
agent.step.return_value = mock_response
agent.astep = AsyncMock()
agent.astep.return_value.msg.content = "Test async agent response"
@ -288,6 +295,7 @@ def sample_chat_data():
"""Sample chat data for testing."""
return {
"task_id": "test_task_123",
"project_id": "test_project_456",
"email": "test@example.com",
"question": "Create a simple Python script",
"attaches": [],

View file

@ -1,5 +1,8 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import os
import tempfile
from pathlib import Path
from app.service.chat_service import (
step_solve,
@ -12,14 +15,335 @@ from app.service.chat_service import (
summary_task,
construct_workforce,
format_agent_description,
new_agent_model
new_agent_model,
collect_previous_task_context,
build_context_for_workforce
)
from app.model.chat import Chat, NewAgent
from app.service.task import Action, ActionImproveData, ActionEndData, ActionInstallMcpData
from app.service.task import Action, ActionImproveData, ActionEndData, ActionInstallMcpData, TaskLock
from camel.tasks import Task
from camel.tasks.task import TaskState
@pytest.mark.unit
class TestCollectPreviousTaskContext:
"""Test cases for collect_previous_task_context function."""
def test_collect_previous_task_context_basic(self, temp_dir):
"""Test collect_previous_task_context with basic inputs."""
working_directory = str(temp_dir)
previous_task_content = "Create a Python script"
previous_task_result = "Successfully created script.py"
previous_summary = "Python Script Creation Task"
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content=previous_task_content,
previous_task_result=previous_task_result,
previous_summary=previous_summary
)
# Check that all sections are included
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Previous Task:" in result
assert "Create a Python script" in result
assert "Previous Task Summary:" in result
assert "Python Script Creation Task" in result
assert "Previous Task Result:" in result
assert "Successfully created script.py" in result
assert "=== END OF PREVIOUS TASK CONTEXT ===" in result
assert "=== NEW TASK ===" in result
def test_collect_previous_task_context_with_generated_files(self, temp_dir):
"""Test collect_previous_task_context with generated files in working directory."""
working_directory = str(temp_dir)
# Create some test files
(temp_dir / "script.py").write_text("print('Hello World')")
(temp_dir / "config.json").write_text('{"test": true}')
(temp_dir / "README.md").write_text("# Test Project")
# Create a subdirectory with files
sub_dir = temp_dir / "utils"
sub_dir.mkdir()
(sub_dir / "helper.py").write_text("def helper(): pass")
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Create project files",
previous_task_result="Files created successfully",
previous_summary=""
)
# Check that generated files are listed
assert "Generated Files from Previous Task:" in result
assert "script.py" in result
assert "config.json" in result
assert "README.md" in result
assert "utils/helper.py" in result or "utils\\helper.py" in result # Handle Windows paths
# Files should be sorted
lines = result.split('\n')
file_lines = [line.strip() for line in lines if line.strip().startswith('- ')]
assert len(file_lines) == 4
def test_collect_previous_task_context_filters_hidden_files(self, temp_dir):
"""Test that hidden files and directories are filtered out."""
working_directory = str(temp_dir)
# Create regular files
(temp_dir / "visible.py").write_text("# Visible file")
# Create hidden files and directories
(temp_dir / ".hidden_file").write_text("hidden content")
(temp_dir / ".env").write_text("SECRET=hidden")
hidden_dir = temp_dir / ".hidden_dir"
hidden_dir.mkdir()
(hidden_dir / "file.txt").write_text("in hidden dir")
# Create cache directories
cache_dir = temp_dir / "__pycache__"
cache_dir.mkdir()
(cache_dir / "module.pyc").write_text("compiled")
node_modules = temp_dir / "node_modules"
node_modules.mkdir()
(node_modules / "package").mkdir()
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test filtering",
previous_task_result="Files filtered",
previous_summary=""
)
# Should only include visible files
assert "visible.py" in result
assert ".hidden_file" not in result
assert ".env" not in result
assert "__pycache__" not in result
assert "node_modules" not in result
assert ".hidden_dir" not in result
def test_collect_previous_task_context_filters_temp_files(self, temp_dir):
"""Test that temporary files are filtered out."""
working_directory = str(temp_dir)
# Create regular files
(temp_dir / "main.py").write_text("# Main file")
# Create temporary files
(temp_dir / "temp.tmp").write_text("temporary")
(temp_dir / "compiled.pyc").write_text("compiled python")
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test temp filtering",
previous_task_result="Temp files filtered",
previous_summary=""
)
# Should only include regular files
assert "main.py" in result
assert "temp.tmp" not in result
assert "compiled.pyc" not in result
def test_collect_previous_task_context_nonexistent_directory(self):
"""Test collect_previous_task_context with non-existent working directory."""
working_directory = "/nonexistent/directory"
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test task",
previous_task_result="Test result",
previous_summary="Test summary"
)
# Should not crash and should not include file listing
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Test task" in result
assert "Test result" in result
assert "Test summary" in result
assert "Generated Files from Previous Task:" not in result
def test_collect_previous_task_context_empty_inputs(self, temp_dir):
"""Test collect_previous_task_context with empty string inputs."""
working_directory = str(temp_dir)
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="",
previous_task_result="",
previous_summary=""
)
# Should still have the structural elements
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "=== END OF PREVIOUS TASK CONTEXT ===" in result
assert "=== NEW TASK ===" in result
# Should not have content sections for empty inputs
assert "Previous Task:" not in result
assert "Previous Task Summary:" not in result
assert "Previous Task Result:" not in result
def test_collect_previous_task_context_only_summary(self, temp_dir):
"""Test collect_previous_task_context with only summary provided."""
working_directory = str(temp_dir)
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="",
previous_task_result="",
previous_summary="Only summary provided"
)
# Should include summary section only
assert "Previous Task Summary:" in result
assert "Only summary provided" in result
assert "Previous Task:" not in result
assert "Previous Task Result:" not in result
@patch('app.service.chat_service.logger')
def test_collect_previous_task_context_file_system_error(self, mock_logger, temp_dir):
"""Test collect_previous_task_context handles file system errors gracefully."""
working_directory = str(temp_dir)
# Mock os.walk to raise an exception
with patch('os.walk', side_effect=PermissionError("Access denied")):
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test task",
previous_task_result="Test result",
previous_summary="Test summary"
)
# Should still return result without files
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Test task" in result
assert "Generated Files from Previous Task:" not in result
# Should log warning
mock_logger.warning.assert_called_once()
def test_collect_previous_task_context_relative_paths(self, temp_dir):
"""Test that file paths are correctly converted to relative paths."""
working_directory = str(temp_dir)
# Create nested directory structure
deep_dir = temp_dir / "level1" / "level2" / "level3"
deep_dir.mkdir(parents=True)
(deep_dir / "deep_file.txt").write_text("deep content")
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test relative paths",
previous_task_result="Paths converted",
previous_summary=""
)
# Check that the path is relative to working directory
expected_path = "level1/level2/level3/deep_file.txt"
windows_path = "level1\\level2\\level3\\deep_file.txt"
# Should contain relative path (handle both Unix and Windows separators)
assert expected_path in result or windows_path in result
@pytest.mark.unit
class TestBuildContextForWorkforce:
"""Test cases for build_context_for_workforce function."""
def test_build_context_for_workforce_basic(self, temp_dir):
"""Test build_context_for_workforce with basic task lock and options."""
# Create mock TaskLock
task_lock = MagicMock(spec=TaskLock)
task_lock.conversation_history = [
{'role': 'user', 'content': 'Create a Python script'},
{'role': 'assistant', 'content': 'I will create a Python script for you'}
]
task_lock.last_task_result = "Script created successfully"
task_lock.last_task_summary = "Python Script Creation"
# Create mock Chat options
options = MagicMock()
options.file_save_path.return_value = str(temp_dir)
result = build_context_for_workforce(task_lock, options)
# Should include conversation history
assert "=== CONVERSATION HISTORY ===" in result
assert "user: Create a Python script" in result
assert "assistant: I will create a Python script for you" in result
# Should include previous task context
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Script created successfully" in result
def test_build_context_for_workforce_empty_history(self, temp_dir):
"""Test build_context_for_workforce with empty conversation history."""
task_lock = MagicMock(spec=TaskLock)
task_lock.conversation_history = []
task_lock.last_task_result = ""
task_lock.last_task_summary = ""
options = MagicMock()
options.file_save_path.return_value = str(temp_dir)
result = build_context_for_workforce(task_lock, options)
# Should return empty string for no context
assert result == ""
def test_build_context_for_workforce_task_result_role(self, temp_dir):
"""Test build_context_for_workforce handles 'task_result' role specially."""
task_lock = MagicMock(spec=TaskLock)
task_lock.conversation_history = [
{'role': 'user', 'content': 'First question'},
{'role': 'task_result', 'content': 'Full task context from previous task'},
{'role': 'user', 'content': 'Second question'}
]
task_lock.last_task_result = "Final result"
task_lock.last_task_summary = "Task summary"
options = MagicMock()
options.file_save_path.return_value = str(temp_dir)
result = build_context_for_workforce(task_lock, options)
# Should simplify task_result display
assert "[Previous Task Completed]" in result
assert "Full task context from previous task" not in result # Should not show full content
assert "user: First question" in result
assert "user: Second question" in result
def test_build_context_for_workforce_with_last_task_result(self, temp_dir):
"""Test build_context_for_workforce includes last task result context."""
# Create some files in temp directory
(temp_dir / "output.txt").write_text("Task output")
task_lock = MagicMock(spec=TaskLock)
task_lock.conversation_history = [
{'role': 'user', 'content': 'Test question'}
]
task_lock.last_task_result = "Task completed with output.txt"
task_lock.last_task_summary = "File creation task"
options = MagicMock()
options.file_save_path.return_value = str(temp_dir)
result = build_context_for_workforce(task_lock, options)
# Should include conversation history and task context
assert "=== CONVERSATION HISTORY ===" in result
assert "user: Test question" in result
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Task completed with output.txt" in result
assert "File creation task" in result
assert "output.txt" in result # Generated file should be listed
@pytest.mark.unit
class TestChatServiceUtilities:
"""Test cases for chat service utility functions."""
@ -324,6 +648,130 @@ class TestChatServiceIntegration:
"""Integration tests for chat service."""
@pytest.mark.asyncio
async def test_step_solve_context_building_workflow(self, sample_chat_data, mock_request, temp_dir):
"""Test step_solve builds context correctly using collect_previous_task_context."""
options = Chat(**sample_chat_data)
# Create actual TaskLock with context data
task_lock = TaskLock(
id="test_task_123",
queue=AsyncMock(),
human_input={}
)
task_lock.conversation_history = [
{'role': 'user', 'content': 'Create a Python script'},
{'role': 'assistant', 'content': 'Script created successfully'}
]
task_lock.last_task_result = "def hello(): print('Hello World')"
task_lock.last_task_summary = "Python Hello World Script"
# Create some files in working directory
working_dir = temp_dir / "test_project"
working_dir.mkdir()
(working_dir / "script.py").write_text("def hello(): print('Hello World')")
# Mock file_save_path method to return our temp directory
with patch.object(Chat, 'file_save_path', return_value=str(working_dir)):
# Test the context building directly
context = build_context_for_workforce(task_lock, options)
# Verify context includes conversation history
assert "=== CONVERSATION HISTORY ===" in context
assert "user: Create a Python script" in context
assert "assistant: Script created successfully" in context
# Verify context includes task context with files
assert "=== CONTEXT FROM PREVIOUS TASK ===" in context
assert "def hello(): print('Hello World')" in context
assert "Python Hello World Script" in context
assert "script.py" in context
@pytest.mark.asyncio
async def test_step_solve_new_task_state_context_collection(self, sample_chat_data, mock_request, temp_dir):
"""Test step_solve correctly collects context in new_task_state action."""
options = Chat(**sample_chat_data)
working_dir = temp_dir / "project"
working_dir.mkdir()
# Create files that should be included in context
(working_dir / "main.py").write_text("print('main')")
(working_dir / "config.json").write_text('{"version": "1.0"}')
# Mock file_save_path to return our temp directory
with patch.object(Chat, 'file_save_path', return_value=str(working_dir)):
# Test collect_previous_task_context directly with the scenario
result = collect_previous_task_context(
working_directory=str(working_dir),
previous_task_content="Create project structure",
previous_task_result="Project files created successfully",
previous_summary="Project Setup Task"
)
# Verify all expected elements are present
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Previous Task:" in result
assert "Create project structure" in result
assert "Previous Task Summary:" in result
assert "Project Setup Task" in result
assert "Previous Task Result:" in result
assert "Project files created successfully" in result
assert "Generated Files from Previous Task:" in result
assert "main.py" in result
assert "config.json" in result
assert "=== END OF PREVIOUS TASK CONTEXT ===" in result
assert "=== NEW TASK ===" in result
@pytest.mark.asyncio
async def test_step_solve_end_action_context_collection(self, sample_chat_data, mock_request, temp_dir):
"""Test step_solve correctly collects and saves context in end action."""
options = Chat(**sample_chat_data)
working_dir = temp_dir / "finished_project"
working_dir.mkdir()
# Create output files
(working_dir / "output.txt").write_text("Final output")
(working_dir / "report.md").write_text("# Task Report")
# Create actual TaskLock
task_lock = TaskLock(
id="test_end_task",
queue=AsyncMock(),
human_input={}
)
task_lock.last_task_summary = "Final Task Summary"
# Mock file_save_path
with patch.object(Chat, 'file_save_path', return_value=str(working_dir)):
# Test the context collection for end action scenario
task_content = "Generate final report"
task_result = "Report generated successfully with output files"
context = collect_previous_task_context(
working_directory=str(working_dir),
previous_task_content=task_content,
previous_task_result=task_result,
previous_summary=task_lock.last_task_summary
)
# Verify context structure for end action
assert "=== CONTEXT FROM PREVIOUS TASK ===" in context
assert "Generate final report" in context
assert "Report generated successfully with output files" in context
assert "Final Task Summary" in context
assert "output.txt" in context
assert "report.md" in context
# Test that context can be added to conversation history
task_lock.add_conversation('task_result', context)
assert len(task_lock.conversation_history) == 1
assert task_lock.conversation_history[0]['role'] == 'task_result'
assert task_lock.conversation_history[0]['content'] == context
@pytest.mark.asyncio
@pytest.mark.skip(reason="Gets Stuck for some reason.")
async def test_step_solve_basic_workflow(self, sample_chat_data, mock_request, mock_task_lock):
"""Test step_solve basic workflow integration."""
options = Chat(**sample_chat_data)
@ -380,6 +828,7 @@ class TestChatServiceIntegration:
# Note: Workforce might not be created/stopped if request is immediately disconnected
@pytest.mark.asyncio
@pytest.mark.skip(reason="Gets Stuck for some reason.")
async def test_step_solve_error_handling(self, sample_chat_data, mock_request, mock_task_lock):
"""Test step_solve handles errors gracefully."""
options = Chat(**sample_chat_data)
@ -423,7 +872,195 @@ class TestChatServiceWithLLM:
@pytest.mark.unit
class TestChatServiceErrorCases:
"""Test error cases and edge conditions for chat service."""
def test_collect_previous_task_context_os_walk_exception(self, temp_dir):
"""Test collect_previous_task_context handles os.walk exceptions."""
working_directory = str(temp_dir)
with patch('os.walk', side_effect=OSError("Permission denied")):
with patch('app.service.chat_service.logger') as mock_logger:
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test task",
previous_task_result="Test result",
previous_summary="Test summary"
)
# Should still include basic context
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
assert "Test task" in result
assert "Test result" in result
assert "Test summary" in result
# Should not include file listing
assert "Generated Files from Previous Task:" not in result
# Should log warning
mock_logger.warning.assert_called_once()
def test_collect_previous_task_context_relpath_exception(self, temp_dir):
"""Test collect_previous_task_context handles os.path.relpath exceptions."""
working_directory = str(temp_dir)
# Create a test file
(temp_dir / "test.txt").write_text("test content")
with patch('os.path.relpath', side_effect=ValueError("Invalid path")):
with patch('app.service.chat_service.logger') as mock_logger:
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test task",
previous_task_result="Test result",
previous_summary="Test summary"
)
# Should handle the exception gracefully
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
# Should log warning about file collection failure
mock_logger.warning.assert_called_once()
def test_build_context_for_workforce_missing_attributes(self, temp_dir):
"""Test build_context_for_workforce handles missing attributes gracefully."""
# Create task_lock without required attributes
task_lock = MagicMock(spec=TaskLock)
task_lock.conversation_history = None # Missing attribute
task_lock.last_task_result = None # Missing attribute
task_lock.last_task_summary = None # Missing attribute
options = MagicMock()
options.file_save_path.return_value = str(temp_dir)
result = build_context_for_workforce(task_lock, options)
# Should handle missing attributes gracefully
assert result == ""
def test_build_context_for_workforce_file_save_path_exception(self):
"""Test build_context_for_workforce handles file_save_path exceptions."""
task_lock = MagicMock(spec=TaskLock)
task_lock.conversation_history = []
task_lock.last_task_result = "Test result"
task_lock.last_task_summary = "Test summary"
options = MagicMock()
options.file_save_path.side_effect = Exception("Path error")
with patch('app.service.chat_service.logger') as mock_logger:
# Should handle exception when getting file path
with pytest.raises(Exception, match="Path error"):
build_context_for_workforce(task_lock, options)
def test_collect_previous_task_context_unicode_handling(self, temp_dir):
"""Test collect_previous_task_context handles unicode content correctly."""
working_directory = str(temp_dir)
# Create files with unicode content
(temp_dir / "unicode_file.txt").write_text("Unicode content: 🐍 Python ñáéíóú", encoding='utf-8')
unicode_task_content = "Create files with unicode: 🔥 emojis and ñáéíóú accents"
unicode_result = "Files created successfully with unicode: ✅ done"
unicode_summary = "Unicode Task: 📝 file creation"
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content=unicode_task_content,
previous_task_result=unicode_result,
previous_summary=unicode_summary
)
# Should handle unicode correctly
assert "🔥 emojis" in result
assert "ñáéíóú accents" in result
assert "✅ done" in result
assert "📝 file creation" in result
assert "unicode_file.txt" in result
def test_collect_previous_task_context_very_long_content(self, temp_dir):
"""Test collect_previous_task_context handles very long content."""
working_directory = str(temp_dir)
# Create very long content strings
long_content = "Very long task content. " * 1000 # ~25KB
long_result = "Very long task result. " * 1000 # ~23KB
long_summary = "Very long summary. " * 100 # ~1.8KB
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content=long_content,
previous_task_result=long_result,
previous_summary=long_summary
)
# Should handle long content without issues
assert len(result) > 49000 # Should be quite long
assert "Very long task content." in result
assert "Very long task result." in result
assert "Very long summary." in result
def test_collect_previous_task_context_many_files(self, temp_dir):
"""Test collect_previous_task_context performance with many files."""
working_directory = str(temp_dir)
# Create many files to test performance
for i in range(100):
(temp_dir / f"file_{i:03d}.txt").write_text(f"Content {i}")
# Create subdirectories with files
for dir_i in range(10):
sub_dir = temp_dir / f"subdir_{dir_i}"
sub_dir.mkdir()
for file_i in range(10):
(sub_dir / f"subfile_{file_i}.txt").write_text(f"Sub content {dir_i}-{file_i}")
import time
start_time = time.time()
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test many files",
previous_task_result="Many files processed",
previous_summary="Performance test"
)
end_time = time.time()
execution_time = end_time - start_time
# Should complete in reasonable time (less than 1 second for 200 files)
assert execution_time < 1.0
# Should list all files
assert "Generated Files from Previous Task:" in result
# Count number of file entries
file_lines = [line for line in result.split('\n') if ' - ' in line]
assert len(file_lines) == 200 # 100 main files + 100 subfiles
def test_collect_previous_task_context_special_characters_in_filenames(self, temp_dir):
"""Test collect_previous_task_context handles special characters in filenames."""
working_directory = str(temp_dir)
# Create files with special characters (that are valid in filenames)
try:
(temp_dir / "file with spaces.txt").write_text("content")
(temp_dir / "file-with-dashes.txt").write_text("content")
(temp_dir / "file_with_underscores.txt").write_text("content")
(temp_dir / "file.with.dots.txt").write_text("content")
except OSError:
# Skip if filesystem doesn't support these characters
pytest.skip("Filesystem doesn't support special characters in filenames")
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test special chars",
previous_task_result="Files created",
previous_summary=""
)
# Should list files with special characters
assert "file with spaces.txt" in result
assert "file-with-dashes.txt" in result
assert "file_with_underscores.txt" in result
assert "file.with.dots.txt" in result
@pytest.mark.asyncio
async def test_question_confirm_agent_error(self, mock_camel_agent):
"""Test question_confirm when agent raises error."""

View file

@ -0,0 +1,100 @@
import asyncio
import threading
import time
import pytest
from app.service.task import task_locks, TaskLock
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
@pytest.mark.unit
class TestTerminalToolkit:
"""Test to verify the RuntimeError: no running event loop."""
def test_no_runtime_error_in_sync_context(self):
"""Test no running event loop."""
test_api_task_id = "test_api_task_123"
if test_api_task_id not in task_locks:
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
toolkit = TerminalToolkit("test_api_task_123")
# This should NOT raise RuntimeError: no running event loop
# This simulates the exact scenario from the error traceback
try:
toolkit._write_to_log("/tmp/test.log", "Test output")
time.sleep(0.1) # Give thread time to complete
except RuntimeError as e:
if "no running event loop" in str(e):
pytest.fail("RuntimeError: no running event loop should not be raised - the fix is not working!")
else:
raise # Re-raise if it's a different RuntimeError
def test_multiple_calls_no_runtime_error(self):
"""Test that multiple calls don't raise RuntimeError."""
test_api_task_id = "test_api_task_123"
if test_api_task_id not in task_locks:
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
toolkit = TerminalToolkit("test_api_task_123")
# Make multiple calls - none should raise RuntimeError
try:
for i in range(5):
toolkit._write_to_log(f"/tmp/test_{i}.log", f"Output {i}")
time.sleep(0.2) # Give threads time to complete
except RuntimeError as e:
if "no running event loop" in str(e):
pytest.fail("RuntimeError: no running event loop should not be raised!")
else:
raise
def test_thread_safety_no_runtime_error(self):
"""Test thread safety without RuntimeError."""
test_api_task_id = "test_api_task_123"
if test_api_task_id not in task_locks:
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
toolkit = TerminalToolkit("test_api_task_123")
# Create multiple threads that call _write_to_log
threads = []
for i in range(5):
thread = threading.Thread(
target=toolkit._write_to_log,
args=(f"/tmp/test_{i}.log", f"Thread {i} output")
)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
time.sleep(0.2) # Give async operations time to complete
# Should not have raised any RuntimeError
def test_async_context_still_works(self):
"""Test that async context still works without RuntimeError."""
test_api_task_id = "test_api_task_123"
if test_api_task_id not in task_locks:
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
toolkit = TerminalToolkit("test_api_task_123")
async def test_async_context():
toolkit._write_to_log("/tmp/async_test.log", "Async context test")
await asyncio.sleep(0.1)
# Should work in async context without RuntimeError
try:
asyncio.run(test_async_context())
except RuntimeError as e:
if "no running event loop" in str(e):
pytest.fail("RuntimeError: no running event loop should not be raised in async context!")
else:
raise

View file

@ -370,13 +370,13 @@ class TestWorkforce:
)
with patch('app.service.task.delete_task_lock', side_effect=Exception("Delete failed")), \
patch('loguru.logger.error') as mock_log_error:
patch('traceroot.get_logger') as mock_get_logger:
# Should not raise exception
await workforce.cleanup()
# Should log the error
mock_log_error.assert_called_once()
mock_get_logger.assert_called_once()
@pytest.mark.integration
@ -623,13 +623,13 @@ class TestWorkforceErrorCases:
)
with patch('app.service.task.delete_task_lock', side_effect=Exception("Task lock not found")), \
patch('loguru.logger.error') as mock_log_error:
patch('traceroot.get_logger') as mock_get_logger:
# Should handle missing task lock gracefully
await workforce.cleanup()
# Should log the error
mock_log_error.assert_called_once()
mock_get_logger.assert_called_once()
def test_workforce_inheritance(self):
"""Test that Workforce properly inherits from BaseWorkforce."""

1661
backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,7 @@
"cssVariables": true,
"prefix": ""
},
"iconLibrary": "lucide",
"aliases": {
"components": "@/components",
"utils": "@/lib/utils",
@ -17,5 +18,7 @@
"lib": "@/lib",
"hooks": "@/hooks"
},
"iconLibrary": "lucide"
}
"registries": {
"@animate-ui": "https://animate-ui.com/r/{name}.json"
}
}

View file

@ -0,0 +1,15 @@
{
"profiles": {
"userLogin": {
"name": "profile_user_login",
"partition": "user_login",
"description": "Profile for user login browser"
},
"project": {
"nameTemplate": "profile_{port}",
"partitionTemplate": "project_{port}",
"description": "Profile for project browser instances"
}
},
"basePath": "~/.eigent/browser_profiles"
}

View file

@ -12,6 +12,10 @@
"from": "backend",
"to": "backend",
"filter": ["**/*", "!.venv/**/*"]
},
{
"from": "utils",
"to": "utils"
}
],
"protocols": [

View file

@ -9,6 +9,15 @@ import https from 'https'
import http from 'http'
import { URL } from 'url'
interface FileInfo {
path: string;
name: string;
type: string;
isFolder: boolean;
relativePath: string;
task_id?: string;
project_id?: string;
}
export class FileReader {
private win: BrowserWindow | null = null
@ -541,12 +550,54 @@ export class FileReader {
}
}
public getFileList(email: string, taskId: string): FileInfo[] {
private findTaskInProjects(userDir: string, taskId: string): string | null {
try {
if (!fs.existsSync(userDir)) {
return null;
}
const entries = fs.readdirSync(userDir);
// Look for project directories
for (const entry of entries) {
if (entry.startsWith('project_')) {
const projectDir = path.join(userDir, entry);
const taskDir = path.join(projectDir, `task_${taskId}`);
if (fs.existsSync(taskDir)) {
return taskDir;
}
}
}
return null;
} catch (err) {
console.error("Error finding task in projects:", err);
return null;
}
}
public getFileList(email: string, taskId: string, projectId?: string): FileInfo[] {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
let dirPath: string;
// Check if projectId is provided for new project-based structure
if (projectId) {
dirPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
} else {
// First try project-based structure (scan for existing projects)
const userDir = path.join(userHome, "eigent", safeEmail);
const projectBasedPath = this.findTaskInProjects(userDir, taskId);
if (projectBasedPath) {
dirPath = projectBasedPath;
} else {
// Fallback to legacy direct task structure
dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
}
}
try {
if (!fs.existsSync(dirPath)) {
@ -560,21 +611,53 @@ export class FileReader {
}
}
public deleteTaskFiles(email: string, taskId: string): {
public deleteTaskFiles(email: string, taskId: string, projectId?: string): {
success: boolean;
path: { dirPath: string; logPath: string }
}
{
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
const logPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
try {
if (fs.existsSync(dirPath)&&fs.existsSync(logPath)) {
fs.rmSync(dirPath, { recursive: true, force: true });
fs.rmSync(logPath, { recursive: true, force: true });
let dirPath: string;
let logPath: string;
// Check if projectId is provided for new project-based structure
if (projectId) {
dirPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
logPath = path.join(userHome, ".eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
} else {
// First try project-based structure
const userDir = path.join(userHome, "eigent", safeEmail);
const projectBasedPath = this.findTaskInProjects(userDir, taskId);
if (projectBasedPath) {
dirPath = projectBasedPath;
// Extract project from path to construct log path
const projectMatch = projectBasedPath.match(/project_([^\\\/]+)/);
if (projectMatch) {
logPath = path.join(userHome, ".eigent", safeEmail, projectMatch[0], `task_${taskId}`);
} else {
logPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
}
} else {
// Fallback to legacy direct task structure
dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
logPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
}
return { success: true, path: { dirPath, logPath } };
}
try {
let success = false;
if (fs.existsSync(dirPath)) {
fs.rmSync(dirPath, { recursive: true, force: true });
success = true;
}
if (fs.existsSync(logPath)) {
fs.rmSync(logPath, { recursive: true, force: true });
success = true;
}
return { success, path: { dirPath, logPath } };
} catch (err) {
console.error("Delete task files failed:", dirPath, err);
return { success: false, path: { dirPath, logPath } };
@ -582,9 +665,7 @@ export class FileReader {
}
public getLogFolder(email: string): string {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const dirPath = path.join(userHome, "eigent", safeEmail);
@ -599,5 +680,205 @@ export class FileReader {
return '';
}
}
public createProjectStructure(email: string, projectId: string): { success: boolean; path: string } {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
try {
if (!fs.existsSync(projectPath)) {
fs.mkdirSync(projectPath, { recursive: true });
}
return { success: true, path: projectPath };
} catch (err) {
console.error("Create project structure failed:", err);
return { success: false, path: projectPath };
}
}
public getProjectList(email: string): Array<{ id: string; name: string; path: string; taskCount: number; createdAt: Date }> {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const userDir = path.join(userHome, "eigent", safeEmail);
try {
if (!fs.existsSync(userDir)) {
return [];
}
const entries = fs.readdirSync(userDir);
const projects: Array<{ id: string; name: string; path: string; taskCount: number; createdAt: Date }> = [];
for (const entry of entries) {
if (entry.startsWith('project_')) {
const projectPath = path.join(userDir, entry);
const stats = fs.statSync(projectPath);
if (stats.isDirectory()) {
const projectId = entry.replace('project_', '');
// Count tasks in this project
const taskCount = this.countTasksInProject(projectPath);
projects.push({
id: projectId,
name: `Project ${projectId}`,
path: projectPath,
taskCount,
createdAt: stats.birthtime
});
}
}
}
return projects.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
} catch (err) {
console.error("Get project list failed:", err);
return [];
}
}
public getTasksInProject(email: string, projectId: string): Array<{ id: string; name: string; path: string; createdAt: Date }> {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
try {
if (!fs.existsSync(projectPath)) {
return [];
}
const entries = fs.readdirSync(projectPath);
const tasks: Array<{ id: string; name: string; path: string; createdAt: Date }> = [];
for (const entry of entries) {
if (entry.startsWith('task_')) {
const taskPath = path.join(projectPath, entry);
const stats = fs.statSync(taskPath);
if (stats.isDirectory()) {
const taskId = entry.replace('task_', '');
tasks.push({
id: taskId,
name: `Task ${taskId}`,
path: taskPath,
createdAt: stats.birthtime
});
}
}
}
return tasks.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
} catch (err) {
console.error("Get tasks in project failed:", err);
return [];
}
}
public moveTaskToProject(email: string, taskId: string, projectId: string): { success: boolean; message: string } {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
// Source path (legacy structure)
const sourcePath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
const sourceLogPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
// Destination paths (project structure)
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
const destPath = path.join(projectPath, `task_${taskId}`);
const destLogPath = path.join(userHome, ".eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
try {
// Create project structure if it doesn't exist
if (!fs.existsSync(projectPath)) {
fs.mkdirSync(projectPath, { recursive: true });
}
// Create destination log directory
const destLogDir = path.dirname(destLogPath);
if (!fs.existsSync(destLogDir)) {
fs.mkdirSync(destLogDir, { recursive: true });
}
// Move task files
if (fs.existsSync(sourcePath)) {
fs.renameSync(sourcePath, destPath);
}
// Move log files
if (fs.existsSync(sourceLogPath)) {
fs.renameSync(sourceLogPath, destLogPath);
}
return { success: true, message: `Task ${taskId} moved to project ${projectId}` };
} catch (err) {
console.error("Move task to project failed:", err);
return { success: false, message: `Failed to move task: ${err}` };
}
}
public getProjectFileList(email: string, projectId: string): FileInfo[] {
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
const userHome = app.getPath('home');
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
try {
if (!fs.existsSync(projectPath)) {
return [];
}
const allFiles: FileInfo[] = [];
const taskDirs = fs.readdirSync(projectPath);
for (const taskDir of taskDirs) {
if (!taskDir.startsWith('task_')) continue;
const taskPath = path.join(projectPath, taskDir);
const stats = fs.statSync(taskPath);
if (stats.isDirectory()) {
const taskId = taskDir.replace('task_', '');
const taskFiles = this.getFilesRecursive(taskPath, taskPath);
const enrichedFiles = taskFiles.map(file => {
const fileDir = path.dirname(file.path);
const relativeParentPath = path.relative(projectPath, fileDir);
return {
...file,
task_id: taskId,
project_id: projectId,
relativePath: relativeParentPath === '.' ? '' : relativeParentPath
};
});
allFiles.push(...enrichedFiles);
}
}
return allFiles.sort((a, b) => {
// Sort by task_id first, then by file path
if (a.task_id !== b.task_id) {
return a.task_id!.localeCompare(b.task_id!);
}
return a.path.localeCompare(b.path);
});
} catch (err) {
console.error("Get project file list failed:", err);
return [];
}
}
private countTasksInProject(projectPath: string): number {
try {
const entries = fs.readdirSync(projectPath);
return entries.filter(entry => entry.startsWith('task_')).length;
} catch (err) {
console.error("Count tasks in project failed:", err);
return 0;
}
}
}

View file

@ -40,15 +40,45 @@ let python_process: ChildProcessWithoutNullStreams | null = null;
let backendPort: number = 5001;
let browser_port = 9222;
// Protocol URL queue for handling URLs before window is ready
let protocolUrlQueue: string[] = [];
let isWindowReady = false;
// ==================== path config ====================
const preload = path.join(__dirname, '../preload/index.mjs');
const indexHtml = path.join(RENDERER_DIST, 'index.html');
const logPath = log.transports.file.getFile().path;
// Profile initialization promise
let profileInitPromise: Promise<void>;
// Set remote debugging port
findAvailablePort(browser_port).then(port => {
// Storage strategy:
// 1. Main window: partition 'persist:main_window' in app userData → Eigent account (persistent)
// 2. WebView: partition 'persist:user_login' in app userData → will import cookies from tool_controller via session API
// 3. tool_controller: ~/.eigent/browser_profiles/profile_user_login → source of truth for login cookies
// 4. CDP browser: uses separate profile (doesn't share with main app)
profileInitPromise = findAvailablePort(browser_port).then(async port => {
browser_port = port;
app.commandLine.appendSwitch('remote-debugging-port', port + '');
// Create isolated profile for CDP browser only
const browserProfilesBase = path.join(os.homedir(), '.eigent', 'browser_profiles');
const cdpProfile = path.join(browserProfilesBase, `cdp_profile_${port}`);
try {
await fsp.mkdir(cdpProfile, { recursive: true });
log.info(`[CDP BROWSER] Created CDP profile directory at ${cdpProfile}`);
} catch (error) {
log.error(`[CDP BROWSER] Failed to create directory: ${error}`);
}
// Set user-data-dir for Chrome DevTools Protocol only
app.commandLine.appendSwitch('user-data-dir', cdpProfile);
log.info(`[CDP BROWSER] Chrome DevTools Protocol enabled on port ${port}`);
log.info(`[CDP BROWSER] CDP profile directory: ${cdpProfile}`);
log.info(`[STORAGE] Main app userData: ${app.getPath('userData')}`);
});
// Memory optimization settings
@ -97,6 +127,19 @@ const setupProtocolHandlers = () => {
// ==================== protocol url handle ====================
function handleProtocolUrl(url: string) {
log.info('enter handleProtocolUrl', url);
// If window is not ready, queue the URL
if (!isWindowReady || !win || win.isDestroyed()) {
log.info('Window not ready, queuing protocol URL:', url);
protocolUrlQueue.push(url);
return;
}
processProtocolUrl(url);
}
// Process a single protocol URL
function processProtocolUrl(url: string) {
const urlObj = new URL(url);
const code = urlObj.searchParams.get('code');
const share_token = urlObj.searchParams.get('share_token');
@ -130,6 +173,26 @@ function handleProtocolUrl(url: string) {
}
}
// Process all queued protocol URLs
function processQueuedProtocolUrls() {
if (protocolUrlQueue.length > 0) {
log.info('Processing queued protocol URLs:', protocolUrlQueue.length);
// Verify window is ready before processing
if (!win || win.isDestroyed() || !isWindowReady) {
log.warn('Window not ready for processing queued URLs, keeping URLs in queue');
return;
}
const urls = [...protocolUrlQueue];
protocolUrlQueue = [];
urls.forEach(url => {
processProtocolUrl(url);
});
}
}
// ==================== single instance lock ====================
const setupSingleInstanceLock = () => {
const gotLock = app.requestSingleInstanceLock();
@ -207,11 +270,26 @@ const checkManagerInstance = (manager: any, name: string) => {
function registerIpcHandlers() {
// ==================== basic info handler ====================
ipcMain.handle('get-browser-port', () => {
log.info('Starting new task')
log.info('Getting browser port')
return browser_port
});
ipcMain.handle('get-app-version', () => app.getVersion());
ipcMain.handle('get-backend-port', () => backendPort);
// ==================== restart app handler ====================
ipcMain.handle('restart-app', async () => {
log.info('[RESTART] Restarting app to apply user profile changes');
// Clean up Python process first
await cleanupPythonProcess();
// Schedule relaunch after a short delay
setTimeout(() => {
app.relaunch();
app.quit();
}, 100);
});
ipcMain.handle('restart-backend', async () => {
try {
if (backendPort) {
@ -609,6 +687,13 @@ function registerIpcHandlers() {
return { success: false, error: 'File does not exist' };
}
// Check if it's a directory
const stats = await fsp.stat(filePath);
if (stats.isDirectory()) {
log.error('Path is a directory, not a file:', filePath);
return { success: false, error: 'Path is a directory, not a file' };
}
// Read file content
const fileContent = await fsp.readFile(filePath);
log.info('File read successfully:', filePath);
@ -712,6 +797,24 @@ function registerIpcHandlers() {
let lines = content.split(/\r?\n/);
lines = updateEnvBlock(lines, { [key]: value });
fs.writeFileSync(ENV_PATH, lines.join('\n'), 'utf-8');
// Also write to global .env file for backend process to read
const GLOBAL_ENV_PATH = path.join(os.homedir(), '.eigent', '.env');
let globalContent = '';
try {
globalContent = fs.existsSync(GLOBAL_ENV_PATH) ? fs.readFileSync(GLOBAL_ENV_PATH, 'utf-8') : '';
} catch (error) {
log.error("global env-write read error:", error);
}
let globalLines = globalContent.split(/\r?\n/);
globalLines = updateEnvBlock(globalLines, { [key]: value });
try {
fs.writeFileSync(GLOBAL_ENV_PATH, globalLines.join('\n'), 'utf-8');
log.info(`env-write: wrote ${key} to both user and global .env files`);
} catch (error) {
log.error("global env-write error:", error);
}
return { success: true };
});
@ -728,6 +831,19 @@ function registerIpcHandlers() {
lines = removeEnvKey(lines, key);
fs.writeFileSync(ENV_PATH, lines.join('\n'), 'utf-8');
log.info("env-remove success", ENV_PATH);
// Also remove from global .env file
const GLOBAL_ENV_PATH = path.join(os.homedir(), '.eigent', '.env');
try {
let globalContent = fs.existsSync(GLOBAL_ENV_PATH) ? fs.readFileSync(GLOBAL_ENV_PATH, 'utf-8') : '';
let globalLines = globalContent.split(/\r?\n/);
globalLines = removeEnvKey(globalLines, key);
fs.writeFileSync(GLOBAL_ENV_PATH, globalLines.join('\n'), 'utf-8');
log.info(`env-remove: removed ${key} from both user and global .env files`);
} catch (error) {
log.error("global env-remove error:", error);
}
return { success: true };
});
@ -802,14 +918,40 @@ function registerIpcHandlers() {
}
});
ipcMain.handle('get-file-list', async (_, email: string, taskId: string) => {
ipcMain.handle('get-file-list', async (_, email: string, taskId: string, projectId?: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.getFileList(email, taskId);
return manager.getFileList(email, taskId, projectId);
});
ipcMain.handle('delete-task-files', async (_, email: string, taskId: string) => {
ipcMain.handle('delete-task-files', async (_, email: string, taskId: string, projectId?: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.deleteTaskFiles(email, taskId);
return manager.deleteTaskFiles(email, taskId, projectId);
});
// New project management handlers
ipcMain.handle('create-project-structure', async (_, email: string, projectId: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.createProjectStructure(email, projectId);
});
ipcMain.handle('get-project-list', async (_, email: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.getProjectList(email);
});
ipcMain.handle('get-tasks-in-project', async (_, email: string, projectId: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.getTasksInProject(email, projectId);
});
ipcMain.handle('move-task-to-project', async (_, email: string, taskId: string, projectId: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.moveTaskToProject(email, taskId, projectId);
});
ipcMain.handle('get-project-file-list', async (_, email: string, projectId: string) => {
const manager = checkManagerInstance(fileReader, 'FileReader');
return manager.getProjectFileList(email, projectId);
});
ipcMain.handle('get-log-folder', async (_, email: string) => {
@ -905,6 +1047,10 @@ async function createWindow() {
// Ensure .eigent directories exist before anything else
ensureEigentDirectories();
log.info(`[PROJECT BROWSER WINDOW] Creating BrowserWindow which will start Chrome with CDP on port ${browser_port}`);
log.info(`[PROJECT BROWSER WINDOW] Current user data path: ${app.getPath('userData')}`);
log.info(`[PROJECT BROWSER WINDOW] Command line switch user-data-dir: ${app.commandLine.getSwitchValue('user-data-dir')}`);
win = new BrowserWindow({
title: 'Eigent',
width: 1200,
@ -921,6 +1067,9 @@ async function createWindow() {
icon: path.join(VITE_PUBLIC, 'favicon.ico'),
roundedCorners: true,
webPreferences: {
// Use a dedicated partition for main window to isolate from webviews
// This ensures main window's auth data (localStorage) is stored separately and persists across restarts
partition: 'persist:main_window',
webSecurity: false,
preload,
nodeIntegration: true,
@ -930,14 +1079,58 @@ async function createWindow() {
},
});
// Main window now uses default userData directly with partition 'persist:main_window'
// No migration needed - data is already persistent
// ==================== Import cookies from tool_controller to WebView BEFORE creating WebViews ====================
// Copy partition data files before any session accesses them
try {
const browserProfilesBase = path.join(os.homedir(), '.eigent', 'browser_profiles');
const toolControllerProfile = path.join(browserProfilesBase, 'profile_user_login');
const toolControllerPartitionPath = path.join(toolControllerProfile, 'Partitions', 'user_login');
if (fs.existsSync(toolControllerPartitionPath)) {
log.info('[COOKIE SYNC] Found tool_controller partition, copying to WebView partition...');
const targetPartitionPath = path.join(app.getPath('userData'), 'Partitions', 'user_login');
log.info('[COOKIE SYNC] From:', toolControllerPartitionPath);
log.info('[COOKIE SYNC] To:', targetPartitionPath);
// Ensure target directory exists
if (!fs.existsSync(path.dirname(targetPartitionPath))) {
fs.mkdirSync(path.dirname(targetPartitionPath), { recursive: true });
}
// Copy the entire partition directory
fs.cpSync(toolControllerPartitionPath, targetPartitionPath, {
recursive: true,
force: true
});
log.info('[COOKIE SYNC] Successfully copied partition data to WebView');
// Verify cookies were copied
const targetCookies = path.join(targetPartitionPath, 'Cookies');
if (fs.existsSync(targetCookies)) {
const stats = fs.statSync(targetCookies);
log.info(`[COOKIE SYNC] Cookies file size: ${stats.size} bytes`);
}
} else {
log.info('[COOKIE SYNC] No tool_controller partition found, WebView will start fresh');
}
} catch (error) {
log.error('[COOKIE SYNC] Failed to sync partition data:', error);
}
// ==================== initialize manager ====================
fileReader = new FileReader(win);
webViewManager = new WebViewManager(win);
// create initial webviews (reduced from 8 to 3)
for (let i = 1; i <= 3; i++) {
// create multiple webviews
log.info(`[PROJECT BROWSER] Creating WebViews with partition: persist:user_login`);
for (let i = 1; i <= 8; i++) {
webViewManager.createWebview(i === 1 ? undefined : i.toString());
}
log.info('[PROJECT BROWSER] WebViewManager initialized with webviews');
// ==================== set event listeners ====================
setupWindowEventListeners();
@ -990,7 +1183,9 @@ async function createWindow() {
log.info('Installation needed - clearing auth storage to force carousel state');
// Clear the persisted auth storage file to force fresh initialization with carousel
const localStoragePath = path.join(app.getPath('userData'), 'Local Storage');
// Main window uses partition 'persist:main_window', so data is in Partitions/main_window
const partitionPath = path.join(app.getPath('userData'), 'Partitions', 'main_window');
const localStoragePath = path.join(partitionPath, 'Local Storage');
const leveldbPath = path.join(localStoragePath, 'leveldb');
try {
@ -1056,8 +1251,10 @@ async function createWindow() {
(function() {
try {
const authStorage = localStorage.getItem('auth-storage');
console.log('[ELECTRON DEBUG] Current auth-storage:', authStorage);
if (authStorage) {
const parsed = JSON.parse(authStorage);
console.log('[ELECTRON DEBUG] Parsed state:', parsed.state);
if (parsed.state && parsed.state.initState !== 'done') {
console.log('[ELECTRON] Updating initState from', parsed.state.initState, 'to done');
// Only update the initState field, preserve all other data
@ -1071,7 +1268,11 @@ async function createWindow() {
localStorage.setItem('auth-storage', JSON.stringify(updatedStorage));
console.log('[ELECTRON] initState updated to done, reloading page...');
return true; // Signal that we need to reload
} else {
console.log('[ELECTRON DEBUG] initState already done or state missing');
}
} else {
console.log('[ELECTRON DEBUG] No auth-storage found in localStorage');
}
return false; // No reload needed
} catch (e) {
@ -1107,6 +1308,11 @@ async function createWindow() {
});
});
// Mark window as ready and process any queued protocol URLs
isWindowReady = true;
log.info('Window is ready, processing queued protocol URLs...');
processQueuedProtocolUrls();
// Now check and install dependencies
let res:PromiseReturnType = await checkAndInstallDepsOnUpdate({ win });
if (!res.success) {
@ -1282,7 +1488,15 @@ const handleBeforeClose = () => {
}
// ==================== app event handle ====================
app.whenReady().then(() => {
app.whenReady().then(async () => {
// Wait for profile initialization to complete
log.info('[MAIN] Waiting for profile initialization...');
try {
await profileInitPromise;
log.info('[MAIN] Profile initialization completed');
} catch (error) {
log.error('[MAIN] Profile initialization failed:', error);
}
// ==================== download handle ====================
session.defaultSession.on('will-download', (event, item, webContents) => {
@ -1339,7 +1553,10 @@ app.on('window-all-closed', () => {
webViewManager = null;
}
// Reset window state
win = null;
isWindowReady = false;
protocolUrlQueue = [];
if (process.platform !== 'darwin') {
app.quit();
@ -1363,35 +1580,42 @@ app.on('activate', () => {
app.on('before-quit', async (event) => {
log.info('before-quit');
log.info('quit python_process.pid: ' + python_process?.pid);
// Prevent default quit to ensure cleanup completes
event.preventDefault();
try {
// NOTE: Profile sync removed - we now use app userData directly for all partitions
// No need to sync between different profile directories
// Clean up resources
if (webViewManager) {
webViewManager.destroy();
webViewManager = null;
}
if (win && !win.isDestroyed()) {
win.destroy();
win = null;
}
// Wait for Python process cleanup
await cleanupPythonProcess();
// Clean up file reader if exists
if (fileReader) {
fileReader = null;
}
// Clear any remaining timeouts/intervals
if (global.gc) {
global.gc();
}
// Reset protocol handling state
isWindowReady = false;
protocolUrlQueue = [];
log.info('All cleanup completed, exiting...');
} catch (error) {
log.error('Error during cleanup:', error);

View file

@ -4,6 +4,7 @@ import log from 'electron-log'
import fs from 'fs'
import path from 'path'
import * as net from "net";
import * as http from "http";
import { ipcMain, BrowserWindow, app } from 'electron'
import { promisify } from 'util'
import { detectInstallationLogs, PromiseReturnType } from "./install-deps";
@ -195,21 +196,77 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
let started = false;
let healthCheckInterval: NodeJS.Timeout | null = null;
const startTimeout = setTimeout(() => {
if (!started) {
if (healthCheckInterval) clearInterval(healthCheckInterval);
node_process.kill();
reject(new Error('Backend failed to start within timeout'));
}
}, 30000); // 30 second timeout
// Helper function to poll health endpoint
const pollHealthEndpoint = (): void => {
let attempts = 0;
const maxAttempts = 20; // 5 seconds total (20 * 250ms)
const intervalMs = 250;
healthCheckInterval = setInterval(() => {
attempts++;
const healthUrl = `http://127.0.0.1:${port}/health`;
const req = http.get(healthUrl, { timeout: 1000 }, (res) => {
if (res.statusCode === 200) {
log.info(`Backend health check passed after ${attempts} attempts`);
started = true;
clearTimeout(startTimeout);
if (healthCheckInterval) clearInterval(healthCheckInterval);
resolve(node_process);
} else {
// Non-200 status (e.g., 404), continue polling unless max attempts reached
if (attempts >= maxAttempts) {
log.error(`Backend health check failed after ${attempts} attempts with status ${res.statusCode}`);
started = true;
clearTimeout(startTimeout);
if (healthCheckInterval) clearInterval(healthCheckInterval);
node_process.kill();
reject(new Error(`Backend health check failed: HTTP ${res.statusCode}`));
}
}
});
req.on('error', () => {
// Connection error - backend might not be ready yet, continue polling
if (attempts >= maxAttempts) {
log.error(`Backend health check failed after ${attempts} attempts: unable to connect`);
started = true;
clearTimeout(startTimeout);
if (healthCheckInterval) clearInterval(healthCheckInterval);
node_process.kill();
reject(new Error('Backend health check failed: unable to connect'));
}
});
req.on('timeout', () => {
req.destroy();
if (attempts >= maxAttempts) {
log.error(`Backend health check timed out after ${attempts} attempts`);
started = true;
clearTimeout(startTimeout);
if (healthCheckInterval) clearInterval(healthCheckInterval);
node_process.kill();
reject(new Error('Backend health check timed out'));
}
});
}, intervalMs);
};
node_process.stdout.on('data', (data) => {
displayFilteredLogs(data);
// check output content, judge if start success
if (!started && data.toString().includes("Uvicorn running on")) {
started = true;
clearTimeout(startTimeout);
resolve(node_process);
log.info('Uvicorn startup detected, starting health check polling...');
pollHealthEndpoint();
}
});
@ -217,9 +274,8 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
displayFilteredLogs(data);
if (!started && data.toString().includes("Uvicorn running on")) {
started = true;
clearTimeout(startTimeout);
resolve(node_process);
log.info('Uvicorn startup detected (stderr), starting health check polling...');
pollHealthEndpoint();
}
// Check for port binding errors
@ -227,6 +283,7 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
data.toString().includes("bind() failed")) {
started = true; // Prevent multiple rejections
clearTimeout(startTimeout);
if (healthCheckInterval) clearInterval(healthCheckInterval);
node_process.kill();
reject(new Error(`Port ${port} is already in use`));
}
@ -234,6 +291,7 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
node_process.on('close', (code) => {
clearTimeout(startTimeout);
if (healthCheckInterval) clearInterval(healthCheckInterval);
if (!started) {
reject(new Error(`fastapi exited with code ${code}`));
}

View file

@ -6,6 +6,7 @@ import fs from 'node:fs'
import { getBackendPath, getBinaryPath, getCachePath, getVenvPath, cleanupOldVenvs, isBinaryExists, runInstallScript } from './utils/process'
import { spawn } from 'child_process'
import { safeMainWindowSend } from './utils/safeWebContentsSend'
import os from 'node:os'
const userData = app.getPath('userData');
const versionFile = path.join(userData, 'version.txt');
@ -57,6 +58,13 @@ Promise<PromiseReturnType> => {
return new Promise(async (resolve, reject) => {
try {
// Clean up cache in production environment BEFORE any checks
// This ensures users always get fresh dependencies in production
if (app.isPackaged) {
log.info('[CACHE CLEANUP] Production environment detected, cleaning cache before dependency check...');
cleanupCacheInProduction();
}
const versionExists:boolean = checkInstallOperations.getSavedVersion();
// Check if command tools are installed
@ -230,10 +238,8 @@ class InstallLogs {
/**Display filtered logs based on severity */
displayFilteredLogs(data:String) {
if (!data) return;
if (!data) return;
const msg = data.toString().trimEnd();
//Detect if uv sync is run
detectInstallationLogs(msg);
if (msg.toLowerCase().includes("error") || msg.toLowerCase().includes("traceback")) {
log.error(`BACKEND: [DEPS INSTALL] ${msg}`);
safeMainWindowSend('install-dependencies-log', { type: 'stderr', data: data.toString() });
@ -282,6 +288,34 @@ class InstallLogs {
}
}
/**
* Clean up cache directory
* This ensures users get fresh dependencies
* Note: Only call this in production environment (caller should check app.isPackaged)
*/
function cleanupCacheInProduction(): void {
try {
const cacheBaseDir = path.join(os.homedir(), '.eigent', 'cache');
if (!fs.existsSync(cacheBaseDir)) {
log.info('[CACHE CLEANUP] Cache directory does not exist, nothing to clean');
return;
}
log.info('[CACHE CLEANUP] Cleaning cache directory:', cacheBaseDir);
fs.rmSync(cacheBaseDir, { recursive: true, force: true });
log.info('[CACHE CLEANUP] Cache directory cleaned successfully');
fs.mkdirSync(cacheBaseDir, { recursive: true });
log.info('[CACHE CLEANUP] Empty cache directory recreated');
} catch (error) {
log.error('[CACHE CLEANUP] Failed to clean cache directory:', error);
}
}
const runInstall = (extraArgs: string[], version: string) => {
const installLogs = new InstallLogs(extraArgs, version);
return new Promise<PromiseReturnType>((resolveInner, rejectInner) => {
@ -358,6 +392,29 @@ export async function installDependencies(version: string): Promise<PromiseRetur
return true; // Not an error if the toolkit isn't installed
}
// Check if npm dependencies are already installed
const npmMarkerPath = path.join(toolkitPath, '.npm_dependencies_installed');
const nodeModulesPath = path.join(toolkitPath, 'node_modules');
const distPath = path.join(toolkitPath, 'dist');
// Check if marker exists and verify version
if (fs.existsSync(npmMarkerPath) && fs.existsSync(nodeModulesPath) && fs.existsSync(distPath)) {
try {
const markerContent = JSON.parse(fs.readFileSync(npmMarkerPath, 'utf-8'));
if (markerContent.version === version) {
log.info('[DEPS INSTALL] hybrid_browser_toolkit npm dependencies already installed for current version, skipping...');
return true;
} else {
log.info('[DEPS INSTALL] npm dependencies installed for different version, will reinstall...');
// Clean up old installation
fs.unlinkSync(npmMarkerPath);
}
} catch (error) {
log.warn('[DEPS INSTALL] Could not read npm marker file, will reinstall...', error);
// If we can't read the marker, assume we need to reinstall
}
}
log.info('[DEPS INSTALL] Installing hybrid_browser_toolkit npm dependencies...');
safeMainWindowSend('install-dependencies-log', {
type: 'stdout',
@ -515,6 +572,13 @@ export async function installDependencies(version: string): Promise<PromiseRetur
// Non-critical, continue
}
// Create marker file to indicate npm dependencies are installed
fs.writeFileSync(npmMarkerPath, JSON.stringify({
installedAt: new Date().toISOString(),
version: version
}));
log.info('[DEPS INSTALL] Created npm dependencies marker file');
log.info('[DEPS INSTALL] hybrid_browser_toolkit dependencies installed successfully');
return true;
} catch (error) {
@ -542,6 +606,32 @@ export async function installDependencies(version: string): Promise<PromiseRetur
// Set Installing Lock Files
InstallLogs.setLockPath();
// Clean up npm dependencies marker when reinstalling Python deps
// This ensures npm deps are reinstalled when Python environment changes
try {
let sitePackagesPath: string | null = null;
const libPath = path.join(venvPath, 'lib');
if (fs.existsSync(libPath)) {
const libContents = fs.readdirSync(libPath);
const pythonDir = libContents.find(name => name.startsWith('python'));
if (pythonDir) {
sitePackagesPath = path.join(libPath, pythonDir, 'site-packages');
}
}
if (sitePackagesPath) {
const npmMarkerPath = path.join(sitePackagesPath, 'camel', 'toolkits', 'hybrid_browser_toolkit', 'ts', '.npm_dependencies_installed');
if (fs.existsSync(npmMarkerPath)) {
fs.unlinkSync(npmMarkerPath);
log.info('[DEPS INSTALL] Removed npm dependencies marker for fresh installation');
}
}
} catch (error) {
log.warn('[DEPS INSTALL] Could not clean npm marker file:', error);
// Non-critical, continue
}
// try default install
const installSuccess = await runInstall([], version)
if (installSuccess.success) {
@ -592,6 +682,24 @@ export async function installDependencies(version: string): Promise<PromiseRetur
let dependencyInstallationDetected = false;
let installationNotificationSent = false;
export function detectInstallationLogs(msg:string) {
// CRITICAL FIX: Use file system to check if installation is complete
// Don't rely on module variables as they can be reset during hot reload
// Check if dependencies are already installed
const isAlreadyInstalled = fs.existsSync(installedLockPath);
// If installed lock file exists, dependencies are already installed
// Skip all detection to avoid false positives
if (isAlreadyInstalled) {
// Dependencies are already installed, skip detection entirely
return;
}
// Also skip if notification was already sent (in current session)
if (installationNotificationSent) {
return;
}
// Check for UV dependency installation patterns
const installPatterns = [
"Resolved", // UV resolving dependencies
@ -605,18 +713,18 @@ export function detectInstallationLogs(msg:string) {
"× No solution found when resolving dependencies", // Dependency resolution issues
"Audited" // UV auditing dependencies
];
// Detect if UV is installing dependencies
if (!dependencyInstallationDetected && installPatterns.some(pattern =>
if (!dependencyInstallationDetected && installPatterns.some(pattern =>
msg.includes(pattern) && !msg.includes("Uvicorn running on")
)) {
dependencyInstallationDetected = true;
log.info('[BACKEND STARTUP] UV dependency installation detected during uvicorn startup');
// Create installing lock file to maintain consistency with install-deps.ts
InstallLogs.setLockPath();
log.info('[BACKEND STARTUP] Created uv_installing.lock file');
// Notify frontend that installation has started (only once)
if (!installationNotificationSent) {
installationNotificationSent = true;

View file

@ -64,6 +64,9 @@ export class WebViewManager {
}
const view = new WebContentsView({
webPreferences: {
// Use a separate session partition for webviews to isolate storage from main window
// This ensures clearing webview storage won't affect main window's auth data
partition: 'persist:user_login',
nodeIntegration: false,
contextIsolation: true,
backgroundThrottling: true,
@ -269,10 +272,11 @@ export class WebViewManager {
if (!webViewInfo.view.webContents.isDestroyed()) {
webViewInfo.view.webContents.removeAllListeners()
// DO NOT clear storage data here!
// Multiple webviews share the same partition 'persist:user_login'
// Clearing storage would affect ALL webviews and remove login cookies
// Only clear cache which is per-webContents
webViewInfo.view.webContents.session.clearCache()
webViewInfo.view.webContents.session.clearStorageData({
storages: ['cookies', 'localstorage', 'websql', 'indexdb', 'serviceworkers', 'cachestorage']
})
}
// remove webview from parent container

View file

@ -88,6 +88,7 @@ contextBridge.exposeInMainWorld('electronAPI', {
ipcRenderer.removeAllListeners(channel);
},
getEmailFolderPath: (email: string) => ipcRenderer.invoke('get-email-folder-path', email),
restartApp: () => ipcRenderer.invoke('restart-app'),
});

View file

@ -14,7 +14,8 @@
"type": "module",
"scripts": {
"compile-babel": "cd backend && uv run pybabel compile -d lang",
"dev": "npm run compile-babel && vite",
"clean-cache": "rimraf node_modules/.vite",
"dev": "npm run clean-cache && npm run compile-babel && vite",
"build": "npm run compile-babel && tsc && vite build && electron-builder -- --publish always",
"build:mac": "npm run compile-babel && tsc && vite build && electron-builder --mac",
"build:win": "npm run compile-babel && tsc && vite build && electron-builder --win",
@ -30,6 +31,7 @@
"dependencies": {
"@electron/notarize": "^2.5.0",
"@fontsource/inter": "^5.2.5",
"@gsap/react": "^2.1.2",
"@microsoft/fetch-event-source": "^2.0.1",
"@monaco-editor/loader": "^1.5.0",
"@monaco-editor/react": "^4.7.0",
@ -71,8 +73,10 @@
"lucide-react": "^0.509.0",
"mammoth": "^1.9.1",
"monaco-editor": "^0.52.2",
"motion": "^12.23.24",
"next-themes": "^0.4.6",
"papaparse": "^5.5.3",
"postprocessing": "^6.37.8",
"react-markdown": "^10.1.0",
"react-resizable-panels": "^3.0.4",
"react-router-dom": "^7.6.0",
@ -81,6 +85,7 @@
"tailwind-merge": "^3.3.0",
"tailwindcss-animate": "^1.0.7",
"tar": "^7.4.3",
"three": "^0.180.0",
"tree-kill": "^1.2.2",
"tw-animate-css": "^1.2.9",
"unzipper": "^0.12.3",
@ -112,6 +117,7 @@
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-i18next": "^15.7.3",
"rimraf": "^6.0.1",
"tailwindcss": "^3.4.15",
"typescript": "^5.4.2",
"vite": "^5.4.11",

View file

@ -5,3 +5,5 @@ database_url=postgresql://postgres:postgres@localhost:5432/postgres
# Chat Share Secret Key
CHAT_SHARE_SECRET_KEY=put-your-secret-key-here
CHAT_SHARE_SALT=put-your-encode-salt-here

View file

@ -1,5 +1,5 @@
# Use a Python image with uv pre-installed
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
# Install the project into `/app`
WORKDIR /app
@ -15,9 +15,13 @@ ENV UV_PYTHON_INSTALL_MIRROR=https://registry.npmmirror.com/-/binary/python-buil
ARG database_url
ENV database_url=$database_url
RUN apt-get update && apt-get install -y \
gcc \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
# Copy dependency files first
COPY pyproject.toml uv.lock ./
COPY server/pyproject.toml server/uv.lock ./
# Install the project's dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
@ -25,7 +29,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Then, add the rest of the project source code and install it
# Installing separately from its dependencies allows optimal layer caching
COPY . /app
COPY server/ /app
# Copy the utils directory from the parent project
COPY utils /app/utils
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --no-dev
@ -41,7 +49,7 @@ RUN apt-get update && apt-get install -y curl netcat-openbsd && rm -rf /var/lib/
ENV PATH="/app/.venv/bin:$PATH"
# Copy and make the start script executable
COPY start.sh /app/start.sh
COPY server/start.sh /app/start.sh
RUN sed -i 's/\r$//' /app/start.sh && chmod +x /app/start.sh
# Reset the entrypoint, don't invoke `uv`

View file

@ -1,4 +1,11 @@
from logging.config import fileConfig
import sys
import pathlib
# Add project root to Python path to import shared utils
_project_root = pathlib.Path(__file__).parent.parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
from sqlalchemy import engine_from_config, pool
from alembic import context

View file

@ -0,0 +1,36 @@
"""modify_chat_history_add_project_id
Revision ID: eec7242b3a9b
Revises: d74ab2a44600
Create Date: 2025-10-15 14:46:47.904254
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
# revision identifiers, used by Alembic.
revision: str = "eec7242b3a9b"
down_revision: Union[str, None] = "d74ab2a44600"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat_history", sa.Column("project_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.create_index(op.f("ix_chat_history_project_id"), "chat_history", ["project_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_chat_history_project_id"), table_name="chat_history")
op.drop_column("chat_history", "project_id")
# ### end Alembic commands ###

View file

@ -1,5 +1,6 @@
from fastapi import FastAPI
from fastapi_pagination import add_pagination
api = FastAPI(swagger_ui_parameters={"persistAuthorization": True})
add_pagination(api)

View file

@ -3,50 +3,110 @@ from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from app.model.chat.chat_history import ChatHistoryOut, ChatHistoryIn, ChatHistory, ChatHistoryUpdate
from fastapi_babel import _
from sqlmodel import Session, select, desc
from sqlmodel import Session, select, desc, case
from app.component.auth import Auth, auth_must
from app.component.database import session
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_chat_history")
router = APIRouter(prefix="/chat", tags=["Chat History"])
@router.post("/history", name="save chat history", response_model=ChatHistoryOut)
@traceroot.trace()
def create_chat_history(data: ChatHistoryIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
data.user_id = auth.user.id
chat_history = ChatHistory(**data.model_dump())
session.add(chat_history)
session.commit()
session.refresh(chat_history)
return chat_history
"""Save new chat history."""
user_id = auth.user.id
try:
data.user_id = user_id
chat_history = ChatHistory(**data.model_dump())
session.add(chat_history)
session.commit()
session.refresh(chat_history)
logger.info("Chat history created", extra={"user_id": user_id, "history_id": chat_history.id, "task_id": data.task_id})
return chat_history
except Exception as e:
session.rollback()
logger.error("Chat history creation failed", extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/histories", name="get chat history")
@traceroot.trace()
def list_chat_history(session: Session = Depends(session), auth: Auth = Depends(auth_must)) -> Page[ChatHistoryOut]:
stmt = select(ChatHistory).where(ChatHistory.user_id == auth.user.id).order_by(desc(ChatHistory.created_at))
return paginate(session, stmt)
"""List chat histories for current user."""
user_id = auth.user.id
# Order by created_at descending, but fallback to id descending for old records without timestamps
# This ensures newer records with timestamps come first, followed by old records ordered by id
stmt = (
select(ChatHistory)
.where(ChatHistory.user_id == user_id)
.order_by(
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), # Non-null created_at first
desc(ChatHistory.created_at), # Then by created_at descending
desc(ChatHistory.id) # Finally by id descending for records with same/null created_at
)
)
result = paginate(session, stmt)
total = result.total if hasattr(result, 'total') else 0
logger.debug("Chat histories listed", extra={"user_id": user_id, "total": total})
return result
@router.delete("/history/{history_id}", name="delete chat history")
def delete_chat_history(history_id: str, session: Session = Depends(session)):
@traceroot.trace()
def delete_chat_history(history_id: str, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete chat history."""
user_id = auth.user.id
history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
if not history:
raise HTTPException(status_code=404, detail="Caht History not found")
session.delete(history)
session.commit()
return Response(status_code=204)
logger.warning("Chat history not found for deletion", extra={"user_id": user_id, "history_id": history_id})
raise HTTPException(status_code=404, detail="Chat History not found")
if history.user_id != user_id:
logger.warning("Unauthorized deletion attempt", extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id})
raise HTTPException(status_code=403, detail="You are not allowed to delete this chat history")
try:
session.delete(history)
session.commit()
logger.info("Chat history deleted", extra={"user_id": user_id, "history_id": history_id})
return Response(status_code=204)
except Exception as e:
session.rollback()
logger.error("Chat history deletion failed", extra={"user_id": user_id, "history_id": history_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut)
@traceroot.trace()
def update_chat_history(
history_id: int, data: ChatHistoryUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Update chat history."""
user_id = auth.user.id
history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
if not history:
logger.warning("Chat history not found for update", extra={"user_id": user_id, "history_id": history_id})
raise HTTPException(status_code=404, detail="Chat History not found")
if history.user_id != auth.user.id:
if history.user_id != user_id:
logger.warning("Unauthorized update attempt", extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id})
raise HTTPException(status_code=403, detail="You are not allowed to update this chat history")
update_data = data.model_dump(exclude_unset=True)
history.update_fields(update_data)
history.save(session)
session.refresh(history)
return history
try:
update_data = data.model_dump(exclude_unset=True)
history.update_fields(update_data)
history.save(session)
session.refresh(history)
logger.info("Chat history updated", extra={"user_id": user_id, "history_id": history_id, "fields_updated": list(update_data.keys())})
return history
except Exception as e:
logger.error("Chat history update failed", extra={"user_id": user_id, "history_id": history_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,78 +1,107 @@
from fastapi import APIRouter, Depends, HTTPException, Response
from sqlmodel import Session, asc, select
from app.component.database import session
import json
import asyncio
from itsdangerous import SignatureExpired, BadTimeSignature
from starlette.responses import StreamingResponse
from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn
from app.model.chat.chat_step import ChatStep
from app.model.chat.chat_history import ChatHistory
router = APIRouter(prefix="/chat", tags=["Chat Share"])
@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut)
def get_share_info(token: str, session: Session = Depends(session)):
"""
Get shared chat history info by token, excluding sensitive data.
"""
try:
task_id = ChatShare.verify_token(token, False)
except (SignatureExpired, BadTimeSignature):
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
stmt = select(ChatHistory).where(ChatHistory.task_id == task_id)
history = session.exec(stmt).one_or_none()
if not history:
raise HTTPException(status_code=404, detail="Chat history not found.")
return history
@router.get("/share/playback/{token}", name="Playback shared chat via SSE")
async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0):
"""
Playbacks the chat history via a sharing token (SSE).
delay_time: control sse interval, max 5 seconds
"""
if delay_time > 5:
delay_time = 5
try:
task_id = ChatShare.verify_token(token, False)
except SignatureExpired:
raise HTTPException(status_code=400, detail="Share link has expired.")
except BadTimeSignature:
raise HTTPException(status_code=400, detail="Share link is invalid.")
async def event_generator():
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
steps = session.exec(stmt).all()
if not steps:
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
return
for step in steps:
step_data = {
"id": step.id,
"task_id": step.task_id,
"step": step.step,
"data": step.data,
"created_at": step.created_at.isoformat() if step.created_at else None,
}
yield f"data: {json.dumps(step_data)}\n\n"
if delay_time > 0 and step.step != "create_agent":
await asyncio.sleep(delay_time)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
def create_share_link(data: ChatShareIn):
"""
Generates a sharing token with an expiration time for the specified task_id.
"""
share_token = ChatShare.generate_token(data.task_id)
return {"share_token": share_token}
from fastapi import APIRouter, Depends, HTTPException, Response
from sqlmodel import Session, asc, select
from app.component.database import session
import json
import asyncio
from itsdangerous import SignatureExpired, BadTimeSignature
from starlette.responses import StreamingResponse
from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn
from app.model.chat.chat_step import ChatStep
from app.model.chat.chat_history import ChatHistory
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_chat_share")
router = APIRouter(prefix="/chat", tags=["Chat Share"])
@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut)
@traceroot.trace()
def get_share_info(token: str, session: Session = Depends(session)):
"""
Get shared chat history info by token, excluding sensitive data.
"""
try:
task_id = ChatShare.verify_token(token, False)
except SignatureExpired:
logger.warning("Shared chat access failed: token expired", extra={"token_prefix": token[:10]})
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
except BadTimeSignature:
logger.warning("Shared chat access failed: invalid token", extra={"token_prefix": token[:10]})
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
stmt = select(ChatHistory).where(ChatHistory.task_id == task_id)
history = session.exec(stmt).one_or_none()
if not history:
logger.warning("Shared chat not found", extra={"task_id": task_id})
raise HTTPException(status_code=404, detail="Chat history not found.")
logger.info("Shared chat info accessed", extra={"task_id": task_id})
return history
@router.get("/share/playback/{token}", name="Playback shared chat via SSE")
@traceroot.trace()
async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0):
"""
Playbacks the chat history via a sharing token (SSE).
delay_time: control sse interval, max 5 seconds
"""
if delay_time > 5:
logger.debug("Delay time capped", extra={"requested": delay_time, "capped": 5})
delay_time = 5
try:
task_id = ChatShare.verify_token(token, False)
except SignatureExpired:
logger.warning("Shared chat playback failed: token expired", extra={"token_prefix": token[:10]})
raise HTTPException(status_code=400, detail="Share link has expired.")
except BadTimeSignature:
logger.warning("Shared chat playback failed: invalid token", extra={"token_prefix": token[:10]})
raise HTTPException(status_code=400, detail="Share link is invalid.")
async def event_generator():
try:
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
steps = session.exec(stmt).all()
if not steps:
logger.warning("No steps found for playback", extra={"task_id": task_id})
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
return
logger.info("Shared chat playback started", extra={"task_id": task_id, "step_count": len(steps), "delay_time": delay_time})
for idx, step in enumerate(steps, start=1):
step_data = {
"id": step.id,
"task_id": step.task_id,
"step": step.step,
"data": step.data,
"created_at": step.created_at.isoformat() if step.created_at else None,
}
yield f"data: {json.dumps(step_data)}\n\n"
if delay_time > 0 and step.step != "create_agent":
await asyncio.sleep(delay_time)
logger.info("Shared chat playback completed", extra={"task_id": task_id, "step_count": len(steps)})
except Exception as e:
logger.error("Shared chat playback error", extra={"task_id": task_id, "error": str(e)}, exc_info=True)
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
@traceroot.trace()
def create_share_link(data: ChatShareIn):
"""Generate sharing token with 1-day expiration for task."""
try:
share_token = ChatShare.generate_token(data.task_id)
logger.info("Share link created", extra={"task_id": data.task_id, "token_prefix": share_token[:10]})
return {"share_token": share_token}
except Exception as e:
logger.error("Share link creation failed", extra={"task_id": data.task_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,81 +1,138 @@
from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn
from typing import List, Optional
from fastapi import Depends, HTTPException, Response, APIRouter
from sqlmodel import Session, select
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"])
@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot])
async def list_chat_snapshots(
api_task_id: Optional[str] = None,
camel_task_id: Optional[str] = None,
browser_url: Optional[str] = None,
session: Session = Depends(session),
):
query = select(ChatSnapshot)
if api_task_id is not None:
query = query.where(ChatSnapshot.api_task_id == api_task_id)
if camel_task_id is not None:
query = query.where(ChatSnapshot.camel_task_id == camel_task_id)
if browser_url is not None:
query = query.where(ChatSnapshot.browser_url == browser_url)
snapshots = session.exec(query).all()
return snapshots
@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot)
async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
snapshot = session.get(ChatSnapshot, snapshot_id)
if not snapshot:
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
return snapshot
@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot)
async def create_chat_snapshot(
snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session)
):
image_path = ChatSnapshotIn.save_image(auth.user.id, snapshot.api_task_id, snapshot.image_base64)
chat_snapshot = ChatSnapshot(
user_id=auth.user.id,
api_task_id=snapshot.api_task_id,
camel_task_id=snapshot.camel_task_id,
browser_url=snapshot.browser_url,
image_path=image_path,
)
session.add(chat_snapshot)
session.commit()
session.refresh(chat_snapshot)
return Response(status_code=200)
@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot)
async def update_chat_snapshot(
snapshot_id: int,
snapshot_update: ChatSnapshot,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
db_snapshot = session.get(ChatSnapshot, snapshot_id)
if not db_snapshot:
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
for key, value in snapshot_update.dict(exclude_unset=True).items():
setattr(db_snapshot, key, value)
session.add(db_snapshot)
session.commit()
session.refresh(db_snapshot)
return db_snapshot
@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot")
async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
db_snapshot = session.get(ChatSnapshot, snapshot_id)
if not db_snapshot:
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
session.delete(db_snapshot)
session.commit()
return Response(status_code=204)
from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn
from typing import List, Optional
from fastapi import Depends, HTTPException, Response, APIRouter
from sqlmodel import Session, select
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_chat_snapshot")
router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"])
@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot])
@traceroot.trace()
async def list_chat_snapshots(
api_task_id: Optional[str] = None,
camel_task_id: Optional[str] = None,
browser_url: Optional[str] = None,
session: Session = Depends(session),
):
"""List chat snapshots with optional filtering."""
query = select(ChatSnapshot)
if api_task_id is not None:
query = query.where(ChatSnapshot.api_task_id == api_task_id)
if camel_task_id is not None:
query = query.where(ChatSnapshot.camel_task_id == camel_task_id)
if browser_url is not None:
query = query.where(ChatSnapshot.browser_url == browser_url)
snapshots = session.exec(query).all()
logger.debug("Snapshots listed", extra={"api_task_id": api_task_id, "camel_task_id": camel_task_id, "count": len(snapshots)})
return snapshots
@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot)
@traceroot.trace()
async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Get specific chat snapshot."""
user_id = auth.user.id
snapshot = session.get(ChatSnapshot, snapshot_id)
if not snapshot:
logger.warning("Snapshot not found", extra={"user_id": user_id, "snapshot_id": snapshot_id})
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
logger.debug("Snapshot retrieved", extra={"user_id": user_id, "snapshot_id": snapshot_id, "api_task_id": snapshot.api_task_id})
return snapshot
@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot)
@traceroot.trace()
async def create_chat_snapshot(
snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session)
):
"""Create new chat snapshot from image."""
user_id = auth.user.id
try:
image_path = ChatSnapshotIn.save_image(user_id, snapshot.api_task_id, snapshot.image_base64)
chat_snapshot = ChatSnapshot(
user_id=user_id,
api_task_id=snapshot.api_task_id,
camel_task_id=snapshot.camel_task_id,
browser_url=snapshot.browser_url,
image_path=image_path,
)
session.add(chat_snapshot)
session.commit()
session.refresh(chat_snapshot)
logger.info("Snapshot created", extra={"user_id": user_id, "snapshot_id": chat_snapshot.id, "api_task_id": snapshot.api_task_id, "image_path": image_path})
return chat_snapshot
except Exception as e:
session.rollback()
logger.error("Snapshot creation failed", extra={"user_id": user_id, "api_task_id": snapshot.api_task_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot)
@traceroot.trace()
async def update_chat_snapshot(
snapshot_id: int,
snapshot_update: ChatSnapshot,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
"""Update chat snapshot."""
user_id = auth.user.id
db_snapshot = session.get(ChatSnapshot, snapshot_id)
if not db_snapshot:
logger.warning("Snapshot not found for update", extra={"user_id": user_id, "snapshot_id": snapshot_id})
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
if db_snapshot.user_id != user_id:
logger.warning("Unauthorized snapshot update", extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id})
raise HTTPException(status_code=403, detail=_("You are not allowed to update this snapshot"))
try:
update_data = snapshot_update.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_snapshot, key, value)
session.add(db_snapshot)
session.commit()
session.refresh(db_snapshot)
logger.info("Snapshot updated", extra={"user_id": user_id, "snapshot_id": snapshot_id, "fields_updated": list(update_data.keys())})
return db_snapshot
except Exception as e:
session.rollback()
logger.error("Snapshot update failed", extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot")
@traceroot.trace()
async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete chat snapshot."""
user_id = auth.user.id
db_snapshot = session.get(ChatSnapshot, snapshot_id)
if not db_snapshot:
logger.warning("Snapshot not found for deletion", extra={"user_id": user_id, "snapshot_id": snapshot_id})
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
if db_snapshot.user_id != user_id:
logger.warning("Unauthorized snapshot deletion", extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id})
raise HTTPException(status_code=403, detail=_("You are not allowed to delete this snapshot"))
try:
session.delete(db_snapshot)
session.commit()
logger.info("Snapshot deleted", extra={"user_id": user_id, "snapshot_id": snapshot_id, "image_path": db_snapshot.image_path})
return Response(status_code=204)
except Exception as e:
session.rollback()
logger.error("Snapshot deletion failed", extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,105 +1,163 @@
import asyncio
import json
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from fastapi.responses import StreamingResponse
from sqlmodel import Session, asc, select
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn
router = APIRouter(prefix="/chat", tags=["Chat Step Management"])
@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut])
async def list_chat_steps(
task_id: str, step: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
query = select(ChatStep)
if task_id is not None:
query = query.where(ChatStep.task_id == task_id)
if step is not None:
query = query.where(ChatStep.step == step)
chat_steps = session.exec(query).all()
return chat_steps
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
async def share_playback(
task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""
Playbacks the chat steps (SSE).
"""
if delay_time > 5:
delay_time = 5
async def event_generator():
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
steps = session.exec(stmt).all()
if not steps:
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
return
for step in steps:
step_data = {
"id": step.id,
"task_id": step.task_id,
"step": step.step,
"data": step.data,
"created_at": step.created_at.isoformat() if step.created_at else None,
}
yield f"data: {json.dumps(step_data)}\n\n"
if delay_time > 0:
await asyncio.sleep(delay_time)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
chat_step = session.get(ChatStep, step_id)
if not chat_step:
raise HTTPException(status_code=404, detail=_("Chat step not found"))
return chat_step
@router.post("/steps", name="create chat step")
# TODO Limit request sources
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
chat_step = ChatStep(
task_id=step.task_id,
step=step.step,
data=step.data,
)
session.add(chat_step)
session.commit()
session.refresh(chat_step)
return {"code": 200, "msg": "success"}
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
async def update_chat_step(
step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
db_chat_step = session.get(ChatStep, step_id)
if not db_chat_step:
raise HTTPException(status_code=404, detail=_("Chat step not found"))
for key, value in chat_step_update.dict(exclude_unset=True).items():
setattr(db_chat_step, key, value)
session.add(db_chat_step)
session.commit()
session.refresh(db_chat_step)
return db_chat_step
@router.delete("/steps/{step_id}", name="delete chat step")
async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
db_chat_step = session.get(ChatStep, step_id)
if not db_chat_step:
raise HTTPException(status_code=404, detail=_("Chat step not found"))
session.delete(db_chat_step)
session.commit()
return Response(status_code=204)
import asyncio
import json
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from fastapi.responses import StreamingResponse
from sqlmodel import Session, asc, select
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_chat_step")
router = APIRouter(prefix="/chat", tags=["Chat Step Management"])
@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut])
@traceroot.trace()
async def list_chat_steps(
task_id: str, step: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""List chat steps for a task with optional step type filtering."""
user_id = auth.user.id
query = select(ChatStep)
if task_id is not None:
query = query.where(ChatStep.task_id == task_id)
if step is not None:
query = query.where(ChatStep.step == step)
chat_steps = session.exec(query).all()
logger.debug("Chat steps listed", extra={"user_id": user_id, "task_id": task_id, "step_type": step, "count": len(chat_steps)})
return chat_steps
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
@traceroot.trace()
async def share_playback(
task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Playback chat steps via SSE stream."""
user_id = auth.user.id
if delay_time > 5:
logger.debug("Delay time capped", extra={"user_id": user_id, "task_id": task_id, "requested": delay_time, "capped": 5})
delay_time = 5
async def event_generator():
try:
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
steps = session.exec(stmt).all()
if not steps:
logger.warning("No steps found for playback", extra={"user_id": user_id, "task_id": task_id})
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
return
logger.info("Chat step playback started", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps), "delay_time": delay_time})
for step in steps:
step_data = {
"id": step.id,
"task_id": step.task_id,
"step": step.step,
"data": step.data,
"created_at": step.created_at.isoformat() if step.created_at else None,
}
yield f"data: {json.dumps(step_data)}\n\n"
if delay_time > 0:
await asyncio.sleep(delay_time)
logger.info("Chat step playback completed", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps)})
except Exception as e:
logger.error("Chat step playback error", extra={"user_id": user_id, "task_id": task_id, "error": str(e)}, exc_info=True)
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
@traceroot.trace()
async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Get specific chat step."""
user_id = auth.user.id
chat_step = session.get(ChatStep, step_id)
if not chat_step:
logger.warning("Chat step not found", extra={"user_id": user_id, "step_id": step_id})
raise HTTPException(status_code=404, detail=_("Chat step not found"))
logger.debug("Chat step retrieved", extra={"user_id": user_id, "step_id": step_id, "task_id": chat_step.task_id})
return chat_step
@router.post("/steps", name="create chat step")
@traceroot.trace()
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
"""Create new chat step. TODO: Implement request source validation."""
try:
chat_step = ChatStep(
task_id=step.task_id,
step=step.step,
data=step.data,
)
session.add(chat_step)
session.commit()
session.refresh(chat_step)
logger.info("Chat step created", extra={"step_id": chat_step.id, "task_id": step.task_id, "step_type": step.step})
return {"code": 200, "msg": "success"}
except Exception as e:
session.rollback()
logger.error("Chat step creation failed", extra={"task_id": step.task_id, "step_type": step.step, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
@traceroot.trace()
async def update_chat_step(
step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Update chat step."""
user_id = auth.user.id
db_chat_step = session.get(ChatStep, step_id)
if not db_chat_step:
logger.warning("Chat step not found for update", extra={"user_id": user_id, "step_id": step_id})
raise HTTPException(status_code=404, detail=_("Chat step not found"))
try:
update_data = chat_step_update.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_chat_step, key, value)
session.add(db_chat_step)
session.commit()
session.refresh(db_chat_step)
logger.info("Chat step updated", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id, "fields_updated": list(update_data.keys())})
return db_chat_step
except Exception as e:
session.rollback()
logger.error("Chat step update failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/steps/{step_id}", name="delete chat step")
@traceroot.trace()
async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete chat step."""
user_id = auth.user.id
db_chat_step = session.get(ChatStep, step_id)
if not db_chat_step:
logger.warning("Chat step not found for deletion", extra={"user_id": user_id, "step_id": step_id})
raise HTTPException(status_code=404, detail=_("Chat step not found"))
try:
session.delete(db_chat_step)
session.commit()
logger.info("Chat step deleted", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id})
return Response(status_code=204)
except Exception as e:
session.rollback()
logger.error("Chat step deletion failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,121 +1,172 @@
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from sqlmodel import Session, select, or_
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from app.model.config.config import Config, ConfigCreate, ConfigUpdate, ConfigInfo, ConfigOut
router = APIRouter(tags=["Config Management"])
@router.get("/configs", name="list configs", response_model=list[ConfigOut])
async def list_configs(
config_group: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
query = select(Config)
user_id = auth.user.id
if user_id is not None:
query = query.where(Config.user_id == user_id)
if config_group is not None:
query = query.where(Config.config_group == config_group)
configs = session.exec(query).all()
return configs
@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut)
async def get_config(
config_id: int,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
query = select(Config).where(Config.user_id == auth.user.id)
if config_id is not None:
query = query.where(Config.id == config_id)
config = session.exec(query).first()
if not config:
raise HTTPException(status_code=404, detail=_("Configuration not found"))
return config
@router.post("/configs", name="create config", response_model=ConfigOut)
async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name):
raise HTTPException(status_code=400, detail=_("Config Name is valid"))
# Check if configuration already exists
existing_config = session.exec(
select(Config).where(Config.user_id == auth.user.id, Config.config_name == config.config_name)
).first()
if existing_config:
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
db_config = Config(
user_id=auth.user.id,
config_name=config.config_name,
config_value=config.config_value,
config_group=config.config_group,
)
session.add(db_config)
session.commit()
session.refresh(db_config)
return db_config
@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut)
async def update_config(
config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == auth.user.id)).first()
if not db_config:
raise HTTPException(status_code=404, detail=_("Configuration not found"))
# Check if configuration group is valid
if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name):
raise HTTPException(status_code=400, detail=_("Invalid configuration group"))
# Check for conflicts with other configurations
existing_config = session.exec(
select(Config).where(
Config.user_id == auth.user.id,
Config.config_name == config_update.config_name,
Config.id != config_id,
)
).first()
if existing_config:
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
db_config.config_name = config_update.config_name
db_config.config_value = config_update.config_value
session.add(db_config)
session.commit()
session.refresh(db_config)
return db_config
@router.delete("/configs/{config_id}", name="delete config")
async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == auth.user.id)).first()
if not db_config:
raise HTTPException(status_code=404, detail=_("Configuration not found"))
session.delete(db_config)
session.commit()
return Response(status_code=204)
@router.get("/config/info", name="get config info")
async def get_config_info(
show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"),
):
configs = ConfigInfo.getinfo()
if show_all:
return configs
return {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0}
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from sqlmodel import Session, select, or_
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from app.model.config.config import Config, ConfigCreate, ConfigUpdate, ConfigInfo, ConfigOut
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_config_controller")
router = APIRouter(tags=["Config Management"])
@router.get("/configs", name="list configs", response_model=list[ConfigOut])
@traceroot.trace()
async def list_configs(
config_group: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""List user's configurations with optional group filtering."""
user_id = auth.user.id
query = select(Config).where(Config.user_id == user_id)
if config_group is not None:
query = query.where(Config.config_group == config_group)
configs = session.exec(query).all()
logger.debug("Configs listed", extra={"user_id": user_id, "config_group": config_group, "count": len(configs)})
return configs
@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut)
@traceroot.trace()
async def get_config(
config_id: int,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
query = select(Config).where(Config.user_id == auth.user.id)
if config_id is not None:
query = query.where(Config.id == config_id)
config = session.exec(query).first()
if not config:
logger.warning("Config not found")
raise HTTPException(status_code=404, detail=_("Configuration not found"))
logger.debug("Config retrieved")
return config
@router.post("/configs", name="create config", response_model=ConfigOut)
@traceroot.trace()
async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Create new configuration."""
user_id = auth.user.id
if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name):
logger.warning("Config validation failed", extra={"user_id": user_id, "config_group": config.config_group, "config_name": config.config_name})
raise HTTPException(status_code=400, detail=_("Invalid config name or group"))
# Check if configuration already exists
existing_config = session.exec(
select(Config).where(Config.user_id == user_id, Config.config_name == config.config_name)
).first()
if existing_config:
logger.warning("Config creation failed: already exists", extra={"user_id": user_id, "config_name": config.config_name})
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
try:
db_config = Config(
user_id=user_id,
config_name=config.config_name,
config_value=config.config_value,
config_group=config.config_group,
)
session.add(db_config)
session.commit()
session.refresh(db_config)
logger.info("Config created", extra={"user_id": user_id, "config_id": db_config.id, "config_group": config.config_group, "config_name": config.config_name})
return db_config
except Exception as e:
session.rollback()
logger.error("Config creation failed", extra={"user_id": user_id, "config_name": config.config_name, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut)
@traceroot.trace()
async def update_config(
config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Update configuration."""
user_id = auth.user.id
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first()
if not db_config:
logger.warning("Config not found for update", extra={"user_id": user_id, "config_id": config_id})
raise HTTPException(status_code=404, detail=_("Configuration not found"))
# Check if configuration group is valid
if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name):
logger.warning("Config update validation failed", extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group})
raise HTTPException(status_code=400, detail=_("Invalid configuration group"))
# Check for conflicts with other configurations
existing_config = session.exec(
select(Config).where(
Config.user_id == user_id,
Config.config_name == config_update.config_name,
Config.id != config_id,
)
).first()
if existing_config:
logger.warning("Config update failed: duplicate name", extra={"user_id": user_id, "config_id": config_id, "config_name": config_update.config_name})
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
try:
db_config.config_name = config_update.config_name
db_config.config_value = config_update.config_value
db_config.config_group = config_update.config_group
session.add(db_config)
session.commit()
session.refresh(db_config)
logger.info("Config updated", extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group})
return db_config
except Exception as e:
session.rollback()
logger.error("Config update failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/configs/{config_id}", name="delete config")
@traceroot.trace()
async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete configuration."""
user_id = auth.user.id
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first()
if not db_config:
logger.warning("Config not found for deletion", extra={"user_id": user_id, "config_id": config_id})
raise HTTPException(status_code=404, detail=_("Configuration not found"))
try:
session.delete(db_config)
session.commit()
logger.info("Config deleted", extra={"user_id": user_id, "config_id": config_id, "config_name": db_config.config_name})
return Response(status_code=204)
except Exception as e:
session.rollback()
logger.error("Config deletion failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/config/info", name="get config info")
@traceroot.trace()
async def get_config_info(
show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"),
):
"""Get available configuration templates and info."""
configs = ConfigInfo.getinfo()
if show_all:
logger.debug("Config info retrieved", extra={"show_all": True, "count": len(configs)})
return configs
filtered = {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0}
logger.debug("Config info retrieved", extra={"show_all": False, "total_count": len(configs), "filtered_count": len(filtered)})
return filtered

View file

@ -0,0 +1,15 @@
from fastapi import APIRouter
from pydantic import BaseModel
router = APIRouter(tags=["Health"])
class HealthResponse(BaseModel):
status: str
service: str
@router.get("/health", name="health check", response_model=HealthResponse)
async def health_check():
"""Health check endpoint for monitoring and container orchestration."""
return HealthResponse(status="ok", service="eigent-server")

View file

@ -1,214 +1,262 @@
import os
from typing import Dict
from fastapi import Depends, HTTPException, APIRouter
from fastapi_babel import _
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import Session, col, select
from sqlalchemy.orm import selectinload, with_loader_criteria
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.model.mcp.mcp import Mcp, McpOut, McpType
from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus
from app.model.mcp.mcp_user import McpImportType, McpUser, Status
from loguru import logger
from camel.toolkits.mcp_toolkit import MCPToolkit
from app.component.environment import env
from app.component.validator.McpServer import (
McpRemoteServer,
McpServerItem,
validate_mcp_remote_servers,
validate_mcp_servers,
)
router = APIRouter(tags=["Mcp Servers"])
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
"""
Pre-instantiate MCP toolkit to complete authentication process
Args:
config_dict: MCP server configuration dictionary
Returns:
bool: Whether successfully instantiated and connected
"""
try:
# Ensure unified auth directory for all mcp servers
for server_config in config_dict.get("mcpServers", {}).values():
if "env" not in server_config:
server_config["env"] = {}
# Set global auth directory to persist authentication across tasks
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
"MCP_REMOTE_CONFIG_DIR",
os.path.expanduser("~/.mcp-auth")
)
# Create MCP toolkit and attempt to connect
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
await mcp_toolkit.connect()
# Get tools list to ensure connection is successful
tools = mcp_toolkit.get_tools()
logger.info(f"Successfully pre-instantiated MCP toolkit with {len(tools)} tools")
# Disconnect, authentication info is already saved
await mcp_toolkit.disconnect()
return True
except Exception as e:
logger.warning(f"Failed to pre-instantiate MCP toolkit: {e!r}")
return False
@router.get("/mcps", name="mcp list")
async def gets(
keyword: str | None = None,
category_id: int | None = None,
mine: int | None = None,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
) -> Page[McpOut]:
stmt = (
select(Mcp)
.where(Mcp.no_delete())
.options(
selectinload(Mcp.category),
selectinload(Mcp.envs),
with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use),
)
# .order_by(col(Mcp.sort).desc())
)
if keyword:
stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%"))
if category_id:
stmt = stmt.where(Mcp.category_id == category_id)
if mine and auth:
stmt = (
stmt.join(McpUser)
.where(McpUser.user_id == auth.user.id)
.options(
selectinload(Mcp.mcp_user),
with_loader_criteria(McpUser, col(McpUser.user_id) == auth.user.id),
)
)
return paginate(session, stmt)
@router.get("/mcp", name="mcp detail", response_model=McpOut)
async def get(id: int, session: Session = Depends(session)):
stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs))
model = session.exec(stmt).one()
return model
@router.post("/mcp/install", name="mcp install")
async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
mcp = session.get_one(Mcp, mcp_id)
if not mcp:
raise HTTPException(status_code=404, detail=_("Mcp not found"))
exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == auth.user.id)).first()
if exists:
raise HTTPException(status_code=400, detail=_("mcp is installed"))
install_command: dict = mcp.install_command
# Pre-instantiate MCP toolkit for authentication
config_dict = {
"mcpServers": {
mcp.key: install_command
}
}
try:
success = await pre_instantiate_mcp_toolkit(config_dict)
if not success:
logger.warning(f"Pre-instantiation failed for MCP {mcp.key}, but continuing with installation")
except Exception as e:
logger.warning(f"Exception during pre-instantiation for MCP {mcp.key}: {e}")
mcp_user = McpUser(
mcp_id=mcp.id,
user_id=auth.user.id,
mcp_name=mcp.name,
mcp_key=mcp.key,
mcp_desc=mcp.description,
type=mcp.type,
status=Status.enable,
command=install_command["command"],
args=install_command["args"],
env=install_command["env"],
server_url=None,
)
mcp_user.save()
return mcp_user
@router.post("/mcp/import/{mcp_type}", name="mcp import")
async def import_mcp(
mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
logger.debug(mcp_type, mcp_type.value)
if mcp_type == McpImportType.Local:
is_valid, res = validate_mcp_servers(mcp_data)
if not is_valid:
raise HTTPException(status_code=400, detail=res)
mcp_data: Dict[str, McpServerItem] = res.mcpServers
for name, data in mcp_data.items():
# Pre-instantiate MCP toolkit for authentication
config_dict = {
"mcpServers": {
name: {
"command": data.command,
"args": data.args,
"env": data.env or {}
}
}
}
try:
success = await pre_instantiate_mcp_toolkit(config_dict)
if not success:
logger.warning(f"Pre-instantiation failed for local MCP {name}, but continuing with installation")
except Exception as e:
logger.warning(f"Exception during pre-instantiation for local MCP {name}: {e}")
mcp_user = McpUser(
mcp_id=0,
user_id=auth.user.id,
mcp_name=name,
mcp_key=name,
mcp_desc=name,
type=McpType.Local,
status=Status.enable,
command=data.command,
args=data.args,
env=data.env,
server_url=None,
)
mcp_user.save()
return {"message": "Local MCP servers imported successfully", "count": len(mcp_data)}
elif mcp_type == McpImportType.Remote:
is_valid, res = validate_mcp_remote_servers(mcp_data)
if not is_valid:
raise HTTPException(status_code=400, detail=res)
data: McpRemoteServer = res
# For remote servers, we don't need to pre-instantiate as they typically don't require authentication
# but we can still try to validate the connection if needed
mcp_user = McpUser(
mcp_id=0,
user_id=auth.user.id,
type=McpType.Remote,
status=Status.enable,
mcp_name=data.server_name,
server_url=data.server_url,
)
mcp_user.save()
return mcp_user
import os
from typing import Dict
from fastapi import Depends, HTTPException, APIRouter
from fastapi_babel import _
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import Session, col, select
from sqlalchemy.orm import selectinload, with_loader_criteria
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.model.mcp.mcp import Mcp, McpOut, McpType
from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus
from app.model.mcp.mcp_user import McpImportType, McpUser, Status
from camel.toolkits.mcp_toolkit import MCPToolkit
from app.component.environment import env
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_mcp_controller")
from app.component.validator.McpServer import (
McpRemoteServer,
McpServerItem,
validate_mcp_remote_servers,
validate_mcp_servers,
)
router = APIRouter(tags=["Mcp Servers"])
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
"""
Pre-instantiate MCP toolkit to complete authentication process
Args:
config_dict: MCP server configuration dictionary
Returns:
bool: Whether successfully instantiated and connected
"""
try:
# Ensure unified auth directory for all mcp servers
for server_config in config_dict.get("mcpServers", {}).values():
if "env" not in server_config:
server_config["env"] = {}
# Set global auth directory to persist authentication across tasks
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
"MCP_REMOTE_CONFIG_DIR",
os.path.expanduser("~/.mcp-auth")
)
# Create MCP toolkit and attempt to connect
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
await mcp_toolkit.connect()
# Get tools list to ensure connection is successful
tools = mcp_toolkit.get_tools()
logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)})
# Disconnect, authentication info is already saved
await mcp_toolkit.disconnect()
return True
except Exception as e:
logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True)
return False
@router.get("/mcps", name="mcp list")
@traceroot.trace()
async def gets(
keyword: str | None = None,
category_id: int | None = None,
mine: int | None = None,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
) -> Page[McpOut]:
"""List MCP servers with optional filtering."""
user_id = auth.user.id
stmt = (
select(Mcp)
.where(Mcp.no_delete())
.options(
selectinload(Mcp.category),
selectinload(Mcp.envs),
with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use),
)
)
if keyword:
stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%"))
if category_id:
stmt = stmt.where(Mcp.category_id == category_id)
if mine and auth:
stmt = (
stmt.join(McpUser)
.where(McpUser.user_id == user_id)
.options(
selectinload(Mcp.mcp_user),
with_loader_criteria(McpUser, col(McpUser.user_id) == user_id),
)
)
result = paginate(session, stmt)
total = result.total if hasattr(result, 'total') else 0
logger.debug("MCP list retrieved", extra={"user_id": user_id, "keyword": keyword, "category_id": category_id, "mine": mine, "total": total})
return result
@router.get("/mcp", name="mcp detail", response_model=McpOut)
@traceroot.trace()
async def get(id: int, session: Session = Depends(session)):
"""Get MCP server details."""
try:
stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs))
model = session.exec(stmt).one()
logger.debug("MCP detail retrieved", extra={"mcp_id": id, "mcp_key": model.key})
return model
except Exception as e:
logger.warning("MCP not found", extra={"mcp_id": id})
raise HTTPException(status_code=404, detail=_("Mcp not found"))
@router.post("/mcp/install", name="mcp install")
@traceroot.trace()
async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Install MCP server for user."""
user_id = auth.user.id
mcp = session.get_one(Mcp, mcp_id)
if not mcp:
logger.warning("MCP install failed: MCP not found", extra={"user_id": user_id, "mcp_id": mcp_id})
raise HTTPException(status_code=404, detail=_("Mcp not found"))
exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == user_id)).first()
if exists:
logger.warning("MCP install failed: already installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
raise HTTPException(status_code=400, detail=_("mcp is installed"))
install_command: dict = mcp.install_command
# Pre-instantiate MCP toolkit for authentication
config_dict = {
"mcpServers": {
mcp.key: install_command
}
}
try:
success = await pre_instantiate_mcp_toolkit(config_dict)
if not success:
logger.warning("MCP pre-instantiation failed, continuing with installation", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
else:
logger.debug("MCP toolkit pre-instantiated", extra={"mcp_key": mcp.key})
except Exception as e:
logger.warning("MCP pre-instantiation exception", extra={"user_id": user_id, "mcp_key": mcp.key, "error": str(e)}, exc_info=True)
try:
mcp_user = McpUser(
mcp_id=mcp.id,
user_id=user_id,
mcp_name=mcp.name,
mcp_key=mcp.key,
mcp_desc=mcp.description,
type=mcp.type,
status=Status.enable,
command=install_command["command"],
args=install_command["args"],
env=install_command["env"],
server_url=None,
)
mcp_user.save()
logger.info("MCP installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
return mcp_user
except Exception as e:
logger.error("MCP installation failed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/mcp/import/{mcp_type}", name="mcp import")
@traceroot.trace()
async def import_mcp(
mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Import MCP servers (local or remote)."""
user_id = auth.user.id
if mcp_type == McpImportType.Local:
logger.info("Importing local MCP servers", extra={"user_id": user_id})
is_valid, res = validate_mcp_servers(mcp_data)
if not is_valid:
logger.warning("Local MCP import validation failed", extra={"user_id": user_id, "error": res})
raise HTTPException(status_code=400, detail=res)
mcp_data: Dict[str, McpServerItem] = res.mcpServers
imported_count = 0
for name, data in mcp_data.items():
config_dict = {
"mcpServers": {
name: {
"command": data.command,
"args": data.args,
"env": data.env or {}
}
}
}
try:
success = await pre_instantiate_mcp_toolkit(config_dict)
if not success:
logger.warning("Local MCP pre-instantiation failed, continuing", extra={"user_id": user_id, "mcp_name": name})
except Exception as e:
logger.warning("Local MCP pre-instantiation exception", extra={"user_id": user_id, "mcp_name": name, "error": str(e)})
try:
mcp_user = McpUser(
mcp_id=0,
user_id=user_id,
mcp_name=name,
mcp_key=name,
mcp_desc=name,
type=McpType.Local,
status=Status.enable,
command=data.command,
args=data.args,
env=data.env,
server_url=None,
)
mcp_user.save()
imported_count += 1
except Exception as e:
logger.error("Failed to import local MCP", extra={"user_id": user_id, "mcp_name": name, "error": str(e)}, exc_info=True)
logger.info("Local MCPs imported", extra={"user_id": user_id, "count": imported_count})
return {"message": "Local MCP servers imported successfully", "count": imported_count}
elif mcp_type == McpImportType.Remote:
logger.info("Importing remote MCP server", extra={"user_id": user_id})
is_valid, res = validate_mcp_remote_servers(mcp_data)
if not is_valid:
logger.warning("Remote MCP import validation failed", extra={"user_id": user_id, "error": res})
raise HTTPException(status_code=400, detail=res)
data: McpRemoteServer = res
try:
# For remote servers, we don't need to pre-instantiate as they typically don't require authentication
# but we can still try to validate the connection if needed
mcp_user = McpUser(
mcp_id=0,
user_id=user_id,
type=McpType.Remote,
status=Status.enable,
mcp_name=data.server_name,
server_url=data.server_url,
)
mcp_user.save()
logger.info("Remote MCP imported", extra={"user_id": user_id, "server_name": data.server_name, "server_url": data.server_url})
return mcp_user
except Exception as e:
logger.error("Remote MCP import failed", extra={"user_id": user_id, "server_name": data.server_name, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,173 +1,196 @@
from fastapi import APIRouter, Depends
from exa_py import Exa
from loguru import logger
from app.component.auth import key_must
from app.component.environment import env_not_empty
from app.model.mcp.proxy import ExaSearch
from typing import Any, cast
import requests
from app.model.user.key import Key
router = APIRouter(prefix="/proxy", tags=["Mcp Servers"])
@router.post("/exa")
def exa_search(search: ExaSearch, key: Key = Depends(key_must)):
EXA_API_KEY = env_not_empty("EXA_API_KEY")
try:
exa = Exa(EXA_API_KEY)
if search.num_results is not None and not 0 < search.num_results <= 100:
raise ValueError("num_results must be between 1 and 100")
if search.include_text is not None:
if len(search.include_text) > 1:
raise ValueError("include_text can only contain 1 string")
if len(search.include_text[0].split()) > 5:
raise ValueError("include_text string cannot be longer than 5 words")
if search.exclude_text is not None:
if len(search.exclude_text) > 1:
raise ValueError("exclude_text can only contain 1 string")
if len(search.exclude_text[0].split()) > 5:
raise ValueError("exclude_text string cannot be longer than 5 words")
# Call Exa API with direct parameters
if search.text:
results = cast(
dict[str, Any],
exa.search_and_contents(
query=search.query,
type=search.search_type,
category=search.category,
num_results=search.num_results,
include_text=search.include_text,
exclude_text=search.exclude_text,
use_autoprompt=search.use_autoprompt,
text=True,
),
)
else:
results = cast(
dict[str, Any],
exa.search(
query=search.query,
type=search.search_type,
category=search.category,
num_results=search.num_results,
include_text=search.include_text,
exclude_text=search.exclude_text,
use_autoprompt=search.use_autoprompt,
),
)
return results
except Exception as e:
return {"error": f"Exa search failed: {e!s}"}
@router.get("/google")
def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)):
# https://developers.google.com/custom-search/v1/overview
GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY")
# https://cse.google.com/cse/all
SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID")
# Using the first page
start_page_idx = 1
# Different language may get different result
search_language = "en"
# How many pages to return
num_result_pages = 10
# Constructing the URL
# Doc: https://developers.google.com/custom-search/v1/using_rest
base_url = (
f"https://www.googleapis.com/customsearch/v1?"
f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start="
f"{start_page_idx}&lr={search_language}&num={num_result_pages}"
)
if search_type == "image":
url = base_url + "&searchType=image"
else:
url = base_url
responses = []
# Fetch the results given the URL
try:
# Make the get
result = requests.get(url)
data = result.json()
# Get the result items
if "items" in data:
search_items = data.get("items")
# Iterate over results found
for i, search_item in enumerate(search_items, start=1):
if search_type == "image":
# Process image search results
title = search_item.get("title")
image_url = search_item.get("link")
display_link = search_item.get("displayLink")
# Get context URL (page containing the image)
image_info = search_item.get("image", {})
context_url = image_info.get("contextLink", "")
# Get image dimensions if available
width = image_info.get("width")
height = image_info.get("height")
response = {
"result_id": i,
"title": title,
"image_url": image_url,
"display_link": display_link,
"context_url": context_url,
}
# Add dimensions if available
if width:
response["width"] = int(width)
if height:
response["height"] = int(height)
responses.append(response)
else:
# Process web search results (existing logic)
# Check metatags are present
if "pagemap" not in search_item:
continue
if "metatags" not in search_item["pagemap"]:
continue
if "og:description" in search_item["pagemap"]["metatags"][0]:
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
else:
long_description = "N/A"
# Get the page title
title = search_item.get("title")
# Page snippet
snippet = search_item.get("snippet")
# Extract the page url
link = search_item.get("link")
response = {
"result_id": i,
"title": title,
"description": snippet,
"long_description": long_description,
"url": link,
}
responses.append(response)
else:
error_info = data.get("error", {})
logger.error(f"Google search failed - API response: {error_info}")
responses.append({"error": f"Google search failed - API response: {error_info}"})
except Exception as e:
responses.append({"error": f"google search failed: {e!s}"})
return responses
from fastapi import APIRouter, Depends, HTTPException
from exa_py import Exa
from app.component.auth import key_must
from app.component.environment import env_not_empty
from app.model.mcp.proxy import ExaSearch
from typing import Any, cast
import requests
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_proxy_controller")
from app.model.user.key import Key
router = APIRouter(prefix="/proxy", tags=["Mcp Servers"])
@router.post("/exa")
@traceroot.trace()
def exa_search(search: ExaSearch, key: Key = Depends(key_must)):
"""Search using Exa API."""
EXA_API_KEY = env_not_empty("EXA_API_KEY")
try:
# Validate input parameters
if search.num_results is not None and not 0 < search.num_results <= 100:
logger.warning("Invalid exa search parameter", extra={"param": "num_results", "value": search.num_results})
raise ValueError("num_results must be between 1 and 100")
if search.include_text is not None and len(search.include_text) > 0:
if len(search.include_text) > 1:
logger.warning("Invalid exa search parameter", extra={"param": "include_text", "reason": "more than 1 string"})
raise ValueError("include_text can only contain 1 string")
if len(search.include_text[0].split()) > 5:
logger.warning("Invalid exa search parameter", extra={"param": "include_text", "reason": "exceeds 5 words"})
raise ValueError("include_text string cannot be longer than 5 words")
if search.exclude_text is not None and len(search.exclude_text) > 0:
if len(search.exclude_text) > 1:
logger.warning("Invalid exa search parameter", extra={"param": "exclude_text", "reason": "more than 1 string"})
raise ValueError("exclude_text can only contain 1 string")
if len(search.exclude_text[0].split()) > 5:
logger.warning("Invalid exa search parameter", extra={"param": "exclude_text", "reason": "exceeds 5 words"})
raise ValueError("exclude_text string cannot be longer than 5 words")
exa = Exa(EXA_API_KEY)
# Call Exa API with direct parameters
if search.text:
results = cast(
dict[str, Any],
exa.search_and_contents(
query=search.query,
type=search.search_type,
category=search.category,
num_results=search.num_results,
include_text=search.include_text,
exclude_text=search.exclude_text,
use_autoprompt=search.use_autoprompt,
text=True,
),
)
else:
results = cast(
dict[str, Any],
exa.search(
query=search.query,
type=search.search_type,
category=search.category,
num_results=search.num_results,
include_text=search.include_text,
exclude_text=search.exclude_text,
use_autoprompt=search.use_autoprompt,
),
)
result_count = len(results.get("results", [])) if "results" in results else 0
logger.info("Exa search completed", extra={"query": search.query, "search_type": search.search_type, "result_count": result_count})
return results
except ValueError as e:
logger.warning("Exa search validation error", extra={"error": str(e)})
raise HTTPException(status_code=500, detail="Internal server error")
except Exception as e:
logger.error("Exa search failed", extra={"query": search.query, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/google")
@traceroot.trace()
def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)):
"""Search using Google Custom Search API."""
# https://developers.google.com/custom-search/v1/overview
GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY")
# https://cse.google.com/cse/all
SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID")
# Using the first page
start_page_idx = 1
# Different language may get different result
search_language = "en"
# How many pages to return
num_result_pages = 10
# Constructing the URL
# Doc: https://developers.google.com/custom-search/v1/using_rest
base_url = (
f"https://www.googleapis.com/customsearch/v1?"
f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start="
f"{start_page_idx}&lr={search_language}&num={num_result_pages}"
)
if search_type == "image":
url = base_url + "&searchType=image"
else:
url = base_url
responses = []
try:
# Make the GET request
result = requests.get(url)
data = result.json()
# Get the result items
if "items" in data:
search_items = data.get("items")
# Iterate over results found
for i, search_item in enumerate(search_items, start=1):
if search_type == "image":
# Process image search results
title = search_item.get("title")
image_url = search_item.get("link")
display_link = search_item.get("displayLink")
# Get context URL (page containing the image)
image_info = search_item.get("image", {})
context_url = image_info.get("contextLink", "")
# Get image dimensions if available
width = image_info.get("width")
height = image_info.get("height")
response = {
"result_id": i,
"title": title,
"image_url": image_url,
"display_link": display_link,
"context_url": context_url,
}
# Add dimensions if available
if width:
response["width"] = int(width)
if height:
response["height"] = int(height)
responses.append(response)
else:
# Process web search results
# Check metatags are present
if "pagemap" not in search_item:
continue
if "metatags" not in search_item["pagemap"]:
continue
if "og:description" in search_item["pagemap"]["metatags"][0]:
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
else:
long_description = "N/A"
# Get the page title
title = search_item.get("title")
# Page snippet
snippet = search_item.get("snippet")
# Extract the page url
link = search_item.get("link")
response = {
"result_id": i,
"title": title,
"description": snippet,
"long_description": long_description,
"url": link,
}
responses.append(response)
logger.info("Google search completed", extra={"query": query, "search_type": search_type, "result_count": len(responses)})
else:
error_info = data.get("error", {})
logger.error("Google search API error", extra={"query": query, "api_error": error_info})
raise HTTPException(status_code=500, detail="Internal server error")
except Exception as e:
logger.error("Google search failed", extra={"query": query, "search_type": search_type, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
return responses

View file

@ -1,139 +1,181 @@
import os
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from sqlmodel import Session, select
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate, Status
from app.model.mcp.mcp import Mcp
from loguru import logger
from camel.toolkits.mcp_toolkit import MCPToolkit
from app.component.environment import env
router = APIRouter(tags=["McpUser Management"])
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
"""
Pre-instantiate MCP toolkit to complete authentication process
Args:
config_dict: MCP server configuration dictionary
Returns:
bool: Whether successfully instantiated and connected
"""
try:
# Ensure unified auth directory for all mcp servers
for server_config in config_dict.get("mcpServers", {}).values():
if "env" not in server_config:
server_config["env"] = {}
# Set global auth directory to persist authentication across tasks
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
"MCP_REMOTE_CONFIG_DIR",
os.path.expanduser("~/.mcp-auth")
)
# Create MCP toolkit and attempt to connect
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
await mcp_toolkit.connect()
# Get tools list to ensure connection is successful
tools = mcp_toolkit.get_tools()
logger.info(f"Successfully pre-instantiated MCP toolkit with {len(tools)} tools")
# Disconnect, authentication info is already saved
await mcp_toolkit.disconnect()
return True
except Exception as e:
logger.warning(f"Failed to pre-instantiate MCP toolkit: {e!r}")
return False
@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut])
async def list_mcp_users(
mcp_id: Optional[int] = None,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
user_id = auth.user.id
query = select(McpUser)
if mcp_id is not None:
query = query.where(McpUser.mcp_id == mcp_id)
if user_id is not None:
query = query.where(McpUser.user_id == user_id)
mcp_users = session.exec(query).all()
return mcp_users
@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut)
async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
query = select(McpUser).where(McpUser.id == mcp_user_id)
mcp_user = session.exec(query).first()
if not mcp_user:
raise HTTPException(status_code=404, detail=_("McpUser not found"))
return mcp_user
@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut)
async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
exists = session.exec(
select(McpUser).where(McpUser.mcp_id == mcp_user.mcp_id, McpUser.user_id == auth.user.id)
).first()
if exists:
raise HTTPException(status_code=400, detail=_("mcp is installed"))
# Get MCP configuration from the main Mcp table
mcp = session.get(Mcp, mcp_user.mcp_id)
if mcp and mcp.install_command:
# Pre-instantiate MCP toolkit for authentication
config_dict = {
"mcpServers": {
mcp.key: mcp.install_command
}
}
try:
success = await pre_instantiate_mcp_toolkit(config_dict)
if not success:
logger.warning(f"Pre-instantiation failed for MCP {mcp.key}, but continuing with user creation")
except Exception as e:
logger.warning(f"Exception during pre-instantiation for MCP {mcp.key}: {e}")
db_mcp_user = McpUser(mcp_id=mcp_user.mcp_id, user_id=auth.user.id, env=mcp_user.env)
session.add(db_mcp_user)
session.commit()
session.refresh(db_mcp_user)
return db_mcp_user
@router.put("/mcp/users/{id}", name="update mcp user")
async def update_mcp_user(
id: int,
update_item: McpUserUpdate,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
model = session.get(McpUser, id)
if not model:
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
if model.user_id != auth.user.id:
raise HTTPException(status_code=400, detail=_("current user have no permission to modify"))
update_data = update_item.model_dump(exclude_unset=True)
model.update_fields(update_data)
model.save(session)
session.refresh(model)
return model
@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user")
async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
db_mcp_user = session.get(McpUser, mcp_user_id)
if not db_mcp_user:
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
session.delete(db_mcp_user)
session.commit()
return Response(status_code=204)
import os
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from sqlmodel import Session, select
from app.component.database import session
from app.component.auth import Auth, auth_must
from fastapi_babel import _
from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate, Status
from app.model.mcp.mcp import Mcp
from camel.toolkits.mcp_toolkit import MCPToolkit
from app.component.environment import env
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_mcp_user_controller")
router = APIRouter(tags=["McpUser Management"])
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
"""
Pre-instantiate MCP toolkit to complete authentication process
Args:
config_dict: MCP server configuration dictionary
Returns:
bool: Whether successfully instantiated and connected
"""
try:
# Ensure unified auth directory for all mcp servers
for server_config in config_dict.get("mcpServers", {}).values():
if "env" not in server_config:
server_config["env"] = {}
# Set global auth directory to persist authentication across tasks
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
"MCP_REMOTE_CONFIG_DIR",
os.path.expanduser("~/.mcp-auth")
)
# Create MCP toolkit and attempt to connect
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
await mcp_toolkit.connect()
# Get tools list to ensure connection is successful
tools = mcp_toolkit.get_tools()
logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)})
# Disconnect, authentication info is already saved
await mcp_toolkit.disconnect()
return True
except Exception as e:
logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True)
return False
@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut])
@traceroot.trace()
async def list_mcp_users(
mcp_id: Optional[int] = None,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
"""List MCP users for current user."""
user_id = auth.user.id
query = select(McpUser)
if mcp_id is not None:
query = query.where(McpUser.mcp_id == mcp_id)
if user_id is not None:
query = query.where(McpUser.user_id == user_id)
mcp_users = session.exec(query).all()
logger.debug("MCP users listed", extra={"user_id": user_id, "mcp_id": mcp_id, "count": len(mcp_users)})
return mcp_users
@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut)
@traceroot.trace()
async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Get MCP user details."""
query = select(McpUser).where(McpUser.id == mcp_user_id)
mcp_user = session.exec(query).first()
if not mcp_user:
logger.warning("MCP user not found", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id})
raise HTTPException(status_code=404, detail=_("McpUser not found"))
logger.debug("MCP user retrieved", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id, "mcp_id": mcp_user.mcp_id})
return mcp_user
@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut)
@traceroot.trace()
async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Create MCP user installation."""
user_id = auth.user.id
mcp_id = mcp_user.mcp_id
exists = session.exec(
select(McpUser).where(McpUser.mcp_id == mcp_id, McpUser.user_id == user_id)
).first()
if exists:
logger.warning("MCP already installed", extra={"user_id": user_id, "mcp_id": mcp_id})
raise HTTPException(status_code=400, detail=_("mcp is installed"))
# Get MCP configuration from the main Mcp table
mcp = session.get(Mcp, mcp_id)
if mcp and mcp.install_command:
config_dict = {
"mcpServers": {
mcp.key: mcp.install_command
}
}
try:
success = await pre_instantiate_mcp_toolkit(config_dict)
if not success:
logger.warning("MCP pre-instantiation failed, continuing", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
except Exception as e:
logger.warning("MCP pre-instantiation exception", extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, exc_info=True)
try:
db_mcp_user = McpUser(mcp_id=mcp_id, user_id=user_id, env=mcp_user.env)
session.add(db_mcp_user)
session.commit()
session.refresh(db_mcp_user)
logger.info("MCP user created", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_user_id": db_mcp_user.id})
return db_mcp_user
except Exception as e:
session.rollback()
logger.error("MCP user creation failed", extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/mcp/users/{id}", name="update mcp user")
@traceroot.trace()
async def update_mcp_user(
id: int,
update_item: McpUserUpdate,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
"""Update MCP user settings."""
user_id = auth.user.id
model = session.get(McpUser, id)
if not model:
logger.warning("MCP user not found for update", extra={"user_id": user_id, "mcp_user_id": id})
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
if model.user_id != user_id:
logger.warning("Unauthorized MCP user update", extra={"user_id": user_id, "mcp_user_id": id, "owner_id": model.user_id})
raise HTTPException(status_code=400, detail=_("current user have no permission to modify"))
try:
update_data = update_item.model_dump(exclude_unset=True)
model.update_fields(update_data)
model.save(session)
session.refresh(model)
logger.info("MCP user updated", extra={"user_id": user_id, "mcp_user_id": id})
return model
except Exception as e:
logger.error("MCP user update failed", extra={"user_id": user_id, "mcp_user_id": id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user")
@traceroot.trace()
async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete MCP user installation."""
user_id = auth.user.id
db_mcp_user = session.get(McpUser, mcp_user_id)
if not db_mcp_user:
logger.warning("MCP user not found for deletion", extra={"user_id": user_id, "mcp_user_id": mcp_user_id})
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
try:
session.delete(db_mcp_user)
session.commit()
logger.info("MCP user deleted", extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "mcp_id": db_mcp_user.mcp_id})
return Response(status_code=204)
except Exception as e:
session.rollback()
logger.error("MCP user deletion failed", extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,58 +1,81 @@
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
from app.component.environment import env
from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter
from typing import Optional
router = APIRouter(prefix="/oauth", tags=["Oauth Servers"])
@router.get("/{app}/login", name="OAuth Login Redirect")
def oauth_login(app: str, request: Request, state: Optional[str] = None):
try:
callback_url = str(request.url_for("OAuth Callback", app=app))
if callback_url.startswith("http://"):
callback_url = "https://" + callback_url[len("http://") :]
adapter = get_oauth_adapter(app, callback_url)
url = adapter.get_authorize_url(state)
if not url:
raise HTTPException(status_code=400, detail="Failed to generate authorization URL")
return RedirectResponse(str(url))
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/{app}/callback", name="OAuth Callback")
def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None):
if not code:
raise HTTPException(status_code=400, detail="Missing code parameter")
redirect_url = f"eigent://callback/oauth?provider={app}&code={code}&state={state}"
html_content = f"""
<html>
<head>
<title>OAuth Callback</title>
</head>
<body>
<script type='text/javascript'>
window.location.href = '{redirect_url}';
</script>
<p>Redirecting, please wait...</p>
<button onclick='window.close()'>Close this window</button>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@router.post("/{app}/token", name="OAuth Fetch Token")
def fetch_token(app: str, request: Request, data: OauthCallbackPayload):
try:
callback_url = str(request.url_for("OAuth Callback", app=app))
if callback_url.startswith("http://"):
callback_url = "https://" + callback_url[len("http://") :]
adapter = get_oauth_adapter(app, callback_url)
token_data = adapter.fetch_token(data.code)
return JSONResponse(token_data)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
from app.component.environment import env
from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter
from typing import Optional
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_oauth_controller")
router = APIRouter(prefix="/oauth", tags=["Oauth Servers"])
@router.get("/{app}/login", name="OAuth Login Redirect")
@traceroot.trace()
def oauth_login(app: str, request: Request, state: Optional[str] = None):
"""Redirect user to OAuth provider's authorization endpoint."""
try:
callback_url = str(request.url_for("OAuth Callback", app=app))
if callback_url.startswith("http://"):
callback_url = "https://" + callback_url[len("http://") :]
adapter = get_oauth_adapter(app, callback_url)
url = adapter.get_authorize_url(state)
if not url:
logger.error("Failed to generate authorization URL", extra={"provider": app, "callback_url": callback_url})
raise HTTPException(status_code=400, detail="Failed to generate authorization URL")
logger.info("OAuth login initiated", extra={"provider": app})
return RedirectResponse(str(url))
except HTTPException:
raise
except Exception as e:
logger.error("OAuth login failed", extra={"provider": app, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=400, detail="OAuth login failed")
@router.get("/{app}/callback", name="OAuth Callback")
@traceroot.trace()
def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None):
"""Handle OAuth provider callback and redirect to client app."""
if not code:
logger.warning("OAuth callback missing code", extra={"provider": app})
raise HTTPException(status_code=400, detail="Missing code parameter")
logger.info("OAuth callback received", extra={"provider": app, "has_state": state is not None})
redirect_url = f"eigent://callback/oauth?provider={app}&code={code}&state={state}"
html_content = f"""
<html>
<head>
<title>OAuth Callback</title>
</head>
<body>
<script type='text/javascript'>
window.location.href = '{redirect_url}';
</script>
<p>Redirecting, please wait...</p>
<button onclick='window.close()'>Close this window</button>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@router.post("/{app}/token", name="OAuth Fetch Token")
@traceroot.trace()
def fetch_token(app: str, request: Request, data: OauthCallbackPayload):
"""Exchange authorization code for access token."""
try:
callback_url = str(request.url_for("OAuth Callback", app=app))
if callback_url.startswith("http://"):
callback_url = "https://" + callback_url[len("http://") :]
adapter = get_oauth_adapter(app, callback_url)
token_data = adapter.fetch_token(data.code)
logger.info("OAuth token fetched", extra={"provider": app})
return JSONResponse(token_data)
except Exception as e:
logger.error("OAuth token fetch failed", extra={"provider": app, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,100 +1,140 @@
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from fastapi_babel import _
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy import update
from sqlmodel import Session, select, col
from sqlalchemy.exc import SQLAlchemyError
from app.component.database import session
from app.component.auth import Auth, auth_must
from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn
router = APIRouter(tags=["Provider Management"])
@router.get("/providers", name="list providers", response_model=Page[ProviderOut])
async def gets(
keyword: str | None = None,
prefer: Optional[bool] = Query(None, description="Filter by prefer status"),
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
) -> Page[ProviderOut]:
user_id = auth.user.id
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete())
if keyword:
stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%"))
if prefer is not None:
stmt = stmt.where(Provider.prefer == prefer)
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc()) # Added for consistent pagination
return paginate(session, stmt)
@router.get("/provider", name="get provider detail", response_model=ProviderOut)
async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
model = session.exec(stmt).one_or_none()
if not model:
raise HTTPException(status_code=404, detail=_("Provider not found"))
return model
@router.post("/provider", name="create provider", response_model=ProviderOut)
async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
model = Provider(**data.model_dump(), user_id=user_id)
model.save(session)
return model
@router.put("/provider/{id}", name="update provider", response_model=ProviderOut)
async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
model = session.exec(
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
).one_or_none()
if not model:
raise HTTPException(status_code=404, detail=_("Provider not found"))
model.model_type = data.model_type
model.provider_name = data.provider_name
model.api_key = data.api_key
model.endpoint_url = data.endpoint_url
model.encrypted_config = data.encrypted_config
model.is_vaild = data.is_vaild
model.save(session)
session.refresh(model)
return model
@router.delete("/provider/{id}", name="delete provider")
async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
model = session.exec(
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
).one_or_none()
if not model:
raise HTTPException(status_code=404, detail=_("Provider not found"))
model.delete(session)
return Response(status_code=204)
@router.post("/provider/prefer", name="set provider prefer")
async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
try:
# 1. current user's all provider prefer set to false
session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False))
# 2. set the prefer of the specified provider_id to true
session.exec(
update(Provider)
.where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == data.provider_id)
.values(prefer=True)
)
session.commit()
return {"success": True}
except SQLAlchemyError as e:
session.rollback()
raise HTTPException(status_code=500, detail=str(e))
from typing import List, Optional
from fastapi import Depends, HTTPException, Query, Response, APIRouter
from fastapi_babel import _
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlalchemy import update
from sqlmodel import Session, select, col
from sqlalchemy.exc import SQLAlchemyError
from app.component.database import session
from app.component.auth import Auth, auth_must
from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_provider_controller")
router = APIRouter(tags=["Provider Management"])
@router.get("/providers", name="list providers", response_model=Page[ProviderOut])
@traceroot.trace()
async def gets(
keyword: str | None = None,
prefer: Optional[bool] = Query(None, description="Filter by prefer status"),
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
) -> Page[ProviderOut]:
"""List user's providers with optional filtering."""
user_id = auth.user.id
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete())
if keyword:
stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%"))
if prefer is not None:
stmt = stmt.where(Provider.prefer == prefer)
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc())
logger.debug("Providers listed", extra={"user_id": user_id, "keyword": keyword, "prefer_filter": prefer})
return paginate(session, stmt)
@router.get("/provider", name="get provider detail", response_model=ProviderOut)
@traceroot.trace()
async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Get provider details."""
user_id = auth.user.id
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
model = session.exec(stmt).one_or_none()
if not model:
logger.warning("Provider not found", extra={"user_id": user_id, "provider_id": id})
raise HTTPException(status_code=404, detail=_("Provider not found"))
logger.debug("Provider retrieved", extra={"user_id": user_id, "provider_id": id})
return model
@router.post("/provider", name="create provider", response_model=ProviderOut)
@traceroot.trace()
async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Create a new provider."""
user_id = auth.user.id
try:
model = Provider(**data.model_dump(), user_id=user_id)
model.save(session)
logger.info("Provider created", extra={"user_id": user_id, "provider_id": model.id, "provider_name": data.provider_name})
return model
except Exception as e:
logger.error("Provider creation failed", extra={"user_id": user_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/provider/{id}", name="update provider", response_model=ProviderOut)
@traceroot.trace()
async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Update provider details."""
user_id = auth.user.id
model = session.exec(
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
).one_or_none()
if not model:
logger.warning("Provider not found for update", extra={"user_id": user_id, "provider_id": id})
raise HTTPException(status_code=404, detail=_("Provider not found"))
try:
model.model_type = data.model_type
model.provider_name = data.provider_name
model.api_key = data.api_key
model.endpoint_url = data.endpoint_url
model.encrypted_config = data.encrypted_config
model.is_vaild = data.is_vaild
model.save(session)
session.refresh(model)
logger.info("Provider updated", extra={"user_id": user_id, "provider_id": id, "provider_name": data.provider_name})
return model
except Exception as e:
logger.error("Provider update failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/provider/{id}", name="delete provider")
@traceroot.trace()
async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete a provider."""
user_id = auth.user.id
model = session.exec(
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
).one_or_none()
if not model:
logger.warning("Provider not found for deletion", extra={"user_id": user_id, "provider_id": id})
raise HTTPException(status_code=404, detail=_("Provider not found"))
try:
model.delete(session)
logger.info("Provider deleted", extra={"user_id": user_id, "provider_id": id})
return Response(status_code=204)
except Exception as e:
logger.error("Provider deletion failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/provider/prefer", name="set provider prefer")
@traceroot.trace()
async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Set preferred provider for user."""
user_id = auth.user.id
provider_id = data.provider_id
try:
# 1. Set all current user's providers prefer to false
session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False))
# 2. Set the prefer of the specified provider_id to true
session.exec(
update(Provider)
.where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
.values(prefer=True)
)
session.commit()
logger.info("Preferred provider set", extra={"user_id": user_id, "provider_id": provider_id})
return {"success": True}
except SQLAlchemyError as e:
session.rollback()
logger.error("Failed to set preferred provider", extra={"user_id": user_id, "provider_id": provider_id, "error": str(e)}, exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,90 +1,114 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi_babel import _
from sqlmodel import Session
from app.component import code
from app.component.auth import Auth
from app.component.database import session
from app.component.encrypt import password_verify
from app.component.stack_auth import StackAuth
from app.exception.exception import UserException
from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User, RegisterIn
from loguru import logger
from app.component.environment import env
router = APIRouter(tags=["Login/Registration"])
@router.post("/login", name="login by email or password")
async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse:
"""
User login with email and password
"""
user = User.by(User.email == data.email, s=session).one_or_none()
if not user or not password_verify(data.password, user.password):
raise UserException(code.password, _("Account or password error"))
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
@router.post("/login-by_stack", name="login by stack")
async def by_stack_auth(
token: str,
type: str = "signup",
invite_code: str | None = None,
session: Session = Depends(session),
):
try:
stack_id = await StackAuth.user_id(token)
info = await StackAuth.user_info(token)
except Exception as e:
logger.error(e)
raise HTTPException(500, detail=_(f"{e}"))
user = User.by(User.stack_id == stack_id, s=session).one_or_none()
if not user:
# Only signup can create user
if type != "signup":
raise UserException(code.error, _("User not found"))
with session as s:
try:
user = User(
username=info["username"] if "username" in info else None,
nickname=info["display_name"],
email=info["primary_email"],
avatar=info["profile_image_url"],
stack_id=stack_id,
)
s.add(user)
s.commit()
session.refresh(user)
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
except Exception as e:
s.rollback()
logger.error(f"Failed to register: {e}")
raise UserException(code.error, _("Failed to register"))
else:
if user.status == Status.Block:
raise UserException(code.error, _("Your account has been blocked."))
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
@router.post("/register", name="register by email/password")
async def register(data: RegisterIn, session: Session = Depends(session)):
# Check if email is already registered
if User.by(User.email == data.email, s=session).one_or_none():
raise UserException(code.error, _("Email already registered"))
with session as s:
try:
user = User(
email=data.email,
password=data.password,
)
s.add(user)
s.commit()
s.refresh(user)
except Exception as e:
s.rollback()
logger.error(f"Failed to register: {e}")
raise UserException(code.error, _("Failed to register"))
return {"status": "success"}
from fastapi import APIRouter, Depends, HTTPException
from fastapi_babel import _
from sqlmodel import Session
from app.component import code
from app.component.auth import Auth
from app.component.database import session
from app.component.encrypt import password_verify
from app.component.stack_auth import StackAuth
from app.exception.exception import UserException
from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User, RegisterIn
from app.component.environment import env
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_login_controller")
router = APIRouter(tags=["Login/Registration"])
@router.post("/login", name="login by email or password")
@traceroot.trace()
async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse:
"""
User login with email and password
"""
email = data.email
user = User.by(User.email == email, s=session).one_or_none()
if not user:
logger.warning("Login failed: user not found", extra={"email": email})
raise UserException(code.password, _("Account or password error"))
if not password_verify(data.password, user.password):
logger.warning("Login failed: invalid password", extra={"user_id": user.id, "email": email})
raise UserException(code.password, _("Account or password error"))
logger.info("User login successful", extra={"user_id": user.id, "email": email})
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
@router.post("/login-by_stack", name="login by stack")
@traceroot.trace()
async def by_stack_auth(
token: str,
type: str = "signup",
invite_code: str | None = None,
session: Session = Depends(session),
):
try:
stack_id = await StackAuth.user_id(token)
info = await StackAuth.user_info(token)
except Exception as e:
logger.error("Stack auth failed", extra={"type": type, "error": str(e)}, exc_info=True)
raise HTTPException(500, detail=_("Authentication failed"))
user = User.by(User.stack_id == stack_id, s=session).one_or_none()
if not user:
if type != "signup":
logger.warning("Stack auth signup blocked: user not found", extra={"stack_id": stack_id, "type": type})
raise UserException(code.error, _("User not found"))
with session as s:
try:
user = User(
username=info["username"] if "username" in info else None,
nickname=info["display_name"],
email=info["primary_email"],
avatar=info["profile_image_url"],
stack_id=stack_id,
)
s.add(user)
s.commit()
s.refresh(user)
logger.info("New user registered via stack", extra={"user_id": user.id, "email": user.email, "stack_id": stack_id})
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
except Exception as e:
s.rollback()
logger.error("Stack auth registration failed", extra={"stack_id": stack_id, "error": str(e)}, exc_info=True)
raise UserException(code.error, _("Failed to register"))
else:
if user.status == Status.Block:
logger.warning("Blocked user login attempt", extra={"user_id": user.id, "stack_id": stack_id})
raise UserException(code.error, _("Your account has been blocked."))
logger.info("User login via stack successful", extra={"user_id": user.id, "email": user.email, "stack_id": stack_id})
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
@router.post("/register", name="register by email/password")
@traceroot.trace()
async def register(data: RegisterIn, session: Session = Depends(session)):
email = data.email
if User.by(User.email == email, s=session).one_or_none():
logger.warning("Registration failed: email already exists", extra={"email": email})
raise UserException(code.error, _("Email already registered"))
with session as s:
try:
user = User(
email=email,
password=data.password,
)
s.add(user)
s.commit()
s.refresh(user)
logger.info("User registered successfully", extra={"user_id": user.id, "email": email})
except Exception as e:
s.rollback()
logger.error("User registration failed", extra={"email": email, "error": str(e)}, exc_info=True)
raise UserException(code.error, _("Failed to register"))
return {"status": "success"}

View file

@ -1,115 +1,151 @@
from fastapi import APIRouter, Depends
from sqlalchemy import func
from sqlmodel import Session, select
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.model.user.privacy import UserPrivacy, UserPrivacySettings
from app.model.user.user import User, UserIn, UserOut, UserProfile
from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut
from app.model.chat.chat_history import ChatHistory
from app.model.mcp.mcp_user import McpUser
from app.model.config.config import Config
from app.model.chat.chat_snpshot import ChatSnapshot
from app.model.user.user_credits_record import UserCreditsRecord
router = APIRouter(tags=["User"])
@router.get("/user", name="user info", response_model=UserOut)
def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
# 获取用户信息时触发积分刷新
user: User = auth.user
user.refresh_credits_on_active(session)
return user
@router.put("/user", name="update user info", response_model=UserOut)
def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
model = auth.user
model.username = data.username
model.save(session)
return model
@router.put("/user/profile", name="update user profile", response_model=UserProfile)
def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
model = auth.user
model.nickname = data.nickname
model.fullname = data.fullname
model.work_desc = data.work_desc
model.save(session)
return model
@router.get("/user/privacy", name="get user privacy")
def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
model = session.exec(stmt).one_or_none()
if not model:
return UserPrivacySettings.default_settings()
return model.pricacy_setting
@router.put("/user/privacy", name="update user privacy")
def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
user_id = auth.user.id
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
model = session.exec(stmt).one_or_none()
default_settings = UserPrivacySettings.default_settings()
if model:
model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()}
model.save(session)
else:
model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()})
model.save(session)
return model.pricacy_setting
@router.get("/user/current_credits", name="get user current credits")
def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
user = auth.user
user.refresh_credits_on_active(session)
credits = user.credits
daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id)
current_daily_credits = 0
if daily_credits:
current_daily_credits = daily_credits.amount - daily_credits.balance
credits += current_daily_credits if current_daily_credits > 0 else 0
return {"credits": credits, "daily_credits": current_daily_credits}
@router.get("/user/stat", name="get user stat", response_model=UserStatOut)
def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
"""Get current user's operation statistics."""
stat = session.exec(select(UserStat).where(UserStat.user_id == auth.user.id)).first()
data = UserStatOut()
if stat:
data = UserStatOut(**stat.model_dump())
else:
data = UserStatOut(user_id=auth.user.id)
data.task_queries = ChatHistory.count(ChatHistory.user_id == auth.user.id, s=session)
mcp = McpUser.count(McpUser.user_id == auth.user.id, s=session)
tool: list = session.exec(
select(func.count("*")).where(Config.user_id == auth.user.id).group_by(Config.config_group)
).all()
tool = tool.__len__()
data.mcp_install_count = mcp + tool
data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(auth.user.id))
return data
@router.post("/user/stat", name="record user stat")
def record_user_stat(
data: UserStatActionIn,
auth: Auth = Depends(auth_must),
session: Session = Depends(session),
):
"""Record or update current user's operation statistics."""
data.user_id = auth.user.id
stat = UserStat.record_action(session, data)
return stat
from fastapi import APIRouter, Depends
from sqlalchemy import func
from sqlmodel import Session, select
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.model.user.privacy import UserPrivacy, UserPrivacySettings
from app.model.user.user import User, UserIn, UserOut, UserProfile
from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut
from app.model.chat.chat_history import ChatHistory
from app.model.mcp.mcp_user import McpUser
from app.model.config.config import Config
from app.model.chat.chat_snpshot import ChatSnapshot
from app.model.user.user_credits_record import UserCreditsRecord
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_user_controller")
router = APIRouter(tags=["User"])
@router.get("/user", name="user info", response_model=UserOut)
@traceroot.trace()
def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
"""Get current user information and refresh credits."""
user: User = auth.user
user.refresh_credits_on_active(session)
logger.debug("User info retrieved", extra={"user_id": user.id})
return user
@router.put("/user", name="update user info", response_model=UserOut)
@traceroot.trace()
def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Update user basic information."""
model = auth.user
model.username = data.username
model.save(session)
logger.info("User info updated", extra={"user_id": model.id, "username": data.username})
return model
@router.put("/user/profile", name="update user profile", response_model=UserProfile)
@traceroot.trace()
def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Update user profile details."""
model = auth.user
model.nickname = data.nickname
model.fullname = data.fullname
model.work_desc = data.work_desc
model.save(session)
logger.info("User profile updated", extra={"user_id": model.id, "nickname": data.nickname})
return model
@router.get("/user/privacy", name="get user privacy")
@traceroot.trace()
def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Get user privacy settings."""
user_id = auth.user.id
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
model = session.exec(stmt).one_or_none()
if not model:
logger.debug("Privacy settings not found, returning defaults", extra={"user_id": user_id})
return UserPrivacySettings.default_settings()
logger.debug("Privacy settings retrieved", extra={"user_id": user_id})
return model.pricacy_setting
@router.put("/user/privacy", name="update user privacy")
@traceroot.trace()
def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Update user privacy settings."""
user_id = auth.user.id
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
model = session.exec(stmt).one_or_none()
default_settings = UserPrivacySettings.default_settings()
if model:
model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()}
model.save(session)
logger.info("Privacy settings updated", extra={"user_id": user_id})
else:
model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()})
model.save(session)
logger.info("Privacy settings created", extra={"user_id": user_id})
return model.pricacy_setting
@router.get("/user/current_credits", name="get user current credits")
@traceroot.trace()
def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
"""Get user's current credit balance."""
user = auth.user
user.refresh_credits_on_active(session)
credits = user.credits
daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id)
current_daily_credits = 0
if daily_credits:
current_daily_credits = daily_credits.amount - daily_credits.balance
credits += current_daily_credits if current_daily_credits > 0 else 0
logger.debug("Credits retrieved", extra={"user_id": user.id, "total_credits": credits, "daily_credits": current_daily_credits})
return {"credits": credits, "daily_credits": current_daily_credits}
@router.get("/user/stat", name="get user stat", response_model=UserStatOut)
@traceroot.trace()
def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
"""Get current user's operation statistics."""
user_id = auth.user.id
stat = session.exec(select(UserStat).where(UserStat.user_id == user_id)).first()
data = UserStatOut()
if stat:
data = UserStatOut(**stat.model_dump())
else:
data = UserStatOut(user_id=user_id)
data.task_queries = ChatHistory.count(ChatHistory.user_id == user_id, s=session)
mcp = McpUser.count(McpUser.user_id == user_id, s=session)
tool: list = session.exec(
select(func.count("*")).where(Config.user_id == user_id).group_by(Config.config_group)
).all()
tool = tool.__len__()
data.mcp_install_count = mcp + tool
data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(user_id))
logger.debug("User stats retrieved", extra={
"user_id": user_id,
"task_queries": data.task_queries,
"mcp_install_count": data.mcp_install_count,
"storage_used": data.storage_used
})
return data
@router.post("/user/stat", name="record user stat")
@traceroot.trace()
def record_user_stat(
data: UserStatActionIn,
auth: Auth = Depends(auth_must),
session: Session = Depends(session),
):
"""Record or update current user's operation statistics."""
data.user_id = auth.user.id
stat = UserStat.record_action(session, data)
logger.info("User stat recorded", extra={"user_id": data.user_id, "action": data.action if hasattr(data, 'action') else "unknown"})
return stat

View file

@ -1,24 +1,36 @@
from fastapi import APIRouter, Depends
from sqlmodel import Session
from app.component import code
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.component.encrypt import password_hash, password_verify
from app.exception.exception import UserException
from app.model.user.user import UpdatePassword, UserOut
from fastapi_babel import _
router = APIRouter(tags=["User"])
@router.put("/user/update-password", name="update password", response_model=UserOut)
def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)):
model = auth.user
if not password_verify(data.password, model.password):
raise UserException(code.error, _("Password is incorrect"))
if data.new_password != data.re_new_password:
raise UserException(code.error, _("The two passwords do not match"))
model.password = password_hash(data.new_password)
model.save(session)
return model
from fastapi import APIRouter, Depends
from sqlmodel import Session
from app.component import code
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.component.encrypt import password_hash, password_verify
from app.exception.exception import UserException
from app.model.user.user import UpdatePassword, UserOut
from fastapi_babel import _
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("server_password_controller")
router = APIRouter(tags=["User"])
@router.put("/user/update-password", name="update password", response_model=UserOut)
@traceroot.trace()
def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)):
"""Update user password after verifying current password."""
user_id = auth.user.id
model = auth.user
if not password_verify(data.password, model.password):
logger.warning("Password update failed: incorrect current password", extra={"user_id": user_id})
raise UserException(code.error, _("Password is incorrect"))
if data.new_password != data.re_new_password:
logger.warning("Password update failed: new passwords do not match", extra={"user_id": user_id})
raise UserException(code.error, _("The two passwords do not match"))
model.password = password_hash(data.new_password)
model.save(session)
logger.info("Password updated successfully", extra={"user_id": user_id})
return model

View file

@ -2,9 +2,10 @@ from sqlalchemy import Float, Integer
from sqlmodel import Field, SmallInteger, Column, JSON, String
from typing import Optional
from enum import IntEnum
from datetime import datetime
from sqlalchemy_utils import ChoiceType
from app.model.abstract.model import AbstractModel, DefaultTimes
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class ChatStatus(IntEnum):
@ -13,9 +14,20 @@ class ChatStatus(IntEnum):
class ChatHistory(AbstractModel, DefaultTimes, table=True):
"""
Chat history model with timestamp tracking.
Inherits from DefaultTimes which provides:
- created_at: timestamp when record is created (auto-populated)
- updated_at: timestamp when record is last modified (auto-updated)
- deleted_at: timestamp for soft deletion (nullable)
For legacy records without timestamps, sorting falls back to id ordering.
"""
id: int = Field(default=None, primary_key=True)
user_id: int = Field(index=True)
task_id: str = Field(index=True, unique=True)
project_id: str = Field(index=True, unique=False, nullable=True)
question: str
language: str
model_platform: str
@ -34,6 +46,7 @@ class ChatHistory(AbstractModel, DefaultTimes, table=True):
class ChatHistoryIn(BaseModel):
task_id: str
project_id: str | None = None
user_id: int | None = None
question: str
language: str
@ -54,6 +67,7 @@ class ChatHistoryIn(BaseModel):
class ChatHistoryOut(BaseModel):
id: int
task_id: str
project_id: str | None = None
question: str
language: str
model_platform: str
@ -67,6 +81,22 @@ class ChatHistoryOut(BaseModel):
summary: str | None = None
tokens: int
status: int
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@model_validator(mode="after")
def fill_project_id_from_task_id(self):
"""Fill project_id from task_id when project_id is None"""
if self.project_id is None:
self.project_id = self.task_id
return self
@model_validator(mode="after")
def handle_legacy_timestamps(self):
"""Handle legacy records that might not have timestamp fields"""
# For old records without timestamps, we rely on database-level defaults
# The sorting in the controller will handle ordering appropriately
return self
class ChatHistoryUpdate(BaseModel):
@ -74,3 +104,4 @@ class ChatHistoryUpdate(BaseModel):
summary: str | None = None
tokens: int | None = None
status: int | None = None
project_id: str | None = None

View file

@ -124,10 +124,10 @@ class ConfigInfo:
"env_vars": [],
"toolkit": "google_drive_mcp_toolkit",
},
# ConfigGroup.GOOGLE_GMAIL_MCP.value: {
# "env_vars": [],
# "toolkit": "google_gmail_mcp_toolkit",
# },
ConfigGroup.GOOGLE_GMAIL_MCP.value: {
"env_vars": ["GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET", "GOOGLE_REFRESH_TOKEN"],
"toolkit": "google_gmail_native_toolkit",
},
ConfigGroup.IMAGE_ANALYSIS.value: {
"env_vars": [],
"toolkit": "image_analysis_toolkit",

View file

@ -1,381 +1,383 @@
from enum import IntEnum
from typing import Optional
from pydantic import BaseModel
from sqlmodel import Relationship, SQLModel, Field, Column, col, select, Session
from sqlalchemy_utils import ChoiceType
from sqlalchemy import Boolean, SmallInteger, text
from app.model.abstract.model import AbstractModel, DefaultTimes
from datetime import date, datetime, timedelta
from app.model.user.key import ModelType
from app.component.database import session_make
from loguru import logger
class CreditsChannel(IntEnum):
register = 1 # 注册赠送
invite = 2 # 邀请赠送
daily = 3 # 每日刷新积分
monthly = 4 # 每月刷新积分
paid = 5 # 付费积分
addon = 6 # 加量包
consume = 7 # 任务消费
class CreditsPriority(IntEnum):
daily = 1 # 每日刷新积分
monthly = 2 # 每月刷新积分
paid = 3 # 付费积分
addon = 4 # 加量包
class CreditsPoint(IntEnum):
register = 1000
invite = 500
special_register = 1500 # 1000 register + 500 invite credit
class UserCreditsRecord(AbstractModel, DefaultTimes, table=True):
id: int = Field(default=None, primary_key=True)
user_id: int = Field(index=True, foreign_key="user.id")
invite_by: int = Field(default=None, nullable=True, description="invite by user id")
invite_code: str = Field(default="", max_length=255)
amount: int = Field(default=0)
balance: int = Field(default=0)
channel: CreditsChannel = Field(
default=CreditsChannel.register.value, sa_column=Column(ChoiceType(CreditsChannel, SmallInteger()))
)
source_id: int = Field(default=0, description="source id")
remark: str = Field(default="", max_length=255)
expire_at: datetime = Field(default=None, nullable=True, description="Expiration time")
used: bool = Field(
default=False,
sa_column=Column(Boolean, server_default=text("false")),
description="Is this record used/expired",
)
used_at: datetime = Field(default=None, nullable=True, description="Time when this record was used/expired")
@classmethod
def get_permanent_credits(cls, user_id: int) -> int:
"""
获取可用的token总量直接用SQL聚合sum
Returns:
int: 可用的token总量
"""
session = session_make()
from sqlalchemy import func
statement = (
select(func.sum(UserCreditsRecord.amount))
.where(UserCreditsRecord.user_id == user_id)
.where(
UserCreditsRecord.channel.in_(
[
CreditsChannel.register,
CreditsChannel.invite,
CreditsChannel.paid,
CreditsChannel.addon,
CreditsChannel.monthly,
]
)
)
.where(UserCreditsRecord.used == False)
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > datetime.now()))
)
result = session.exec(statement).first()
return result or 0
@classmethod
def get_temp_credits(cls, user_id: int) -> tuple[int, date]:
"""
1. 获取可用的临时token总量需要通过credits 然后根据model_type来计算
2. 每天只允许赠送一次临时的量
Returns:
int: 可用的临时token总量
"""
session = session_make()
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
.where(UserCreditsRecord.expire_at.is_not(None))
.where(col(UserCreditsRecord.expire_at) > datetime.now())
)
record: UserCreditsRecord = session.exec(statement).first()
if record is None:
return 0, None
return record.amount - record.balance, record.expire_at
@classmethod
def consume_credits(cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""):
"""
消耗积分优先消耗每日积分daily再消耗monthlypaidaddon等
消耗时更新UserCreditsRecord的balance字段记录已消耗积分数
同时生成积分消耗记录更新用户积分credits字段不包括每日积分
避免重复生成积分消耗记录和重复扣减积分
"""
# 检查是否已有积分消耗记录
existing_consume_record = None
if source_id > 0:
existing_consume_record = session.exec(
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.consume)
.where(UserCreditsRecord.source_id == source_id)
).first()
if existing_consume_record:
# 如果新amount更大需要额外消耗积分
if amount > 0:
existing_consume_record.amount -= amount
session.add(existing_consume_record)
# 直接处理额外的积分消耗,不生成新的消耗记录
cls._consume_credits_internal_update(user_id, amount, session, source_id, remark)
# 如果新amount更小需要退还积分这里可以根据业务需求决定是否实现
else:
# 暂时不实现退还逻辑,可以根据需要添加
pass
session.commit()
return
# 没有现有记录,执行正常的积分消耗流程
cls._consume_credits_internal(user_id, amount, session, source_id, remark)
@classmethod
def _consume_credits_internal(
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
):
"""
内部积分消耗逻辑处理实际的积分扣减
"""
from app.model.user.user import User
remain = amount
now = datetime.now()
consumed_from_daily = 0
consumed_from_other = 0
# 优先消耗daily
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
.where(UserCreditsRecord.expire_at.is_not(None))
.where(col(UserCreditsRecord.expire_at) > now)
.order_by(UserCreditsRecord.expire_at)
)
daily_records = session.exec(statement).first()
if daily_records:
can_consume = daily_records.amount - daily_records.balance
use = min(remain, can_consume)
daily_records.balance += use
session.add(daily_records)
remain -= use
consumed_from_daily = use
if remain == 0:
# 生成积分消耗记录
consume_record = UserCreditsRecord(
user_id=user_id,
amount=-amount,
channel=CreditsChannel.consume,
source_id=source_id,
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily})",
)
session.add(consume_record)
session.commit()
return
# 若daily不够继续消耗monthly/paid/addon
if remain > 0:
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(
UserCreditsRecord.channel.in_(
[
CreditsChannel.monthly,
CreditsChannel.paid,
CreditsChannel.addon,
CreditsChannel.register,
CreditsChannel.invite,
]
)
)
.where(UserCreditsRecord.used == False)
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
.order_by(UserCreditsRecord.expire_at)
)
other_records = session.exec(statement).all()
for record in other_records:
can_consume = record.amount - record.balance
if can_consume <= 0:
continue
use = min(remain, can_consume)
record.balance += use
session.add(record)
remain -= use
consumed_from_other += use
if remain == 0:
break
# 更新用户积分字段(只扣除非每日积分消耗的部分)
if consumed_from_other > 0:
user = session.exec(select(User).where(User.id == user_id)).first()
if user:
user.credits -= consumed_from_other
session.add(user)
# 生成积分消耗记录
consume_record = UserCreditsRecord(
user_id=user_id,
amount=-amount,
channel=CreditsChannel.consume,
source_id=source_id,
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily}, other: {consumed_from_other})",
)
session.add(consume_record)
session.commit()
if remain > 0:
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
@classmethod
def _consume_credits_internal_update(
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
):
"""
内部积分消耗逻辑更新模式处理实际的积分扣减但不生成新的消耗记录
用于更新现有消耗记录时的额外积分消耗
"""
from app.model.user.user import User
remain = amount
now = datetime.now()
consumed_from_daily = 0
consumed_from_other = 0
# 优先消耗daily
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
.where(UserCreditsRecord.expire_at.is_not(None))
.where(col(UserCreditsRecord.expire_at) > now)
.order_by(UserCreditsRecord.expire_at)
)
daily_records = session.exec(statement).first()
if daily_records:
can_consume = daily_records.amount - daily_records.balance
use = min(remain, can_consume)
daily_records.balance += use
session.add(daily_records)
remain -= use
consumed_from_daily = use
if remain == 0:
# 不生成新的消耗记录,只更新现有记录
return
# 若daily不够继续消耗monthly/paid/addon
if remain > 0:
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(
UserCreditsRecord.channel.in_(
[
CreditsChannel.monthly,
CreditsChannel.paid,
CreditsChannel.addon,
CreditsChannel.register,
CreditsChannel.invite,
]
)
)
.where(UserCreditsRecord.used == False)
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
.order_by(UserCreditsRecord.expire_at)
)
other_records = session.exec(statement).all()
for record in other_records:
can_consume = record.amount - record.balance
if can_consume <= 0:
continue
use = min(remain, can_consume)
record.balance += use
session.add(record)
remain -= use
consumed_from_other += use
if remain == 0:
break
logger.info(f"consumed_from_other: {consumed_from_other}")
# 更新用户积分字段(只扣除非每日积分消耗的部分)
if consumed_from_other > 0:
user = session.exec(select(User).where(User.id == user_id)).first()
if user:
user.credits -= consumed_from_other
session.add(user)
# 不生成新的消耗记录,因为现有记录已经在主函数中更新了
if remain > 0:
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
@classmethod
def get_daily_balance_sum(cls, user_id: int) -> int:
"""
获取用户所有每日积分daily channel的balance字段之和
"""
session = session_make()
statement = (
select(UserCreditsRecord.balance)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
)
balances = session.exec(statement).all()
return sum(balances) if balances else 0
@classmethod
def get_daily_balance(cls, user_id: int) -> int:
"""
获取用户当前的每日积分数据
"""
session = session_make()
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
)
record = session.exec(statement).first()
return record
class UserCreditsRecordWithChatOut(BaseModel):
"""扩展的积分记录输出模型,包含聊天历史信息"""
amount: int
balance: int
channel: CreditsChannel
source_id: int
expire_at: Optional[datetime] = None
created_at: datetime
updated_at: Optional[datetime] = None
# 聊天历史相关字段当channel为consume且source_id有效时
chat_project_name: Optional[str] = None
chat_tokens: Optional[int] = None
class UserCreditsRecordOut(BaseModel):
amount: int
balance: int
channel: CreditsChannel
source_id: int
remark: str
expire_at: datetime | None
created_at: datetime
updated_at: datetime | None
from enum import IntEnum
from typing import Optional
from pydantic import BaseModel
from sqlmodel import Relationship, SQLModel, Field, Column, col, select, Session
from sqlalchemy_utils import ChoiceType
from sqlalchemy import Boolean, SmallInteger, text
from app.model.abstract.model import AbstractModel, DefaultTimes
from datetime import date, datetime, timedelta
from app.model.user.key import ModelType
from app.component.database import session_make
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("user_credits_record")
class CreditsChannel(IntEnum):
register = 1 # 注册赠送
invite = 2 # 邀请赠送
daily = 3 # 每日刷新积分
monthly = 4 # 每月刷新积分
paid = 5 # 付费积分
addon = 6 # 加量包
consume = 7 # 任务消费
class CreditsPriority(IntEnum):
daily = 1 # 每日刷新积分
monthly = 2 # 每月刷新积分
paid = 3 # 付费积分
addon = 4 # 加量包
class CreditsPoint(IntEnum):
register = 1000
invite = 500
special_register = 1500 # 1000 register + 500 invite credit
class UserCreditsRecord(AbstractModel, DefaultTimes, table=True):
id: int = Field(default=None, primary_key=True)
user_id: int = Field(index=True, foreign_key="user.id")
invite_by: int = Field(default=None, nullable=True, description="invite by user id")
invite_code: str = Field(default="", max_length=255)
amount: int = Field(default=0)
balance: int = Field(default=0)
channel: CreditsChannel = Field(
default=CreditsChannel.register.value, sa_column=Column(ChoiceType(CreditsChannel, SmallInteger()))
)
source_id: int = Field(default=0, description="source id")
remark: str = Field(default="", max_length=255)
expire_at: datetime = Field(default=None, nullable=True, description="Expiration time")
used: bool = Field(
default=False,
sa_column=Column(Boolean, server_default=text("false")),
description="Is this record used/expired",
)
used_at: datetime = Field(default=None, nullable=True, description="Time when this record was used/expired")
@classmethod
def get_permanent_credits(cls, user_id: int) -> int:
"""
获取可用的token总量直接用SQL聚合sum
Returns:
int: 可用的token总量
"""
session = session_make()
from sqlalchemy import func
statement = (
select(func.sum(UserCreditsRecord.amount))
.where(UserCreditsRecord.user_id == user_id)
.where(
UserCreditsRecord.channel.in_(
[
CreditsChannel.register,
CreditsChannel.invite,
CreditsChannel.paid,
CreditsChannel.addon,
CreditsChannel.monthly,
]
)
)
.where(UserCreditsRecord.used == False)
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > datetime.now()))
)
result = session.exec(statement).first()
return result or 0
@classmethod
def get_temp_credits(cls, user_id: int) -> tuple[int, date]:
"""
1. 获取可用的临时token总量需要通过credits 然后根据model_type来计算
2. 每天只允许赠送一次临时的量
Returns:
int: 可用的临时token总量
"""
session = session_make()
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
.where(UserCreditsRecord.expire_at.is_not(None))
.where(col(UserCreditsRecord.expire_at) > datetime.now())
)
record: UserCreditsRecord = session.exec(statement).first()
if record is None:
return 0, None
return record.amount - record.balance, record.expire_at
@classmethod
def consume_credits(cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""):
"""
消耗积分优先消耗每日积分daily再消耗monthlypaidaddon等
消耗时更新UserCreditsRecord的balance字段记录已消耗积分数
同时生成积分消耗记录更新用户积分credits字段不包括每日积分
避免重复生成积分消耗记录和重复扣减积分
"""
# 检查是否已有积分消耗记录
existing_consume_record = None
if source_id > 0:
existing_consume_record = session.exec(
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.consume)
.where(UserCreditsRecord.source_id == source_id)
).first()
if existing_consume_record:
# 如果新amount更大需要额外消耗积分
if amount > 0:
existing_consume_record.amount -= amount
session.add(existing_consume_record)
# 直接处理额外的积分消耗,不生成新的消耗记录
cls._consume_credits_internal_update(user_id, amount, session, source_id, remark)
# 如果新amount更小需要退还积分这里可以根据业务需求决定是否实现
else:
# 暂时不实现退还逻辑,可以根据需要添加
pass
session.commit()
return
# 没有现有记录,执行正常的积分消耗流程
cls._consume_credits_internal(user_id, amount, session, source_id, remark)
@classmethod
def _consume_credits_internal(
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
):
"""
内部积分消耗逻辑处理实际的积分扣减
"""
from app.model.user.user import User
remain = amount
now = datetime.now()
consumed_from_daily = 0
consumed_from_other = 0
# 优先消耗daily
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
.where(UserCreditsRecord.expire_at.is_not(None))
.where(col(UserCreditsRecord.expire_at) > now)
.order_by(UserCreditsRecord.expire_at)
)
daily_records = session.exec(statement).first()
if daily_records:
can_consume = daily_records.amount - daily_records.balance
use = min(remain, can_consume)
daily_records.balance += use
session.add(daily_records)
remain -= use
consumed_from_daily = use
if remain == 0:
# 生成积分消耗记录
consume_record = UserCreditsRecord(
user_id=user_id,
amount=-amount,
channel=CreditsChannel.consume,
source_id=source_id,
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily})",
)
session.add(consume_record)
session.commit()
return
# 若daily不够继续消耗monthly/paid/addon
if remain > 0:
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(
UserCreditsRecord.channel.in_(
[
CreditsChannel.monthly,
CreditsChannel.paid,
CreditsChannel.addon,
CreditsChannel.register,
CreditsChannel.invite,
]
)
)
.where(UserCreditsRecord.used == False)
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
.order_by(UserCreditsRecord.expire_at)
)
other_records = session.exec(statement).all()
for record in other_records:
can_consume = record.amount - record.balance
if can_consume <= 0:
continue
use = min(remain, can_consume)
record.balance += use
session.add(record)
remain -= use
consumed_from_other += use
if remain == 0:
break
# 更新用户积分字段(只扣除非每日积分消耗的部分)
if consumed_from_other > 0:
user = session.exec(select(User).where(User.id == user_id)).first()
if user:
user.credits -= consumed_from_other
session.add(user)
# 生成积分消耗记录
consume_record = UserCreditsRecord(
user_id=user_id,
amount=-amount,
channel=CreditsChannel.consume,
source_id=source_id,
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily}, other: {consumed_from_other})",
)
session.add(consume_record)
session.commit()
if remain > 0:
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
@classmethod
def _consume_credits_internal_update(
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
):
"""
内部积分消耗逻辑更新模式处理实际的积分扣减但不生成新的消耗记录
用于更新现有消耗记录时的额外积分消耗
"""
from app.model.user.user import User
remain = amount
now = datetime.now()
consumed_from_daily = 0
consumed_from_other = 0
# 优先消耗daily
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
.where(UserCreditsRecord.expire_at.is_not(None))
.where(col(UserCreditsRecord.expire_at) > now)
.order_by(UserCreditsRecord.expire_at)
)
daily_records = session.exec(statement).first()
if daily_records:
can_consume = daily_records.amount - daily_records.balance
use = min(remain, can_consume)
daily_records.balance += use
session.add(daily_records)
remain -= use
consumed_from_daily = use
if remain == 0:
# 不生成新的消耗记录,只更新现有记录
return
# 若daily不够继续消耗monthly/paid/addon
if remain > 0:
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(
UserCreditsRecord.channel.in_(
[
CreditsChannel.monthly,
CreditsChannel.paid,
CreditsChannel.addon,
CreditsChannel.register,
CreditsChannel.invite,
]
)
)
.where(UserCreditsRecord.used == False)
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
.order_by(UserCreditsRecord.expire_at)
)
other_records = session.exec(statement).all()
for record in other_records:
can_consume = record.amount - record.balance
if can_consume <= 0:
continue
use = min(remain, can_consume)
record.balance += use
session.add(record)
remain -= use
consumed_from_other += use
if remain == 0:
break
logger.info(f"consumed_from_other: {consumed_from_other}")
# 更新用户积分字段(只扣除非每日积分消耗的部分)
if consumed_from_other > 0:
user = session.exec(select(User).where(User.id == user_id)).first()
if user:
user.credits -= consumed_from_other
session.add(user)
# 不生成新的消耗记录,因为现有记录已经在主函数中更新了
if remain > 0:
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
@classmethod
def get_daily_balance_sum(cls, user_id: int) -> int:
"""
获取用户所有每日积分daily channel的balance字段之和
"""
session = session_make()
statement = (
select(UserCreditsRecord.balance)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
)
balances = session.exec(statement).all()
return sum(balances) if balances else 0
@classmethod
def get_daily_balance(cls, user_id: int) -> int:
"""
获取用户当前的每日积分数据
"""
session = session_make()
statement = (
select(UserCreditsRecord)
.where(UserCreditsRecord.user_id == user_id)
.where(UserCreditsRecord.channel == CreditsChannel.daily)
.where(UserCreditsRecord.used == False)
)
record = session.exec(statement).first()
return record
class UserCreditsRecordWithChatOut(BaseModel):
"""扩展的积分记录输出模型,包含聊天历史信息"""
amount: int
balance: int
channel: CreditsChannel
source_id: int
expire_at: Optional[datetime] = None
created_at: datetime
updated_at: Optional[datetime] = None
# 聊天历史相关字段当channel为consume且source_id有效时
chat_project_name: Optional[str] = None
chat_tokens: Optional[int] = None
class UserCreditsRecordOut(BaseModel):
amount: int
balance: int
channel: CreditsChannel
source_id: int
remark: str
expire_at: datetime | None
created_at: datetime
updated_at: datetime | None

View file

@ -21,7 +21,7 @@ class ConfigGroup(str, Enum):
GITHUB = "Github"
GOOGLE_CALENDAR = "Google Calendar"
GOOGLE_DRIVE_MCP = "Google Drive MCP"
GOOGLE_GMAIL_MCP = "Google Gmail MCP"
GOOGLE_GMAIL_MCP = "Google Gmail"
IMAGE_ANALYSIS = "Image Analysis"
MCP_SEARCH = "MCP Search"
PPTX = "PPTX"

View file

@ -25,8 +25,8 @@ services:
# FastAPI Application
api:
build:
context: .
dockerfile: Dockerfile
context: ..
dockerfile: server/Dockerfile
args:
database_url: postgresql://postgres:123456@postgres:5432/eigent
container_name: eigent_api

View file

@ -1,30 +1,36 @@
from app import api
from app.component.environment import auto_include_routers, env
from loguru import logger
import os
from fastapi.staticfiles import StaticFiles
prefix = env("url_prefix", "")
auto_include_routers(api, prefix, "app/controller")
public_dir = os.environ.get("PUBLIC_DIR") or os.path.join(os.path.dirname(__file__), "app", "public")
# Ensure static directory exists or gracefully skip mounting
if not os.path.isdir(public_dir):
try:
os.makedirs(public_dir, exist_ok=True)
logger.warning(f"Public directory did not exist. Created: {public_dir}")
except Exception as e:
logger.error(f"Public directory missing and could not be created: {public_dir}. Error: {e}")
public_dir = None
if public_dir and os.path.isdir(public_dir):
api.mount("/public", StaticFiles(directory=public_dir), name="public")
else:
logger.warning("Skipping /public mount because public directory is unavailable")
logger.add(
"runtime/log/app.log",
rotation="10 MB",
retention="10 days",
level="DEBUG",
enqueue=True,
)
import os
import sys
import pathlib
# Add project root to Python path to import shared utils
_project_root = pathlib.Path(__file__).parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
from utils import traceroot_wrapper as traceroot
from app import api
from app.component.environment import auto_include_routers, env
from fastapi.staticfiles import StaticFiles
# Only initialize traceroot if enabled
if traceroot.is_enabled():
from traceroot.integrations.fastapi import connect_fastapi
connect_fastapi(api)
logger = traceroot.get_logger("server_main")
prefix = env("url_prefix", "")
auto_include_routers(api, prefix, "app/controller")
public_dir = os.environ.get("PUBLIC_DIR") or os.path.join(os.path.dirname(__file__), "app", "public")
if not os.path.isdir(public_dir):
try:
os.makedirs(public_dir, exist_ok=True)
logger.warning(f"Public directory did not exist. Created: {public_dir}")
except Exception as e:
logger.error(f"Public directory missing and could not be created: {public_dir}. Error: {e}")
public_dir = None
if public_dir and os.path.isdir(public_dir):
api.mount("/public", StaticFiles(directory=public_dir), name="public")
else:
logger.warning("Skipping /public mount because public directory is unavailable")

View file

@ -1,40 +1,41 @@
[project]
name = "Eigent"
version = "0.1.0"
description = "Eigent"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"alembic>=1.15.2",
"click>=8.1.8",
"fastapi>=0.115.12",
"fastapi-babel>=1.0.0",
"fastapi-pagination>=0.12.34",
"passlib[bcrypt]>=1.7.4",
"bcrypt==4.0.1",
"pydantic-i18n>=0.4.5",
"pydantic[email]>=2.11.1",
"pyjwt>=2.10.1",
"python-dotenv>=1.1.0",
"sqlalchemy-utils>=0.41.2",
"sqlmodel>=0.0.24",
"pandas>=2.2.3",
"openpyxl>=3.1.5",
"pandas>=2.2.3",
"arrow>=1.3.0",
"fastapi-filter>=2.0.1",
"psycopg2-binary>=2.9.10",
"convert-case>=1.2.3",
"python-multipart>=0.0.20",
"loguru>=0.7.3",
"httpx>=0.28.1",
"pydash>=8.0.5",
"requests>=2.32.4",
"itsdangerous>=2.2.0",
"cryptography>=45.0.4",
"sqids>=0.5.2",
"exa-py>=1.14.16",
]
[tool.ruff]
line-length = 120
[project]
name = "Eigent"
version = "0.1.0"
description = "Eigent"
readme = "README.md"
requires-python = ">=3.12,<3.13"
dependencies = [
"alembic>=1.15.2",
"openai>=1.99.3,<2",
"camel-ai==0.2.76a13",
"pydantic[email]>=2.11.1",
"click>=8.1.8",
"fastapi>=0.115.12",
"fastapi-babel>=1.0.0",
"fastapi-pagination>=0.12.34",
"passlib[bcrypt]>=1.7.4",
"bcrypt==4.0.1",
"pydantic-i18n>=0.4.5",
"pyjwt>=2.10.1",
"python-dotenv>=1.1.0",
"sqlalchemy-utils>=0.41.2",
"sqlmodel>=0.0.24",
"pandas>=2.2.3",
"openpyxl>=3.1.5",
"arrow>=1.3.0",
"fastapi-filter>=2.0.1",
"psycopg2-binary>=2.9.10",
"convert-case>=1.2.3",
"python-multipart>=0.0.20",
"httpx>=0.28.1",
"pydash>=8.0.5",
"requests>=2.32.4",
"itsdangerous>=2.2.0",
"cryptography>=45.0.4",
"sqids>=0.5.2",
"exa-py>=1.14.16",
"traceroot>=0.0.7",
]
[tool.ruff]
line-length = 120

1353
server/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -170,7 +170,6 @@ async function proxyFetchRequest(
...customHeaders,
}
console.debug('url', url, token)
if (!url.includes('http://') && !url.includes('https://') && token) {
headers['Authorization'] = `Bearer ${token}`
}

Some files were not shown because too many files have changed in this diff Show more