mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-14 16:42:47 +00:00
Merge branch 'main' into fix/markdown-ordered-list-numbering
This commit is contained in:
commit
cd30c7d840
277 changed files with 27748 additions and 8413 deletions
|
|
@ -5,4 +5,26 @@ VITE_PROXY_URL=https://dev.eigent.ai
|
|||
VITE_USE_LOCAL_PROXY=false
|
||||
|
||||
# VITE_PROXY_URL=http://localhost:3001
|
||||
# VITE_USE_LOCAL_PROXY=true
|
||||
# VITE_USE_LOCAL_PROXY=true
|
||||
|
||||
TRACEROOT_TOKEN=your_traceroot_token_here
|
||||
|
||||
TRACEROOT_SERVICE_NAME=eigent
|
||||
|
||||
TRACEROOT_GITHUB_OWNER=eigent
|
||||
|
||||
TRACEROOT_GITHUB_REPO_NAME=eigent-ai
|
||||
|
||||
TRACEROOT_GITHUB_COMMIT_HASH=main
|
||||
|
||||
TRACEROOT_ENABLE_SPAN_CLOUD_EXPORT=true
|
||||
|
||||
TRACEROOT_ENABLE_LOG_CLOUD_EXPORT=true
|
||||
|
||||
TRACEROOT_ENABLE_SPAN_CONSOLE_EXPORT=false
|
||||
|
||||
TRACEROOT_ENABLE_LOG_CONSOLE_EXPORT=true
|
||||
|
||||
TRACEROOT_TRACER_VERBOSE=false
|
||||
|
||||
TRACEROOT_LOGGER_VERBOSE=false
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
|
|
@ -46,3 +46,11 @@ public/
|
|||
|
||||
# Testing
|
||||
coverage/
|
||||
.traceroot-config.yaml
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
|
|
|
|||
4
.vscode/extensions.json
vendored
4
.vscode/extensions.json
vendored
|
|
@ -2,6 +2,8 @@
|
|||
// See http://go.microsoft.com/fwlink/?LinkId=827846
|
||||
// for the documentation about the extensions.json format
|
||||
"recommendations": [
|
||||
"mrmlnc.vscode-json5"
|
||||
"mrmlnc.vscode-json5",
|
||||
"ms-python.python",
|
||||
"ms-python.debugpy"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
16
.vscode/launch.json
vendored
16
.vscode/launch.json
vendored
|
|
@ -50,5 +50,21 @@
|
|||
"http://127.0.0.1:7777/**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Debug Python Backend (Attach)",
|
||||
"type": "debugpy",
|
||||
"request": "attach",
|
||||
"connect": {
|
||||
"host": "localhost",
|
||||
"port": 5678
|
||||
},
|
||||
"pathMappings": [
|
||||
{
|
||||
"localRoot": "${workspaceFolder}/backend",
|
||||
"remoteRoot": "."
|
||||
}
|
||||
],
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from app.utils import traceroot_wrapper as traceroot
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ import os
|
|||
import re
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from app.utils import traceroot_wrapper as traceroot
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
from app.component import code
|
||||
from app.exception.exception import UserException
|
||||
from app.model.chat import Chat, HumanReply, McpServers, Status, SupplementChat
|
||||
from app.model.chat import Chat, HumanReply, McpServers, Status, SupplementChat, AddTaskRequest
|
||||
from app.service.chat_service import step_solve
|
||||
from app.service.task import (
|
||||
Action,
|
||||
|
|
@ -17,13 +16,18 @@ from app.service.task import (
|
|||
ActionInstallMcpData,
|
||||
ActionStopData,
|
||||
ActionSupplementData,
|
||||
create_task_lock,
|
||||
ActionAddTaskData,
|
||||
ActionRemoveTaskData,
|
||||
ActionSkipTaskData,
|
||||
get_or_create_task_lock,
|
||||
get_task_lock,
|
||||
)
|
||||
from app.component.environment import set_user_env_path
|
||||
from app.utils.workforce import Workforce
|
||||
from camel.tasks.task import Task
|
||||
|
||||
|
||||
router = APIRouter(tags=["chat"])
|
||||
router = APIRouter()
|
||||
|
||||
# Create traceroot logger for chat controller
|
||||
chat_logger = traceroot.get_logger('chat_controller')
|
||||
|
|
@ -32,55 +36,110 @@ chat_logger = traceroot.get_logger('chat_controller')
|
|||
@router.post("/chat", name="start chat")
|
||||
@traceroot.trace()
|
||||
async def post(data: Chat, request: Request):
|
||||
chat_logger.info(f"Starting new chat session for task_id: {data.task_id}, user: {data.email}")
|
||||
task_lock = create_task_lock(data.task_id)
|
||||
|
||||
chat_logger.info("Starting new chat session", extra={"project_id": data.project_id, "task_id": data.task_id, "user": data.email})
|
||||
task_lock = get_or_create_task_lock(data.project_id)
|
||||
|
||||
# Set user-specific environment path for this thread
|
||||
set_user_env_path(data.env_path)
|
||||
load_dotenv(dotenv_path=data.env_path)
|
||||
|
||||
# logger.debug(f"start chat: {data.model_dump_json()}")
|
||||
|
||||
os.environ["file_save_path"] = data.file_save_path()
|
||||
os.environ["browser_port"] = str(data.browser_port)
|
||||
os.environ["OPENAI_API_KEY"] = data.api_key
|
||||
os.environ["OPENAI_API_BASE_URL"] = data.api_url or "https://api.openai.com/v1"
|
||||
os.environ["CAMEL_MODEL_LOG_ENABLED"] = "true"
|
||||
|
||||
email = re.sub(r'[\\/*?:"<>|\s]', "_", data.email.split("@")[0]).strip(".")
|
||||
camel_log = Path.home() / ".eigent" / email / ("task_" + data.task_id) / "camel_logs"
|
||||
# Set user-specific search engine configuration if provided
|
||||
if data.search_config:
|
||||
for key, value in data.search_config.items():
|
||||
if value: # Only set non-empty values
|
||||
os.environ[key] = value
|
||||
chat_logger.info(f"Set search config: {key}", extra={"project_id": data.project_id})
|
||||
|
||||
email_sanitized = re.sub(r'[\\/*?:"<>|\s]', "_", data.email.split("@")[0]).strip(".")
|
||||
camel_log = Path.home() / ".eigent" / email_sanitized / ("project_" + data.project_id) / ("task_" + data.task_id) / "camel_logs"
|
||||
camel_log.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
os.environ["CAMEL_LOG_DIR"] = str(camel_log)
|
||||
|
||||
if data.is_cloud():
|
||||
os.environ["cloud_api_key"] = data.api_key
|
||||
|
||||
chat_logger.info(f"Chat session initialized, starting streaming response for task_id: {data.task_id}")
|
||||
|
||||
# Put initial action in queue to start processing
|
||||
await task_lock.put_queue(ActionImproveData(data=data.question))
|
||||
|
||||
chat_logger.info("Chat session initialized, starting streaming response", extra={"project_id": data.project_id, "task_id": data.task_id, "log_dir": str(camel_log)})
|
||||
return StreamingResponse(step_solve(data, request, task_lock), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/chat/{id}", name="improve chat")
|
||||
@traceroot.trace()
|
||||
def improve(id: str, data: SupplementChat):
|
||||
chat_logger.info(f"Improving chat for task_id: {id} with question: {data.question}")
|
||||
chat_logger.info("Chat improvement requested", extra={"task_id": id, "question_length": len(data.question)})
|
||||
task_lock = get_task_lock(id)
|
||||
|
||||
# Allow continuing conversation even after task is done
|
||||
# This supports multi-turn conversation after complex task completion
|
||||
if task_lock.status == Status.done:
|
||||
raise UserException(code.error, "Task was done")
|
||||
# Reset status to allow processing new messages
|
||||
task_lock.status = Status.confirming
|
||||
# Clear any existing background tasks since workforce was stopped
|
||||
if hasattr(task_lock, 'background_tasks'):
|
||||
task_lock.background_tasks.clear()
|
||||
# Note: conversation_history and last_task_result are preserved
|
||||
|
||||
# Log context preservation
|
||||
if hasattr(task_lock, 'conversation_history'):
|
||||
chat_logger.info(f"[CONTEXT] Preserved {len(task_lock.conversation_history)} conversation entries")
|
||||
if hasattr(task_lock, 'last_task_result'):
|
||||
chat_logger.info(f"[CONTEXT] Preserved task result: {len(task_lock.last_task_result)} chars")
|
||||
|
||||
# Update file save path if task_id is provided
|
||||
new_folder_path = None
|
||||
if data.task_id:
|
||||
try:
|
||||
# Get current environment values needed to construct new path
|
||||
current_email = None
|
||||
|
||||
# Extract email from current file_save_path if available
|
||||
current_file_save_path = os.environ.get("file_save_path", "")
|
||||
if current_file_save_path:
|
||||
path_parts = Path(current_file_save_path).parts
|
||||
if len(path_parts) >= 3 and "eigent" in path_parts:
|
||||
eigent_index = path_parts.index("eigent")
|
||||
if eigent_index + 1 < len(path_parts):
|
||||
current_email = path_parts[eigent_index + 1]
|
||||
|
||||
# If we have the necessary information, update the file_save_path
|
||||
if current_email and id:
|
||||
# Create new path using the existing pattern: email/project_{project_id}/task_{task_id}
|
||||
new_folder_path = Path.home() / "eigent" / current_email / f"project_{id}" / f"task_{data.task_id}"
|
||||
new_folder_path.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["file_save_path"] = str(new_folder_path)
|
||||
chat_logger.info(f"Updated file_save_path to: {new_folder_path}")
|
||||
|
||||
# Store the new folder path in task_lock for potential cleanup and persistence
|
||||
task_lock.new_folder_path = new_folder_path
|
||||
else:
|
||||
chat_logger.warning(f"Could not update file_save_path - email: {current_email}, project_id: {id}")
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"Error updating file path for project_id: {id}, task_id: {data.task_id}: {e}")
|
||||
|
||||
asyncio.run(task_lock.put_queue(ActionImproveData(data=data.question)))
|
||||
chat_logger.info(f"Improvement request queued for task_id: {id}")
|
||||
chat_logger.info("Improvement request queued with preserved context", extra={"project_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
@router.put("/chat/{id}", name="supplement task")
|
||||
@traceroot.trace()
|
||||
def supplement(id: str, data: SupplementChat):
|
||||
chat_logger.info(f"Supplementing task_id: {id} with additional data")
|
||||
chat_logger.info("Chat supplement requested", extra={"task_id": id})
|
||||
task_lock = get_task_lock(id)
|
||||
if task_lock.status != Status.done:
|
||||
raise UserException(code.error, "Please wait task done")
|
||||
asyncio.run(task_lock.put_queue(ActionSupplementData(data=data)))
|
||||
chat_logger.info(f"Supplement data queued for task_id: {id}")
|
||||
chat_logger.debug("Supplement data queued", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
|
|
@ -88,28 +147,92 @@ def supplement(id: str, data: SupplementChat):
|
|||
@traceroot.trace()
|
||||
def stop(id: str):
|
||||
"""stop the task"""
|
||||
chat_logger.warning(f"Stopping chat session for task_id: {id}")
|
||||
chat_logger.warning("Stopping chat session", extra={"task_id": id})
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionStopData(action=Action.stop)))
|
||||
chat_logger.info(f"Stop signal sent for task_id: {id}")
|
||||
chat_logger.info("Chat stop signal sent", extra={"task_id": id})
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/chat/{id}/human-reply")
|
||||
@traceroot.trace()
|
||||
def human_reply(id: str, data: HumanReply):
|
||||
chat_logger.info(f"Human reply received for task_id: {id}, agent: {data.agent}")
|
||||
chat_logger.info("Human reply received", extra={"task_id": id, "reply_length": len(data.reply)})
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_human_input(data.agent, data.reply))
|
||||
chat_logger.info(f"Human reply processed for task_id: {id}")
|
||||
chat_logger.debug("Human reply processed", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
@router.post("/chat/{id}/install-mcp")
|
||||
@traceroot.trace()
|
||||
def install_mcp(id: str, data: McpServers):
|
||||
chat_logger.info(f"Installing MCP servers for task_id: {id}, servers count: {len(data.get('mcpServers', {}))}")
|
||||
chat_logger.info("Installing MCP servers", extra={"task_id": id, "servers_count": len(data.get('mcpServers', {}))})
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionInstallMcpData(action=Action.install_mcp, data=data)))
|
||||
chat_logger.info(f"MCP installation queued for task_id: {id}")
|
||||
chat_logger.info("MCP installation queued", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
@router.post("/chat/{id}/add-task", name="add task to workforce")
|
||||
@traceroot.trace()
|
||||
def add_task(id: str, data: AddTaskRequest):
|
||||
"""Add a new task to the workforce"""
|
||||
chat_logger.info(f"Adding task to workforce for task_id: {id}, content: {data.content[:100]}...")
|
||||
task_lock = get_task_lock(id)
|
||||
|
||||
try:
|
||||
# Queue the add task action
|
||||
add_task_action = ActionAddTaskData(
|
||||
content=data.content,
|
||||
project_id=data.project_id,
|
||||
task_id=data.task_id,
|
||||
additional_info=data.additional_info,
|
||||
insert_position=data.insert_position
|
||||
)
|
||||
asyncio.run(task_lock.put_queue(add_task_action))
|
||||
return Response(status_code=201)
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"Error adding task for task_id: {id}: {e}")
|
||||
raise UserException(code.error, f"Failed to add task: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/chat/{project_id}/remove-task/{task_id}", name="remove task from workforce")
|
||||
@traceroot.trace()
|
||||
def remove_task(project_id: str, task_id: str):
|
||||
"""Remove a task from the workforce"""
|
||||
chat_logger.info(f"Removing task {task_id} from workforce for project_id: {project_id}")
|
||||
task_lock = get_task_lock(project_id)
|
||||
|
||||
try:
|
||||
# Queue the remove task action
|
||||
remove_task_action = ActionRemoveTaskData(task_id=task_id, project_id=project_id)
|
||||
asyncio.run(task_lock.put_queue(remove_task_action))
|
||||
|
||||
chat_logger.info(f"Task removal request queued for project_id: {project_id}, removing task: {task_id}")
|
||||
return Response(status_code=204)
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"Error removing task {task_id} for project_id: {project_id}: {e}")
|
||||
raise UserException(code.error, f"Failed to remove task: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/chat/{project_id}/skip-task", name="skip task in workforce")
|
||||
@traceroot.trace()
|
||||
def skip_task(project_id: str):
|
||||
"""Skip a task in the workforce"""
|
||||
chat_logger.info(f"Skipping task in workforce for project_id: {project_id}")
|
||||
task_lock = get_task_lock(project_id)
|
||||
|
||||
try:
|
||||
# Queue the skip task action
|
||||
skip_task_action = ActionSkipTaskData(project_id=project_id)
|
||||
asyncio.run(task_lock.put_queue(skip_task_action))
|
||||
|
||||
chat_logger.info(f"Task skip request queued for project_id: {project_id}")
|
||||
return Response(status_code=201)
|
||||
|
||||
except Exception as e:
|
||||
chat_logger.error(f"Error skipping task for project_id: {project_id}: {e}")
|
||||
raise UserException(code.error, f"Failed to skip task: {str(e)}")
|
||||
|
|
|
|||
16
backend/app/controller/health_controller.py
Normal file
16
backend/app/controller/health_controller.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
service: str
|
||||
|
||||
|
||||
@router.get("/health", name="health check", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint for verifying backend is ready to accept requests."""
|
||||
return HealthResponse(status="ok", service="eigent")
|
||||
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from app.component.model_validation import create_agent
|
||||
from camel.types import ModelType
|
||||
from app.component.error_format import normalize_error_to_openai_format
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("model_controller")
|
||||
|
||||
|
||||
router = APIRouter(tags=["model"])
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ValidateModelRequest(BaseModel):
|
||||
|
|
@ -26,33 +29,46 @@ class ValidateModelResponse(BaseModel):
|
|||
|
||||
|
||||
@router.post("/model/validate")
|
||||
@traceroot.trace()
|
||||
async def validate_model(request: ValidateModelRequest):
|
||||
try:
|
||||
# API key validation
|
||||
if request.api_key is not None and str(request.api_key).strip() == "":
|
||||
return ValidateModelResponse(
|
||||
is_valid=False,
|
||||
is_tool_calls=False,
|
||||
message="Invalid key. Validation failed.",
|
||||
error_code="invalid_api_key",
|
||||
error={
|
||||
"message": "Invalid key. Validation failed.",
|
||||
"""Validate model configuration and tool call support."""
|
||||
platform = request.model_platform
|
||||
model_type = request.model_type
|
||||
has_custom_url = request.url is not None
|
||||
has_config = request.model_config_dict is not None
|
||||
|
||||
logger.info("Model validation started", extra={"platform": platform, "model_type": model_type, "has_url": has_custom_url, "has_config": has_config})
|
||||
|
||||
# API key validation
|
||||
if request.api_key is not None and str(request.api_key).strip() == "":
|
||||
logger.warning("Model validation failed: empty API key", extra={"platform": platform, "model_type": model_type})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": "Invalid key. Validation failed.",
|
||||
"error_code": "invalid_api_key",
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
extra = request.extra_params or {}
|
||||
|
||||
logger.debug("Creating agent for validation", extra={"platform": platform, "model_type": model_type})
|
||||
agent = create_agent(
|
||||
request.model_platform,
|
||||
request.model_type,
|
||||
platform,
|
||||
model_type,
|
||||
api_key=request.api_key,
|
||||
url=request.url,
|
||||
model_config_dict=request.model_config_dict,
|
||||
**extra,
|
||||
)
|
||||
|
||||
logger.debug("Agent created, executing test step", extra={"platform": platform, "model_type": model_type})
|
||||
response = agent.step(
|
||||
input_message="""
|
||||
Get the content of https://www.camel-ai.org,
|
||||
|
|
@ -61,17 +77,23 @@ async def validate_model(request: ValidateModelRequest):
|
|||
you must call the get_website_content tool only once.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Normalize error to OpenAI-style error structure
|
||||
logger.error("Model validation failed", extra={"platform": platform, "model_type": model_type, "error": str(e)}, exc_info=True)
|
||||
message, error_code, error_obj = normalize_error_to_openai_format(e)
|
||||
|
||||
return ValidateModelResponse(
|
||||
is_valid=False,
|
||||
is_tool_calls=False,
|
||||
message=message,
|
||||
error_code=error_code,
|
||||
error=error_obj,
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": message,
|
||||
"error_code": error_code,
|
||||
"error": error_obj,
|
||||
}
|
||||
)
|
||||
|
||||
# Check validation results
|
||||
is_valid = bool(response)
|
||||
is_tool_calls = False
|
||||
|
||||
|
|
@ -83,7 +105,7 @@ async def validate_model(request: ValidateModelRequest):
|
|||
== "Tool execution completed successfully for https://www.camel-ai.org, Website Content: Welcome to CAMEL AI!"
|
||||
)
|
||||
|
||||
return ValidateModelResponse(
|
||||
result = ValidateModelResponse(
|
||||
is_valid=is_valid,
|
||||
is_tool_calls=is_tool_calls,
|
||||
message="Validation Success"
|
||||
|
|
@ -92,3 +114,7 @@ async def validate_model(request: ValidateModelRequest):
|
|||
error_code=None,
|
||||
error=None,
|
||||
)
|
||||
|
||||
logger.info("Model validation completed", extra={"platform": platform, "model_type": model_type, "is_valid": is_valid, "is_tool_calls": is_tool_calls})
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Literal
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Response
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from app.model.chat import NewAgent, UpdateData
|
||||
from app.service.task import (
|
||||
|
|
@ -16,24 +15,32 @@ from app.service.task import (
|
|||
)
|
||||
import asyncio
|
||||
from app.component.environment import set_user_env_path
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("task_controller")
|
||||
|
||||
|
||||
router = APIRouter(tags=["task"])
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/task/{id}/start", name="start task")
|
||||
@traceroot.trace()
|
||||
def start(id: str):
|
||||
task_lock = get_task_lock(id)
|
||||
logger.debug(f"start task {id}")
|
||||
logger.info("Starting task", extra={"task_id": id})
|
||||
asyncio.run(task_lock.put_queue(ActionStartData(action=Action.start)))
|
||||
logger.debug(f"start task {id} success")
|
||||
logger.info("Task started successfully", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
@router.put("/task/{id}", name="update task")
|
||||
@traceroot.trace()
|
||||
def put(id: str, data: UpdateData):
|
||||
logger.info("Updating task", extra={"task_id": id, "task_items_count": len(data.task)})
|
||||
logger.debug("Update task data", extra={"task_id": id, "data": data.model_dump_json()})
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionUpdateTaskData(action=Action.update_task, data=data)))
|
||||
logger.info("Task updated successfully", extra={"task_id": id})
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
|
|
@ -42,23 +49,33 @@ class TakeControl(BaseModel):
|
|||
|
||||
|
||||
@router.put("/task/{id}/take-control", name="take control pause or resume")
|
||||
@traceroot.trace()
|
||||
def take_control(id: str, data: TakeControl):
|
||||
logger.info("Task control action", extra={"task_id": id, "action": data.action})
|
||||
task_lock = get_task_lock(id)
|
||||
asyncio.run(task_lock.put_queue(ActionTakeControl(action=data.action)))
|
||||
logger.info("Task control action completed", extra={"task_id": id, "action": data.action})
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/task/{id}/add-agent", name="add new agent")
|
||||
@traceroot.trace()
|
||||
def add_agent(id: str, data: NewAgent):
|
||||
logger.info("Adding new agent to task", extra={"task_id": id, "agent_name": data.name})
|
||||
logger.debug("New agent data", extra={"task_id": id, "agent_data": data.model_dump_json()})
|
||||
# Set user-specific environment path for this thread
|
||||
set_user_env_path(data.env_path)
|
||||
load_dotenv(dotenv_path=data.env_path)
|
||||
asyncio.run(get_task_lock(id).put_queue(ActionNewAgent(**data.model_dump())))
|
||||
logger.info("Agent added to task", extra={"task_id": id, "agent_name": data.name})
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete("/task/stop-all", name="stop all tasks")
|
||||
@traceroot.trace()
|
||||
def stop_all():
|
||||
logger.warning("Stopping all tasks", extra={"task_count": len(task_locks)})
|
||||
for task_lock in task_locks.values():
|
||||
asyncio.run(task_lock.put_queue(ActionStopData()))
|
||||
logger.info("All tasks stopped", extra={"task_count": len(task_locks)})
|
||||
return Response(status_code=204)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
from app.utils.toolkit.notion_mcp_toolkit import NotionMCPToolkit
|
||||
from app.utils.toolkit.google_calendar_toolkit import GoogleCalendarToolkit
|
||||
from app.utils.oauth_state_manager import oauth_state_manager
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
from camel.toolkits.hybrid_browser_toolkit.hybrid_browser_toolkit_ts import (
|
||||
HybridBrowserToolkit as BaseHybridBrowserToolkit,
|
||||
)
|
||||
from app.utils.cookie_manager import CookieManager
|
||||
import os
|
||||
import uuid
|
||||
|
||||
|
||||
router = APIRouter(tags=["task"])
|
||||
logger = traceroot.get_logger("tool_controller")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/install/tool/{tool}", name="install tool")
|
||||
|
|
@ -28,8 +35,10 @@ async def install_tool(tool: str):
|
|||
await toolkit.connect()
|
||||
|
||||
# Get available tools to verify connection
|
||||
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
|
||||
logger.info(f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
|
||||
tools = [tool_func.func.__name__ for tool_func in
|
||||
toolkit.get_tools()]
|
||||
logger.info(
|
||||
f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
|
||||
|
||||
# Disconnect, authentication info is saved
|
||||
await toolkit.disconnect()
|
||||
|
|
@ -42,7 +51,8 @@ async def install_tool(tool: str):
|
|||
"toolkit_name": "NotionMCPToolkit"
|
||||
}
|
||||
except Exception as connect_error:
|
||||
logger.warning(f"Could not connect to {tool} MCP server: {connect_error}")
|
||||
logger.warning(
|
||||
f"Could not connect to {tool} MCP server: {connect_error}")
|
||||
# Even if connection fails, mark as installed so user can use it later
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -60,20 +70,34 @@ async def install_tool(tool: str):
|
|||
)
|
||||
elif tool == "google_calendar":
|
||||
try:
|
||||
# Use a dummy task_id for installation, as this is just for pre-authentication
|
||||
toolkit = GoogleCalendarToolkit("install_auth")
|
||||
# Try to initialize toolkit - will succeed if credentials exist
|
||||
try:
|
||||
toolkit = GoogleCalendarToolkit("install_auth")
|
||||
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
|
||||
logger.info(f"Successfully initialized Google Calendar toolkit with {len(tools)} tools")
|
||||
|
||||
# Get available tools to verify connection
|
||||
tools = [tool_func.func.__name__ for tool_func in toolkit.get_tools()]
|
||||
logger.info(f"Successfully pre-instantiated {tool} toolkit with {len(tools)} tools")
|
||||
return {
|
||||
"success": True,
|
||||
"tools": tools,
|
||||
"message": f"Successfully installed {tool} toolkit",
|
||||
"count": len(tools),
|
||||
"toolkit_name": "GoogleCalendarToolkit"
|
||||
}
|
||||
except ValueError as auth_error:
|
||||
# No credentials - need authorization
|
||||
logger.info(f"No credentials found, starting authorization: {auth_error}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tools": tools,
|
||||
"message": f"Successfully installed {tool} toolkit",
|
||||
"count": len(tools),
|
||||
"toolkit_name": "GoogleCalendarToolkit"
|
||||
}
|
||||
# Start background authorization in a new thread
|
||||
logger.info("Starting background Google Calendar authorization")
|
||||
GoogleCalendarToolkit.start_background_auth("install_auth")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"status": "authorizing",
|
||||
"message": "Authorization required. Browser should open automatically. Complete authorization and try installing again.",
|
||||
"toolkit_name": "GoogleCalendarToolkit",
|
||||
"requires_auth": True
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install {tool} toolkit: {e}")
|
||||
raise HTTPException(
|
||||
|
|
@ -113,3 +137,880 @@ async def list_available_tools():
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/oauth/status/{provider}", name="get oauth status")
|
||||
async def get_oauth_status(provider: str):
|
||||
"""
|
||||
Get the current OAuth authorization status for a provider
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g., 'google_calendar')
|
||||
|
||||
Returns:
|
||||
Current authorization status
|
||||
"""
|
||||
state = oauth_state_manager.get_state(provider)
|
||||
|
||||
if not state:
|
||||
return {
|
||||
"provider": provider,
|
||||
"status": "not_started",
|
||||
"message": "No authorization in progress"
|
||||
}
|
||||
|
||||
return state.to_dict()
|
||||
|
||||
|
||||
@router.post("/oauth/cancel/{provider}", name="cancel oauth")
|
||||
async def cancel_oauth(provider: str):
|
||||
"""
|
||||
Cancel an ongoing OAuth authorization flow
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g., 'google_calendar')
|
||||
|
||||
Returns:
|
||||
Cancellation result
|
||||
"""
|
||||
state = oauth_state_manager.get_state(provider)
|
||||
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No authorization found for provider '{provider}'"
|
||||
)
|
||||
|
||||
if state.status not in ["pending", "authorizing"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel authorization with status '{state.status}'"
|
||||
)
|
||||
|
||||
state.cancel()
|
||||
logger.info(f"Cancelled OAuth authorization for {provider}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"provider": provider,
|
||||
"message": "Authorization cancelled successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/uninstall/tool/{tool}", name="uninstall tool")
|
||||
async def uninstall_tool(tool: str):
|
||||
"""
|
||||
Uninstall a tool and clean up its authentication data
|
||||
|
||||
Args:
|
||||
tool: Tool name to uninstall (notion, google_calendar)
|
||||
|
||||
Returns:
|
||||
Uninstallation result
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
|
||||
if tool == "notion":
|
||||
try:
|
||||
import hashlib
|
||||
import glob
|
||||
|
||||
# Calculate the hash for Notion MCP URL
|
||||
# mcp-remote uses MD5 hash of the URL to generate file names
|
||||
notion_url = "https://mcp.notion.com/mcp"
|
||||
url_hash = hashlib.md5(notion_url.encode()).hexdigest()
|
||||
|
||||
# Find and remove Notion-specific auth files
|
||||
mcp_auth_dir = os.path.join(os.path.expanduser("~"), ".mcp-auth")
|
||||
deleted_files = []
|
||||
|
||||
if os.path.exists(mcp_auth_dir):
|
||||
# Look for all files with the Notion hash prefix
|
||||
for version_dir in os.listdir(mcp_auth_dir):
|
||||
version_path = os.path.join(mcp_auth_dir, version_dir)
|
||||
if os.path.isdir(version_path):
|
||||
# Find all files matching the hash pattern
|
||||
pattern = os.path.join(version_path, f"{url_hash}_*")
|
||||
notion_files = glob.glob(pattern)
|
||||
|
||||
for file_path in notion_files:
|
||||
try:
|
||||
os.remove(file_path)
|
||||
deleted_files.append(file_path)
|
||||
logger.info(f"Removed Notion auth file: {file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remove {file_path}: {e}")
|
||||
|
||||
message = f"Successfully uninstalled {tool}"
|
||||
if deleted_files:
|
||||
message += f" and cleaned up {len(deleted_files)} authentication file(s)"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"deleted_files": deleted_files
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to uninstall {tool}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to uninstall {tool}: {str(e)}"
|
||||
)
|
||||
|
||||
elif tool == "google_calendar":
|
||||
try:
|
||||
# Clean up Google Calendar token directory
|
||||
token_dir = os.path.join(os.path.expanduser("~"), ".eigent", "tokens", "google_calendar")
|
||||
if os.path.exists(token_dir):
|
||||
shutil.rmtree(token_dir)
|
||||
logger.info(f"Removed Google Calendar token directory: {token_dir}")
|
||||
|
||||
# Clear OAuth state manager cache (this is the key fix!)
|
||||
# This removes the cached credentials from memory
|
||||
state = oauth_state_manager.get_state("google_calendar")
|
||||
if state:
|
||||
if state.status in ["pending", "authorizing"]:
|
||||
state.cancel()
|
||||
logger.info("Cancelled ongoing Google Calendar authorization")
|
||||
# Clear the state completely to remove cached credentials
|
||||
oauth_state_manager._states.pop("google_calendar", None)
|
||||
logger.info("Cleared Google Calendar OAuth state cache")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully uninstalled {tool} and cleaned up authentication tokens"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to uninstall {tool}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to uninstall {tool}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Tool '{tool}' not found. Available tools: ['notion', 'google_calendar']"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/browser/login", name="open browser for login")
|
||||
async def open_browser_login():
|
||||
"""
|
||||
Open an Electron-based Chrome browser for user login with a dedicated user data directory
|
||||
|
||||
Returns:
|
||||
Browser session information
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import platform
|
||||
import socket
|
||||
import json
|
||||
|
||||
# Use fixed profile name for persistent logins (no port suffix)
|
||||
session_id = "user_login"
|
||||
cdp_port = 9223
|
||||
|
||||
# IMPORTANT: Use dedicated profile for tool_controller browser
|
||||
# This is the SOURCE OF TRUTH for login data
|
||||
# On Eigent startup, this data will be copied to WebView partition (one-way sync)
|
||||
browser_profiles_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(browser_profiles_base, "profile_user_login")
|
||||
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
|
||||
logger.info(
|
||||
f"Creating browser session {session_id} with profile at: {user_data_dir}")
|
||||
|
||||
# Check if browser is already running on this port
|
||||
def is_port_in_use(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if is_port_in_use(cdp_port):
|
||||
logger.info(f"Browser already running on port {cdp_port}")
|
||||
return {
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"user_data_dir": user_data_dir,
|
||||
"cdp_port": cdp_port,
|
||||
"message": "Browser already running. Use existing window to log in.",
|
||||
"note": "Your login data will be saved in the profile."
|
||||
}
|
||||
|
||||
# Create Electron browser script with .cjs extension for CommonJS
|
||||
electron_script_path = os.path.join(os.path.dirname(__file__), "electron_browser.cjs")
|
||||
electron_script_content = '''
|
||||
const { app, BrowserWindow, ipcMain } = require('electron');
|
||||
const path = require('path');
|
||||
|
||||
// Parse command line arguments
|
||||
const args = process.argv.slice(2);
|
||||
const userDataDir = args[0];
|
||||
const cdpPort = args[1];
|
||||
const startUrl = args[2] || 'https://www.google.com';
|
||||
|
||||
// This must be called before app.ready
|
||||
app.commandLine.appendSwitch('remote-debugging-port', cdpPort);
|
||||
|
||||
console.log('[ELECTRON BROWSER] Starting with:');
|
||||
console.log(' Chrome version:', process.versions.chrome);
|
||||
console.log(' User data dir (requested):', userDataDir);
|
||||
console.log(' CDP port:', cdpPort);
|
||||
console.log(' Start URL:', startUrl);
|
||||
|
||||
// Set app paths - must be done before app.ready
|
||||
// Do NOT use commandLine.appendSwitch('user-data-dir') as it conflicts with setPath
|
||||
app.setPath('userData', userDataDir);
|
||||
app.setPath('sessionData', userDataDir);
|
||||
|
||||
app.whenReady().then(async () => {
|
||||
const { session } = require('electron');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
// Log actual paths being used
|
||||
console.log('[ELECTRON BROWSER] Actual paths:');
|
||||
console.log(' app.getPath("userData"):', app.getPath('userData'));
|
||||
console.log(' app.getPath("sessionData"):', app.getPath('sessionData'));
|
||||
console.log(' app.getPath("cache"):', app.getPath('cache'));
|
||||
console.log(' app.getPath("temp"):', app.getPath('temp'));
|
||||
console.log(' process.argv:', process.argv);
|
||||
|
||||
// Check command line switches
|
||||
console.log('[ELECTRON BROWSER] Command line switches:');
|
||||
console.log(' user-data-dir:', app.commandLine.getSwitchValue('user-data-dir'));
|
||||
console.log(' remote-debugging-port:', app.commandLine.getSwitchValue('remote-debugging-port'));
|
||||
|
||||
// Log partition session info
|
||||
const userLoginSession = session.fromPartition('persist:user_login');
|
||||
console.log('[ELECTRON BROWSER] Session info:');
|
||||
console.log(' Partition: persist:user_login');
|
||||
console.log(' Session storage path:', userLoginSession.getStoragePath());
|
||||
|
||||
// Check if Cookies file exists
|
||||
const cookiesPath = path.join(app.getPath('userData'), 'Partitions', 'user_login', 'Cookies');
|
||||
console.log('[ELECTRON BROWSER] Cookies path:', cookiesPath);
|
||||
console.log('[ELECTRON BROWSER] Cookies exists:', fs.existsSync(cookiesPath));
|
||||
if (fs.existsSync(cookiesPath)) {
|
||||
const stats = fs.statSync(cookiesPath);
|
||||
console.log('[ELECTRON BROWSER] Cookies file size:', stats.size, 'bytes');
|
||||
}
|
||||
const win = new BrowserWindow({
|
||||
width: 1400,
|
||||
height: 900,
|
||||
title: 'Eigent Browser - Login',
|
||||
webPreferences: {
|
||||
nodeIntegration: true,
|
||||
contextIsolation: false,
|
||||
webviewTag: true
|
||||
}
|
||||
});
|
||||
|
||||
// Create navigation bar and webview
|
||||
const html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
#nav-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 8px;
|
||||
background: #f5f5f5;
|
||||
border-bottom: 1px solid #ddd;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 6px 12px;
|
||||
border: 1px solid #ccc;
|
||||
background: white;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
button:hover:not(:disabled) {
|
||||
background: #f0f0f0;
|
||||
}
|
||||
|
||||
button:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
#url-input {
|
||||
flex: 1;
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
#url-input:focus {
|
||||
outline: none;
|
||||
border-color: #4285f4;
|
||||
}
|
||||
|
||||
#webview {
|
||||
flex: 1;
|
||||
width: 100%;
|
||||
border: none;
|
||||
}
|
||||
|
||||
.nav-icon {
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
#loading-indicator {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border: 2px solid #f3f3f3;
|
||||
border-top: 2px solid #4285f4;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
display: none;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.loading #loading-indicator {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.loading #reload-btn .nav-icon {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="nav-bar">
|
||||
<button id="back-btn" title="Back">
|
||||
<span class="nav-icon">←</span>
|
||||
</button>
|
||||
<button id="forward-btn" title="Forward">
|
||||
<span class="nav-icon">→</span>
|
||||
</button>
|
||||
<button id="reload-btn" title="Reload">
|
||||
<span class="nav-icon">↻</span>
|
||||
<div id="loading-indicator"></div>
|
||||
</button>
|
||||
<button id="home-btn" title="Home">
|
||||
<span class="nav-icon">🏠</span>
|
||||
</button>
|
||||
<input type="text" id="url-input" placeholder="Enter URL..." />
|
||||
<button id="go-btn">Go</button>
|
||||
<button id="linkedin-btn" style="background: #0077B5; color: white; border-color: #0077B5;">
|
||||
LinkedIn
|
||||
</button>
|
||||
<button id="info-btn" title="Show Info">ℹ️</button>
|
||||
</div>
|
||||
|
||||
<webview id="webview" src="${startUrl}" partition="persist:user_login"></webview>
|
||||
|
||||
<div id="info-panel" style="display: none; position: absolute; top: 50px; right: 10px; background: white; border: 1px solid #ccc; padding: 15px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); z-index: 1000; max-width: 400px; font-size: 12px;">
|
||||
<h4 style="margin: 0 0 10px 0;">Browser Info</h4>
|
||||
<div id="info-content"></div>
|
||||
<button onclick="document.getElementById('info-panel').style.display='none'" style="margin-top: 10px;">Close</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const webview = document.getElementById('webview');
|
||||
const backBtn = document.getElementById('back-btn');
|
||||
const forwardBtn = document.getElementById('forward-btn');
|
||||
const reloadBtn = document.getElementById('reload-btn');
|
||||
const homeBtn = document.getElementById('home-btn');
|
||||
const urlInput = document.getElementById('url-input');
|
||||
const goBtn = document.getElementById('go-btn');
|
||||
const linkedinBtn = document.getElementById('linkedin-btn');
|
||||
const navBar = document.getElementById('nav-bar');
|
||||
const infoBtn = document.getElementById('info-btn');
|
||||
const infoPanel = document.getElementById('info-panel');
|
||||
const infoContent = document.getElementById('info-content');
|
||||
|
||||
// Show info panel
|
||||
infoBtn.addEventListener('click', () => {
|
||||
const { ipcRenderer } = require('electron');
|
||||
|
||||
// Get browser info
|
||||
const info = {
|
||||
'Chrome Version': process.versions.chrome,
|
||||
'Electron Version': process.versions.electron,
|
||||
'Node Version': process.versions.node,
|
||||
'User Data Dir (requested)': '${userDataDir}',
|
||||
'CDP Port': '${cdpPort}',
|
||||
'Platform': process.platform,
|
||||
'Architecture': process.arch
|
||||
};
|
||||
|
||||
// Also check webview partition info
|
||||
const partition = webview.partition || 'default';
|
||||
info['WebView Partition'] = partition;
|
||||
|
||||
// Format info as HTML
|
||||
let html = '<table style="width: 100%; border-collapse: collapse;">';
|
||||
for (const [key, value] of Object.entries(info)) {
|
||||
html += '<tr><td style="padding: 4px; border-bottom: 1px solid #eee;"><strong>' + key + ':</strong></td><td style="padding: 4px; border-bottom: 1px solid #eee; word-break: break-all;">' + value + '</td></tr>';
|
||||
}
|
||||
html += '</table>';
|
||||
|
||||
infoContent.innerHTML = html;
|
||||
infoPanel.style.display = 'block';
|
||||
});
|
||||
|
||||
// Update navigation buttons
|
||||
function updateNavButtons() {
|
||||
backBtn.disabled = !webview.canGoBack();
|
||||
forwardBtn.disabled = !webview.canGoForward();
|
||||
}
|
||||
|
||||
// Navigate to URL
|
||||
function navigateToUrl() {
|
||||
let url = urlInput.value.trim();
|
||||
if (!url) return;
|
||||
|
||||
if (!url.startsWith('http://') && !url.startsWith('https://')) {
|
||||
url = 'https://' + url;
|
||||
}
|
||||
|
||||
webview.loadURL(url);
|
||||
}
|
||||
|
||||
// Event listeners
|
||||
backBtn.addEventListener('click', () => webview.goBack());
|
||||
forwardBtn.addEventListener('click', () => webview.goForward());
|
||||
reloadBtn.addEventListener('click', () => webview.reload());
|
||||
homeBtn.addEventListener('click', () => webview.loadURL('${startUrl}'));
|
||||
goBtn.addEventListener('click', navigateToUrl);
|
||||
linkedinBtn.addEventListener('click', () => webview.loadURL('https://www.linkedin.com'));
|
||||
|
||||
urlInput.addEventListener('keypress', (e) => {
|
||||
if (e.key === 'Enter') {
|
||||
navigateToUrl();
|
||||
}
|
||||
});
|
||||
|
||||
// WebView events
|
||||
webview.addEventListener('did-start-loading', () => {
|
||||
navBar.classList.add('loading');
|
||||
});
|
||||
|
||||
webview.addEventListener('did-stop-loading', () => {
|
||||
navBar.classList.remove('loading');
|
||||
updateNavButtons();
|
||||
});
|
||||
|
||||
webview.addEventListener('did-navigate', (e) => {
|
||||
urlInput.value = e.url;
|
||||
updateNavButtons();
|
||||
});
|
||||
|
||||
webview.addEventListener('did-navigate-in-page', (e) => {
|
||||
urlInput.value = e.url;
|
||||
updateNavButtons();
|
||||
});
|
||||
|
||||
webview.addEventListener('new-window', (e) => {
|
||||
// Open new windows in the same webview
|
||||
e.preventDefault();
|
||||
webview.loadURL(e.url);
|
||||
});
|
||||
|
||||
// Initialize
|
||||
updateNavButtons();
|
||||
</script>
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
win.loadURL('data:text/html;charset=UTF-8,' + encodeURIComponent(html));
|
||||
|
||||
// Show window when ready
|
||||
win.once('ready-to-show', () => {
|
||||
win.show();
|
||||
|
||||
// Log cookies periodically to track changes
|
||||
setInterval(async () => {
|
||||
try {
|
||||
const cookies = await userLoginSession.cookies.get({});
|
||||
console.log('[ELECTRON BROWSER] Current cookies count:', cookies.length);
|
||||
if (cookies.length > 0) {
|
||||
console.log('[ELECTRON BROWSER] Cookie domains:', [...new Set(cookies.map(c => c.domain))]);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[ELECTRON BROWSER] Failed to get cookies:', error);
|
||||
}
|
||||
}, 5000); // Check every 5 seconds
|
||||
});
|
||||
|
||||
win.on('closed', async () => {
|
||||
console.log('[ELECTRON BROWSER] Window closed, preparing to quit...');
|
||||
|
||||
// Flush storage data before quitting to ensure cookies are saved
|
||||
try {
|
||||
const { session } = require('electron');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const userLoginSession = session.fromPartition('persist:user_login');
|
||||
|
||||
// Log cookies before flush
|
||||
const cookiesBeforeFlush = await userLoginSession.cookies.get({});
|
||||
console.log('[ELECTRON BROWSER] Cookies count before flush:', cookiesBeforeFlush.length);
|
||||
|
||||
// Flush storage
|
||||
console.log('[ELECTRON BROWSER] Flushing storage data...');
|
||||
await userLoginSession.flushStorageData();
|
||||
console.log('[ELECTRON BROWSER] Storage data flushed successfully');
|
||||
|
||||
// Check cookies file after flush
|
||||
const cookiesPath = path.join(app.getPath('userData'), 'Partitions', 'user_login', 'Cookies');
|
||||
if (fs.existsSync(cookiesPath)) {
|
||||
const stats = fs.statSync(cookiesPath);
|
||||
console.log('[ELECTRON BROWSER] Cookies file size after flush:', stats.size, 'bytes');
|
||||
} else {
|
||||
console.log('[ELECTRON BROWSER] WARNING: Cookies file does not exist after flush!');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[ELECTRON BROWSER] Failed to flush storage data:', error);
|
||||
}
|
||||
app.quit();
|
||||
});
|
||||
});
|
||||
|
||||
let isQuitting = false;
|
||||
|
||||
app.on('before-quit', async (event) => {
|
||||
if (isQuitting) return;
|
||||
|
||||
// Prevent immediate quit to allow storage flush and cookie sync
|
||||
event.preventDefault();
|
||||
isQuitting = true;
|
||||
|
||||
console.log('[ELECTRON BROWSER] before-quit event triggered');
|
||||
|
||||
try {
|
||||
const { session } = require('electron');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const userLoginSession = session.fromPartition('persist:user_login');
|
||||
|
||||
// Log cookies before flush
|
||||
const cookiesBeforeQuit = await userLoginSession.cookies.get({});
|
||||
console.log('[ELECTRON BROWSER] Cookies count before quit:', cookiesBeforeQuit.length);
|
||||
if (cookiesBeforeQuit.length > 0) {
|
||||
console.log('[ELECTRON BROWSER] Cookie domains before quit:', [...new Set(cookiesBeforeQuit.map(c => c.domain))]);
|
||||
}
|
||||
|
||||
// Flush storage
|
||||
console.log('[ELECTRON BROWSER] Flushing storage on quit...');
|
||||
await userLoginSession.flushStorageData();
|
||||
console.log('[ELECTRON BROWSER] Storage data flushed on quit');
|
||||
} catch (error) {
|
||||
console.error('[ELECTRON BROWSER] Failed to sync cookies:', error);
|
||||
} finally {
|
||||
console.log('[ELECTRON BROWSER] Exiting now...');
|
||||
// Force quit after sync
|
||||
app.exit(0);
|
||||
}
|
||||
});
|
||||
|
||||
app.on('window-all-closed', () => {
|
||||
if (!isQuitting) {
|
||||
app.quit();
|
||||
}
|
||||
});
|
||||
'''
|
||||
|
||||
# Write the Electron script
|
||||
with open(electron_script_path, 'w') as f:
|
||||
f.write(electron_script_content)
|
||||
|
||||
# Find Electron executable
|
||||
# Try to use the same Electron version as the main app
|
||||
electron_cmd = "npx"
|
||||
electron_args = [
|
||||
electron_cmd,
|
||||
"electron",
|
||||
electron_script_path,
|
||||
user_data_dir,
|
||||
str(cdp_port),
|
||||
"https://www.google.com"
|
||||
]
|
||||
|
||||
# Get the app's directory to run npx in the right context
|
||||
app_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
logger.info(f"[PROFILE USER LOGIN] Launching Electron browser with CDP on port {cdp_port}")
|
||||
logger.info(f"[PROFILE USER LOGIN] Working directory: {app_dir}")
|
||||
logger.info(f"[PROFILE USER LOGIN] userData path: {user_data_dir}")
|
||||
logger.info(f"[PROFILE USER LOGIN] Electron args: {electron_args}")
|
||||
|
||||
# Start process and capture output in real-time
|
||||
process = subprocess.Popen(
|
||||
electron_args,
|
||||
cwd=app_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Redirect stderr to stdout
|
||||
universal_newlines=True,
|
||||
bufsize=1 # Line buffered
|
||||
)
|
||||
|
||||
# Create async task to log Electron output
|
||||
async def log_electron_output():
|
||||
for line in iter(process.stdout.readline, ''):
|
||||
if line:
|
||||
logger.info(f"[ELECTRON OUTPUT] {line.strip()}")
|
||||
|
||||
import asyncio
|
||||
asyncio.create_task(log_electron_output())
|
||||
|
||||
# Wait a bit for Electron to start
|
||||
import asyncio
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Clean up the script file after a delay
|
||||
async def cleanup_script():
|
||||
await asyncio.sleep(10)
|
||||
try:
|
||||
os.remove(electron_script_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
asyncio.create_task(cleanup_script())
|
||||
|
||||
logger.info(f"[PROFILE USER LOGIN] Electron browser launched with PID {process.pid}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"user_data_dir": user_data_dir,
|
||||
"cdp_port": cdp_port,
|
||||
"pid": process.pid,
|
||||
"chrome_version": "130.0.6723.191", # Electron 33's Chrome version
|
||||
"message": "Electron browser opened successfully. Please log in to your accounts.",
|
||||
"note": "The browser will remain open for you to log in. Your login data will be saved in the profile."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open Electron browser for login: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to open browser: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/browser/cookies", name="list cookie domains")
|
||||
async def list_cookie_domains(search: str = None):
|
||||
"""
|
||||
list cookie domains
|
||||
|
||||
Args:
|
||||
search: url
|
||||
|
||||
Returns:
|
||||
list of cookie domains
|
||||
"""
|
||||
try:
|
||||
# Use tool_controller browser's user data directory (source of truth)
|
||||
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(user_data_base, "profile_user_login")
|
||||
|
||||
logger.info(f"[COOKIES CHECK] Tool controller user_data_dir: {user_data_dir}")
|
||||
logger.info(f"[COOKIES CHECK] Tool controller user_data_dir exists: {os.path.exists(user_data_dir)}")
|
||||
|
||||
# Check partition path
|
||||
partition_path = os.path.join(user_data_dir, "Partitions", "user_login")
|
||||
logger.info(f"[COOKIES CHECK] partition path: {partition_path}")
|
||||
logger.info(f"[COOKIES CHECK] partition exists: {os.path.exists(partition_path)}")
|
||||
|
||||
# Check cookies file
|
||||
cookies_file = os.path.join(partition_path, "Cookies")
|
||||
logger.info(f"[COOKIES CHECK] cookies file: {cookies_file}")
|
||||
logger.info(f"[COOKIES CHECK] cookies file exists: {os.path.exists(cookies_file)}")
|
||||
if os.path.exists(cookies_file):
|
||||
stat = os.stat(cookies_file)
|
||||
logger.info(f"[COOKIES CHECK] cookies file size: {stat.st_size} bytes")
|
||||
|
||||
# Try to read actual cookie count
|
||||
try:
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(cookies_file)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM cookies")
|
||||
count = cursor.fetchone()[0]
|
||||
logger.info(f"[COOKIES CHECK] actual cookie count in database: {count}")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[COOKIES CHECK] failed to read cookie count: {e}")
|
||||
|
||||
if not os.path.exists(user_data_dir):
|
||||
return {
|
||||
"success": True,
|
||||
"domains": [],
|
||||
"message": "No browser profile found. Please login first using /browser/login."
|
||||
}
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
|
||||
if search:
|
||||
domains = cookie_manager.search_cookies(search)
|
||||
else:
|
||||
domains = cookie_manager.get_cookie_domains()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"domains": domains,
|
||||
"total": len(domains),
|
||||
"user_data_dir": user_data_dir
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list cookie domains: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to list cookies: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/browser/cookies/{domain}", name="get domain cookies")
|
||||
async def get_domain_cookies(domain: str):
|
||||
"""
|
||||
get domain cookies
|
||||
|
||||
Args:
|
||||
domain
|
||||
|
||||
Returns:
|
||||
cookies
|
||||
"""
|
||||
try:
|
||||
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(user_data_base, "profile_user_login")
|
||||
|
||||
if not os.path.exists(user_data_dir):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No browser profile found. Please login first using /browser/login."
|
||||
)
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
cookies = cookie_manager.get_cookies_for_domain(domain)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"domain": domain,
|
||||
"cookies": cookies,
|
||||
"count": len(cookies)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cookies for domain {domain}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get cookies: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/browser/cookies/{domain}", name="delete domain cookies")
|
||||
async def delete_domain_cookies(domain: str):
|
||||
"""
|
||||
Delete cookies
|
||||
|
||||
Args:
|
||||
domain
|
||||
|
||||
Returns:
|
||||
deleted cookies
|
||||
"""
|
||||
try:
|
||||
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(user_data_base, "profile_user_login")
|
||||
|
||||
if not os.path.exists(user_data_dir):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No browser profile found. Please login first using /browser/login."
|
||||
)
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
success = cookie_manager.delete_cookies_for_domain(domain)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully deleted cookies for domain: {domain}"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete cookies for domain: {domain}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete cookies for domain {domain}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete cookies: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/browser/cookies", name="delete all cookies")
|
||||
async def delete_all_cookies():
|
||||
"""
|
||||
delete all cookies
|
||||
|
||||
Returns:
|
||||
deleted cookies
|
||||
"""
|
||||
try:
|
||||
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(user_data_base, "profile_user_login")
|
||||
|
||||
if not os.path.exists(user_data_dir):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No browser profile found."
|
||||
)
|
||||
|
||||
cookie_manager = CookieManager(user_data_dir)
|
||||
success = cookie_manager.delete_all_cookies()
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Successfully deleted all cookies"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete all cookies"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete all cookies: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete cookies: {str(e)}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,18 +3,22 @@ from fastapi import Request
|
|||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from app import api
|
||||
from app.component import code
|
||||
from app.exception.exception import NoPermissionException, ProgramException, TokenException
|
||||
from app.component.pydantic.i18n import trans, get_language
|
||||
from app.exception.exception import UserException
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("exception_handler")
|
||||
|
||||
|
||||
@api.exception_handler(RequestValidationError)
|
||||
async def request_exception(request: Request, e: RequestValidationError):
|
||||
if (lang := get_language(request.headers.get("Accept-Language"))) is None:
|
||||
lang = "en_US"
|
||||
logger.warning(f"Validation error on {request.url.path}: {e.errors()}")
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"code": code.form_error,
|
||||
|
|
@ -25,16 +29,19 @@ async def request_exception(request: Request, e: RequestValidationError):
|
|||
|
||||
@api.exception_handler(TokenException)
|
||||
async def token_exception(request: Request, e: TokenException):
|
||||
logger.warning(f"Token exception on {request.url.path}: {e.text}")
|
||||
return JSONResponse(content={"code": e.code, "text": e.text})
|
||||
|
||||
|
||||
@api.exception_handler(UserException)
|
||||
async def user_exception(request: Request, e: UserException):
|
||||
logger.info(f"User exception on {request.url.path}: {e.description}")
|
||||
return JSONResponse(content={"code": e.code, "text": e.description})
|
||||
|
||||
|
||||
@api.exception_handler(NoPermissionException)
|
||||
async def no_permission(request: Request, exception: NoPermissionException):
|
||||
logger.warning(f"No permission on {request.url.path}: {exception.text}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"code": code.no_permission_error, "text": exception.text},
|
||||
|
|
@ -43,6 +50,7 @@ async def no_permission(request: Request, exception: NoPermissionException):
|
|||
|
||||
@api.exception_handler(ProgramException)
|
||||
async def program_exception(request: Request, exception: NoPermissionException):
|
||||
logger.error(f"Program exception on {request.url.path}: {exception.text}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"code": code.program_error, "text": exception.text},
|
||||
|
|
@ -51,8 +59,16 @@ async def program_exception(request: Request, exception: NoPermissionException):
|
|||
|
||||
@api.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled error: {exc}")
|
||||
traceback.print_exc() # output to electron log
|
||||
logger.error(
|
||||
f"Unhandled exception on {request.method} {request.url.path}: {exc}",
|
||||
exc_info=True,
|
||||
extra={
|
||||
"request_method": request.method,
|
||||
"request_path": str(request.url.path),
|
||||
"request_query": str(request.url.query),
|
||||
"client_host": request.client.host if request.client else None,
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ import json
|
|||
from pathlib import Path
|
||||
import re
|
||||
from typing import Literal
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from camel.types import ModelType, RoleType
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("chat_model")
|
||||
|
||||
|
||||
class Status(str, Enum):
|
||||
|
|
@ -20,11 +22,22 @@ class ChatHistory(BaseModel):
|
|||
content: str
|
||||
|
||||
|
||||
class QuestionAnalysisResult(BaseModel):
|
||||
type: Literal["simple", "complex"] = Field(
|
||||
description="Whether this is a simple question or complex task"
|
||||
)
|
||||
answer: str | None = Field(
|
||||
default=None,
|
||||
description="Direct answer for simple questions. None for complex tasks."
|
||||
)
|
||||
|
||||
|
||||
McpServers = dict[Literal["mcpServers"], dict[str, dict]]
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
task_id: str
|
||||
project_id: str
|
||||
question: str
|
||||
email: str
|
||||
attaches: list[str] = []
|
||||
|
|
@ -50,6 +63,7 @@ class Chat(BaseModel):
|
|||
)
|
||||
new_agents: list["NewAgent"] = []
|
||||
extra_params: dict | None = None # For provider-specific parameters like Azure
|
||||
search_config: dict[str, str] | None = None # User-specific search engine configurations (e.g., GOOGLE_API_KEY, SEARCH_ENGINE_ID)
|
||||
|
||||
@field_validator("model_type")
|
||||
@classmethod
|
||||
|
|
@ -72,7 +86,8 @@ class Chat(BaseModel):
|
|||
|
||||
def file_save_path(self, path: str | None = None):
|
||||
email = re.sub(r'[\\/*?:"<>|\s]', "_", self.email.split("@")[0]).strip(".")
|
||||
save_path = Path.home() / "eigent" / email / ("task_" + self.task_id)
|
||||
# Use project-based structure: project_{project_id}/task_{task_id}
|
||||
save_path = Path.home() / "eigent" / email / f"project_{self.project_id}" / f"task_{self.task_id}"
|
||||
if path is not None:
|
||||
save_path = save_path / path
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -82,6 +97,7 @@ class Chat(BaseModel):
|
|||
|
||||
class SupplementChat(BaseModel):
|
||||
question: str
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
class HumanReply(BaseModel):
|
||||
|
|
@ -106,6 +122,18 @@ class NewAgent(BaseModel):
|
|||
env_path: str | None = None
|
||||
|
||||
|
||||
class AddTaskRequest(BaseModel):
|
||||
content: str
|
||||
project_id: str | None = None
|
||||
task_id: str | None = None
|
||||
additional_info: dict | None = None
|
||||
insert_position: int = -1
|
||||
is_independent: bool = False
|
||||
|
||||
|
||||
class RemoveTaskRequest(BaseModel):
|
||||
task_id: str
|
||||
|
||||
def sse_json(step: str, data):
|
||||
res_format = {"step": step, "data": data}
|
||||
return f"data: {json.dumps(res_format, ensure_ascii=False)}\n\n"
|
||||
|
|
|
|||
64
backend/app/router.py
Normal file
64
backend/app/router.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""
|
||||
Centralized router registration for the Eigent API.
|
||||
All routers are explicitly registered here for better visibility and maintainability.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
from app.controller import chat_controller, model_controller, task_controller, tool_controller, health_controller
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("router")
|
||||
|
||||
|
||||
def register_routers(app: FastAPI, prefix: str = "") -> None:
|
||||
"""
|
||||
Register all API routers with their respective prefixes and tags.
|
||||
|
||||
This replaces the auto-discovery mechanism for better:
|
||||
- Visibility: See all routes in one place
|
||||
- Maintainability: Easy to add/remove routes
|
||||
- Debugging: Clear registration order and configuration
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
prefix: Optional global prefix for all routes (e.g., "/api")
|
||||
"""
|
||||
routers_config = [
|
||||
{
|
||||
"router": health_controller.router,
|
||||
"tags": ["Health"],
|
||||
"description": "Health check endpoint for service readiness"
|
||||
},
|
||||
{
|
||||
"router": chat_controller.router,
|
||||
"tags": ["chat"],
|
||||
"description": "Chat session management, improvements, and human interactions"
|
||||
},
|
||||
{
|
||||
"router": model_controller.router,
|
||||
"tags": ["model"],
|
||||
"description": "Model validation and configuration"
|
||||
},
|
||||
{
|
||||
"router": task_controller.router,
|
||||
"tags": ["task"],
|
||||
"description": "Task lifecycle management (start, stop, update, control)"
|
||||
},
|
||||
{
|
||||
"router": tool_controller.router,
|
||||
"tags": ["tool"],
|
||||
"description": "Tool installation and management"
|
||||
},
|
||||
]
|
||||
|
||||
for config in routers_config:
|
||||
app.include_router(
|
||||
config["router"],
|
||||
prefix=prefix,
|
||||
tags=config["tags"]
|
||||
)
|
||||
route_count = len(config["router"].routes)
|
||||
logger.info(
|
||||
f"Registered {config['tags'][0]} router: {route_count} routes - {config['description']}"
|
||||
)
|
||||
|
||||
logger.info(f"Total routers registered: {len(routers_config)}")
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
import platform
|
||||
from typing import Literal
|
||||
|
|
@ -8,6 +9,7 @@ from inflection import titleize
|
|||
from pydash import chain
|
||||
from app.component.debug import dump_class
|
||||
from app.component.environment import env
|
||||
from app.utils.file_utils import get_working_directory
|
||||
from app.service.task import (
|
||||
ActionImproveData,
|
||||
ActionInstallMcpData,
|
||||
|
|
@ -19,7 +21,6 @@ from camel.toolkits import AgentCommunicationToolkit, ToolkitMessageIntegration
|
|||
from app.utils.toolkit.human_toolkit import HumanToolkit
|
||||
from app.utils.toolkit.note_taking_toolkit import NoteTakingToolkit
|
||||
from app.utils.workforce import Workforce
|
||||
from loguru import logger
|
||||
from app.model.chat import Chat, NewAgent, Status, sse_json, TaskContent
|
||||
from camel.tasks import Task
|
||||
from app.utils.agent import (
|
||||
|
|
@ -40,9 +41,197 @@ from app.service.task import Action, Agents
|
|||
from app.utils.server.sync_step import sync_step
|
||||
from camel.types import ModelPlatformType
|
||||
from camel.models import ModelProcessingError
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
import os
|
||||
|
||||
logger = traceroot.get_logger("chat_service")
|
||||
|
||||
|
||||
def format_task_context(task_data: dict, seen_files: set | None = None, skip_files: bool = False) -> str:
|
||||
"""Format structured task data into a readable context string.
|
||||
|
||||
Args:
|
||||
task_data: Dictionary containing task content, result, and working directory
|
||||
seen_files: Optional set to track already-listed files and avoid duplicates (deprecated, use skip_files instead)
|
||||
skip_files: If True, skip the file listing entirely
|
||||
"""
|
||||
context_parts = []
|
||||
|
||||
if task_data.get('task_content'):
|
||||
context_parts.append(f"Previous Task: {task_data['task_content']}")
|
||||
|
||||
if task_data.get('task_result'):
|
||||
context_parts.append(f"Previous Task Result: {task_data['task_result']}")
|
||||
|
||||
# Skip file listing if requested
|
||||
if not skip_files:
|
||||
working_directory = task_data.get('working_directory')
|
||||
if working_directory:
|
||||
try:
|
||||
if os.path.exists(working_directory):
|
||||
generated_files = []
|
||||
for root, dirs, files in os.walk(working_directory):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']]
|
||||
for file in files:
|
||||
if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')):
|
||||
file_path = os.path.join(root, file)
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
|
||||
# Only add if not seen before (or if we're not tracking seen files)
|
||||
if seen_files is None or absolute_path not in seen_files:
|
||||
generated_files.append(absolute_path)
|
||||
if seen_files is not None:
|
||||
seen_files.add(absolute_path)
|
||||
|
||||
if generated_files:
|
||||
context_parts.append("Generated Files from Previous Task:")
|
||||
for file_path in sorted(generated_files):
|
||||
context_parts.append(f" - {file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to collect generated files: {e}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def collect_previous_task_context(working_directory: str, previous_task_content: str, previous_task_result: str, previous_summary: str = "") -> str:
|
||||
"""
|
||||
Collect context from previous task including content, result, summary, and generated files.
|
||||
|
||||
Args:
|
||||
working_directory: The working directory to scan for generated files
|
||||
previous_task_content: The content of the previous task
|
||||
previous_task_result: The result/output of the previous task
|
||||
previous_summary: The summary of the previous task
|
||||
|
||||
Returns:
|
||||
Formatted context string to prepend to new task
|
||||
"""
|
||||
|
||||
context_parts = []
|
||||
|
||||
# Add previous task information
|
||||
context_parts.append("=== CONTEXT FROM PREVIOUS TASK ===\n")
|
||||
|
||||
# Add previous task content
|
||||
if previous_task_content:
|
||||
context_parts.append(f"Previous Task:\n{previous_task_content}\n")
|
||||
|
||||
# Add previous task summary
|
||||
if previous_summary:
|
||||
context_parts.append(f"Previous Task Summary:\n{previous_summary}\n")
|
||||
|
||||
# Add previous task result
|
||||
if previous_task_result:
|
||||
context_parts.append(f"Previous Task Result:\n{previous_task_result}\n")
|
||||
|
||||
# Collect generated files from working directory
|
||||
try:
|
||||
if os.path.exists(working_directory):
|
||||
generated_files = []
|
||||
for root, dirs, files in os.walk(working_directory):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']]
|
||||
for file in files:
|
||||
if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')):
|
||||
file_path = os.path.join(root, file)
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
generated_files.append(absolute_path)
|
||||
|
||||
if generated_files:
|
||||
context_parts.append("Generated Files from Previous Task:")
|
||||
for file_path in sorted(generated_files):
|
||||
context_parts.append(f" - {file_path}")
|
||||
context_parts.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to collect generated files: {e}")
|
||||
|
||||
context_parts.append("=== END OF PREVIOUS TASK CONTEXT ===\n")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def check_conversation_history_length(task_lock: TaskLock, max_length: int = 100000) -> tuple[bool, int]:
|
||||
"""
|
||||
Check if conversation history exceeds maximum length
|
||||
|
||||
Returns:
|
||||
tuple: (is_exceeded, total_length)
|
||||
"""
|
||||
if not hasattr(task_lock, 'conversation_history') or not task_lock.conversation_history:
|
||||
return False, 0
|
||||
|
||||
total_length = 0
|
||||
for entry in task_lock.conversation_history:
|
||||
total_length += len(entry.get('content', ''))
|
||||
|
||||
is_exceeded = total_length > max_length
|
||||
|
||||
if is_exceeded:
|
||||
logger.warning(f"Conversation history length {total_length} exceeds maximum {max_length}")
|
||||
|
||||
return is_exceeded, total_length
|
||||
|
||||
|
||||
def build_conversation_context(task_lock: TaskLock, header: str = "=== CONVERSATION HISTORY ===") -> str:
|
||||
"""Build conversation context from task_lock history with files listed only once at the end.
|
||||
|
||||
Args:
|
||||
task_lock: TaskLock containing conversation history
|
||||
header: Header text for the context section
|
||||
|
||||
Returns:
|
||||
Formatted context string with task history and files listed once at the end
|
||||
"""
|
||||
context = ""
|
||||
working_directories = set() # Collect all unique working directories
|
||||
|
||||
if task_lock.conversation_history:
|
||||
context = f"{header}\n"
|
||||
|
||||
for entry in task_lock.conversation_history:
|
||||
if entry['role'] == 'task_result':
|
||||
if isinstance(entry['content'], dict):
|
||||
formatted_context = format_task_context(entry['content'], skip_files=True)
|
||||
context += formatted_context + "\n\n"
|
||||
if entry['content'].get('working_directory'):
|
||||
working_directories.add(entry['content']['working_directory'])
|
||||
else:
|
||||
context += entry['content'] + "\n"
|
||||
elif entry['role'] == 'assistant':
|
||||
context += f"Assistant: {entry['content']}\n\n"
|
||||
|
||||
if working_directories:
|
||||
all_generated_files = set() # Use set to avoid duplicates
|
||||
for working_directory in working_directories:
|
||||
try:
|
||||
if os.path.exists(working_directory):
|
||||
for root, dirs, files in os.walk(working_directory):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']]
|
||||
for file in files:
|
||||
if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')):
|
||||
file_path = os.path.join(root, file)
|
||||
absolute_path = os.path.abspath(file_path)
|
||||
all_generated_files.add(absolute_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to collect generated files from {working_directory}: {e}")
|
||||
|
||||
if all_generated_files:
|
||||
context += "Generated Files from Previous Tasks:\n"
|
||||
for file_path in sorted(all_generated_files):
|
||||
context += f" - {file_path}\n"
|
||||
context += "\n"
|
||||
|
||||
context += "\n"
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def build_context_for_workforce(task_lock: TaskLock, options: Chat) -> str:
|
||||
"""Build context information for workforce."""
|
||||
return build_conversation_context(task_lock, header="=== CONVERSATION HISTORY ===")
|
||||
|
||||
|
||||
@sync_step
|
||||
@traceroot.trace()
|
||||
async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
||||
# if True:
|
||||
# import faulthandler
|
||||
|
|
@ -52,12 +241,40 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
# faulthandler.dump_traceback_later(second)
|
||||
|
||||
start_event_loop = True
|
||||
question_agent = question_confirm_agent(options)
|
||||
|
||||
if not hasattr(task_lock, 'conversation_history'):
|
||||
task_lock.conversation_history = []
|
||||
if not hasattr(task_lock, 'last_task_result'):
|
||||
task_lock.last_task_result = ""
|
||||
if not hasattr(task_lock, 'question_agent'):
|
||||
task_lock.question_agent = None
|
||||
if not hasattr(task_lock, 'summary_generated'):
|
||||
task_lock.summary_generated = False
|
||||
|
||||
# Create or reuse persistent question_agent
|
||||
if task_lock.question_agent is None:
|
||||
task_lock.question_agent = question_confirm_agent(options)
|
||||
logger.info(f"Created new persistent question_agent for project {options.project_id}")
|
||||
else:
|
||||
logger.info(f"Reusing existing question_agent with {len(task_lock.conversation_history)} history entries")
|
||||
|
||||
question_agent = task_lock.question_agent
|
||||
|
||||
# Other variables
|
||||
camel_task = None
|
||||
workforce = None
|
||||
last_completed_task_result = "" # Track the last completed task result
|
||||
summary_task_content = "" # Track task summary
|
||||
loop_iteration = 0
|
||||
|
||||
logger.info("Starting step_solve", extra={"project_id": options.project_id, "task_id": options.task_id})
|
||||
logger.debug("Step solve options", extra={"task_id": options.task_id, "model_platform": options.model_platform})
|
||||
|
||||
while True:
|
||||
loop_iteration += 1
|
||||
|
||||
if await request.is_disconnected():
|
||||
logger.warning(f"Client disconnected for task {options.task_id}")
|
||||
logger.warning(f"Client disconnected for project {options.project_id}")
|
||||
if workforce is not None:
|
||||
if workforce._running:
|
||||
workforce.stop()
|
||||
|
|
@ -70,10 +287,10 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
break
|
||||
try:
|
||||
item = await task_lock.get_queue()
|
||||
# logger.info(f"item: {dump_class(item)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting item from queue: {e}")
|
||||
break
|
||||
logger.error("Error getting item from queue", extra={"project_id": options.project_id, "task_id": options.task_id, "error": str(e)}, exc_info=True)
|
||||
# Continue waiting instead of breaking on queue error
|
||||
continue
|
||||
|
||||
try:
|
||||
if item.action == Action.improve or start_event_loop:
|
||||
|
|
@ -87,33 +304,116 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
else:
|
||||
assert isinstance(item, ActionImproveData)
|
||||
question = item.data
|
||||
if len(question) < 12 and len(options.attaches) == 0:
|
||||
confirm = await question_confirm(question_agent, question)
|
||||
else:
|
||||
confirm = True
|
||||
|
||||
if confirm is not True:
|
||||
yield confirm
|
||||
is_exceeded, total_length = check_conversation_history_length(task_lock)
|
||||
if is_exceeded:
|
||||
logger.error("Conversation history too long", extra={"project_id": options.project_id, "current_length": total_length, "max_length": 100000})
|
||||
yield sse_json("context_too_long", {
|
||||
"message": "The conversation history is too long. Please create a new project to continue.",
|
||||
"current_length": total_length,
|
||||
"max_length": 100000
|
||||
})
|
||||
continue
|
||||
|
||||
# Simplified logic: attachments mean workforce, otherwise let agent decide
|
||||
is_complex_task: bool
|
||||
if len(options.attaches) > 0:
|
||||
# Questions with attachments always need workforce
|
||||
is_complex_task = True
|
||||
else:
|
||||
yield sse_json("confirmed", "")
|
||||
is_complex_task = await question_confirm(question_agent, question, task_lock)
|
||||
|
||||
if not is_complex_task:
|
||||
simple_answer_prompt = f"{build_conversation_context(task_lock, header='=== Previous Conversation ===')}User Query: {question}\n\nProvide a direct, helpful answer to this simple question."
|
||||
|
||||
try:
|
||||
simple_resp = question_agent.step(simple_answer_prompt)
|
||||
answer_content = simple_resp.msgs[0].content if simple_resp and simple_resp.msgs else "I understand your question, but I'm having trouble generating a response right now."
|
||||
|
||||
task_lock.add_conversation('assistant', answer_content)
|
||||
|
||||
yield sse_json("wait_confirm", {"content": answer_content, "question": question})
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating simple answer: {e}")
|
||||
yield sse_json("wait_confirm", {"content": "I encountered an error while processing your question.", "question": question})
|
||||
|
||||
# Clean up empty folder if it was created for this task
|
||||
if hasattr(task_lock, 'new_folder_path') and task_lock.new_folder_path:
|
||||
try:
|
||||
folder_path = Path(task_lock.new_folder_path)
|
||||
if folder_path.exists() and folder_path.is_dir():
|
||||
# Check if folder is empty
|
||||
if not any(folder_path.iterdir()):
|
||||
folder_path.rmdir()
|
||||
logger.info(f"Cleaned up empty folder: {folder_path}")
|
||||
# Also clean up parent project folder if it becomes empty
|
||||
project_folder = folder_path.parent
|
||||
if project_folder.exists() and not any(project_folder.iterdir()):
|
||||
project_folder.rmdir()
|
||||
logger.info(f"Cleaned up empty project folder: {project_folder}")
|
||||
else:
|
||||
logger.info(f"Folder not empty, keeping: {folder_path}")
|
||||
# Reset the folder path
|
||||
task_lock.new_folder_path = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up folder: {e}")
|
||||
else:
|
||||
yield sse_json("confirmed", {"question": question})
|
||||
|
||||
context_for_coordinator = build_context_for_workforce(task_lock, options)
|
||||
|
||||
(workforce, mcp) = await construct_workforce(options)
|
||||
for new_agent in options.new_agents:
|
||||
workforce.add_single_agent_worker(
|
||||
format_agent_description(new_agent), await new_agent_model(new_agent, options)
|
||||
)
|
||||
summary_task_agent = task_summary_agent(options)
|
||||
task_lock.status = Status.confirmed
|
||||
question = question + options.summary_prompt
|
||||
camel_task = Task(content=question, id=options.task_id)
|
||||
|
||||
clean_task_content = question + options.summary_prompt
|
||||
camel_task = Task(content=clean_task_content, id=options.task_id)
|
||||
if len(options.attaches) > 0:
|
||||
camel_task.additional_info = {Path(file_path).name: file_path for file_path in options.attaches}
|
||||
|
||||
sub_tasks = await asyncio.to_thread(workforce.eigent_make_sub_tasks, camel_task)
|
||||
summary_task_content = await summary_task(summary_task_agent, camel_task)
|
||||
sub_tasks = await asyncio.to_thread(
|
||||
workforce.eigent_make_sub_tasks,
|
||||
camel_task,
|
||||
context_for_coordinator
|
||||
)
|
||||
|
||||
if not task_lock.summary_generated:
|
||||
summary_task_agent = task_summary_agent(options)
|
||||
try:
|
||||
summary_task_content = await asyncio.wait_for(
|
||||
summary_task(summary_task_agent, camel_task), timeout=10
|
||||
)
|
||||
task_lock.summary_generated = True
|
||||
logger.info("Generated summary for first task", extra={"project_id": options.project_id})
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("summary_task timeout", extra={"project_id": options.project_id, "task_id": options.task_id})
|
||||
# Fallback to a minimal summary to unblock UI
|
||||
fallback_name = "Task"
|
||||
content_preview = camel_task.content if hasattr(camel_task, "content") else ""
|
||||
if content_preview is None:
|
||||
content_preview = ""
|
||||
fallback_summary = (
|
||||
(content_preview[:80] + "...") if len(content_preview) > 80 else content_preview
|
||||
)
|
||||
summary_task_content = f"{fallback_name}|{fallback_summary}"
|
||||
task_lock.summary_generated = True
|
||||
else:
|
||||
if len(question) > 100:
|
||||
summary_task_content = f"Task|{question[:97]}..."
|
||||
else:
|
||||
summary_task_content = f"Task|{question}"
|
||||
logger.info("Skipped summary generation for subsequent task", extra={"project_id": options.project_id})
|
||||
|
||||
yield to_sub_tasks(camel_task, summary_task_content)
|
||||
# tracer.stop()
|
||||
# tracer.save("trace.json")
|
||||
|
||||
# Only auto-start in debug mode
|
||||
if env("debug") == "on":
|
||||
logger.info(f"[DEBUG] Auto-starting workforce in debug mode")
|
||||
task_lock.status = Status.processing
|
||||
task = asyncio.create_task(workforce.eigent_start(sub_tasks))
|
||||
task_lock.add_background_task(task)
|
||||
|
|
@ -124,12 +424,185 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
sub_tasks = update_sub_tasks(sub_tasks, update_tasks)
|
||||
add_sub_tasks(camel_task, item.data.task)
|
||||
yield to_sub_tasks(camel_task, summary_task_content)
|
||||
elif item.action == Action.add_task:
|
||||
|
||||
# Check if this might be a misrouted second question
|
||||
if camel_task is None and workforce is None:
|
||||
logger.error(f"Cannot add task: both camel_task and workforce are None for project {options.project_id}")
|
||||
yield sse_json("error", {"message": "Cannot add task: task not initialized. Please start a task first."})
|
||||
continue
|
||||
|
||||
assert camel_task is not None
|
||||
if workforce is None:
|
||||
logger.error(f"Cannot add task: workforce not initialized for project {options.project_id}")
|
||||
yield sse_json("error", {"message": "Workforce not initialized. Please start the task first."})
|
||||
continue
|
||||
|
||||
# Add task to the workforce queue
|
||||
workforce.add_task(
|
||||
item.content,
|
||||
item.task_id,
|
||||
item.additional_info
|
||||
)
|
||||
|
||||
returnData = {
|
||||
"project_id": item.project_id,
|
||||
"task_id": item.task_id or (len(camel_task.subtasks) + 1)
|
||||
}
|
||||
yield sse_json("add_task", returnData)
|
||||
elif item.action == Action.remove_task:
|
||||
if workforce is None:
|
||||
logger.error(f"Cannot remove task: workforce not initialized for project {options.project_id}")
|
||||
yield sse_json("error", {"message": "Workforce not initialized. Please start the task first."})
|
||||
continue
|
||||
|
||||
workforce.remove_task(item.task_id)
|
||||
returnData = {
|
||||
"project_id": item.project_id,
|
||||
"task_id": item.task_id
|
||||
}
|
||||
yield sse_json("remove_task", returnData)
|
||||
elif item.action == Action.skip_task:
|
||||
if workforce is not None and item.project_id == options.project_id:
|
||||
if workforce._state.name == 'PAUSED':
|
||||
# Resume paused workforce to skip the task
|
||||
workforce.resume()
|
||||
workforce.skip_gracefully()
|
||||
elif item.action == Action.start:
|
||||
# Check conversation history length before starting task
|
||||
is_exceeded, total_length = check_conversation_history_length(task_lock)
|
||||
if is_exceeded:
|
||||
logger.error(f"Cannot start task: conversation history too long ({total_length} chars) for project {options.project_id}")
|
||||
yield sse_json("context_too_long", {
|
||||
"message": "The conversation history is too long. Please create a new project to continue.",
|
||||
"current_length": total_length,
|
||||
"max_length": 100000
|
||||
})
|
||||
continue
|
||||
|
||||
if workforce is not None:
|
||||
if workforce._state.name == 'PAUSED':
|
||||
# Resume paused workforce - subtasks should already be loaded
|
||||
workforce.resume()
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
task_lock.status = Status.processing
|
||||
task = asyncio.create_task(workforce.eigent_start(sub_tasks))
|
||||
task_lock.add_background_task(task)
|
||||
elif item.action == Action.task_state:
|
||||
# Track completed task results for the end event
|
||||
task_id = item.data.get('task_id', 'unknown')
|
||||
task_state = item.data.get('state', 'unknown')
|
||||
task_result = item.data.get('result', '')
|
||||
|
||||
|
||||
if task_state == 'DONE' and task_result:
|
||||
last_completed_task_result = task_result
|
||||
|
||||
yield sse_json("task_state", item.data)
|
||||
elif item.action == Action.new_task_state:
|
||||
|
||||
# Log new task state details
|
||||
new_task_id = item.data.get('task_id', 'unknown')
|
||||
new_task_state = item.data.get('state', 'unknown')
|
||||
new_task_result = item.data.get('result', '')
|
||||
|
||||
|
||||
assert camel_task is not None
|
||||
|
||||
old_task_content: str = camel_task.content
|
||||
old_task_result: str = await get_task_result_with_optional_summary(camel_task, options)
|
||||
|
||||
old_task_content_clean: str = old_task_content
|
||||
if "=== CURRENT TASK ===" in old_task_content_clean:
|
||||
old_task_content_clean = old_task_content_clean.split("=== CURRENT TASK ===")[-1].strip()
|
||||
|
||||
task_lock.add_conversation('task_result', {
|
||||
'task_content': old_task_content_clean,
|
||||
'task_result': old_task_result,
|
||||
'working_directory': get_working_directory(options, task_lock)
|
||||
})
|
||||
|
||||
new_task_content = item.data.get('content', '')
|
||||
|
||||
if new_task_content:
|
||||
import time
|
||||
task_id = item.data.get('task_id', f"{int(time.time() * 1000)}-multi")
|
||||
new_camel_task = Task(content=new_task_content, id=task_id)
|
||||
if hasattr(camel_task, 'additional_info') and camel_task.additional_info:
|
||||
new_camel_task.additional_info = camel_task.additional_info
|
||||
camel_task = new_camel_task
|
||||
|
||||
# Now trigger end of previous task using stored result
|
||||
yield sse_json("end", old_task_result)
|
||||
# Always yield new_task_state first - this is not optional
|
||||
yield sse_json("new_task_state", item.data)
|
||||
# Trigger Queue Removal
|
||||
yield sse_json("remove_task", {"task_id": item.data.get("task_id")})
|
||||
|
||||
# Then handle multi-turn processing
|
||||
if workforce is not None and new_task_content:
|
||||
task_lock.status = Status.confirming
|
||||
workforce.pause()
|
||||
|
||||
try:
|
||||
is_multi_turn_complex = await question_confirm(question_agent, new_task_content, task_lock)
|
||||
|
||||
if not is_multi_turn_complex:
|
||||
simple_answer_prompt = f"{build_conversation_context(task_lock, header='=== Previous Conversation ===')}User Query: {new_task_content}\n\nProvide a direct, helpful answer to this simple question."
|
||||
|
||||
try:
|
||||
simple_resp = question_agent.step(simple_answer_prompt)
|
||||
answer_content = simple_resp.msgs[0].content if simple_resp and simple_resp.msgs else "I understand your question, but I'm having trouble generating a response right now."
|
||||
|
||||
task_lock.add_conversation('assistant', answer_content)
|
||||
|
||||
# Send response to user (don't send confirmed if simple response)
|
||||
yield sse_json("wait_confirm", {"content": answer_content, "question": new_task_content})
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating simple answer in multi-turn: {e}")
|
||||
yield sse_json("wait_confirm", {"content": "I encountered an error while processing your question.", "question": new_task_content})
|
||||
|
||||
workforce.resume()
|
||||
continue # This continues the main while loop, waiting for next action
|
||||
|
||||
yield sse_json("confirmed", {"question": new_task_content})
|
||||
task_lock.status = Status.confirmed
|
||||
|
||||
context_for_multi_turn = build_context_for_workforce(task_lock, options)
|
||||
|
||||
new_sub_tasks = await workforce.handle_decompose_append_task(
|
||||
camel_task,
|
||||
reset=False,
|
||||
coordinator_context=context_for_multi_turn
|
||||
)
|
||||
|
||||
task_content_for_summary = new_task_content
|
||||
if len(task_content_for_summary) > 100:
|
||||
new_summary_content = f"Follow-up Task|{task_content_for_summary[:97]}..."
|
||||
else:
|
||||
new_summary_content = f"Follow-up Task|{task_content_for_summary}"
|
||||
|
||||
# Send the extracted events
|
||||
yield to_sub_tasks(camel_task, new_summary_content)
|
||||
|
||||
# Update the context with new task data
|
||||
sub_tasks = new_sub_tasks
|
||||
summary_task_content = new_summary_content
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"[TRACE] Traceback: {traceback.format_exc()}")
|
||||
# Continue with existing context if decomposition fails
|
||||
yield sse_json("error", {"message": f"Failed to process task: {str(e)}"})
|
||||
else:
|
||||
if workforce is None:
|
||||
logger.warning(f"[TRACE] Workforce is None - this might be the issue")
|
||||
if not new_task_content:
|
||||
logger.warning(f"[TRACE] No new task content provided")
|
||||
elif item.action == Action.create_agent:
|
||||
yield sse_json("create_agent", item.data)
|
||||
elif item.action == Action.activate_agent:
|
||||
|
|
@ -167,9 +640,15 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
elif item.action == Action.pause:
|
||||
if workforce is not None:
|
||||
workforce.pause()
|
||||
logger.info(f"Workforce paused for project {options.project_id}")
|
||||
else:
|
||||
logger.warning(f"Cannot pause: workforce is None for project {options.project_id}")
|
||||
elif item.action == Action.resume:
|
||||
if workforce is not None:
|
||||
workforce.resume()
|
||||
logger.info(f"Workforce resumed for project {options.project_id}")
|
||||
else:
|
||||
logger.warning(f"Cannot resume: workforce is None for project {options.project_id}")
|
||||
elif item.action == Action.new_agent:
|
||||
if workforce is not None:
|
||||
workforce.pause()
|
||||
|
|
@ -180,21 +659,52 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
elif item.action == Action.end:
|
||||
assert camel_task is not None
|
||||
task_lock.status = Status.done
|
||||
yield sse_json("end", str(camel_task.result))
|
||||
final_result: str = await get_task_result_with_optional_summary(camel_task, options)
|
||||
|
||||
task_lock.last_task_result = final_result
|
||||
|
||||
task_content: str = camel_task.content
|
||||
if "=== CURRENT TASK ===" in task_content:
|
||||
task_content = task_content.split("=== CURRENT TASK ===")[-1].strip()
|
||||
|
||||
task_lock.add_conversation('task_result', {
|
||||
'task_content': task_content,
|
||||
'task_result': final_result,
|
||||
'working_directory': get_working_directory(options, task_lock)
|
||||
})
|
||||
|
||||
|
||||
yield sse_json("end", final_result)
|
||||
|
||||
if workforce is not None:
|
||||
workforce.stop_gracefully()
|
||||
break
|
||||
logger.info(f"Workforce stopped gracefully for project {options.project_id}")
|
||||
workforce = None
|
||||
else:
|
||||
logger.warning(f"Workforce already None at end action for project {options.project_id}")
|
||||
|
||||
camel_task = None
|
||||
|
||||
if question_agent is not None:
|
||||
question_agent.reset()
|
||||
logger.info(f"Reset question_agent for project {options.project_id}")
|
||||
elif item.action == Action.supplement:
|
||||
assert camel_task is not None
|
||||
task_lock.status = Status.processing
|
||||
camel_task.add_subtask(
|
||||
Task(
|
||||
content=item.data.question,
|
||||
id=f"{camel_task.id}.{len(camel_task.subtasks)}",
|
||||
|
||||
# Check if this might be a misrouted second question
|
||||
if camel_task is None:
|
||||
logger.warning(f"SUPPLEMENT action received but camel_task is None for project {options.project_id}")
|
||||
else:
|
||||
assert camel_task is not None
|
||||
task_lock.status = Status.processing
|
||||
camel_task.add_subtask(
|
||||
Task(
|
||||
content=item.data.question,
|
||||
id=f"{camel_task.id}.{len(camel_task.subtasks)}",
|
||||
)
|
||||
)
|
||||
)
|
||||
task = asyncio.create_task(workforce.eigent_start(camel_task.subtasks))
|
||||
task_lock.add_background_task(task)
|
||||
if workforce is not None:
|
||||
task = asyncio.create_task(workforce.eigent_start(camel_task.subtasks))
|
||||
task_lock.add_background_task(task)
|
||||
elif item.action == Action.budget_not_enough:
|
||||
if workforce is not None:
|
||||
workforce.pause()
|
||||
|
|
@ -204,32 +714,43 @@ async def step_solve(options: Chat, request: Request, task_lock: TaskLock):
|
|||
if workforce._running:
|
||||
workforce.stop()
|
||||
workforce.stop_gracefully()
|
||||
logger.info(f"Workforce stopped for project {options.project_id}")
|
||||
else:
|
||||
logger.warning(f"Workforce is None at stop action for project {options.project_id}")
|
||||
await delete_task_lock(task_lock.id)
|
||||
break
|
||||
else:
|
||||
logger.warning(f"Unknown action: {item.action}")
|
||||
except ModelProcessingError as e:
|
||||
if "Budget has been exceeded" in str(e):
|
||||
logger.warning(f"Budget exceeded for task {options.task_id}, action: {item.action}")
|
||||
# workforce decompose task don't use ListenAgent, this need return sse
|
||||
if "workforce" in locals() and workforce is not None:
|
||||
workforce.pause()
|
||||
yield sse_json(Action.budget_not_enough, {"message": "budget not enouth"})
|
||||
else:
|
||||
logger.error(f"Error processing action {item.action}: {e}")
|
||||
logger.error(f"ModelProcessingError for task {options.task_id}, action {item.action}: {e}", exc_info=True)
|
||||
yield sse_json("error", {"message": str(e)})
|
||||
if "workforce" in locals() and workforce is not None and workforce._running:
|
||||
workforce.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing action {item.action}: {e}")
|
||||
logger.error(f"Unhandled exception for task {options.task_id}, action {item.action}: {e}", exc_info=True)
|
||||
yield sse_json("error", {"message": str(e)})
|
||||
# Continue processing other items instead of breaking
|
||||
|
||||
|
||||
@traceroot.trace()
|
||||
async def install_mcp(
|
||||
mcp: ListenChatAgent,
|
||||
install_mcp: ActionInstallMcpData,
|
||||
):
|
||||
mcp.add_tools(await get_mcp_tools(install_mcp.data))
|
||||
logger.info(f"Installing MCP tools: {list(install_mcp.data.get('mcpServers', {}).keys())}")
|
||||
try:
|
||||
mcp.add_tools(await get_mcp_tools(install_mcp.data))
|
||||
logger.info("MCP tools installed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error installing MCP tools: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def to_sub_tasks(task: Task, summary_task_content: str):
|
||||
|
|
@ -287,30 +808,53 @@ def add_sub_tasks(camel_task: Task, update_tasks: list[TaskContent]):
|
|||
)
|
||||
|
||||
|
||||
async def question_confirm(agent: ListenChatAgent, prompt: str) -> str | Literal[True]:
|
||||
prompt = f"""
|
||||
> **Your Role:** You are a highly capable agent. Your primary function is to analyze a user's request and determine the appropriate course of action.
|
||||
>
|
||||
> **Your Process:**
|
||||
>
|
||||
> 1. **Analyze the User's Query:** Carefully examine the user's request: `{prompt}`.
|
||||
>
|
||||
> 2. **Categorize the Query:**
|
||||
> * **Simple Query:** Is this a simple greeting, a question that can be answered directly, or a conversational interaction (e.g., "hello", "thank you")?
|
||||
> * **Complex Task:** Is this a request that requires a series of steps, code execution, or interaction with tools to complete?
|
||||
>
|
||||
> 3. **Execute Your Decision:**
|
||||
> * **For a Simple Query:** Provide a direct and helpful response.
|
||||
> * **For a Complex Task:** Your *only* response should be "yes". This will trigger a specialized workforce to handle the task. Do not include any other text, punctuation, or pleasantries.
|
||||
"""
|
||||
resp = agent.step(prompt)
|
||||
logger.info(f"resp: {agent.chat_history}")
|
||||
if resp.msgs[0].content.lower() != "yes":
|
||||
return sse_json("wait_confirm", {"content": resp.msgs[0].content})
|
||||
else:
|
||||
async def question_confirm(agent: ListenChatAgent, prompt: str, task_lock: TaskLock | None = None) -> bool:
|
||||
"""Simple question confirmation - returns True for complex tasks, False for simple questions."""
|
||||
|
||||
context_prompt = ""
|
||||
if task_lock:
|
||||
context_prompt = build_conversation_context(task_lock, header="=== Previous Conversation ===")
|
||||
|
||||
full_prompt = f"""{context_prompt}User Query: {prompt}
|
||||
|
||||
Determine if this user query is a complex task or a simple question.
|
||||
|
||||
**Complex task** (answer "yes"): Requires tools, code execution, file operations, multi-step planning, or creating/modifying content
|
||||
- Examples: "create a file", "search for X", "implement feature Y", "write code", "analyze data", "build something"
|
||||
|
||||
**Simple question** (answer "no"): Can be answered directly with knowledge or conversation history, no action needed
|
||||
- Examples: greetings ("hello", "hi"), fact queries ("what is X?"), clarifications ("what did you mean?"), status checks ("how are you?")
|
||||
|
||||
Answer only "yes" or "no". Do not provide any explanation.
|
||||
|
||||
Is this a complex task? (yes/no):"""
|
||||
|
||||
try:
|
||||
resp = agent.step(full_prompt)
|
||||
|
||||
if not resp or not resp.msgs or len(resp.msgs) == 0:
|
||||
logger.warning("No response from agent, defaulting to complex task")
|
||||
return True
|
||||
|
||||
content = resp.msgs[0].content
|
||||
if not content:
|
||||
logger.warning("Empty content from agent, defaulting to complex task")
|
||||
return True
|
||||
|
||||
normalized = content.strip().lower()
|
||||
is_complex = "yes" in normalized
|
||||
|
||||
logger.info(f"Question confirm result: {'complex task' if is_complex else 'simple question'}",
|
||||
extra={"response": content, "is_complex": is_complex})
|
||||
|
||||
return is_complex
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in question_confirm: {e}")
|
||||
return True
|
||||
|
||||
|
||||
@traceroot.trace()
|
||||
async def summary_task(agent: ListenChatAgent, task: Task) -> str:
|
||||
prompt = f"""The user's task is:
|
||||
---
|
||||
|
|
@ -324,13 +868,100 @@ Your instructions are:
|
|||
Example format: "Task Name|This is the summary of the task."
|
||||
Do not include any other text or formatting.
|
||||
"""
|
||||
logger.debug("Generating task summary", extra={"task_id": task.id})
|
||||
try:
|
||||
res = agent.step(prompt)
|
||||
summary = res.msgs[0].content
|
||||
logger.info("Task summary generated", extra={"summary": summary})
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error("Error generating task summary", extra={"error": str(e)}, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def summary_subtasks_result(agent: ListenChatAgent, task: Task) -> str:
|
||||
"""
|
||||
Summarize the aggregated results from all subtasks into a concise summary.
|
||||
|
||||
Args:
|
||||
agent: The summary agent to use
|
||||
task: The main task containing subtasks and their aggregated results
|
||||
|
||||
Returns:
|
||||
A concise summary of all subtask results
|
||||
"""
|
||||
subtasks_info = ""
|
||||
for i, subtask in enumerate(task.subtasks, 1):
|
||||
subtasks_info += f"\n**Subtask {i}**\n"
|
||||
subtasks_info += f"Description: {subtask.content}\n"
|
||||
subtasks_info += f"Result: {subtask.result or 'No result'}\n"
|
||||
subtasks_info += "---\n"
|
||||
|
||||
prompt = f"""You are a professional summarizer. Summarize the results of the following subtasks.
|
||||
|
||||
Main Task: {task.content}
|
||||
|
||||
Subtasks (with descriptions and results):
|
||||
---
|
||||
{subtasks_info}
|
||||
---
|
||||
|
||||
Instructions:
|
||||
1. Provide a concise summary of what was accomplished
|
||||
2. Highlight key findings or outputs from each subtask
|
||||
3. Mention any important files created or actions taken
|
||||
4. Use bullet points or sections for clarity
|
||||
5. DO NOT repeat the task name in your summary - go straight to the results
|
||||
6. Keep it professional but conversational
|
||||
|
||||
Summary:
|
||||
"""
|
||||
|
||||
res = agent.step(prompt)
|
||||
logger.info(f"summary_task: {res.msgs[0].content}")
|
||||
return res.msgs[0].content
|
||||
summary = res.msgs[0].content
|
||||
|
||||
logger.info(f"Generated subtasks summary for task {task.id} with {len(task.subtasks)} subtasks")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def get_task_result_with_optional_summary(task: Task, options: Chat) -> str:
|
||||
"""
|
||||
Get the task result, with LLM summary if there are multiple subtasks.
|
||||
|
||||
Args:
|
||||
task: The task to get result from
|
||||
options: Chat options for creating summary agent
|
||||
|
||||
Returns:
|
||||
The task result (summarized if multiple subtasks, raw otherwise)
|
||||
"""
|
||||
result = str(task.result or "")
|
||||
|
||||
if task.subtasks and len(task.subtasks) > 1:
|
||||
logger.info(f"Task {task.id} has {len(task.subtasks)} subtasks, generating summary")
|
||||
try:
|
||||
summary_agent = task_summary_agent(options)
|
||||
summarized_result = await summary_subtasks_result(summary_agent, task)
|
||||
result = summarized_result
|
||||
logger.info(f"Successfully generated summary for task {task.id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary for task {task.id}: {e}")
|
||||
elif task.subtasks and len(task.subtasks) == 1:
|
||||
logger.info(f"Task {task.id} has only 1 subtask, skipping LLM summary")
|
||||
if result and "--- Subtask" in result and "Result ---" in result:
|
||||
parts = result.split("Result ---", 1)
|
||||
if len(parts) > 1:
|
||||
result = parts[1].strip()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@traceroot.trace()
|
||||
async def construct_workforce(options: Chat) -> tuple[Workforce, ListenChatAgent]:
|
||||
working_directory = options.file_save_path()
|
||||
logger.info("Constructing workforce", extra={"project_id": options.project_id, "task_id": options.task_id})
|
||||
working_directory = get_working_directory(options)
|
||||
logger.debug("Working directory set", extra={"working_directory": working_directory})
|
||||
[coordinator_agent, task_agent] = [
|
||||
agent_model(
|
||||
key,
|
||||
|
|
@ -339,8 +970,8 @@ async def construct_workforce(options: Chat) -> tuple[Workforce, ListenChatAgent
|
|||
[
|
||||
*(
|
||||
ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.task_id, key).send_message_to_user
|
||||
).register_toolkits(NoteTakingToolkit(options.task_id, working_directory=working_directory))
|
||||
message_handler=HumanToolkit(options.project_id, key).send_message_to_user
|
||||
).register_toolkits(NoteTakingToolkit(options.project_id, working_directory=working_directory))
|
||||
).get_tools()
|
||||
],
|
||||
)
|
||||
|
|
@ -373,11 +1004,11 @@ The current date is {datetime.date.today()}. For any date-related tasks, you MUS
|
|||
""",
|
||||
options,
|
||||
[
|
||||
*HumanToolkit.get_can_use_tools(options.task_id, Agents.new_worker_agent),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id, Agents.new_worker_agent),
|
||||
*(
|
||||
ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.task_id, Agents.new_worker_agent).send_message_to_user
|
||||
).register_toolkits(NoteTakingToolkit(options.task_id, working_directory=working_directory))
|
||||
message_handler=HumanToolkit(options.project_id, Agents.new_worker_agent).send_message_to_user
|
||||
).register_toolkits(NoteTakingToolkit(options.project_id, working_directory=working_directory))
|
||||
).get_tools(),
|
||||
],
|
||||
)
|
||||
|
|
@ -402,7 +1033,7 @@ The current date is {datetime.date.today()}. For any date-related tasks, you MUS
|
|||
model_platform_enum = None
|
||||
|
||||
workforce = Workforce(
|
||||
options.task_id,
|
||||
options.project_id,
|
||||
"A workforce",
|
||||
graceful_shutdown_timeout=3, # 30 seconds for debugging
|
||||
share_memory=False,
|
||||
|
|
@ -481,10 +1112,13 @@ def format_agent_description(agent_data: NewAgent | ActionNewAgent) -> str:
|
|||
return " ".join(description_parts)
|
||||
|
||||
|
||||
@traceroot.trace()
|
||||
async def new_agent_model(data: NewAgent | ActionNewAgent, options: Chat):
|
||||
working_directory = options.file_save_path()
|
||||
logger.info("Creating new agent", extra={"agent_name": data.name, "project_id": options.project_id, "task_id": options.task_id})
|
||||
logger.debug("New agent data", extra={"agent_data": data.model_dump_json()})
|
||||
working_directory = get_working_directory(options)
|
||||
tool_names = []
|
||||
tools = [*await get_toolkits(data.tools, data.name, options.task_id)]
|
||||
tools = [*await get_toolkits(data.tools, data.name, options.project_id)]
|
||||
for item in data.tools:
|
||||
tool_names.append(titleize(item))
|
||||
if data.mcp_tools is not None:
|
||||
|
|
@ -492,7 +1126,8 @@ async def new_agent_model(data: NewAgent | ActionNewAgent, options: Chat):
|
|||
for item in data.mcp_tools["mcpServers"].keys():
|
||||
tool_names.append(titleize(item))
|
||||
for item in tools:
|
||||
logger.debug(f"new agent function tool ====== {item.func.__name__}")
|
||||
logger.debug(f"Agent {data.name} tool: {item.func.__name__}")
|
||||
logger.info(f"Agent {data.name} created with {len(tools)} tools: {tool_names}")
|
||||
# Enhanced system message with platform information
|
||||
enhanced_description = f"""{data.description}
|
||||
- You are now working in system {platform.system()} with architecture
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing_extensions import Any, Literal, TypedDict
|
||||
from typing import List, Dict, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.exception.exception import ProgramException
|
||||
from app.model.chat import McpServers, Status, SupplementChat, Chat, UpdateData
|
||||
|
|
@ -9,13 +10,16 @@ from contextlib import contextmanager
|
|||
from contextvars import ContextVar
|
||||
from datetime import datetime, timedelta
|
||||
import weakref
|
||||
from loguru import logger
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("task_service")
|
||||
|
||||
|
||||
class Action(str, Enum):
|
||||
improve = "improve" # user -> backend
|
||||
update_task = "update_task" # user -> backend
|
||||
task_state = "task_state" # backend -> user
|
||||
new_task_state = "new_task_state" # backend -> user
|
||||
start = "start" # user -> backend
|
||||
create_agent = "create_agent" # backend -> user
|
||||
activate_agent = "activate_agent" # backend -> user
|
||||
|
|
@ -36,6 +40,9 @@ class Action(str, Enum):
|
|||
resume = "resume" # user -> backend user take control
|
||||
new_agent = "new_agent" # user -> backend
|
||||
budget_not_enough = "budget_not_enough" # backend -> user
|
||||
add_task = "add_task" # user -> backend
|
||||
remove_task = "remove_task" # user -> backend
|
||||
skip_task = "skip_task" # user -> backend
|
||||
|
||||
|
||||
class ActionImproveData(BaseModel):
|
||||
|
|
@ -56,6 +63,10 @@ class ActionTaskStateData(BaseModel):
|
|||
action: Literal[Action.task_state] = Action.task_state
|
||||
data: dict[Literal["task_id", "content", "state", "result", "failure_count"], str | int]
|
||||
|
||||
class ActionNewTaskStateData(BaseModel):
|
||||
action: Literal[Action.new_task_state] = Action.new_task_state
|
||||
data: dict[Literal["task_id", "content", "state", "result", "failure_count"], str | int]
|
||||
|
||||
|
||||
class ActionAskData(BaseModel):
|
||||
action: Literal[Action.ask] = Action.ask
|
||||
|
|
@ -169,6 +180,26 @@ class ActionBudgetNotEnough(BaseModel):
|
|||
action: Literal[Action.budget_not_enough] = Action.budget_not_enough
|
||||
|
||||
|
||||
class ActionAddTaskData(BaseModel):
|
||||
action: Literal[Action.add_task] = Action.add_task
|
||||
content: str
|
||||
project_id: str | None = None
|
||||
task_id: str | None = None
|
||||
additional_info: dict | None = None
|
||||
insert_position: int = -1
|
||||
|
||||
|
||||
class ActionRemoveTaskData(BaseModel):
|
||||
action: Literal[Action.remove_task] = Action.remove_task
|
||||
task_id: str
|
||||
project_id: str
|
||||
|
||||
|
||||
class ActionSkipTaskData(BaseModel):
|
||||
action: Literal[Action.skip_task] = Action.skip_task
|
||||
project_id: str
|
||||
|
||||
|
||||
ActionData = (
|
||||
ActionImproveData
|
||||
| ActionStartData
|
||||
|
|
@ -192,6 +223,9 @@ ActionData = (
|
|||
| ActionTakeControl
|
||||
| ActionNewAgent
|
||||
| ActionBudgetNotEnough
|
||||
| ActionAddTaskData
|
||||
| ActionRemoveTaskData
|
||||
| ActionSkipTaskData
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -221,6 +255,16 @@ class TaskLock:
|
|||
background_tasks: set[asyncio.Task]
|
||||
"""Track all background tasks for cleanup"""
|
||||
|
||||
# Context management fields
|
||||
conversation_history: List[Dict[str, Any]]
|
||||
"""Store conversation history for context"""
|
||||
last_task_result: str
|
||||
"""Store the last task execution result"""
|
||||
question_agent: Optional[Any]
|
||||
"""Persistent question confirmation agent"""
|
||||
summary_generated: bool
|
||||
"""Track if summary has been generated for this project"""
|
||||
|
||||
def __init__(self, id: str, queue: asyncio.Queue, human_input: dict) -> None:
|
||||
self.id = id
|
||||
self.queue = queue
|
||||
|
|
@ -229,6 +273,12 @@ class TaskLock:
|
|||
self.last_accessed = datetime.now()
|
||||
self.background_tasks = set()
|
||||
|
||||
# Initialize context management fields
|
||||
self.conversation_history = []
|
||||
self.last_task_result = ""
|
||||
self.last_task_summary = ""
|
||||
self.question_agent = None
|
||||
|
||||
async def put_queue(self, data: ActionData):
|
||||
self.last_accessed = datetime.now()
|
||||
await self.queue.put(data)
|
||||
|
|
@ -262,6 +312,25 @@ class TaskLock:
|
|||
pass
|
||||
self.background_tasks.clear()
|
||||
|
||||
def add_conversation(self, role: str, content: str | dict):
|
||||
"""Add a conversation entry to history"""
|
||||
self.conversation_history.append({
|
||||
'role': role,
|
||||
'content': content,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
def get_recent_context(self, max_entries: int = None) -> str:
|
||||
"""Get recent conversation context as a formatted string"""
|
||||
if not self.conversation_history:
|
||||
return ""
|
||||
|
||||
context = "=== Recent Conversation ===\n"
|
||||
history_to_use = self.conversation_history if max_entries is None else self.conversation_history[-max_entries:]
|
||||
for entry in history_to_use:
|
||||
context += f"{entry['role']}: {entry['content']}\n"
|
||||
return context
|
||||
|
||||
|
||||
task_locks = dict[str, TaskLock]()
|
||||
# Cleanup task for removing stale task locks
|
||||
|
|
@ -275,6 +344,11 @@ def get_task_lock(id: str) -> TaskLock:
|
|||
return task_locks[id]
|
||||
|
||||
|
||||
def get_task_lock_if_exists(id: str) -> TaskLock | None:
|
||||
"""Get task lock if it exists, otherwise return None"""
|
||||
return task_locks.get(id)
|
||||
|
||||
|
||||
def create_task_lock(id: str) -> TaskLock:
|
||||
if id in task_locks:
|
||||
raise ProgramException("Task already exists")
|
||||
|
|
@ -288,6 +362,13 @@ def create_task_lock(id: str) -> TaskLock:
|
|||
return task_locks[id]
|
||||
|
||||
|
||||
def get_or_create_task_lock(id: str) -> TaskLock:
|
||||
"""Get existing task lock or create a new one if it doesn't exist"""
|
||||
if id in task_locks:
|
||||
return task_locks[id]
|
||||
return create_task_lock(id)
|
||||
|
||||
|
||||
async def delete_task_lock(id: str):
|
||||
if id not in task_locks:
|
||||
raise ProgramException("Task not found")
|
||||
|
|
|
|||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
|
|
@ -6,7 +6,7 @@ from threading import Event
|
|||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
import uuid
|
||||
from app.utils import traceroot_wrapper as traceroot
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
from camel.agents import ChatAgent
|
||||
from camel.agents.chat_agent import StreamingChatAgentResponse, AsyncStreamingChatAgentResponse
|
||||
from camel.agents._types import ToolCallRequest
|
||||
|
|
@ -18,6 +18,7 @@ from camel.terminators import ResponseTerminator
|
|||
from camel.toolkits import FunctionTool, RegisteredAgentToolkit
|
||||
from camel.types.agents import ToolCallingRecord
|
||||
from app.component.environment import env
|
||||
from app.utils.file_utils import get_working_directory
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from app.utils.toolkit.hybrid_browser_toolkit import HybridBrowserToolkit
|
||||
from app.utils.toolkit.excel_toolkit import ExcelToolkit
|
||||
|
|
@ -50,7 +51,6 @@ from camel.types import ModelPlatformType, ModelType
|
|||
from camel.toolkits import MCPToolkit, ToolkitMessageIntegration
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from app.model.chat import Chat, McpServers
|
||||
|
||||
# Create traceroot logger for agent tracking
|
||||
|
|
@ -68,6 +68,8 @@ from app.service.task import (
|
|||
)
|
||||
from app.service.task import set_process_task
|
||||
|
||||
NOW_STR = datetime.datetime.now().strftime("%Y-%m-%d %H:00:00")
|
||||
|
||||
|
||||
class ListenChatAgent(ChatAgent):
|
||||
@traceroot.trace()
|
||||
|
|
@ -171,7 +173,6 @@ class ListenChatAgent(ChatAgent):
|
|||
except Exception as e:
|
||||
res = None
|
||||
error_info = e
|
||||
logger.exception(e)
|
||||
traceroot_logger.error(f"Agent {self.agent_name} unexpected error in step: {e}", exc_info=True)
|
||||
message = f"Error processing message: {e!s}"
|
||||
total_tokens = 0
|
||||
|
|
@ -246,8 +247,7 @@ class ListenChatAgent(ChatAgent):
|
|||
except Exception as e:
|
||||
res = None
|
||||
error_info = e
|
||||
logger.exception(e)
|
||||
traceroot_logger.error(f"Agent {self.agent_name} unexpected error in step: {e}", exc_info=True)
|
||||
traceroot_logger.error(f"Agent {self.agent_name} unexpected error in async step: {e}", exc_info=True)
|
||||
message = f"Error processing message: {e!s}"
|
||||
total_tokens = 0
|
||||
|
||||
|
|
@ -323,6 +323,17 @@ class ListenChatAgent(ChatAgent):
|
|||
else:
|
||||
result = raw_result
|
||||
mask_flag = False
|
||||
# Prepare result message with truncation
|
||||
if isinstance(result, str):
|
||||
result_msg = result
|
||||
else:
|
||||
result_str = repr(result)
|
||||
MAX_RESULT_LENGTH = 500
|
||||
if len(result_str) > MAX_RESULT_LENGTH:
|
||||
result_msg = result_str[:MAX_RESULT_LENGTH] + f"... (truncated, total length: {len(result_str)} chars)"
|
||||
else:
|
||||
result_msg = result_str
|
||||
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionDeactivateToolkitData(
|
||||
|
|
@ -331,7 +342,7 @@ class ListenChatAgent(ChatAgent):
|
|||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": result if isinstance(result, str) else repr(result),
|
||||
"message": result_msg,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -341,9 +352,7 @@ class ListenChatAgent(ChatAgent):
|
|||
error_msg = f"Error executing tool '{func_name}': {e!s}"
|
||||
result = f"Tool execution failed: {error_msg}"
|
||||
mask_flag = False
|
||||
logger.debug(error_msg)
|
||||
traceroot_logger.error(f"Tool execution failed for {func_name}: {e}")
|
||||
traceback.print_exc()
|
||||
traceroot_logger.error(f"Tool execution failed for {func_name}: {e}", exc_info=True)
|
||||
|
||||
return self._record_tool_calling(func_name, args, result, tool_call_id, mask_output=mask_flag)
|
||||
|
||||
|
|
@ -403,9 +412,18 @@ class ListenChatAgent(ChatAgent):
|
|||
# Capture the error message to prevent framework crash
|
||||
error_msg = f"Error executing async tool '{func_name}': {e!s}"
|
||||
result = {"error": error_msg}
|
||||
logger.warning(error_msg)
|
||||
traceroot_logger.error(f"Async tool execution failed for {func_name}: {e}")
|
||||
traceback.print_exc()
|
||||
traceroot_logger.error(f"Async tool execution failed for {func_name}: {e}", exc_info=True)
|
||||
|
||||
# Prepare result message with truncation
|
||||
if isinstance(result, str):
|
||||
result_msg = result
|
||||
else:
|
||||
result_str = repr(result)
|
||||
MAX_RESULT_LENGTH = 500
|
||||
if len(result_str) > MAX_RESULT_LENGTH:
|
||||
result_msg = result_str[:MAX_RESULT_LENGTH] + f"... (truncated, total length: {len(result_str)} chars)"
|
||||
else:
|
||||
result_msg = result_str
|
||||
|
||||
await task_lock.put_queue(
|
||||
ActionDeactivateToolkitData(
|
||||
|
|
@ -414,7 +432,7 @@ class ListenChatAgent(ChatAgent):
|
|||
"process_task_id": self.process_task_id,
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": func_name,
|
||||
"message": result if isinstance(result, str) else repr(result),
|
||||
"message": result_msg,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -427,7 +445,7 @@ class ListenChatAgent(ChatAgent):
|
|||
|
||||
# Clone tools and collect toolkits that need registration
|
||||
cloned_tools, toolkits_to_register = self._clone_tools()
|
||||
|
||||
|
||||
new_agent = ListenChatAgent(
|
||||
api_task_id=self.api_task_id,
|
||||
agent_name=self.agent_name,
|
||||
|
|
@ -443,7 +461,6 @@ class ListenChatAgent(ChatAgent):
|
|||
response_terminators=self.response_terminators,
|
||||
scheduling_strategy=self.model_backend.scheduling_strategy.__name__,
|
||||
max_iteration=self.max_iteration,
|
||||
agent_id=self.agent_id,
|
||||
stop_event=self.stop_event,
|
||||
tool_execution_timeout=self.tool_execution_timeout,
|
||||
mask_tool_output=self.mask_tool_output,
|
||||
|
|
@ -474,9 +491,9 @@ def agent_model(
|
|||
tool_names: list[str] | None = None,
|
||||
toolkits_to_register_agent: list[RegisteredAgentToolkit] | None = None,
|
||||
):
|
||||
task_lock = get_task_lock(options.task_id)
|
||||
task_lock = get_task_lock(options.project_id)
|
||||
agent_id = str(uuid.uuid4())
|
||||
traceroot_logger.info(f"Creating agent: {agent_name} with id: {agent_id} for task: {options.task_id}")
|
||||
traceroot_logger.info(f"Creating agent: {agent_name} with id: {agent_id} for project: {options.project_id}")
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionCreateAgentData(data={"agent_name": agent_name, "agent_id": agent_id, "tools": tool_names or []})
|
||||
|
|
@ -484,7 +501,7 @@ def agent_model(
|
|||
)
|
||||
|
||||
return ListenChatAgent(
|
||||
options.task_id,
|
||||
options.project_id,
|
||||
agent_name,
|
||||
system_message,
|
||||
model=ModelFactory.create(
|
||||
|
|
@ -493,7 +510,7 @@ def agent_model(
|
|||
api_key=options.api_key,
|
||||
url=options.api_url,
|
||||
model_config_dict={
|
||||
"user": str(options.task_id),
|
||||
"user": str(options.project_id),
|
||||
}
|
||||
if options.is_cloud()
|
||||
else None,
|
||||
|
|
@ -515,7 +532,7 @@ def agent_model(
|
|||
def question_confirm_agent(options: Chat):
|
||||
return agent_model(
|
||||
"question_confirm_agent",
|
||||
f"You are a highly capable agent. Your primary function is to analyze a user's request and determine the appropriate course of action. The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.",
|
||||
f"You are a highly capable agent. Your primary function is to analyze a user's request and determine the appropriate course of action. The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.",
|
||||
options,
|
||||
)
|
||||
|
||||
|
|
@ -531,24 +548,24 @@ def task_summary_agent(options: Chat):
|
|||
|
||||
@traceroot.trace()
|
||||
async def developer_agent(options: Chat):
|
||||
working_directory = options.file_save_path()
|
||||
traceroot_logger.info(f"Creating developer agent for task: {options.task_id} in directory: {working_directory}")
|
||||
working_directory = get_working_directory(options)
|
||||
traceroot_logger.info(f"Creating developer agent for project: {options.project_id} in directory: {working_directory}")
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.task_id, Agents.developer_agent).send_message_to_user
|
||||
message_handler=HumanToolkit(options.project_id, Agents.developer_agent).send_message_to_user
|
||||
)
|
||||
note_toolkit = NoteTakingToolkit(
|
||||
api_task_id=options.task_id, agent_name=Agents.developer_agent, working_directory=working_directory
|
||||
api_task_id=options.project_id, agent_name=Agents.developer_agent, working_directory=working_directory
|
||||
)
|
||||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
web_deploy_toolkit = WebDeployToolkit(api_task_id=options.task_id)
|
||||
web_deploy_toolkit = WebDeployToolkit(api_task_id=options.project_id)
|
||||
web_deploy_toolkit = message_integration.register_toolkits(web_deploy_toolkit)
|
||||
screenshot_toolkit = ScreenshotToolkit(options.task_id, working_directory=working_directory)
|
||||
screenshot_toolkit = ScreenshotToolkit(options.project_id, working_directory=working_directory)
|
||||
screenshot_toolkit = message_integration.register_toolkits(screenshot_toolkit)
|
||||
|
||||
terminal_toolkit = TerminalToolkit(options.task_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
|
||||
terminal_toolkit = TerminalToolkit(options.project_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
|
||||
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
|
||||
tools = [
|
||||
*HumanToolkit.get_can_use_tools(options.task_id, Agents.developer_agent),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id, Agents.developer_agent),
|
||||
*note_toolkit.get_tools(),
|
||||
*web_deploy_toolkit.get_tools(),
|
||||
*terminal_toolkit.get_tools(),
|
||||
|
|
@ -577,7 +594,7 @@ and generation.
|
|||
- **System**: {platform.system()} ({platform.machine()})
|
||||
- **Working Directory**: `{working_directory}`. All local file operations must
|
||||
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
|
||||
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
|
||||
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
|
||||
</operating_environment>
|
||||
|
||||
<mandatory_instructions>
|
||||
|
|
@ -702,14 +719,14 @@ these tips to maximize your effectiveness:
|
|||
|
||||
@traceroot.trace()
|
||||
def search_agent(options: Chat):
|
||||
working_directory = options.file_save_path()
|
||||
traceroot_logger.info(f"Creating search agent for task: {options.task_id} in directory: {working_directory}")
|
||||
working_directory = get_working_directory(options)
|
||||
traceroot_logger.info(f"Creating search agent for project: {options.project_id} in directory: {working_directory}")
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.task_id, Agents.search_agent).send_message_to_user
|
||||
message_handler=HumanToolkit(options.project_id, Agents.search_agent).send_message_to_user
|
||||
)
|
||||
|
||||
web_toolkit_custom = HybridBrowserToolkit(
|
||||
options.task_id,
|
||||
options.project_id,
|
||||
headless=False,
|
||||
browser_log_to_file=True,
|
||||
stealth=True,
|
||||
|
|
@ -729,12 +746,14 @@ def search_agent(options: Chat):
|
|||
],
|
||||
)
|
||||
|
||||
# Save reference before registering for toolkits_to_register_agent
|
||||
web_toolkit_for_agent_registration = web_toolkit_custom
|
||||
web_toolkit_custom = message_integration.register_toolkits(web_toolkit_custom)
|
||||
terminal_toolkit = TerminalToolkit(options.task_id, Agents.search_agent, safe_mode=True, clone_current_env=False)
|
||||
terminal_toolkit = TerminalToolkit(options.project_id, Agents.search_agent, safe_mode=True, clone_current_env=False)
|
||||
terminal_toolkit = message_integration.register_functions([terminal_toolkit.shell_exec])
|
||||
note_toolkit = NoteTakingToolkit(options.task_id, Agents.search_agent, working_directory=working_directory)
|
||||
note_toolkit = NoteTakingToolkit(options.project_id, Agents.search_agent, working_directory=working_directory)
|
||||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
search_tools = SearchToolkit.get_can_use_tools(options.task_id)
|
||||
search_tools = SearchToolkit.get_can_use_tools(options.project_id)
|
||||
# Only register search tools if any are available
|
||||
if search_tools:
|
||||
search_tools = message_integration.register_functions(search_tools)
|
||||
|
|
@ -742,7 +761,7 @@ def search_agent(options: Chat):
|
|||
search_tools = []
|
||||
|
||||
tools = [
|
||||
*HumanToolkit.get_can_use_tools(options.task_id, Agents.search_agent),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id, Agents.search_agent),
|
||||
*web_toolkit_custom.get_tools(),
|
||||
*terminal_toolkit,
|
||||
*note_toolkit.get_tools(),
|
||||
|
|
@ -772,7 +791,7 @@ comprehensive and well-documented information.
|
|||
- **System**: {platform.system()} ({platform.machine()})
|
||||
- **Working Directory**: `{working_directory}`. All local file operations must
|
||||
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
|
||||
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
|
||||
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
|
||||
</operating_environment>
|
||||
|
||||
<mandatory_instructions>
|
||||
|
|
@ -793,7 +812,7 @@ The current date is {datetime.date.today()}. For any date-related tasks, you MUS
|
|||
- **CRITICAL URL POLICY**: You are STRICTLY FORBIDDEN from inventing,
|
||||
guessing, or constructing URLs yourself. You MUST only use URLs from
|
||||
trusted sources:
|
||||
1. URLs returned by search tools (like `search_google` or `search_exa`)
|
||||
1. URLs returned by search tools (`search_google`)
|
||||
2. URLs found on webpages you have visited through browser tools
|
||||
3. URLs provided by the user in their request
|
||||
Fabricating or guessing URLs is considered a critical error and must
|
||||
|
|
@ -839,8 +858,6 @@ Your approach depends on available search tools:
|
|||
sites using `browser_type` and submit with `browser_enter`
|
||||
- **Extract URLs from results**: Only use URLs that appear in the search
|
||||
results on these websites
|
||||
- **Alternative Search**: If available, use `search_exa` for additional
|
||||
results
|
||||
|
||||
**Common Browser Operations (both scenarios):**
|
||||
- **Navigation and Exploration**: Use `browser_visit_page` to open URLs.
|
||||
|
|
@ -877,41 +894,42 @@ Your approach depends on available search tools:
|
|||
NoteTakingToolkit.toolkit_name(),
|
||||
TerminalToolkit.toolkit_name(),
|
||||
],
|
||||
toolkits_to_register_agent=[web_toolkit_for_agent_registration],
|
||||
)
|
||||
|
||||
|
||||
@traceroot.trace()
|
||||
async def document_agent(options: Chat):
|
||||
working_directory = options.file_save_path()
|
||||
traceroot_logger.info(f"Creating document agent for task: {options.task_id} in directory: {working_directory}")
|
||||
working_directory = get_working_directory(options)
|
||||
traceroot_logger.info(f"Creating document agent for project: {options.project_id} in directory: {working_directory}")
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.task_id, Agents.task_agent).send_message_to_user
|
||||
message_handler=HumanToolkit(options.project_id, Agents.task_agent).send_message_to_user
|
||||
)
|
||||
file_write_toolkit = FileToolkit(options.task_id, working_directory=working_directory)
|
||||
pptx_toolkit = PPTXToolkit(options.task_id, working_directory=working_directory)
|
||||
file_write_toolkit = FileToolkit(options.project_id, working_directory=working_directory)
|
||||
pptx_toolkit = PPTXToolkit(options.project_id, working_directory=working_directory)
|
||||
pptx_toolkit = message_integration.register_toolkits(pptx_toolkit)
|
||||
mark_it_down_toolkit = MarkItDownToolkit(options.task_id)
|
||||
mark_it_down_toolkit = MarkItDownToolkit(options.project_id)
|
||||
mark_it_down_toolkit = message_integration.register_toolkits(mark_it_down_toolkit)
|
||||
excel_toolkit = ExcelToolkit(options.task_id, working_directory=working_directory)
|
||||
excel_toolkit = ExcelToolkit(options.project_id, working_directory=working_directory)
|
||||
excel_toolkit = message_integration.register_toolkits(excel_toolkit)
|
||||
note_toolkit = NoteTakingToolkit(options.task_id, Agents.document_agent, working_directory=working_directory)
|
||||
note_toolkit = NoteTakingToolkit(options.project_id, Agents.document_agent, working_directory=working_directory)
|
||||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
terminal_toolkit = TerminalToolkit(options.task_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
|
||||
terminal_toolkit = TerminalToolkit(options.project_id, Agents.document_agent, safe_mode=True, clone_current_env=False)
|
||||
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
|
||||
tools = [
|
||||
*file_write_toolkit.get_tools(),
|
||||
*pptx_toolkit.get_tools(),
|
||||
*HumanToolkit.get_can_use_tools(options.task_id, Agents.document_agent),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id, Agents.document_agent),
|
||||
*mark_it_down_toolkit.get_tools(),
|
||||
*excel_toolkit.get_tools(),
|
||||
*note_toolkit.get_tools(),
|
||||
*terminal_toolkit.get_tools(),
|
||||
*await GoogleDriveMCPToolkit.get_can_use_tools(options.task_id, options.get_bun_env()),
|
||||
*await GoogleDriveMCPToolkit.get_can_use_tools(options.project_id, options.get_bun_env()),
|
||||
]
|
||||
if env("EXA_API_KEY") or options.is_cloud():
|
||||
search_toolkit = SearchToolkit(options.task_id, Agents.document_agent).search_exa
|
||||
search_toolkit = message_integration.register_functions([search_toolkit])
|
||||
tools.extend(search_toolkit)
|
||||
# if env("EXA_API_KEY") or options.is_cloud():
|
||||
# search_toolkit = SearchToolkit(options.project_id, Agents.document_agent).search_exa
|
||||
# search_toolkit = message_integration.register_functions([search_toolkit])
|
||||
# tools.extend(search_toolkit)
|
||||
system_message = f"""
|
||||
<role>
|
||||
You are a Documentation Specialist, responsible for creating, modifying, and
|
||||
|
|
@ -935,7 +953,7 @@ to be embedded in your work.
|
|||
- **System**: {platform.system()} ({platform.machine()})
|
||||
- **Working Directory**: `{working_directory}`. All local file operations must
|
||||
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
|
||||
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
|
||||
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
|
||||
</operating_environment>
|
||||
|
||||
<mandatory_instructions>
|
||||
|
|
@ -1083,32 +1101,32 @@ supported formats including advanced spreadsheet functionality.
|
|||
|
||||
@traceroot.trace()
|
||||
def multi_modal_agent(options: Chat):
|
||||
working_directory = options.file_save_path()
|
||||
traceroot_logger.info(f"Creating multi-modal agent for task: {options.task_id} in directory: {working_directory}")
|
||||
working_directory = get_working_directory(options)
|
||||
traceroot_logger.info(f"Creating multi-modal agent for project: {options.project_id} in directory: {working_directory}")
|
||||
message_integration = ToolkitMessageIntegration(
|
||||
message_handler=HumanToolkit(options.task_id, Agents.multi_modal_agent).send_message_to_user
|
||||
message_handler=HumanToolkit(options.project_id, Agents.multi_modal_agent).send_message_to_user
|
||||
)
|
||||
video_download_toolkit = VideoDownloaderToolkit(options.task_id, working_directory=working_directory)
|
||||
video_download_toolkit = VideoDownloaderToolkit(options.project_id, working_directory=working_directory)
|
||||
video_download_toolkit = message_integration.register_toolkits(video_download_toolkit)
|
||||
image_analysis_toolkit = ImageAnalysisToolkit(options.task_id)
|
||||
image_analysis_toolkit = ImageAnalysisToolkit(options.project_id)
|
||||
image_analysis_toolkit = message_integration.register_toolkits(image_analysis_toolkit)
|
||||
|
||||
terminal_toolkit = TerminalToolkit(
|
||||
options.task_id, agent_name=Agents.multi_modal_agent, safe_mode=True, clone_current_env=False
|
||||
options.project_id, agent_name=Agents.multi_modal_agent, safe_mode=True, clone_current_env=False
|
||||
)
|
||||
terminal_toolkit = message_integration.register_toolkits(terminal_toolkit)
|
||||
note_toolkit = NoteTakingToolkit(options.task_id, Agents.multi_modal_agent, working_directory=working_directory)
|
||||
note_toolkit = NoteTakingToolkit(options.project_id, Agents.multi_modal_agent, working_directory=working_directory)
|
||||
note_toolkit = message_integration.register_toolkits(note_toolkit)
|
||||
tools = [
|
||||
*video_download_toolkit.get_tools(),
|
||||
*image_analysis_toolkit.get_tools(),
|
||||
*HumanToolkit.get_can_use_tools(options.task_id, Agents.multi_modal_agent),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id, Agents.multi_modal_agent),
|
||||
*terminal_toolkit.get_tools(),
|
||||
*note_toolkit.get_tools(),
|
||||
]
|
||||
if options.is_cloud():
|
||||
open_ai_image_toolkit = OpenAIImageToolkit( # todo check llm has this model
|
||||
options.task_id,
|
||||
options.project_id,
|
||||
model="dall-e-3",
|
||||
response_format="b64_json",
|
||||
size="1024x1024",
|
||||
|
|
@ -1130,7 +1148,7 @@ def multi_modal_agent(options: Chat):
|
|||
|
||||
if model_platform_enum == ModelPlatformType.OPENAI:
|
||||
audio_analysis_toolkit = AudioAnalysisToolkit(
|
||||
options.task_id,
|
||||
options.project_id,
|
||||
working_directory,
|
||||
OpenAIAudioModels(
|
||||
api_key=options.api_key,
|
||||
|
|
@ -1140,10 +1158,10 @@ def multi_modal_agent(options: Chat):
|
|||
audio_analysis_toolkit = message_integration.register_toolkits(audio_analysis_toolkit)
|
||||
tools.extend(audio_analysis_toolkit.get_tools())
|
||||
|
||||
if env("EXA_API_KEY") or options.is_cloud():
|
||||
search_toolkit = SearchToolkit(options.task_id, Agents.multi_modal_agent).search_exa
|
||||
search_toolkit = message_integration.register_functions([search_toolkit])
|
||||
tools.extend(search_toolkit)
|
||||
# if env("EXA_API_KEY") or options.is_cloud():
|
||||
# search_toolkit = SearchToolkit(options.project_id, Agents.multi_modal_agent).search_exa
|
||||
# search_toolkit = message_integration.register_functions([search_toolkit])
|
||||
# tools.extend(search_toolkit)
|
||||
|
||||
system_message = f"""
|
||||
<role>
|
||||
|
|
@ -1167,7 +1185,7 @@ presentations, and other documents.
|
|||
- **System**: {platform.system()} ({platform.machine()})
|
||||
- **Working Directory**: `{working_directory}`. All local file operations must
|
||||
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
|
||||
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
|
||||
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
|
||||
</operating_environment>
|
||||
|
||||
<mandatory_instructions>
|
||||
|
|
@ -1253,27 +1271,27 @@ async def social_medium_agent(options: Chat):
|
|||
Agent to handling tasks related to social media:
|
||||
include toolkits: WhatsApp, Twitter, LinkedIn, Reddit, Notion, Slack, Discord and Google Suite.
|
||||
"""
|
||||
working_directory = options.file_save_path()
|
||||
traceroot_logger.info(f"Creating social medium agent for task: {options.task_id} in directory: {working_directory}")
|
||||
working_directory = get_working_directory(options)
|
||||
traceroot_logger.info(f"Creating social medium agent for project: {options.project_id} in directory: {working_directory}")
|
||||
tools = [
|
||||
*WhatsAppToolkit.get_can_use_tools(options.task_id),
|
||||
*TwitterToolkit.get_can_use_tools(options.task_id),
|
||||
*LinkedInToolkit.get_can_use_tools(options.task_id),
|
||||
*RedditToolkit.get_can_use_tools(options.task_id),
|
||||
*await NotionMCPToolkit.get_can_use_tools(options.task_id),
|
||||
# *SlackToolkit.get_can_use_tools(options.task_id),
|
||||
*await GoogleGmailMCPToolkit.get_can_use_tools(options.task_id, options.get_bun_env()),
|
||||
*GoogleCalendarToolkit.get_can_use_tools(options.task_id),
|
||||
*HumanToolkit.get_can_use_tools(options.task_id, Agents.social_medium_agent),
|
||||
*TerminalToolkit(options.task_id, agent_name=Agents.social_medium_agent, clone_current_env=False).get_tools(),
|
||||
*WhatsAppToolkit.get_can_use_tools(options.project_id),
|
||||
*TwitterToolkit.get_can_use_tools(options.project_id),
|
||||
*LinkedInToolkit.get_can_use_tools(options.project_id),
|
||||
*RedditToolkit.get_can_use_tools(options.project_id),
|
||||
*await NotionMCPToolkit.get_can_use_tools(options.project_id),
|
||||
# *SlackToolkit.get_can_use_tools(options.project_id),
|
||||
*await GoogleGmailMCPToolkit.get_can_use_tools(options.project_id, options.get_bun_env()),
|
||||
*GoogleCalendarToolkit.get_can_use_tools(options.project_id),
|
||||
*HumanToolkit.get_can_use_tools(options.project_id, Agents.social_medium_agent),
|
||||
*TerminalToolkit(options.project_id, agent_name=Agents.social_medium_agent, clone_current_env=False).get_tools(),
|
||||
*NoteTakingToolkit(
|
||||
options.task_id, Agents.social_medium_agent, working_directory=working_directory
|
||||
options.project_id, Agents.social_medium_agent, working_directory=working_directory
|
||||
).get_tools(),
|
||||
# *DiscordToolkit(options.task_id).get_tools(), # Not supported temporarily
|
||||
# *GoogleSuiteToolkit(options.task_id).get_tools(), # Not supported temporarily
|
||||
# *DiscordToolkit(options.project_id).get_tools(), # Not supported temporarily
|
||||
# *GoogleSuiteToolkit(options.project_id).get_tools(), # Not supported temporarily
|
||||
]
|
||||
if env("EXA_API_KEY") or options.is_cloud():
|
||||
tools.append(FunctionTool(SearchToolkit(options.task_id, Agents.social_medium_agent).search_exa))
|
||||
# if env("EXA_API_KEY") or options.is_cloud():
|
||||
# tools.append(FunctionTool(SearchToolkit(options.project_id, Agents.social_medium_agent).search_exa))
|
||||
return agent_model(
|
||||
Agents.social_medium_agent,
|
||||
BaseMessage.make_assistant_message(
|
||||
|
|
@ -1290,7 +1308,7 @@ use plain text formatting instead.
|
|||
|
||||
- **Working Directory**: `{working_directory}`. All local file operations must
|
||||
occur here, but you can access files from any place in the file system. For all file system operations, you MUST use absolute paths to ensure precision and avoid ambiguity.
|
||||
The current date is {datetime.date.today()}. For any date-related tasks, you MUST use this as the current date.
|
||||
The current date is {NOW_STR}(Accurate to the hour). For any date-related tasks, you MUST use this as the current date.
|
||||
|
||||
Your integrated toolkits enable you to:
|
||||
|
||||
|
|
@ -1369,16 +1387,16 @@ operations.
|
|||
@traceroot.trace()
|
||||
async def mcp_agent(options: Chat):
|
||||
traceroot_logger.info(
|
||||
f"Creating MCP agent for task: {options.task_id} with {len(options.installed_mcp['mcpServers'])} MCP servers"
|
||||
f"Creating MCP agent for project: {options.project_id} with {len(options.installed_mcp['mcpServers'])} MCP servers"
|
||||
)
|
||||
tools = [
|
||||
# *HumanToolkit.get_can_use_tools(options.task_id, Agents.mcp_agent),
|
||||
*McpSearchToolkit(options.task_id).get_tools(),
|
||||
# *HumanToolkit.get_can_use_tools(options.project_id, Agents.mcp_agent),
|
||||
*McpSearchToolkit(options.project_id).get_tools(),
|
||||
]
|
||||
if len(options.installed_mcp["mcpServers"]) > 0:
|
||||
try:
|
||||
mcp_tools = await get_mcp_tools(options.installed_mcp)
|
||||
traceroot_logger.info(f"Retrieved {len(mcp_tools)} MCP tools for task {options.task_id}")
|
||||
traceroot_logger.info(f"Retrieved {len(mcp_tools)} MCP tools for task {options.project_id}")
|
||||
if mcp_tools:
|
||||
tool_names = [tool.get_function_name() if hasattr(tool, 'get_function_name') else str(tool) for tool in mcp_tools]
|
||||
traceroot_logger.debug(f"MCP tools: {tool_names}")
|
||||
|
|
@ -1386,9 +1404,9 @@ async def mcp_agent(options: Chat):
|
|||
except Exception as e:
|
||||
traceroot_logger.debug(repr(e))
|
||||
|
||||
task_lock = get_task_lock(options.task_id)
|
||||
task_lock = get_task_lock(options.project_id)
|
||||
agent_id = str(uuid.uuid4())
|
||||
traceroot_logger.info(f"Creating MCP agent: {Agents.mcp_agent} with id: {agent_id} for task: {options.task_id}")
|
||||
traceroot_logger.info(f"Creating MCP agent: {Agents.mcp_agent} with id: {agent_id} for task: {options.project_id}")
|
||||
asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionCreateAgentData(
|
||||
|
|
@ -1401,7 +1419,7 @@ async def mcp_agent(options: Chat):
|
|||
)
|
||||
)
|
||||
return ListenChatAgent(
|
||||
options.task_id,
|
||||
options.project_id,
|
||||
Agents.mcp_agent,
|
||||
system_message="You are a helpful assistant that can help users search mcp servers. The found mcp services will be returned to the user, and you will ask the user via ask_human_via_gui whether they want to install these mcp services.",
|
||||
model=ModelFactory.create(
|
||||
|
|
@ -1410,7 +1428,7 @@ async def mcp_agent(options: Chat):
|
|||
api_key=options.api_key,
|
||||
url=options.api_url,
|
||||
model_config_dict={
|
||||
"user": str(options.task_id),
|
||||
"user": str(options.project_id),
|
||||
}
|
||||
if options.is_cloud()
|
||||
else None,
|
||||
|
|
|
|||
236
backend/app/utils/cookie_manager.py
Normal file
236
backend/app/utils/cookie_manager.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
import sqlite3
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
logger = traceroot.get_logger("cookie_manager")
|
||||
|
||||
|
||||
class CookieManager:
|
||||
"""Manager for reading and managing browser cookies
|
||||
from Electron/Chrome SQLite database"""
|
||||
|
||||
def __init__(self, user_data_dir: str):
|
||||
self.user_data_dir = user_data_dir
|
||||
|
||||
# Check for cookies in partition directory first (for persist:user_login)
|
||||
partition_cookies_path = os.path.join(user_data_dir, "Partitions", "user_login", "Cookies")
|
||||
|
||||
if os.path.exists(partition_cookies_path):
|
||||
self.cookies_db_path = partition_cookies_path
|
||||
logger.info(f"Using partition cookies at: {partition_cookies_path}")
|
||||
else:
|
||||
# Fallback to default location
|
||||
self.cookies_db_path = os.path.join(user_data_dir, "Cookies")
|
||||
|
||||
if not os.path.exists(self.cookies_db_path):
|
||||
alt_path = os.path.join(user_data_dir, "Network", "Cookies")
|
||||
if os.path.exists(alt_path):
|
||||
self.cookies_db_path = alt_path
|
||||
else:
|
||||
logger.warning(f"Cookies database not found at {self.cookies_db_path} or {partition_cookies_path}")
|
||||
|
||||
def _get_cookies_connection(self) -> Optional[sqlite3.Connection]:
|
||||
"""Get database connection using a temporary copy to avoid locks"""
|
||||
if not os.path.exists(self.cookies_db_path):
|
||||
logger.warning(f"Cookies database not found: {self.cookies_db_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
temp_db_path = self.cookies_db_path + ".tmp"
|
||||
shutil.copy2(self.cookies_db_path, temp_db_path)
|
||||
conn = sqlite3.connect(temp_db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to cookies database: {e}")
|
||||
return None
|
||||
|
||||
def _cleanup_temp_db(self):
|
||||
"""Clean up temporary database file"""
|
||||
temp_db_path = self.cookies_db_path + ".tmp"
|
||||
try:
|
||||
if os.path.exists(temp_db_path):
|
||||
os.remove(temp_db_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error cleaning up temp database: {e}")
|
||||
|
||||
def get_cookie_domains(self) -> List[Dict[str, any]]:
|
||||
"""Get list of all domains with cookies"""
|
||||
conn = self._get_cookies_connection()
|
||||
if not conn:
|
||||
return []
|
||||
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
query = """
|
||||
SELECT
|
||||
host_key as domain,
|
||||
COUNT(*) as cookie_count,
|
||||
MAX(last_access_utc) as last_access
|
||||
FROM cookies
|
||||
GROUP BY host_key
|
||||
ORDER BY last_access DESC
|
||||
"""
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
domains = []
|
||||
for row in rows:
|
||||
try:
|
||||
chrome_timestamp = row['last_access']
|
||||
if chrome_timestamp:
|
||||
seconds_since_epoch = (chrome_timestamp / 1000000.0) - 11644473600
|
||||
last_access = datetime.fromtimestamp(seconds_since_epoch).strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
last_access = "Never"
|
||||
except Exception as e:
|
||||
logger.debug(f"Error converting timestamp: {e}")
|
||||
last_access = "Unknown"
|
||||
|
||||
domains.append({
|
||||
'domain': row['domain'],
|
||||
'cookie_count': row['cookie_count'],
|
||||
'last_access': last_access
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(domains)} domains with cookies")
|
||||
return domains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading cookies: {e}")
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
self._cleanup_temp_db()
|
||||
|
||||
def get_cookies_for_domain(self, domain: str) -> List[Dict[str, str]]:
|
||||
"""Get all cookies for a specific domain"""
|
||||
conn = self._get_cookies_connection()
|
||||
if not conn:
|
||||
return []
|
||||
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
query = """
|
||||
SELECT
|
||||
host_key,
|
||||
name,
|
||||
value,
|
||||
path,
|
||||
expires_utc,
|
||||
is_secure,
|
||||
is_httponly
|
||||
FROM cookies
|
||||
WHERE host_key = ? OR host_key LIKE ?
|
||||
ORDER BY name
|
||||
"""
|
||||
cursor.execute(query, (domain, f'%.{domain}'))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
cookies = []
|
||||
for row in rows:
|
||||
cookies.append({
|
||||
'domain': row['host_key'],
|
||||
'name': row['name'],
|
||||
'value': row['value'][:50] + '...' if len(row['value']) > 50 else row['value'],
|
||||
'path': row['path'],
|
||||
'secure': bool(row['is_secure']),
|
||||
'httponly': bool(row['is_httponly'])
|
||||
})
|
||||
|
||||
return cookies
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading cookies for domain {domain}: {e}")
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
self._cleanup_temp_db()
|
||||
|
||||
def delete_cookies_for_domain(self, domain: str) -> bool:
|
||||
"""Delete all cookies for a specific domain"""
|
||||
if not os.path.exists(self.cookies_db_path):
|
||||
logger.warning(f"Cookies database not found: {self.cookies_db_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(self.cookies_db_path)
|
||||
cursor = conn.cursor()
|
||||
delete_query = """
|
||||
DELETE FROM cookies
|
||||
WHERE host_key = ? OR host_key LIKE ?
|
||||
"""
|
||||
cursor.execute(delete_query, (domain, f'%.{domain}'))
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
# IMPORTANT: Execute VACUUM to remove deleted data and compact database
|
||||
# This prevents recovery from WAL files
|
||||
cursor.execute("VACUUM")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Also remove WAL and SHM files to ensure clean state
|
||||
self._cleanup_wal_files()
|
||||
|
||||
logger.info(f"Deleted {deleted_count} cookies for domain {domain}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cookies for domain {domain}: {e}")
|
||||
return False
|
||||
|
||||
def _cleanup_wal_files(self):
|
||||
"""Remove SQLite WAL and SHM files"""
|
||||
try:
|
||||
wal_path = self.cookies_db_path + '-wal'
|
||||
shm_path = self.cookies_db_path + '-shm'
|
||||
journal_path = self.cookies_db_path + '-journal'
|
||||
|
||||
for path in [wal_path, shm_path, journal_path]:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
logger.info(f"Removed temporary file: {path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up WAL files: {e}")
|
||||
|
||||
def delete_all_cookies(self) -> bool:
|
||||
"""Delete all cookies"""
|
||||
if not os.path.exists(self.cookies_db_path):
|
||||
logger.warning(f"Cookies database not found: {self.cookies_db_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(self.cookies_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM cookies")
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
# IMPORTANT: Execute VACUUM to remove deleted data and compact database
|
||||
# This prevents recovery from WAL files
|
||||
cursor.execute("VACUUM")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Also remove WAL and SHM files to ensure clean state
|
||||
self._cleanup_wal_files()
|
||||
|
||||
logger.info(f"Deleted all {deleted_count} cookies")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting all cookies: {e}")
|
||||
return False
|
||||
|
||||
def search_cookies(self, keyword: str) -> List[Dict[str, any]]:
|
||||
"""Search cookies by domain keyword"""
|
||||
domains = self.get_cookie_domains()
|
||||
keyword_lower = keyword.lower()
|
||||
return [
|
||||
domain for domain in domains
|
||||
if keyword_lower in domain['domain'].lower()
|
||||
]
|
||||
20
backend/app/utils/file_utils.py
Normal file
20
backend/app/utils/file_utils.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""File system utilities."""
|
||||
|
||||
from app.component.environment import env
|
||||
from app.model.chat import Chat
|
||||
|
||||
|
||||
def get_working_directory(options: Chat, task_lock=None) -> str:
|
||||
"""
|
||||
Get the correct working directory for file operations.
|
||||
First checks if there's an updated path from improve API call,
|
||||
then falls back to environment variable or default path.
|
||||
"""
|
||||
if not task_lock:
|
||||
from app.service.task import get_task_lock_if_exists
|
||||
task_lock = get_task_lock_if_exists(options.project_id)
|
||||
|
||||
if task_lock and hasattr(task_lock, 'new_folder_path') and task_lock.new_folder_path:
|
||||
return str(task_lock.new_folder_path)
|
||||
else:
|
||||
return env("file_save_path", options.file_save_path())
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
from functools import wraps
|
||||
from inspect import iscoroutinefunction
|
||||
from inspect import iscoroutinefunction, getmembers, ismethod, signature
|
||||
import json
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Type, TypeVar
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from loguru import logger
|
||||
from app.service.task import (
|
||||
ActionActivateToolkitData,
|
||||
ActionDeactivateToolkitData,
|
||||
|
|
@ -12,6 +13,41 @@ from app.service.task import (
|
|||
)
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from app.service.task import process_task
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("toolkit_listen")
|
||||
|
||||
|
||||
def _safe_put_queue(task_lock, data):
|
||||
"""Safely put data to the queue, handling both sync and async contexts"""
|
||||
try:
|
||||
# Try to get current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're in an async context, create a task
|
||||
task = asyncio.create_task(task_lock.put_queue(data))
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
except RuntimeError:
|
||||
# No running event loop, we need to handle this differently
|
||||
try:
|
||||
# Create a new event loop in a separate thread to avoid conflicts
|
||||
def run_in_thread():
|
||||
try:
|
||||
# Create a new event loop for this thread
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
new_loop.run_until_complete(task_lock.put_queue(data))
|
||||
finally:
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[listen_toolkit] Failed to send data in thread: {e}")
|
||||
|
||||
# Run in a separate thread to avoid blocking
|
||||
thread = threading.Thread(target=run_in_thread, daemon=True)
|
||||
thread.start()
|
||||
except Exception as e:
|
||||
logger.error(f"[listen_toolkit] Failed to send data to queue: {e}")
|
||||
|
||||
|
||||
def listen_toolkit(
|
||||
|
|
@ -27,6 +63,11 @@ def listen_toolkit(
|
|||
@wraps(wrap)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
toolkit: AbstractToolkit = args[0]
|
||||
# Check if api_task_id exists
|
||||
if not hasattr(toolkit, 'api_task_id'):
|
||||
logger.warning(f"[listen_toolkit] {toolkit.__class__.__name__} missing api_task_id, calling method directly")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
task_lock = get_task_lock(toolkit.api_task_id)
|
||||
|
||||
if inputs is not None:
|
||||
|
|
@ -40,19 +81,23 @@ def listen_toolkit(
|
|||
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
|
||||
args_str = f"{args_str}, {kwargs_str}" if args_str else kwargs_str
|
||||
|
||||
# Truncate args_str if too long
|
||||
MAX_ARGS_LENGTH = 500
|
||||
if len(args_str) > MAX_ARGS_LENGTH:
|
||||
args_str = args_str[:MAX_ARGS_LENGTH] + f"... (truncated, total length: {len(args_str)} chars)"
|
||||
|
||||
toolkit_name = toolkit.toolkit_name()
|
||||
method_name = func.__name__.replace("_", " ")
|
||||
await task_lock.put_queue(
|
||||
ActionActivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": args_str,
|
||||
},
|
||||
)
|
||||
activate_data = ActionActivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": args_str,
|
||||
},
|
||||
)
|
||||
await task_lock.put_queue(activate_data)
|
||||
error = None
|
||||
res = None
|
||||
try:
|
||||
|
|
@ -70,21 +115,26 @@ def listen_toolkit(
|
|||
res_msg = json.dumps(res, ensure_ascii=False)
|
||||
except TypeError:
|
||||
# Handle cases where res contains non-serializable objects (like coroutines)
|
||||
res_msg = str(res)
|
||||
res_str = str(res)
|
||||
# Truncate very long outputs to avoid flooding logs
|
||||
MAX_LENGTH = 500
|
||||
if len(res_str) > MAX_LENGTH:
|
||||
res_msg = res_str[:MAX_LENGTH] + f"... (truncated, total length: {len(res_str)} chars)"
|
||||
else:
|
||||
res_msg = res_str
|
||||
else:
|
||||
res_msg = str(error)
|
||||
|
||||
await task_lock.put_queue(
|
||||
ActionDeactivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": res_msg,
|
||||
},
|
||||
)
|
||||
deactivate_data = ActionDeactivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": res_msg,
|
||||
},
|
||||
)
|
||||
await task_lock.put_queue(deactivate_data)
|
||||
if error is not None:
|
||||
raise error
|
||||
return res
|
||||
|
|
@ -96,6 +146,11 @@ def listen_toolkit(
|
|||
@wraps(wrap)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
toolkit: AbstractToolkit = args[0]
|
||||
# Check if api_task_id exists
|
||||
if not hasattr(toolkit, 'api_task_id'):
|
||||
logger.warning(f"[listen_toolkit] {toolkit.__class__.__name__} missing api_task_id, calling method directly")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
task_lock = get_task_lock(toolkit.api_task_id)
|
||||
|
||||
if inputs is not None:
|
||||
|
|
@ -109,34 +164,34 @@ def listen_toolkit(
|
|||
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
|
||||
args_str = f"{args_str}, {kwargs_str}" if args_str else kwargs_str
|
||||
|
||||
# Truncate args_str if too long
|
||||
MAX_ARGS_LENGTH = 500
|
||||
if len(args_str) > MAX_ARGS_LENGTH:
|
||||
args_str = args_str[:MAX_ARGS_LENGTH] + f"... (truncated, total length: {len(args_str)} chars)"
|
||||
|
||||
toolkit_name = toolkit.toolkit_name()
|
||||
method_name = func.__name__.replace("_", " ")
|
||||
task = asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionActivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": args_str,
|
||||
},
|
||||
)
|
||||
)
|
||||
activate_data = ActionActivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": args_str,
|
||||
},
|
||||
)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
_safe_put_queue(task_lock, activate_data)
|
||||
error = None
|
||||
res = None
|
||||
try:
|
||||
logger.debug(f"Executing toolkit method: {toolkit_name}.{method_name} for agent '{toolkit.agent_name}'")
|
||||
res = func(*args, **kwargs)
|
||||
# Safety check: if the result is a coroutine, we need to await it
|
||||
# Safety check: if the result is a coroutine, this is a programming error
|
||||
if asyncio.iscoroutine(res):
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Async function {func.__name__} was incorrectly called synchronously")
|
||||
res = asyncio.run(res)
|
||||
error_msg = f"Async function {func.__name__} was incorrectly called in sync context. This is a bug - the function should be marked as async or should not return a coroutine."
|
||||
logger.error(f"[listen_toolkit] {error_msg}")
|
||||
# Cannot safely await in sync context - close the coroutine to prevent warnings
|
||||
res.close()
|
||||
raise TypeError(error_msg)
|
||||
except Exception as e:
|
||||
error = e
|
||||
|
||||
|
|
@ -150,25 +205,26 @@ def listen_toolkit(
|
|||
res_msg = json.dumps(res, ensure_ascii=False)
|
||||
except TypeError:
|
||||
# Handle cases where res contains non-serializable objects (like coroutines)
|
||||
res_msg = str(res)
|
||||
res_str = str(res)
|
||||
# Truncate very long outputs to avoid flooding logs
|
||||
MAX_LENGTH = 500
|
||||
if len(res_str) > MAX_LENGTH:
|
||||
res_msg = res_str[:MAX_LENGTH] + f"... (truncated, total length: {len(res_str)} chars)"
|
||||
else:
|
||||
res_msg = res_str
|
||||
else:
|
||||
res_msg = str(error)
|
||||
|
||||
task = asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionDeactivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": res_msg,
|
||||
},
|
||||
)
|
||||
)
|
||||
deactivate_data = ActionDeactivateToolkitData(
|
||||
data={
|
||||
"agent_name": toolkit.agent_name,
|
||||
"process_task_id": process_task.get(""),
|
||||
"toolkit_name": toolkit_name,
|
||||
"method_name": method_name,
|
||||
"message": res_msg,
|
||||
},
|
||||
)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
_safe_put_queue(task_lock, deactivate_data)
|
||||
if error is not None:
|
||||
raise error
|
||||
return res
|
||||
|
|
@ -176,3 +232,87 @@ def listen_toolkit(
|
|||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Methods that should not be wrapped by auto_listen_toolkit
|
||||
# These are utility/helper methods that don't perform actual tool operations
|
||||
EXCLUDED_METHODS = {
|
||||
'get_tools', # Tool enumeration
|
||||
'get_can_use_tools', # Tool filtering
|
||||
'toolkit_name', # Metadata getter
|
||||
'run_mcp_server', # MCP server initialization
|
||||
'model_dump', # Pydantic model serialization
|
||||
'model_dump_json', # Pydantic model serialization
|
||||
'dict', # Pydantic legacy dict method
|
||||
'json', # Pydantic legacy json method
|
||||
'copy', # Object copying
|
||||
'update', # Object update
|
||||
}
|
||||
|
||||
|
||||
def auto_listen_toolkit(base_toolkit_class: Type[T]) -> Callable[[Type[T]], Type[T]]:
|
||||
"""
|
||||
Class decorator that automatically wraps all public methods from the base toolkit
|
||||
with the @listen_toolkit decorator.
|
||||
|
||||
Excluded methods (not wrapped):
|
||||
- get_tools, get_can_use_tools: Tool enumeration/filtering
|
||||
- toolkit_name: Metadata getter
|
||||
- run_mcp_server: MCP server initialization
|
||||
- Pydantic serialization methods: model_dump, model_dump_json, dict, json
|
||||
- Object utility methods: copy, update
|
||||
|
||||
These methods are typically called during initialization or for metadata,
|
||||
and should not trigger activate/deactivate events.
|
||||
|
||||
Usage:
|
||||
@auto_listen_toolkit(BaseNoteTakingToolkit)
|
||||
class NoteTakingToolkit(BaseNoteTakingToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
"""
|
||||
def class_decorator(cls: Type[T]) -> Type[T]:
|
||||
|
||||
base_methods = {}
|
||||
for name in dir(base_toolkit_class):
|
||||
# Skip private methods and excluded helper methods
|
||||
if not name.startswith('_') and name not in EXCLUDED_METHODS:
|
||||
attr = getattr(base_toolkit_class, name)
|
||||
if callable(attr):
|
||||
base_methods[name] = attr
|
||||
|
||||
for method_name, base_method in base_methods.items():
|
||||
if method_name in cls.__dict__:
|
||||
continue
|
||||
|
||||
sig = signature(base_method)
|
||||
|
||||
def create_wrapper(method_name: str, base_method: Callable) -> Callable:
|
||||
# Unwrap decorators to check the actual function
|
||||
unwrapped_method = base_method
|
||||
while hasattr(unwrapped_method, '__wrapped__'):
|
||||
unwrapped_method = unwrapped_method.__wrapped__
|
||||
|
||||
# Check if the unwrapped method is a coroutine function
|
||||
if iscoroutinefunction(unwrapped_method):
|
||||
async def async_method_wrapper(self, *args, **kwargs):
|
||||
return await getattr(super(cls, self), method_name)(*args, **kwargs)
|
||||
async_method_wrapper.__name__ = method_name
|
||||
async_method_wrapper.__signature__ = sig
|
||||
return async_method_wrapper
|
||||
else:
|
||||
def sync_method_wrapper(self, *args, **kwargs):
|
||||
return getattr(super(cls, self), method_name)(*args, **kwargs)
|
||||
sync_method_wrapper.__name__ = method_name
|
||||
sync_method_wrapper.__signature__ = sig
|
||||
return sync_method_wrapper
|
||||
|
||||
wrapper = create_wrapper(method_name, base_method)
|
||||
decorated_method = listen_toolkit(base_method)(wrapper)
|
||||
|
||||
setattr(cls, method_name, decorated_method)
|
||||
|
||||
return cls
|
||||
|
||||
return class_decorator
|
||||
|
|
|
|||
94
backend/app/utils/oauth_state_manager.py
Normal file
94
backend/app/utils/oauth_state_manager.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""
|
||||
OAuth authorization state manager for background authorization flows
|
||||
"""
|
||||
import threading
|
||||
from typing import Dict, Optional, Literal, Any
|
||||
from datetime import datetime
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
logger = traceroot.get_logger("main")
|
||||
|
||||
AuthStatus = Literal["pending", "authorizing", "success", "failed", "cancelled"]
|
||||
|
||||
|
||||
class OAuthState:
|
||||
"""Represents the state of an OAuth authorization flow"""
|
||||
|
||||
def __init__(self, provider: str):
|
||||
self.provider = provider
|
||||
self.status: AuthStatus = "pending"
|
||||
self.error: Optional[str] = None
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.result: Optional[Any] = None
|
||||
self.started_at = datetime.now()
|
||||
self.completed_at: Optional[datetime] = None
|
||||
self._cancel_event = threading.Event()
|
||||
self.server = None # Store the local server instance for forced shutdown
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested"""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def cancel(self):
|
||||
"""Request cancellation of the authorization flow"""
|
||||
self._cancel_event.set()
|
||||
self.status = "cancelled"
|
||||
self.completed_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert state to dictionary for API response"""
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"status": self.status,
|
||||
"error": self.error,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manager for tracking OAuth authorization flows"""
|
||||
|
||||
def __init__(self):
|
||||
self._states: Dict[str, OAuthState] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def create_state(self, provider: str) -> OAuthState:
|
||||
"""Create a new OAuth state for a provider"""
|
||||
with self._lock:
|
||||
# Cancel any existing authorization for this provider
|
||||
if provider in self._states:
|
||||
old_state = self._states[provider]
|
||||
if old_state.status in ["pending", "authorizing"]:
|
||||
old_state.cancel()
|
||||
logger.info(f"Cancelled previous {provider} authorization")
|
||||
|
||||
state = OAuthState(provider)
|
||||
self._states[provider] = state
|
||||
return state
|
||||
|
||||
def get_state(self, provider: str) -> Optional[OAuthState]:
|
||||
"""Get the current state for a provider"""
|
||||
with self._lock:
|
||||
return self._states.get(provider)
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
provider: str,
|
||||
status: AuthStatus,
|
||||
error: Optional[str] = None,
|
||||
result: Optional[Any] = None
|
||||
):
|
||||
"""Update the status of an authorization flow"""
|
||||
with self._lock:
|
||||
if provider in self._states:
|
||||
state = self._states[provider]
|
||||
state.status = status
|
||||
state.error = error
|
||||
state.result = result
|
||||
if status in ["success", "failed", "cancelled"]:
|
||||
state.completed_at = datetime.now()
|
||||
logger.info(f"Updated {provider} OAuth status to {status}")
|
||||
|
||||
# Global instance
|
||||
oauth_state_manager = OAuthStateManager()
|
||||
|
||||
|
|
@ -3,9 +3,11 @@ import httpx
|
|||
import asyncio
|
||||
import os
|
||||
import json
|
||||
from loguru import logger
|
||||
from app.service.chat_service import Chat
|
||||
from app.component.environment import env
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("sync_step")
|
||||
|
||||
|
||||
def sync_step(func):
|
||||
|
|
@ -28,7 +30,9 @@ def sync_step(func):
|
|||
send_to_api(
|
||||
sync_url,
|
||||
{
|
||||
"task_id": chat.task_id,
|
||||
# TODO: revert to task_id to support multi-task project replay
|
||||
# "task_id": chat.task_id,
|
||||
"task_id": chat.project_id,
|
||||
"step": json_data["step"],
|
||||
"data": json_data["data"],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -2,11 +2,15 @@ import datetime
|
|||
from camel.agents.chat_agent import AsyncStreamingChatAgentResponse
|
||||
from camel.societies.workforce.single_agent_worker import SingleAgentWorker as BaseSingleAgentWorker
|
||||
from camel.tasks.task import Task, TaskState, is_task_result_insufficient
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
from app.utils.agent import ListenChatAgent
|
||||
from camel.societies.workforce.prompts import PROCESS_TASK_PROMPT
|
||||
from colorama import Fore
|
||||
from camel.societies.workforce.utils import TaskResult
|
||||
from camel.utils.context_utils import ContextUtility
|
||||
|
||||
logger = traceroot.get_logger("single_agent_worker")
|
||||
|
||||
|
||||
class SingleAgentWorker(BaseSingleAgentWorker):
|
||||
|
|
@ -19,6 +23,8 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
pool_max_size: int = 10,
|
||||
auto_scale_pool: bool = True,
|
||||
use_structured_output_handler: bool = True,
|
||||
context_utility: ContextUtility | None = None,
|
||||
enable_workflow_memory: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
description=description,
|
||||
|
|
@ -28,6 +34,8 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
pool_max_size=pool_max_size,
|
||||
auto_scale_pool=auto_scale_pool,
|
||||
use_structured_output_handler=use_structured_output_handler,
|
||||
context_utility=context_utility,
|
||||
enable_workflow_memory=enable_workflow_memory,
|
||||
)
|
||||
self.worker = worker # change type hint
|
||||
|
||||
|
|
@ -54,6 +62,7 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
worker_agent.process_task_id = task.id # type: ignore rewrite line
|
||||
|
||||
response_content = ""
|
||||
final_response = None
|
||||
try:
|
||||
dependency_tasks_info = self._get_dep_tasks_info(dependencies)
|
||||
prompt = PROCESS_TASK_PROMPT.format(
|
||||
|
|
@ -130,8 +139,28 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
usage_info = response.info.get("usage") or response.info.get("token_usage")
|
||||
total_tokens = usage_info.get("total_tokens", 0) if usage_info else 0
|
||||
|
||||
# collect conversation from working agent to
|
||||
# accumulator for workflow memory
|
||||
# Only transfer memory if workflow memory is enabled
|
||||
if self.enable_workflow_memory:
|
||||
accumulator = self._get_conversation_accumulator()
|
||||
|
||||
# transfer all memory records from working agent to accumulator
|
||||
try:
|
||||
# retrieve all context records from the working agent
|
||||
work_records = worker_agent.memory.retrieve()
|
||||
|
||||
# write these records to the accumulator's memory
|
||||
memory_records = [record.memory_record for record in work_records]
|
||||
accumulator.memory.write_records(memory_records)
|
||||
|
||||
logger.debug(f"Transferred {len(memory_records)} memory records to accumulator")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to transfer conversation to accumulator: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}Error processing task {task.id}: {type(e).__name__}: {e}{Fore.RESET}")
|
||||
logger.error(f"Error processing task {task.id}: {type(e).__name__}: {e}")
|
||||
# Store error information in task result
|
||||
task.result = f"{type(e).__name__}: {e!s}"
|
||||
return TaskState.FAILED
|
||||
|
|
@ -144,6 +173,8 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
task.additional_info = {}
|
||||
|
||||
# Create worker attempt details with descriptive keys
|
||||
# Use final_response if available (streaming), otherwise use response
|
||||
response_for_info = final_response if final_response is not None else response
|
||||
worker_attempt_details = {
|
||||
"agent_id": getattr(worker_agent, "agent_id", worker_agent.role_name),
|
||||
"original_worker_id": getattr(self.worker, "agent_id", self.worker.role_name),
|
||||
|
|
@ -154,11 +185,7 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
f"{getattr(self.worker, 'agent_id', self.worker.role_name)}) "
|
||||
f"to process task: {task.content}",
|
||||
"response_content": response_content[:50],
|
||||
"tool_calls": str(
|
||||
final_response.info.get("tool_calls")
|
||||
if isinstance(response, AsyncStreamingChatAgentResponse)
|
||||
else response.info.get("tool_calls")
|
||||
)[:50],
|
||||
"tool_calls": str(response_for_info.info.get("tool_calls", []) if response_for_info and hasattr(response_for_info, 'info') else [])[:50],
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
|
|
@ -172,9 +199,12 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
|
||||
print(f"======\n{Fore.GREEN}Response from {self}:{Fore.RESET}")
|
||||
|
||||
logger.info(f"Response from {self}:")
|
||||
|
||||
if not self.use_structured_output_handler:
|
||||
# Handle native structured output parsing
|
||||
if task_result is None:
|
||||
logger.error("Error in worker step execution: Invalid task result")
|
||||
print(f"{Fore.RED}Error in worker step execution: Invalid task result{Fore.RESET}")
|
||||
task_result = TaskResult(
|
||||
content="Failed to generate valid task result.",
|
||||
|
|
@ -186,12 +216,17 @@ class SingleAgentWorker(BaseSingleAgentWorker):
|
|||
f"\n{color}{task_result.content}{Fore.RESET}\n======", # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
if task_result.failed: # type: ignore[union-attr]
|
||||
logger.error(f"{task_result.content}") # type: ignore[union-attr]
|
||||
else:
|
||||
logger.info(f"{task_result.content}") # type: ignore[union-attr]
|
||||
|
||||
task.result = task_result.content # type: ignore[union-attr]
|
||||
|
||||
if task_result.failed: # type: ignore[union-attr]
|
||||
return TaskState.FAILED
|
||||
|
||||
if is_task_result_insufficient(task):
|
||||
print(f"{Fore.RED}Task {task.id}: Content validation failed - task marked as failed{Fore.RESET}")
|
||||
logger.warning(f"Task {task.id}: Content validation failed - task marked as failed")
|
||||
return TaskState.FAILED
|
||||
return TaskState.DONE
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ from camel.toolkits import AudioAnalysisToolkit as BaseAudioAnalysisToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseAudioAnalysisToolkit)
|
||||
class AudioAnalysisToolkit(BaseAudioAnalysisToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.multi_modal_agent
|
||||
|
||||
|
|
@ -23,14 +24,3 @@ class AudioAnalysisToolkit(BaseAudioAnalysisToolkit, AbstractToolkit):
|
|||
cache_dir = env("file_save_path", os.path.expanduser("~/.eigent/tmp/"))
|
||||
super().__init__(cache_dir, transcribe_model, audio_reasoning_model, timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseAudioAnalysisToolkit.audio2text,
|
||||
lambda _, audio_path, question: f"transcribe audio from {audio_path} and ask question: {question}",
|
||||
)
|
||||
def ask_question_about_audio(self, audio_path: str, question: str) -> str:
|
||||
return super().ask_question_about_audio(audio_path, question)
|
||||
|
||||
@listen_toolkit(BaseAudioAnalysisToolkit.audio2text)
|
||||
def audio2text(self, audio_path: str) -> str:
|
||||
return super().audio2text(audio_path)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import List, Literal
|
||||
from camel.toolkits import CodeExecutionToolkit as BaseCodeExecutionToolkit, FunctionTool
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseCodeExecutionToolkit)
|
||||
class CodeExecutionToolkit(BaseCodeExecutionToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.developer_agent
|
||||
|
||||
|
|
@ -21,18 +22,6 @@ class CodeExecutionToolkit(BaseCodeExecutionToolkit, AbstractToolkit):
|
|||
self.api_task_id = api_task_id
|
||||
super().__init__(sandbox, verbose, unsafe_mode, import_white_list, require_confirm, timeout)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseCodeExecutionToolkit.execute_code,
|
||||
)
|
||||
def execute_code(self, code: str, code_type: str = "python") -> str:
|
||||
return super().execute_code(code, code_type)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseCodeExecutionToolkit.execute_command,
|
||||
)
|
||||
def execute_command(self, command: str) -> str | tuple[str, str]:
|
||||
return super().execute_command(command)
|
||||
|
||||
def get_tools(self) -> List[FunctionTool]:
|
||||
return [
|
||||
FunctionTool(self.execute_code),
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from camel.toolkits import Crawl4AIToolkit as BaseCrawl4AIToolkit
|
||||
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseCrawl4AIToolkit)
|
||||
class Crawl4AIToolkit(BaseCrawl4AIToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.search_agent
|
||||
|
||||
|
|
@ -12,18 +13,5 @@ class Crawl4AIToolkit(BaseCrawl4AIToolkit, AbstractToolkit):
|
|||
self.api_task_id = api_task_id
|
||||
super().__init__(timeout)
|
||||
|
||||
# async def _get_client(self):
|
||||
# r"""Get or create the AsyncWebCrawler client."""
|
||||
# if self._client is None:
|
||||
# from crawl4ai import AsyncWebCrawler
|
||||
|
||||
# self._client = AsyncWebCrawler(use_managed_browser=True)
|
||||
# await self._client.__aenter__()
|
||||
# return self._client
|
||||
|
||||
@listen_toolkit(BaseCrawl4AIToolkit.scrape)
|
||||
async def scrape(self, url: str) -> str:
|
||||
return await super().scrape(url)
|
||||
|
||||
def toolkit_name(self) -> str:
|
||||
return "Crawl Toolkit"
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from camel.toolkits import ExcelToolkit as BaseExcelToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseExcelToolkit)
|
||||
class ExcelToolkit(BaseExcelToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
|
||||
|
|
@ -20,7 +21,3 @@ class ExcelToolkit(BaseExcelToolkit, AbstractToolkit):
|
|||
if working_directory is None:
|
||||
working_directory = env("file_save_path", os.path.expanduser("~/Downloads"))
|
||||
super().__init__(timeout=timeout, working_directory=working_directory)
|
||||
|
||||
@listen_toolkit(BaseExcelToolkit.extract_excel_content)
|
||||
def extract_excel_content(self, document_path: str) -> str:
|
||||
return super().extract_excel_content(document_path)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@ from camel.toolkits import FileToolkit as BaseFileToolkit
|
|||
from app.component.environment import env
|
||||
from app.service.task import process_task
|
||||
from app.service.task import ActionWriteFileData, Agents, get_task_lock
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseFileToolkit)
|
||||
class FileToolkit(BaseFileToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
|
||||
|
|
@ -54,15 +55,3 @@ class FileToolkit(BaseFileToolkit, AbstractToolkit):
|
|||
)
|
||||
)
|
||||
return res
|
||||
|
||||
@listen_toolkit(
|
||||
BaseFileToolkit.read_file,
|
||||
)
|
||||
def read_file(self, file_paths: str | list[str]) -> str | dict[str, str]:
|
||||
return super().read_file(file_paths)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseFileToolkit.edit_file,
|
||||
)
|
||||
def edit_file(self, file_path: str, old_content: str, new_content: str) -> str:
|
||||
return super().edit_file(file_path, old_content, new_content)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from camel.toolkits import GithubToolkit as BaseGithubToolkit
|
|||
from camel.toolkits.function_tool import FunctionTool
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseGithubToolkit)
|
||||
class GithubToolkit(BaseGithubToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.developer_agent
|
||||
|
||||
|
|
@ -19,86 +20,6 @@ class GithubToolkit(BaseGithubToolkit, AbstractToolkit):
|
|||
super().__init__(access_token, timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.create_pull_request,
|
||||
lambda _,
|
||||
repo_name,
|
||||
file_path,
|
||||
new_content,
|
||||
pr_title,
|
||||
body,
|
||||
branch_name: f"Create PR in {repo_name} for {file_path} with title '{pr_title}', branch '{branch_name}', content '{new_content}'",
|
||||
)
|
||||
def create_pull_request(
|
||||
self,
|
||||
repo_name: str,
|
||||
file_path: str,
|
||||
new_content: str,
|
||||
pr_title: str,
|
||||
body: str,
|
||||
branch_name: str,
|
||||
) -> str:
|
||||
return super().create_pull_request(repo_name, file_path, new_content, pr_title, body, branch_name)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.get_issue_list,
|
||||
lambda _, repo_name, state="all": f"Get issue list from {repo_name} with state '{state}'",
|
||||
lambda issues: f"Retrieved {len(issues)} issues",
|
||||
)
|
||||
def get_issue_list(
|
||||
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
|
||||
) -> list[dict[str, object]]:
|
||||
return super().get_issue_list(repo_name, state)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.get_issue_content,
|
||||
lambda _, repo_name, issue_number: f"Get content of issue {issue_number} from {repo_name}",
|
||||
)
|
||||
def get_issue_content(self, repo_name: str, issue_number: int) -> str:
|
||||
return super().get_issue_content(repo_name, issue_number)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.get_pull_request_list,
|
||||
lambda _, repo_name, state="all": f"Get pull request list from {repo_name} with state '{state}'",
|
||||
lambda prs: f"Retrieved {len(prs)} pull requests",
|
||||
)
|
||||
def get_pull_request_list(
|
||||
self, repo_name: str, state: Literal["open", "closed", "all"] = "all"
|
||||
) -> list[dict[str, object]]:
|
||||
return super().get_pull_request_list(repo_name, state)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.get_pull_request_code,
|
||||
lambda _, repo_name, pr_number: f"Get code for pull request {pr_number} in {repo_name}",
|
||||
lambda code: f"Retrieved {len(code)} code files",
|
||||
)
|
||||
def get_pull_request_code(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
|
||||
return super().get_pull_request_code(repo_name, pr_number)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.get_pull_request_comments,
|
||||
lambda _, repo_name, pr_number: f"Get comments for pull request {pr_number} in {repo_name}",
|
||||
lambda comments: f"Retrieved {len(comments)} comments",
|
||||
)
|
||||
def get_pull_request_comments(self, repo_name: str, pr_number: int) -> list[dict[str, str]]:
|
||||
return super().get_pull_request_comments(repo_name, pr_number)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.get_all_file_paths,
|
||||
lambda _, repo_name, path="": f"Get all file paths from {repo_name}, path '{path}'",
|
||||
lambda paths: f"Retrieved {len(paths)} file paths",
|
||||
)
|
||||
def get_all_file_paths(self, repo_name: str, path: str = "") -> list[str]:
|
||||
return super().get_all_file_paths(repo_name, path)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseGithubToolkit.retrieve_file_content,
|
||||
lambda _, repo_name, file_path: f"Retrieve content of file {file_path} from {repo_name}",
|
||||
lambda content: f"Retrieved content of length {len(content)}",
|
||||
)
|
||||
def retrieve_file_content(self, repo_name: str, file_path: str) -> str:
|
||||
return super().retrieve_file_content(repo_name, file_path)
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
if env("GITHUB_ACCESS_TOKEN"):
|
||||
|
|
|
|||
|
|
@ -1,59 +1,272 @@
|
|||
from typing import Any, Dict, List
|
||||
import os
|
||||
import threading
|
||||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from app.utils.oauth_state_manager import oauth_state_manager
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
from camel.toolkits import GoogleCalendarToolkit as BaseGoogleCalendarToolkit
|
||||
|
||||
logger = traceroot.get_logger("main")
|
||||
|
||||
SCOPES = ['https://www.googleapis.com/auth/calendar']
|
||||
|
||||
@auto_listen_toolkit(BaseGoogleCalendarToolkit)
|
||||
class GoogleCalendarToolkit(BaseGoogleCalendarToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.social_medium_agent
|
||||
|
||||
def __init__(self, api_task_id: str, timeout: float | None = None):
|
||||
self.api_task_id = api_task_id
|
||||
self._token_path = (
|
||||
env("GOOGLE_CALENDAR_TOKEN_PATH")
|
||||
or os.path.join(
|
||||
os.path.expanduser("~"),
|
||||
".eigent",
|
||||
"tokens",
|
||||
"google_calendar",
|
||||
f"google_calendar_token_{api_task_id}.json",
|
||||
)
|
||||
)
|
||||
super().__init__(timeout)
|
||||
|
||||
@listen_toolkit(BaseGoogleCalendarToolkit.create_event)
|
||||
def create_event(
|
||||
self,
|
||||
event_title: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
description: str = "",
|
||||
location: str = "",
|
||||
attendees_email: List[str] | None = None,
|
||||
timezone: str = "UTC",
|
||||
) -> Dict[str, Any]:
|
||||
return super().create_event(event_title, start_time, end_time, description, location, attendees_email, timezone)
|
||||
|
||||
@listen_toolkit(BaseGoogleCalendarToolkit.get_events)
|
||||
def get_events(self, max_results: int = 10, time_min: str | None = None) -> List[Dict[str, Any]] | Dict[str, Any]:
|
||||
return super().get_events(max_results, time_min)
|
||||
|
||||
@listen_toolkit(BaseGoogleCalendarToolkit.update_event)
|
||||
def update_event(
|
||||
self,
|
||||
event_id: str,
|
||||
event_title: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees_email: List[str] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
return super().update_event(event_id, event_title, start_time, end_time, description, location, attendees_email)
|
||||
|
||||
@listen_toolkit(BaseGoogleCalendarToolkit.delete_event)
|
||||
def delete_event(self, event_id: str) -> str:
|
||||
return super().delete_event(event_id)
|
||||
|
||||
@listen_toolkit(BaseGoogleCalendarToolkit.get_calendar_details)
|
||||
def get_calendar_details(self) -> Dict[str, Any]:
|
||||
return super().get_calendar_details()
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str):
|
||||
if env("GOOGLE_CLIENT_ID") and env("GOOGLE_CLIENT_SECRET"):
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Force reload environment variables
|
||||
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
|
||||
if os.path.exists(default_env_path):
|
||||
load_dotenv(dotenv_path=default_env_path, override=True)
|
||||
|
||||
if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
|
||||
return cls(api_task_id).get_tools()
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_calendar_service(self):
|
||||
from googleapiclient.discovery import build
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
creds = self._authenticate()
|
||||
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self._token_path), exist_ok=True)
|
||||
with open(self._token_path, "w") as f:
|
||||
f.write(creds.to_json())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
|
||||
def _authenticate(self):
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from google.auth.transport.requests import Request
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Force reload environment variables from default .env file
|
||||
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
|
||||
if os.path.exists(default_env_path):
|
||||
load_dotenv(dotenv_path=default_env_path, override=True)
|
||||
|
||||
creds = None
|
||||
|
||||
# First, try to load from token file
|
||||
try:
|
||||
if os.path.exists(self._token_path):
|
||||
logger.info(f"Loading credentials from token file: {self._token_path}")
|
||||
creds = Credentials.from_authorized_user_file(self._token_path, SCOPES)
|
||||
logger.info("Successfully loaded credentials from token file")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load from token file: {e}")
|
||||
creds = None
|
||||
|
||||
# If no token file, try environment variables
|
||||
if not creds:
|
||||
client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
||||
client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
|
||||
refresh_token = os.environ.get("GOOGLE_REFRESH_TOKEN")
|
||||
token_uri = os.environ.get("GOOGLE_TOKEN_URI") or "https://oauth2.googleapis.com/token"
|
||||
|
||||
if refresh_token and client_id and client_secret:
|
||||
logger.info("Creating credentials from environment variables")
|
||||
creds = Credentials(
|
||||
None,
|
||||
refresh_token=refresh_token,
|
||||
token_uri=token_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=SCOPES,
|
||||
)
|
||||
|
||||
# If still no creds, check background authorization
|
||||
if not creds:
|
||||
state = oauth_state_manager.get_state("google_calendar")
|
||||
if state and state.status == "success" and state.result:
|
||||
logger.info("Using credentials from background authorization")
|
||||
creds = state.result
|
||||
else:
|
||||
# No credentials available
|
||||
raise ValueError("No credentials available. Please run authorization first via /api/install/tool/google_calendar")
|
||||
|
||||
# Refresh if expired
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
logger.info("Token expired, refreshing...")
|
||||
creds.refresh(Request())
|
||||
logger.info("Token refreshed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh token: {e}")
|
||||
raise ValueError("Failed to refresh expired token. Please re-authorize.")
|
||||
|
||||
# Save credentials
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self._token_path), exist_ok=True)
|
||||
with open(self._token_path, "w") as f:
|
||||
f.write(creds.to_json())
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save credentials: {e}")
|
||||
|
||||
return creds
|
||||
|
||||
@staticmethod
|
||||
def start_background_auth(api_task_id: str = "install_auth") -> str:
|
||||
"""
|
||||
Start background OAuth authorization flow with timeout
|
||||
Returns the status of the authorization
|
||||
"""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Force reload environment variables from default .env file
|
||||
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
|
||||
if os.path.exists(default_env_path):
|
||||
logger.info(f"Reloading environment variables from {default_env_path}")
|
||||
load_dotenv(dotenv_path=default_env_path, override=True)
|
||||
|
||||
# Check if there's an existing authorization and force stop it
|
||||
old_state = oauth_state_manager.get_state("google_calendar")
|
||||
if old_state and old_state.status in ["pending", "authorizing"]:
|
||||
logger.info("Found existing authorization, forcing shutdown...")
|
||||
old_state.cancel()
|
||||
# Try to shutdown the old server if it exists
|
||||
if hasattr(old_state, 'server') and old_state.server:
|
||||
try:
|
||||
old_state.server.shutdown()
|
||||
logger.info("Old server shutdown successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not shutdown old server: {e}")
|
||||
|
||||
# Create new state for this authorization
|
||||
state = oauth_state_manager.create_state("google_calendar")
|
||||
|
||||
def auth_flow():
|
||||
try:
|
||||
state.status = "authorizing"
|
||||
oauth_state_manager.update_status("google_calendar", "authorizing")
|
||||
|
||||
# Reload environment variables in this thread
|
||||
from dotenv import load_dotenv
|
||||
default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env")
|
||||
if os.path.exists(default_env_path):
|
||||
load_dotenv(dotenv_path=default_env_path, override=True)
|
||||
|
||||
client_id = os.environ.get("GOOGLE_CLIENT_ID")
|
||||
client_secret = os.environ.get("GOOGLE_CLIENT_SECRET")
|
||||
token_uri = os.environ.get("GOOGLE_TOKEN_URI") or "https://oauth2.googleapis.com/token"
|
||||
|
||||
logger.info(f"Google Calendar auth - client_id present: {bool(client_id)}, client_secret present: {bool(client_secret)}")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
error_msg = "GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET must be set in environment variables"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
client_config = {
|
||||
"installed": {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": token_uri,
|
||||
"redirect_uris": ["http://localhost"],
|
||||
}
|
||||
}
|
||||
logger.debug(f"calendar client_config initialized with client_id: {client_id[:10]}...")
|
||||
flow = InstalledAppFlow.from_client_config(client_config, SCOPES)
|
||||
|
||||
# Check for cancellation before starting
|
||||
if state.is_cancelled():
|
||||
logger.info("Authorization cancelled before starting")
|
||||
return
|
||||
|
||||
# This will automatically open browser and wait for user authorization
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"[Thread {threading.current_thread().name}] Starting local server for Google Calendar authorization")
|
||||
logger.info("Browser should open automatically in a moment...")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Run local server - this will block until authorization completes
|
||||
# Note: Each call uses a random port (port=0), so multiple concurrent attempts won't conflict
|
||||
try:
|
||||
creds = flow.run_local_server(
|
||||
port=0,
|
||||
authorization_prompt_message="",
|
||||
success_message="<h1>Authorization successful!</h1><p>You can close this window and return to Eigent.</p>",
|
||||
open_browser=True
|
||||
)
|
||||
logger.info("Authorization flow completed successfully!")
|
||||
except Exception as server_error:
|
||||
logger.error(f"Error during run_local_server: {server_error}")
|
||||
raise
|
||||
|
||||
# Check for cancellation after auth
|
||||
if state.is_cancelled():
|
||||
logger.info("Authorization cancelled after completion")
|
||||
return
|
||||
|
||||
# Save credentials to token file
|
||||
token_path = os.path.join(
|
||||
os.path.expanduser("~"),
|
||||
".eigent",
|
||||
"tokens",
|
||||
"google_calendar",
|
||||
f"google_calendar_token_{api_task_id}.json",
|
||||
)
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(token_path), exist_ok=True)
|
||||
with open(token_path, "w") as f:
|
||||
f.write(creds.to_json())
|
||||
logger.info(f"Saved Google Calendar credentials to {token_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save credentials: {e}")
|
||||
|
||||
# Update state with success
|
||||
oauth_state_manager.update_status("google_calendar", "success", result=creds)
|
||||
logger.info("Google Calendar authorization successful!")
|
||||
|
||||
except Exception as e:
|
||||
if state.is_cancelled():
|
||||
logger.info("Authorization was cancelled")
|
||||
oauth_state_manager.update_status("google_calendar", "cancelled")
|
||||
else:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Google Calendar authorization failed: {error_msg}")
|
||||
oauth_state_manager.update_status("google_calendar", "failed", error=error_msg)
|
||||
finally:
|
||||
# Clean up server reference
|
||||
state.server = None
|
||||
|
||||
# Start authorization in background thread
|
||||
thread = threading.Thread(target=auth_flow, daemon=True, name=f"GoogleCalendar-OAuth-{state.started_at.timestamp()}")
|
||||
state.thread = thread
|
||||
thread.start()
|
||||
|
||||
logger.info("Started background Google Calendar authorization")
|
||||
return "authorizing"
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
import asyncio
|
||||
from camel.toolkits.base import BaseToolkit
|
||||
from loguru import logger
|
||||
from camel.toolkits.function_tool import FunctionTool
|
||||
from app.service.task import Action, ActionAskData, ActionNoticeData, get_task_lock
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from app.service.task import process_task
|
||||
# Rewrite HumanToolkit because the system's user interaction was using console, but in electron we cannot use console. Changed to use SSE response to let frontend show dialog for user interaction
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("human_toolkit")
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseToolkit)
|
||||
class HumanToolkit(BaseToolkit, AbstractToolkit):
|
||||
r"""A class representing a toolkit for human interaction.
|
||||
Note:
|
||||
|
|
|
|||
|
|
@ -12,12 +12,14 @@ from camel.toolkits.hybrid_browser_toolkit_py.actions import ActionExecutor
|
|||
from camel.toolkits.hybrid_browser_toolkit_py.snapshot import PageSnapshot
|
||||
from camel.toolkits.hybrid_browser_toolkit_py.agent import PlaywrightLLMAgent
|
||||
from camel.toolkits.function_tool import FunctionTool
|
||||
from loguru import logger
|
||||
from app.component.environment import env
|
||||
from app.exception.exception import ProgramException
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("hybrid_browser_python_toolkit")
|
||||
|
||||
|
||||
class BrowserSession(BaseHybridBrowserSession):
|
||||
|
|
@ -124,6 +126,7 @@ class BrowserSession(BaseHybridBrowserSession):
|
|||
break
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseHybridBrowserToolkit)
|
||||
class HybridBrowserPythonToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.search_agent
|
||||
|
||||
|
|
@ -224,14 +227,6 @@ class HybridBrowserPythonToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
|||
self._agent: PlaywrightLLMAgent | None = None
|
||||
self._unified_script = self._load_unified_analyzer()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_open)
|
||||
async def browser_open(self) -> Dict[str, str]:
|
||||
return await super().browser_open()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_close)
|
||||
async def browser_close(self) -> str:
|
||||
return await super().browser_close()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_visit_page, lambda _, url: url)
|
||||
async def browser_visit_page(self, url: str) -> Dict[str, Any]:
|
||||
r"""Navigates to a URL.
|
||||
|
|
@ -282,66 +277,6 @@ class HybridBrowserPythonToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
|||
|
||||
return {"result": nav_result, "snapshot": snapshot, **tab_info}
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_back)
|
||||
async def browser_back(self) -> Dict[str, Any]:
|
||||
return await super().browser_back()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_forward)
|
||||
async def browser_forward(self) -> Dict[str, Any]:
|
||||
return await super().browser_forward()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_click)
|
||||
async def browser_click(self, *, ref: str) -> Dict[str, Any]:
|
||||
return await super().browser_click(ref=ref)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_type)
|
||||
async def browser_type(self, *, ref: str, text: str) -> Dict[str, Any]:
|
||||
return await super().browser_type(ref=ref, text=text)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_switch_tab)
|
||||
async def browser_switch_tab(self, *, tab_id: str) -> Dict[str, Any]:
|
||||
return await super().browser_switch_tab(tab_id=tab_id)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_select)
|
||||
async def browser_select(self, *, ref: str, value: str) -> Dict[str, str]:
|
||||
return await super().browser_select(ref=ref, value=value)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_scroll)
|
||||
async def browser_scroll(self, *, direction: str, amount: int) -> Dict[str, str]:
|
||||
return await super().browser_scroll(direction=direction, amount=amount)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_wait_user)
|
||||
async def browser_wait_user(self, timeout_sec: float | None = None) -> Dict[str, str]:
|
||||
return await super().browser_wait_user(timeout_sec)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_enter)
|
||||
async def browser_enter(self) -> Dict[str, str]:
|
||||
return await super().browser_enter()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_solve_task)
|
||||
async def browser_solve_task(self, task_prompt: str, start_url: str, max_steps: int = 15) -> str:
|
||||
return await super().browser_solve_task(task_prompt, start_url, max_steps)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_page_snapshot)
|
||||
async def browser_get_page_snapshot(self) -> str:
|
||||
return await super().browser_get_page_snapshot()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_som_screenshot)
|
||||
async def browser_get_som_screenshot(self):
|
||||
return await super().browser_get_som_screenshot()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_page_links)
|
||||
async def browser_get_page_links(self, *, ref: List[str]) -> Dict[str, Any]:
|
||||
return await super().browser_get_page_links(ref=ref)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_close_tab)
|
||||
async def browser_close_tab(self, *, tab_id: str) -> Dict[str, Any]:
|
||||
return await super().browser_close_tab(tab_id=tab_id)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_tab_info)
|
||||
async def browser_get_tab_info(self) -> Dict[str, Any]:
|
||||
return await super().browser_get_tab_info()
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
browser = HybridBrowserPythonToolkit(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import time
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from loguru import logger
|
||||
import websockets
|
||||
import websockets.exceptions
|
||||
|
||||
|
|
@ -16,8 +15,11 @@ from camel.toolkits.hybrid_browser_toolkit.ws_wrapper import WebSocketBrowserWra
|
|||
from app.component.command import bun, uv
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("hybrid_browser_toolkit")
|
||||
|
||||
|
||||
class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper):
|
||||
|
|
@ -45,8 +47,13 @@ class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper):
|
|||
future.set_result(response)
|
||||
logger.debug(f"Processed response for message {message_id}")
|
||||
else:
|
||||
# Log unexpected messages
|
||||
logger.warning(f"Received unexpected message: {response}")
|
||||
message_summary = {
|
||||
"id": response.get("id"),
|
||||
"success": response.get("success"),
|
||||
"has_result": "result" in response,
|
||||
"result_type": type(response.get("result")).__name__ if "result" in response else None
|
||||
}
|
||||
logger.debug(f"Received unexpected message: {message_summary}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
disconnect_reason = "Receive loop cancelled"
|
||||
|
|
@ -210,6 +217,7 @@ class WebSocketConnectionPool:
|
|||
websocket_connection_pool = WebSocketConnectionPool()
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseHybridBrowserToolkit)
|
||||
class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.search_agent
|
||||
|
||||
|
|
@ -240,7 +248,22 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
|||
cdp_keep_current_page: bool = False,
|
||||
full_visual_mode: bool = False,
|
||||
) -> None:
|
||||
logger.info(f"[HybridBrowserToolkit] Initializing with api_task_id: {api_task_id}")
|
||||
self.api_task_id = api_task_id
|
||||
logger.debug(f"[HybridBrowserToolkit] api_task_id set to: {self.api_task_id}")
|
||||
|
||||
# Set default user_data_dir if not provided
|
||||
if user_data_dir is None:
|
||||
# Use browser port to determine profile directory
|
||||
browser_port = env('browser_port', '9222')
|
||||
user_data_base = os.path.expanduser("~/.eigent/browser_profiles")
|
||||
user_data_dir = os.path.join(user_data_base, f"profile_{browser_port}")
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
logger.info(f"[HybridBrowserToolkit] Using port-based user_data_dir: {user_data_dir} (port: {browser_port})")
|
||||
else:
|
||||
logger.info(f"[HybridBrowserToolkit] Using provided user_data_dir: {user_data_dir}")
|
||||
|
||||
logger.debug(f"[HybridBrowserToolkit] Calling super().__init__ with session_id: {session_id}")
|
||||
super().__init__(
|
||||
headless=headless,
|
||||
user_data_dir=user_data_dir,
|
||||
|
|
@ -264,16 +287,24 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
|||
cdp_keep_current_page=cdp_keep_current_page,
|
||||
full_visual_mode=full_visual_mode,
|
||||
)
|
||||
logger.info(f"[HybridBrowserToolkit] Initialization complete for api_task_id: {self.api_task_id}")
|
||||
|
||||
async def _ensure_ws_wrapper(self):
|
||||
"""Ensure WebSocket wrapper is initialized using connection pool."""
|
||||
logger.debug(f"[HybridBrowserToolkit] _ensure_ws_wrapper called for api_task_id: {getattr(self, 'api_task_id', 'NOT SET')}")
|
||||
global websocket_connection_pool
|
||||
|
||||
# Get session ID from config or use default
|
||||
session_id = self._ws_config.get("session_id", "default")
|
||||
logger.debug(f"[HybridBrowserToolkit] Using session_id: {session_id}")
|
||||
|
||||
# Log when connecting to browser
|
||||
cdp_url = self._ws_config.get("cdp_url", f"http://localhost:{env('browser_port', '9222')}")
|
||||
logger.info(f"[PROJECT BROWSER] Connecting to browser via CDP at {cdp_url}")
|
||||
|
||||
# Get or create connection from pool
|
||||
self._ws_wrapper = await websocket_connection_pool.get_connection(session_id, self._ws_config)
|
||||
logger.info(f"[HybridBrowserToolkit] WebSocket wrapper initialized for session: {session_id}")
|
||||
|
||||
# Additional health check
|
||||
if self._ws_wrapper.websocket is None:
|
||||
|
|
@ -287,10 +318,16 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
|||
if new_session_id is None:
|
||||
new_session_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# For cloned sessions, use the same user_data_dir to share login state
|
||||
# This allows multiple agents to use the same browser profile without conflicts
|
||||
logger.info(f"Cloning session {new_session_id} with shared user_data_dir: {self._user_data_dir}")
|
||||
|
||||
# Use the same session_id to share the same browser instance
|
||||
# This ensures all clones use the same WebSocket connection and browser
|
||||
return HybridBrowserToolkit(
|
||||
self.api_task_id,
|
||||
headless=self._headless,
|
||||
user_data_dir=self._user_data_dir,
|
||||
user_data_dir=self._user_data_dir, # Use the same user_data_dir
|
||||
stealth=self._stealth,
|
||||
web_agent_model=self._web_agent_model,
|
||||
cache_dir=f"{self._cache_dir.rstrip('/')}/_clone_{new_session_id}/",
|
||||
|
|
@ -336,74 +373,3 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit):
|
|||
if hasattr(self, "_ws_wrapper") and self._ws_wrapper:
|
||||
session_id = self._ws_config.get("session_id", "default")
|
||||
logger.debug(f"HybridBrowserToolkit for session {session_id} is being garbage collected")
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_open)
|
||||
async def browser_open(self) -> Dict[str, Any]:
|
||||
return await super().browser_open()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_close)
|
||||
async def browser_close(self) -> str:
|
||||
return await super().browser_close()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_visit_page)
|
||||
async def browser_visit_page(self, url: str) -> Dict[str, Any]:
|
||||
logger.debug(f"browser_visit_page called with URL: {url}")
|
||||
try:
|
||||
result = await super().browser_visit_page(url)
|
||||
logger.debug(f"browser_visit_page succeeded for URL: {url}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"browser_visit_page failed for URL {url}: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_back)
|
||||
async def browser_back(self) -> Dict[str, Any]:
|
||||
return await super().browser_back()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_forward)
|
||||
async def browser_forward(self) -> Dict[str, Any]:
|
||||
return await super().browser_forward()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_page_snapshot)
|
||||
async def browser_get_page_snapshot(self) -> str:
|
||||
return await super().browser_get_page_snapshot()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_som_screenshot)
|
||||
async def browser_get_som_screenshot(self, read_image: bool = False, instruction: str | None = None) -> str:
|
||||
return await super().browser_get_som_screenshot(read_image, instruction)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_click)
|
||||
async def browser_click(self, *, ref: str) -> Dict[str, Any]:
|
||||
return await super().browser_click(ref=ref)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_type)
|
||||
async def browser_type(self, *, ref: str, text: str) -> Dict[str, Any]:
|
||||
return await super().browser_type(ref=ref, text=text)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_select)
|
||||
async def browser_select(self, *, ref: str, value: str) -> Dict[str, Any]:
|
||||
return await super().browser_select(ref=ref, value=value)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_scroll)
|
||||
async def browser_scroll(self, *, direction: str, amount: int = 500) -> Dict[str, Any]:
|
||||
return await super().browser_scroll(direction=direction, amount=amount)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_enter)
|
||||
async def browser_enter(self) -> Dict[str, Any]:
|
||||
return await super().browser_enter()
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_wait_user)
|
||||
async def browser_wait_user(self, timeout_sec: float | None = None) -> Dict[str, Any]:
|
||||
return await super().browser_wait_user(timeout_sec)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_switch_tab)
|
||||
async def browser_switch_tab(self, *, tab_id: str) -> Dict[str, Any]:
|
||||
return await super().browser_switch_tab(tab_id=tab_id)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_close_tab)
|
||||
async def browser_close_tab(self, *, tab_id: str) -> Dict[str, Any]:
|
||||
return await super().browser_close_tab(tab_id=tab_id)
|
||||
|
||||
@listen_toolkit(BaseHybridBrowserToolkit.browser_get_tab_info)
|
||||
async def browser_get_tab_info(self) -> Dict[str, Any]:
|
||||
return await super().browser_get_tab_info()
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ from camel.models import BaseModelBackend
|
|||
from camel.toolkits import ImageAnalysisToolkit as BaseImageAnalysisToolkit
|
||||
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseImageAnalysisToolkit)
|
||||
class ImageAnalysisToolkit(BaseImageAnalysisToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.multi_modal_agent
|
||||
|
||||
|
|
@ -17,24 +18,3 @@ class ImageAnalysisToolkit(BaseImageAnalysisToolkit, AbstractToolkit):
|
|||
):
|
||||
super().__init__(model, timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseImageAnalysisToolkit.image_to_text,
|
||||
lambda _,
|
||||
image_path,
|
||||
sys_prompt: f"transcribe image from {image_path} and ask sys_prompt: {sys_prompt}",
|
||||
)
|
||||
def image_to_text(self, image_path: str, sys_prompt: str | None = None) -> str:
|
||||
return super().image_to_text(image_path, sys_prompt)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseImageAnalysisToolkit.ask_question_about_image,
|
||||
lambda _,
|
||||
image_path,
|
||||
question,
|
||||
sys_prompt: f"transcribe image from {image_path} and ask question: {question} with sys_prompt: {sys_prompt}",
|
||||
)
|
||||
def ask_question_about_image(
|
||||
self, image_path: str, question: str, sys_prompt: str | None = None
|
||||
) -> str:
|
||||
return super().ask_question_about_image(image_path, question, sys_prompt)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ from camel.toolkits import LinkedInToolkit as BaseLinkedInToolkit
|
|||
from camel.toolkits.function_tool import FunctionTool
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseLinkedInToolkit)
|
||||
class LinkedInToolkit(BaseLinkedInToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.social_medium_agent
|
||||
|
||||
|
|
@ -13,27 +14,6 @@ class LinkedInToolkit(BaseLinkedInToolkit, AbstractToolkit):
|
|||
super().__init__(timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseLinkedInToolkit.create_post,
|
||||
lambda _, text: f"create a LinkedIn post with text: {text}",
|
||||
)
|
||||
def create_post(self, text: str) -> dict:
|
||||
return super().create_post(text)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseLinkedInToolkit.delete_post,
|
||||
lambda _, post_id: f"delete LinkedIn post with id: {post_id}",
|
||||
)
|
||||
def delete_post(self, post_id: str) -> str:
|
||||
return super().delete_post(post_id)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseLinkedInToolkit.get_profile,
|
||||
lambda _, include_id: f"get LinkedIn profile with include_id: {include_id}",
|
||||
)
|
||||
def get_profile(self, include_id: bool = False) -> dict:
|
||||
return super().get_profile(include_id)
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
if env("LINKEDIN_ACCESS_TOKEN"):
|
||||
|
|
|
|||
|
|
@ -2,17 +2,14 @@ from typing import Dict, List
|
|||
from camel.toolkits import MarkItDownToolkit as BaseMarkItDownToolkit
|
||||
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseMarkItDownToolkit)
|
||||
class MarkItDownToolkit(BaseMarkItDownToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
|
||||
def __init__(self, api_task_id: str, timeout: float | None = None):
|
||||
self.api_task_id = api_task_id
|
||||
super().__init__(timeout)
|
||||
|
||||
@listen_toolkit(BaseMarkItDownToolkit.read_files)
|
||||
def read_files(self, file_paths: List[str]) -> Dict[str, str]:
|
||||
return super().read_files(file_paths)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@ from typing import Optional
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseNoteTakingToolkit)
|
||||
class NoteTakingToolkit(BaseNoteTakingToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
|
||||
|
|
@ -25,19 +26,3 @@ class NoteTakingToolkit(BaseNoteTakingToolkit, AbstractToolkit):
|
|||
if working_directory is None:
|
||||
working_directory = env("file_save_path", os.path.expanduser("~/.eigent/notes")) + "/note.md"
|
||||
super().__init__(working_directory=working_directory, timeout=timeout)
|
||||
|
||||
@listen_toolkit(BaseNoteTakingToolkit.append_note)
|
||||
def append_note(self, note_name: str, content: str) -> str:
|
||||
return super().append_note(note_name=note_name, content=content)
|
||||
|
||||
@listen_toolkit(BaseNoteTakingToolkit.read_note)
|
||||
def read_note(self, note_name: Optional[str] = "all_notes") -> str:
|
||||
return super().read_note(note_name=note_name)
|
||||
|
||||
@listen_toolkit(BaseNoteTakingToolkit.create_note)
|
||||
def create_note(self, note_name: str, content: str, overwrite: bool = False) -> str:
|
||||
return super().create_note(note_name=note_name, content=content, overwrite=overwrite)
|
||||
|
||||
@listen_toolkit(BaseNoteTakingToolkit.list_note)
|
||||
def list_note(self) -> str:
|
||||
return super().list_note()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,36 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from textwrap import indent
|
||||
from typing import Any, Dict, List
|
||||
from loguru import logger
|
||||
from camel.toolkits import FunctionTool
|
||||
from app.component.environment import env
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("notion_mcp_toolkit")
|
||||
|
||||
def _customize_function_parameters(schema: Dict[str, Any]) -> None:
|
||||
r"""Customize function parameters for specific functions.
|
||||
|
||||
This method allows modifying parameter descriptions or other schema
|
||||
attributes for specific functions.
|
||||
"""
|
||||
function_info = schema.get("function", {})
|
||||
function_name = function_info.get("name", "")
|
||||
parameters = function_info.get("parameters", {})
|
||||
properties = parameters.get("properties", {})
|
||||
required = parameters.get("required", [])
|
||||
|
||||
help_description = "If you need use parent, you can use `notion-search` for the information"
|
||||
# Modify the notion-create-pages function to make parent optional
|
||||
if function_name == "notion-create-pages" or function_name == "notion-create-database":
|
||||
required.remove("parent")
|
||||
parameters["required"] = required
|
||||
if "parent" in properties:
|
||||
# Update the parent parameter description
|
||||
properties["parent"]["description"] = "Optional. " + properties["parent"]["description"] + help_description
|
||||
|
||||
class NotionMCPToolkit(MCPToolkit, AbstractToolkit):
|
||||
|
||||
|
|
@ -33,80 +58,57 @@ class NotionMCPToolkit(MCPToolkit, AbstractToolkit):
|
|||
}
|
||||
}
|
||||
}
|
||||
super().__init__(config_dict=config_dict, timeout=timeout)
|
||||
|
||||
def get_tools(self) -> List[FunctionTool]:
|
||||
r"""Returns a list of tools provided by the NotionMCPToolkit.
|
||||
|
||||
Returns:
|
||||
List[FunctionTool]: List of available tools.
|
||||
"""
|
||||
all_tools = []
|
||||
for client in self.clients:
|
||||
try:
|
||||
original_build_schema = client._build_tool_schema
|
||||
|
||||
def create_wrapper(orig_func):
|
||||
def wrapper(mcp_tool):
|
||||
return self._build_custom_tool_schema(
|
||||
mcp_tool, orig_func
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
client._build_tool_schema = create_wrapper( # type: ignore[method-assign]
|
||||
original_build_schema
|
||||
)
|
||||
|
||||
client_tools = client.get_tools()
|
||||
all_tools.extend(client_tools)
|
||||
|
||||
client._build_tool_schema = original_build_schema # type: ignore[method-assign]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get tools from client: {e}")
|
||||
return all_tools
|
||||
|
||||
def _build_custom_tool_schema(self, mcp_tool, original_build_schema):
|
||||
r"""Build tool schema with custom modifications."""
|
||||
schema = original_build_schema(mcp_tool)
|
||||
self._customize_function_parameters(schema)
|
||||
return schema
|
||||
|
||||
def _customize_function_parameters(self, schema: Dict[str, Any]) -> None:
|
||||
r"""Customize function parameters for specific functions.
|
||||
|
||||
This method allows modifying parameter descriptions or other schema
|
||||
attributes for specific functions.
|
||||
"""
|
||||
function_info = schema.get("function", {})
|
||||
function_name = function_info.get("name", "")
|
||||
parameters = function_info.get("parameters", {})
|
||||
properties = parameters.get("properties", {})
|
||||
|
||||
# Modify the notion-create-pages function to make parent optional
|
||||
if function_name == "notion-create-pages":
|
||||
if "parent" in properties:
|
||||
# Update the parent parameter description
|
||||
properties["parent"]["description"] = (
|
||||
"Optional. The parent under which the new pages will be created. "
|
||||
"This can be a page (page_id), a database page (database_id), or "
|
||||
"a data source/collection under a database (data_source_id). "
|
||||
"If omitted, the new pages will be created as private pages at the workspace level. "
|
||||
"Use data_source_id when you have a collection:// URL from the fetch tool."
|
||||
)
|
||||
super().__init__(config_dict=config_dict, timeout=timeout)
|
||||
|
||||
@classmethod
|
||||
async def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
tools = []
|
||||
toolkit = cls(api_task_id)
|
||||
try:
|
||||
await toolkit.connect()
|
||||
# Use subclass implementation that inlines upstream processing
|
||||
all_tools = toolkit.get_tools()
|
||||
for item in all_tools:
|
||||
setattr(item, "_toolkit_name", cls.__name__)
|
||||
tools.append(item)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not connect to Notion MCP server: {e}")
|
||||
return tools
|
||||
# Retry mechanism for remote MCP connection
|
||||
max_retries = 3
|
||||
retry_delay = 2 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
tools = []
|
||||
toolkit = None
|
||||
|
||||
try:
|
||||
# Create a fresh toolkit instance for each retry
|
||||
toolkit = cls(api_task_id)
|
||||
logger.info(f"Attempting to connect to Notion MCP server (attempt {attempt + 1}/{max_retries})")
|
||||
|
||||
await toolkit.connect()
|
||||
|
||||
# Get tools from the connected toolkit
|
||||
all_tools = toolkit.get_tools()
|
||||
tool_schema = [
|
||||
item.get_openai_tool_schema() for item in all_tools
|
||||
]
|
||||
|
||||
# Adjust tool schema
|
||||
for item in tool_schema:
|
||||
_customize_function_parameters(item)
|
||||
|
||||
for item in all_tools:
|
||||
setattr(item, "_toolkit_name", cls.__name__)
|
||||
tools.append(item)
|
||||
|
||||
# Check if we actually got tools
|
||||
if len(tools) == 0:
|
||||
logger.warning(f"Connected to Notion MCP server but got 0 tools (attempt {attempt + 1}/{max_retries})")
|
||||
raise Exception("No tools retrieved from Notion MCP server")
|
||||
|
||||
# Success! Got tools
|
||||
logger.info(f"Successfully connected to Notion MCP server and loaded {len(tools)} tools")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to Notion MCP server (attempt {attempt + 1}/{max_retries}): {e}")
|
||||
|
||||
# If not the last attempt, wait and retry
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
# Last attempt failed
|
||||
logger.error(f"All {max_retries} connection attempts to Notion MCP server failed. Notion tools will not be available for this task.")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from camel.toolkits import NotionToolkit as BaseNotionToolkit
|
|||
from camel.toolkits.function_tool import FunctionTool
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseNotionToolkit)
|
||||
class NotionToolkit(BaseNotionToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
|
||||
|
|
@ -19,29 +20,6 @@ class NotionToolkit(BaseNotionToolkit, AbstractToolkit):
|
|||
super().__init__(notion_token, timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseNotionToolkit.list_all_pages,
|
||||
lambda _: "list all pages in Notion workspace",
|
||||
lambda result: f"{len(result)} pages found",
|
||||
)
|
||||
def list_all_pages(self) -> List[dict]:
|
||||
return super().list_all_pages()
|
||||
|
||||
@listen_toolkit(
|
||||
BaseNotionToolkit.list_all_users,
|
||||
lambda _: "list all users in Notion workspace",
|
||||
lambda result: f"{len(result)} users found",
|
||||
)
|
||||
def list_all_users(self) -> List[dict]:
|
||||
return super().list_all_users()
|
||||
|
||||
@listen_toolkit(
|
||||
BaseNotionToolkit.get_notion_block_text_content,
|
||||
lambda _, page_id: f"get text content of page with id: {page_id}",
|
||||
)
|
||||
def get_notion_block_text_content(self, block_id: str) -> str:
|
||||
return super().get_notion_block_text_content(block_id)
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> List[FunctionTool]:
|
||||
if env("NOTION_TOKEN"):
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ from camel.toolkits import OpenAIImageToolkit as BaseOpenAIImageToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from typing import Literal, Optional, Union, List
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseOpenAIImageToolkit)
|
||||
class OpenAIImageToolkit(BaseOpenAIImageToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.multi_modal_agent
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@ from camel.toolkits import PPTXToolkit as BasePPTXToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import ActionWriteFileData, Agents, get_task_lock
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from app.service.task import process_task
|
||||
|
||||
|
||||
@auto_listen_toolkit(BasePPTXToolkit)
|
||||
class PPTXToolkit(BasePPTXToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.document_agent
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ from camel.toolkits import PyAutoGUIToolkit as BasePyAutoGUIToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BasePyAutoGUIToolkit)
|
||||
class PyAutoGUIToolkit(BasePyAutoGUIToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.search_agent
|
||||
|
||||
|
|
@ -21,69 +22,3 @@ class PyAutoGUIToolkit(BasePyAutoGUIToolkit, AbstractToolkit):
|
|||
screenshots_dir = env("file_save_path", os.path.expanduser("~/Downloads"))
|
||||
super().__init__(timeout, screenshots_dir)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(BasePyAutoGUIToolkit.mouse_move, lambda _, x, y: f"mouse move to {x}, {y}")
|
||||
def mouse_move(self, x: int, y: int) -> str:
|
||||
return super().mouse_move(x, y)
|
||||
|
||||
@listen_toolkit(
|
||||
BasePyAutoGUIToolkit.mouse_click,
|
||||
lambda _, button="left", clicks=1, x=None, y=None: f"mouse click {button} {clicks} times at {x}, {y}",
|
||||
)
|
||||
def mouse_click(
|
||||
self,
|
||||
button: Literal["left", "middle", "right"] = "left",
|
||||
clicks: int = 1,
|
||||
x: int | None = None,
|
||||
y: int | None = None,
|
||||
) -> str:
|
||||
return super().mouse_click(button, clicks, x, y)
|
||||
|
||||
@listen_toolkit(
|
||||
BasePyAutoGUIToolkit.keyboard_type,
|
||||
lambda _, text, interval=0: f"keyboard type {text}, interval {interval}",
|
||||
)
|
||||
def keyboard_type(self, text: str, interval: float = 0) -> str:
|
||||
return super().keyboard_type(text, interval)
|
||||
|
||||
@listen_toolkit(BasePyAutoGUIToolkit.take_screenshot)
|
||||
def take_screenshot(self) -> str:
|
||||
return super().take_screenshot()
|
||||
|
||||
@listen_toolkit(BasePyAutoGUIToolkit.get_mouse_position)
|
||||
def get_mouse_position(self) -> str:
|
||||
return super().get_mouse_position()
|
||||
|
||||
@listen_toolkit(BasePyAutoGUIToolkit.press_key, lambda _, key: f"press key {key}")
|
||||
def press_key(self, key: str | list[str]) -> str:
|
||||
return super().press_key(key)
|
||||
|
||||
@listen_toolkit(BasePyAutoGUIToolkit.hotkey, lambda _, keys: f"hotkey {keys}")
|
||||
def hotkey(self, keys: List[str]) -> str:
|
||||
return super().hotkey(keys)
|
||||
|
||||
@listen_toolkit(
|
||||
BasePyAutoGUIToolkit.mouse_drag,
|
||||
lambda _,
|
||||
start_x,
|
||||
start_y,
|
||||
end_x,
|
||||
end_y,
|
||||
button="left": f"mouse drag from {start_x}, {start_y} to {end_x}, {end_y} with {button} button",
|
||||
)
|
||||
def mouse_drag(
|
||||
self,
|
||||
start_x: int,
|
||||
start_y: int,
|
||||
end_x: int,
|
||||
end_y: int,
|
||||
button: Literal["left", "middle", "right"] = "left",
|
||||
) -> str:
|
||||
return super().mouse_drag(start_x, start_y, end_x, end_y, button)
|
||||
|
||||
@listen_toolkit(
|
||||
BasePyAutoGUIToolkit.scroll,
|
||||
lambda _, scroll_amount, x=None, y=None: f"scroll {scroll_amount} at {x}, {y}",
|
||||
)
|
||||
def scroll(self, scroll_amount: int, x: int | None = None, y: int | None = None) -> str:
|
||||
return super().scroll(scroll_amount, x, y)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from camel.toolkits import RedditToolkit as BaseRedditToolkit
|
|||
from camel.toolkits.function_tool import FunctionTool
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseRedditToolkit)
|
||||
class RedditToolkit(BaseRedditToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.social_medium_agent
|
||||
|
||||
|
|
@ -20,47 +21,6 @@ class RedditToolkit(BaseRedditToolkit, AbstractToolkit):
|
|||
super().__init__(retries, delay, timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseRedditToolkit.collect_top_posts,
|
||||
lambda _,
|
||||
subreddit_name,
|
||||
post_limit=5,
|
||||
comment_limit=5: f"collect top posts from subreddit: {subreddit_name} with post limit: {post_limit} and comment limit: {comment_limit}",
|
||||
lambda result: f"top posts collected: {result}",
|
||||
)
|
||||
def collect_top_posts(
|
||||
self, subreddit_name: str, post_limit: int = 5, comment_limit: int = 5
|
||||
) -> List[Dict[str, Any]] | str:
|
||||
return super().collect_top_posts(subreddit_name, post_limit, comment_limit)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseRedditToolkit.perform_sentiment_analysis,
|
||||
lambda _, data: f"perform sentiment analysis on data number: {len(data)}",
|
||||
lambda result: f"perform analysis result: {result}",
|
||||
)
|
||||
def perform_sentiment_analysis(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
return super().perform_sentiment_analysis(data)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseRedditToolkit.track_keyword_discussions,
|
||||
lambda _,
|
||||
subreddits,
|
||||
keywords,
|
||||
post_limit=10,
|
||||
comment_limit=10,
|
||||
sentiment_analysis=False: f"track keyword discussions for subreddits: {subreddits}, keywords: {keywords}",
|
||||
lambda result: f"track keyword discussions result: {result}",
|
||||
)
|
||||
def track_keyword_discussions(
|
||||
self,
|
||||
subreddits: List[str],
|
||||
keywords: List[str],
|
||||
post_limit: int = 10,
|
||||
comment_limit: int = 10,
|
||||
sentiment_analysis: bool = False,
|
||||
) -> List[Dict[str, Any]] | str:
|
||||
return super().track_keyword_discussions(subreddits, keywords, post_limit, comment_limit, sentiment_analysis)
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
if env("REDDIT_CLIENT_ID") and env("REDDIT_CLIENT_SECRET") and env("REDDIT_USER_AGENT"):
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from camel.toolkits import ScreenshotToolkit as BaseScreenshotToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseScreenshotToolkit)
|
||||
class ScreenshotToolkit(BaseScreenshotToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.developer_agent
|
||||
|
||||
|
|
@ -15,13 +16,3 @@ class ScreenshotToolkit(BaseScreenshotToolkit, AbstractToolkit):
|
|||
if working_directory is None:
|
||||
working_directory = env("file_save_path", os.path.expanduser("~/Downloads"))
|
||||
super().__init__(working_directory, timeout)
|
||||
|
||||
@listen_toolkit(BaseScreenshotToolkit.take_screenshot_and_read_image)
|
||||
def take_screenshot_and_read_image(
|
||||
self, filename: str, save_to_file: bool = True, read_image: bool = True, instruction: str | None = None
|
||||
) -> str:
|
||||
return super().take_screenshot_and_read_image(filename, save_to_file, read_image, instruction)
|
||||
|
||||
@listen_toolkit(BaseScreenshotToolkit.read_image)
|
||||
def read_image(self, image_path: str, instruction: str = "") -> str:
|
||||
return super().read_image(image_path, instruction)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,17 @@ from typing import Any, Dict, List, Literal
|
|||
from camel.toolkits import SearchToolkit as BaseSearchToolkit
|
||||
from camel.toolkits.function_tool import FunctionTool
|
||||
import httpx
|
||||
from loguru import logger
|
||||
import os
|
||||
from app.component.environment import env, env_not_empty
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("search_toolkit")
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseSearchToolkit)
|
||||
class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.search_agent
|
||||
|
||||
|
|
@ -25,6 +29,32 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
|
|||
super().__init__(
|
||||
timeout=timeout, exclude_domains=exclude_domains
|
||||
)
|
||||
# Cache for user-specific search configurations
|
||||
self._user_google_api_key = None
|
||||
self._user_search_engine_id = None
|
||||
self._config_loaded = False
|
||||
|
||||
def _load_user_search_config(self):
|
||||
"""
|
||||
Load user-specific Google Search configuration from user's .env file.
|
||||
This is called lazily when search_google is invoked.
|
||||
"""
|
||||
if self._config_loaded:
|
||||
return
|
||||
|
||||
self._config_loaded = True
|
||||
|
||||
# Try to get user-specific configuration from thread-local environment
|
||||
# which is set by the middleware based on the user's project settings
|
||||
google_api_key = env("GOOGLE_API_KEY")
|
||||
search_engine_id = env("SEARCH_ENGINE_ID")
|
||||
|
||||
if google_api_key and search_engine_id:
|
||||
self._user_google_api_key = google_api_key
|
||||
self._user_search_engine_id = search_engine_id
|
||||
logger.info("Loaded user-specific Google Search configuration")
|
||||
else:
|
||||
logger.debug("No user-specific Google Search configuration found, will use cloud search")
|
||||
|
||||
# @listen_toolkit(BaseSearchToolkit.search_wiki)
|
||||
# def search_wiki(self, entity: str) -> str:
|
||||
|
|
@ -50,19 +80,61 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
|
|||
|
||||
@listen_toolkit(
|
||||
BaseSearchToolkit.search_google,
|
||||
lambda _, query, search_type="web": f"with query '{query}' and {search_type} result pages",
|
||||
lambda _, query, search_type="web", number_of_result_pages=10, start_page=1: f"with query '{query}', {search_type} type, {number_of_result_pages} result pages starting from page {start_page}",
|
||||
)
|
||||
def search_google(self, query: str, search_type: str = "web") -> list[dict[str, Any]]:
|
||||
if env("GOOGLE_API_KEY") and env("SEARCH_ENGINE_ID"):
|
||||
return super().search_google(query, search_type)
|
||||
else:
|
||||
return self.cloud_search_google(query, search_type)
|
||||
def search_google(
|
||||
self,
|
||||
query: str,
|
||||
search_type: str = "web",
|
||||
number_of_result_pages: int = 10,
|
||||
start_page: int = 1
|
||||
) -> list[dict[str, Any]]:
|
||||
# Load user-specific configuration
|
||||
self._load_user_search_config()
|
||||
|
||||
def cloud_search_google(self, query: str, search_type):
|
||||
# If user has configured their own Google API keys, use them
|
||||
if self._user_google_api_key and self._user_search_engine_id:
|
||||
logger.info("Using user-configured Google Search API")
|
||||
# Temporarily set environment variables for this search
|
||||
old_google_key = os.environ.get("GOOGLE_API_KEY")
|
||||
old_search_id = os.environ.get("SEARCH_ENGINE_ID")
|
||||
|
||||
try:
|
||||
os.environ["GOOGLE_API_KEY"] = self._user_google_api_key
|
||||
os.environ["SEARCH_ENGINE_ID"] = self._user_search_engine_id
|
||||
return super().search_google(query, search_type, number_of_result_pages, start_page)
|
||||
finally:
|
||||
# Restore original environment variables
|
||||
if old_google_key is not None:
|
||||
os.environ["GOOGLE_API_KEY"] = old_google_key
|
||||
elif "GOOGLE_API_KEY" in os.environ:
|
||||
del os.environ["GOOGLE_API_KEY"]
|
||||
|
||||
if old_search_id is not None:
|
||||
os.environ["SEARCH_ENGINE_ID"] = old_search_id
|
||||
elif "SEARCH_ENGINE_ID" in os.environ:
|
||||
del os.environ["SEARCH_ENGINE_ID"]
|
||||
else:
|
||||
# Fallback to cloud search
|
||||
logger.info("Using cloud Google Search (no user configuration found)")
|
||||
return self.cloud_search_google(query, search_type, number_of_result_pages, start_page)
|
||||
|
||||
def cloud_search_google(
|
||||
self,
|
||||
query: str,
|
||||
search_type: str = "web",
|
||||
number_of_result_pages: int = 10,
|
||||
start_page: int = 1
|
||||
):
|
||||
url = env_not_empty("SERVER_URL")
|
||||
res = httpx.get(
|
||||
url + "/proxy/google",
|
||||
params={"query": query, "search_type": search_type},
|
||||
params={
|
||||
"query": query,
|
||||
"search_type": search_type,
|
||||
"number_of_result_pages": number_of_result_pages,
|
||||
"start_page": start_page
|
||||
},
|
||||
headers={"api-key": env_not_empty("cloud_api_key")},
|
||||
)
|
||||
return res.json()
|
||||
|
|
@ -163,73 +235,73 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
|
|||
# def search_bing(self, query: str) -> dict[str, Any]:
|
||||
# return super().search_bing(query)
|
||||
|
||||
@listen_toolkit(BaseSearchToolkit.search_exa, lambda _, query, *args, **kwargs: f"{query}, {args}, {kwargs}")
|
||||
def search_exa(
|
||||
self,
|
||||
query: str,
|
||||
search_type: Literal["auto", "neural", "keyword"] = "auto",
|
||||
category: None
|
||||
| Literal[
|
||||
"company",
|
||||
"research paper",
|
||||
"news",
|
||||
"pdf",
|
||||
"github",
|
||||
"tweet",
|
||||
"personal site",
|
||||
"linkedin profile",
|
||||
"financial report",
|
||||
] = None,
|
||||
include_text: List[str] | None = None,
|
||||
exclude_text: List[str] | None = None,
|
||||
use_autoprompt: bool = True,
|
||||
text: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if env("EXA_API_KEY"):
|
||||
res = super().search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
|
||||
return res
|
||||
else:
|
||||
return self.cloud_search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
|
||||
|
||||
def cloud_search_exa(
|
||||
self,
|
||||
query: str,
|
||||
search_type: Literal["auto", "neural", "keyword"] = "auto",
|
||||
category: None
|
||||
| Literal[
|
||||
"company",
|
||||
"research paper",
|
||||
"news",
|
||||
"pdf",
|
||||
"github",
|
||||
"tweet",
|
||||
"personal site",
|
||||
"linkedin profile",
|
||||
"financial report",
|
||||
] = None,
|
||||
include_text: List[str] | None = None,
|
||||
exclude_text: List[str] | None = None,
|
||||
use_autoprompt: bool = True,
|
||||
text: bool = False,
|
||||
):
|
||||
url = env_not_empty("SERVER_URL")
|
||||
logger.debug(f">>>>>>>>>>>>>>>>{url}<<<<")
|
||||
res = httpx.post(
|
||||
url + "/proxy/exa",
|
||||
json={
|
||||
"query": query,
|
||||
"search_type": search_type,
|
||||
"category": category,
|
||||
"include_text": include_text,
|
||||
"exclude_text": exclude_text,
|
||||
"use_autoprompt": use_autoprompt,
|
||||
"text": text,
|
||||
},
|
||||
headers={"api-key": env_not_empty("cloud_api_key")},
|
||||
)
|
||||
logger.debug(">>>>>>>>>>>>>>>>>")
|
||||
logger.debug(res)
|
||||
return res.json()
|
||||
# @listen_toolkit(BaseSearchToolkit.search_exa, lambda _, query, *args, **kwargs: f"{query}, {args}, {kwargs}")
|
||||
# def search_exa(
|
||||
# self,
|
||||
# query: str,
|
||||
# search_type: Literal["auto", "neural", "keyword"] = "auto",
|
||||
# category: None
|
||||
# | Literal[
|
||||
# "company",
|
||||
# "research paper",
|
||||
# "news",
|
||||
# "pdf",
|
||||
# "github",
|
||||
# "tweet",
|
||||
# "personal site",
|
||||
# "linkedin profile",
|
||||
# "financial report",
|
||||
# ] = None,
|
||||
# include_text: List[str] | None = None,
|
||||
# exclude_text: List[str] | None = None,
|
||||
# use_autoprompt: bool = True,
|
||||
# text: bool = False,
|
||||
# ) -> Dict[str, Any]:
|
||||
# if env("EXA_API_KEY"):
|
||||
# res = super().search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
|
||||
# return res
|
||||
# else:
|
||||
# return self.cloud_search_exa(query, search_type, category, include_text, exclude_text, use_autoprompt, text)
|
||||
#
|
||||
# def cloud_search_exa(
|
||||
# self,
|
||||
# query: str,
|
||||
# search_type: Literal["auto", "neural", "keyword"] = "auto",
|
||||
# category: None
|
||||
# | Literal[
|
||||
# "company",
|
||||
# "research paper",
|
||||
# "news",
|
||||
# "pdf",
|
||||
# "github",
|
||||
# "tweet",
|
||||
# "personal site",
|
||||
# "linkedin profile",
|
||||
# "financial report",
|
||||
# ] = None,
|
||||
# include_text: List[str] | None = None,
|
||||
# exclude_text: List[str] | None = None,
|
||||
# use_autoprompt: bool = True,
|
||||
# text: bool = False,
|
||||
# ):
|
||||
# url = env_not_empty("SERVER_URL")
|
||||
# logger.debug(f">>>>>>>>>>>>>>>>{url}<<<<")
|
||||
# res = httpx.post(
|
||||
# url + "/proxy/exa",
|
||||
# json={
|
||||
# "query": query,
|
||||
# "search_type": search_type,
|
||||
# "category": category,
|
||||
# "include_text": include_text,
|
||||
# "exclude_text": exclude_text,
|
||||
# "use_autoprompt": use_autoprompt,
|
||||
# "text": text,
|
||||
# },
|
||||
# headers={"api-key": env_not_empty("cloud_api_key")},
|
||||
# )
|
||||
# logger.debug(">>>>>>>>>>>>>>>>>")
|
||||
# logger.debug(res)
|
||||
# return res.json()
|
||||
|
||||
# @listen_toolkit(
|
||||
# BaseSearchToolkit.search_alibaba_tongxiao,
|
||||
|
|
@ -289,12 +361,12 @@ class SearchToolkit(BaseSearchToolkit, AbstractToolkit):
|
|||
# if env("BOCHA_API_KEY"):
|
||||
# tools.append(FunctionTool(search_toolkit.search_bocha))
|
||||
|
||||
if env("EXA_API_KEY") or env("cloud_api_key"):
|
||||
tools.append(FunctionTool(search_toolkit.search_exa))
|
||||
# if env("EXA_API_KEY") or env("cloud_api_key"):
|
||||
# tools.append(FunctionTool(search_toolkit.search_exa))
|
||||
|
||||
# if env("TONGXIAO_API_KEY"):
|
||||
# tools.append(FunctionTool(search_toolkit.search_alibaba_tongxiao))
|
||||
return tools
|
||||
|
||||
def get_tools(self) -> List[FunctionTool]:
|
||||
return [FunctionTool(self.search_exa)]
|
||||
# def get_tools(self) -> List[FunctionTool]:
|
||||
# return [FunctionTool(self.search_exa)]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from camel.toolkits import SlackToolkit as BaseSlackToolkit
|
||||
from camel.toolkits.function_tool import FunctionTool
|
||||
from loguru import logger
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("slack_toolkit")
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseSlackToolkit)
|
||||
class SlackToolkit(BaseSlackToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.social_medium_agent
|
||||
|
||||
|
|
@ -14,71 +17,6 @@ class SlackToolkit(BaseSlackToolkit, AbstractToolkit):
|
|||
super().__init__(timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.create_slack_channel,
|
||||
lambda _, name, is_private=True: f"create a Slack channel with name: {name} and is_private: {is_private}",
|
||||
)
|
||||
def create_slack_channel(self, name: str, is_private: bool | None = True) -> str:
|
||||
return super().create_slack_channel(name, is_private)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.join_slack_channel,
|
||||
lambda _, channel_id: f"join Slack channel with id: {channel_id}",
|
||||
)
|
||||
def join_slack_channel(self, channel_id: str) -> str:
|
||||
return super().join_slack_channel(channel_id)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.leave_slack_channel,
|
||||
lambda _, channel_id: f"leave Slack channel with id: {channel_id}",
|
||||
)
|
||||
def leave_slack_channel(self, channel_id: str) -> str:
|
||||
return super().leave_slack_channel(channel_id)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.get_slack_channel_information,
|
||||
lambda _: "get Slack channel information",
|
||||
)
|
||||
def get_slack_channel_information(self) -> str:
|
||||
return super().get_slack_channel_information()
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.get_slack_channel_message,
|
||||
lambda _, channel_id: f"get Slack channel message for channel id: {channel_id}",
|
||||
)
|
||||
def get_slack_channel_message(self, channel_id: str) -> str:
|
||||
return super().get_slack_channel_message(channel_id)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.send_slack_message,
|
||||
lambda _, message, channel_id, file_path=None, user=None: f"send Slack message: {message} to channel id: {channel_id}, file: {file_path}, user: {user}",
|
||||
)
|
||||
def send_slack_message(self, message: str, channel_id: str, file_path: str | None = None, user: str | None = None) -> str:
|
||||
return super().send_slack_message(message, channel_id, file_path, user)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.delete_slack_message,
|
||||
lambda _,
|
||||
time_stamp,
|
||||
channel_id: f"delete Slack message with timestamp: {time_stamp} in channel id: {channel_id}",
|
||||
)
|
||||
def delete_slack_message(self, time_stamp: str, channel_id: str) -> str:
|
||||
return super().delete_slack_message(time_stamp, channel_id)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.get_slack_user_list,
|
||||
lambda _: "get Slack user list",
|
||||
)
|
||||
def get_slack_user_list(self) -> str:
|
||||
return super().get_slack_user_list()
|
||||
|
||||
@listen_toolkit(
|
||||
BaseSlackToolkit.get_slack_user_info,
|
||||
lambda _, user_id: f"get Slack user info with user id: {user_id}",
|
||||
)
|
||||
def get_slack_user_info(self, user_id: str) -> str:
|
||||
return super().get_slack_user_info(user_id)
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
logger.debug(f"slack===={env('SLACK_BOT_TOKEN')}")
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from camel.toolkits.terminal_toolkit import TerminalToolkit as BaseTerminalToolkit
|
||||
from camel.toolkits.terminal_toolkit.terminal_toolkit import _to_plain
|
||||
from app.component.environment import env
|
||||
from app.service.task import Action, ActionTerminalData, Agents, get_task_lock
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
from app.service.task import process_task
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseTerminalToolkit)
|
||||
class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.developer_agent
|
||||
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
_thread_local = threading.local()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -30,6 +37,11 @@ class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
|
|||
self.agent_name = agent_name
|
||||
if working_directory is None:
|
||||
working_directory = env("file_save_path", os.path.expanduser("~/.eigent/terminal/"))
|
||||
if TerminalToolkit._thread_pool is None:
|
||||
TerminalToolkit._thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix="terminal_toolkit"
|
||||
)
|
||||
super().__init__(
|
||||
timeout=timeout,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -54,58 +66,57 @@ class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
|
|||
|
||||
def _update_terminal_output(self, output: str):
|
||||
task_lock = get_task_lock(self.api_task_id)
|
||||
# This method will be called during init. At that time, the process_task_id parameter does not exist, so it is set to be empty default
|
||||
process_task_id = process_task.get("")
|
||||
task = asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionTerminalData(
|
||||
action=Action.terminal,
|
||||
process_task_id=process_task_id,
|
||||
data=output,
|
||||
)
|
||||
|
||||
# Create the coroutine
|
||||
coro = task_lock.put_queue(
|
||||
ActionTerminalData(
|
||||
action=Action.terminal,
|
||||
process_task_id=process_task_id,
|
||||
data=output,
|
||||
)
|
||||
)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseTerminalToolkit.shell_exec,
|
||||
lambda _, id, command, block=True: f"id: {id}, command: {command}, block: {block}",
|
||||
)
|
||||
def shell_exec(self, id: str, command: str, block: bool = True) -> str:
|
||||
return super().shell_exec(id=id, command=command, block=block)
|
||||
# Try to get the current event loop, if none exists, create a new one in a thread
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, schedule the coroutine
|
||||
task = loop.create_task(coro)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
except RuntimeError:
|
||||
self._thread_pool.submit(self._run_coro_in_thread, coro,task_lock)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseTerminalToolkit.shell_view,
|
||||
lambda _, id: f"id: {id}",
|
||||
)
|
||||
def shell_view(self, id: str) -> str:
|
||||
return super().shell_view(id)
|
||||
@staticmethod
|
||||
def _run_coro_in_thread(coro,task_lock):
|
||||
"""
|
||||
Execute coro in the thread pool, with each thread bound to a long-term event loop
|
||||
"""
|
||||
if not hasattr(TerminalToolkit._thread_local, "loop"):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
TerminalToolkit._thread_local.loop = loop
|
||||
else:
|
||||
loop = TerminalToolkit._thread_local.loop
|
||||
|
||||
@listen_toolkit(
|
||||
BaseTerminalToolkit.shell_wait,
|
||||
lambda _, id, wait_seconds=None: f"id: {id}, wait_seconds: {wait_seconds}",
|
||||
)
|
||||
def shell_wait(self, id: str, wait_seconds: float = 5.0) -> str:
|
||||
return super().shell_wait(id=id, wait_seconds=wait_seconds)
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
TerminalToolkit._thread_local.loop = loop
|
||||
|
||||
@listen_toolkit(
|
||||
BaseTerminalToolkit.shell_write_to_process,
|
||||
lambda _, id, command: f"id: {id}, command: {command}",
|
||||
)
|
||||
def shell_write_to_process(self, id: str, command: str) -> str:
|
||||
return super().shell_write_to_process(id=id, command=command)
|
||||
try:
|
||||
task = loop.create_task(coro)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
loop.run_until_complete(task)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to execute coroutine in thread pool: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseTerminalToolkit.shell_kill_process,
|
||||
lambda _, id: f"id: {id}",
|
||||
)
|
||||
def shell_kill_process(self, id: str) -> str:
|
||||
return super().shell_kill_process(id=id)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseTerminalToolkit.shell_ask_user_for_help,
|
||||
lambda _, id, prompt: f"id: {id}, prompt: {prompt}",
|
||||
)
|
||||
def shell_ask_user_for_help(self, id: str, prompt: str) -> str:
|
||||
return super().shell_ask_user_for_help(id=id, prompt=prompt)
|
||||
@classmethod
|
||||
def shutdown(cls):
|
||||
if cls._thread_pool:
|
||||
cls._thread_pool.shutdown(wait=True)
|
||||
cls._thread_pool = None
|
||||
|
|
|
|||
|
|
@ -1,40 +1,13 @@
|
|||
from camel.toolkits import ThinkingToolkit as BaseThinkingToolkit
|
||||
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseThinkingToolkit)
|
||||
class ThinkingToolkit(BaseThinkingToolkit, AbstractToolkit):
|
||||
|
||||
def __init__(self, api_task_id: str, agent_name: str, timeout: float | None = None):
|
||||
super().__init__(timeout)
|
||||
self.api_task_id = api_task_id
|
||||
self.agent_name = agent_name
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.plan)
|
||||
def plan(self, plan: str) -> str:
|
||||
return super().plan(plan)
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.hypothesize)
|
||||
def hypothesize(self, hypothesis: str) -> str:
|
||||
return super().hypothesize(hypothesis)
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.think)
|
||||
def think(self, thought: str) -> str:
|
||||
return super().think(thought)
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.contemplate)
|
||||
def contemplate(self, contemplation: str) -> str:
|
||||
return super().contemplate(contemplation)
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.critique)
|
||||
def critique(self, critique: str) -> str:
|
||||
return super().critique(critique)
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.synthesize)
|
||||
def synthesize(self, synthesis: str) -> str:
|
||||
return super().synthesize(synthesis)
|
||||
|
||||
@listen_toolkit(BaseThinkingToolkit.reflect)
|
||||
def reflect(self, reflection: str) -> str:
|
||||
return super().reflect(reflection)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ from camel.toolkits.twitter_toolkit import (
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseTwitterToolkit)
|
||||
class TwitterToolkit(BaseTwitterToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.social_medium_agent
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ from camel.toolkits import VideoAnalysisToolkit as BaseVideoAnalysisToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseVideoAnalysisToolkit)
|
||||
class VideoAnalysisToolkit(BaseVideoAnalysisToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.multi_modal_agent
|
||||
|
||||
|
|
@ -36,10 +37,3 @@ class VideoAnalysisToolkit(BaseVideoAnalysisToolkit, AbstractToolkit):
|
|||
cookies_path,
|
||||
timeout,
|
||||
)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseVideoAnalysisToolkit.ask_question_about_video,
|
||||
lambda _, video_path, question: f"transcribe video from {video_path} and ask question: {question}",
|
||||
)
|
||||
def ask_question_about_video(self, video_path: str, question: str) -> str:
|
||||
return super().ask_question_about_video(video_path, question)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@ from camel.toolkits import VideoDownloaderToolkit as BaseVideoDownloaderToolkit
|
|||
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseVideoDownloaderToolkit)
|
||||
class VideoDownloaderToolkit(BaseVideoDownloaderToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.multi_modal_agent
|
||||
|
||||
|
|
@ -23,23 +24,3 @@ class VideoDownloaderToolkit(BaseVideoDownloaderToolkit, AbstractToolkit):
|
|||
working_directory = env("file_save_path", os.path.expanduser("~/Downloads"))
|
||||
super().__init__(working_directory, cookies_path, timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(BaseVideoDownloaderToolkit.download_video)
|
||||
def download_video(self, url: str) -> str:
|
||||
return super().download_video(url)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseVideoDownloaderToolkit.get_video_bytes,
|
||||
lambda _, video_path: f"get video bytes from {video_path}",
|
||||
lambda _: "get video bytes",
|
||||
)
|
||||
def get_video_bytes(self, video_path: str) -> bytes:
|
||||
return super().get_video_bytes(video_path)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseVideoDownloaderToolkit.get_video_screenshots,
|
||||
lambda _, video_path, amount: f"get video screenshots from {video_path}, amount: {amount}",
|
||||
lambda results: f"get video screenshots {len(results)}",
|
||||
)
|
||||
def get_video_screenshots(self, video_path: str, amount: int) -> List[Image]:
|
||||
return super().get_video_screenshots(video_path, amount)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from typing import Any, Dict
|
|||
from camel.toolkits import WebDeployToolkit as BaseWebDeployToolkit
|
||||
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit, listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseWebDeployToolkit)
|
||||
class WebDeployToolkit(BaseWebDeployToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.developer_agent
|
||||
|
||||
|
|
@ -43,11 +44,3 @@ class WebDeployToolkit(BaseWebDeployToolkit, AbstractToolkit):
|
|||
) -> Dict[str, Any]:
|
||||
subdirectory = str(uuid.uuid4())
|
||||
return super().deploy_folder(folder_path, port, domain, subdirectory)
|
||||
|
||||
@listen_toolkit(BaseWebDeployToolkit.stop_server)
|
||||
def stop_server(self, port: int) -> Dict[str, Any]:
|
||||
return super().stop_server(port)
|
||||
|
||||
@listen_toolkit(BaseWebDeployToolkit.list_running_servers)
|
||||
def list_running_servers(self) -> Dict[str, Any]:
|
||||
return super().list_running_servers()
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from camel.toolkits import WhatsAppToolkit as BaseWhatsAppToolkit
|
|||
from camel.toolkits.function_tool import FunctionTool
|
||||
from app.component.environment import env
|
||||
from app.service.task import Agents
|
||||
from app.utils.listen.toolkit_listen import listen_toolkit
|
||||
from app.utils.listen.toolkit_listen import auto_listen_toolkit
|
||||
from app.utils.toolkit.abstract_toolkit import AbstractToolkit
|
||||
|
||||
|
||||
@auto_listen_toolkit(BaseWhatsAppToolkit)
|
||||
class WhatsAppToolkit(BaseWhatsAppToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.social_medium_agent
|
||||
|
||||
|
|
@ -14,30 +15,6 @@ class WhatsAppToolkit(BaseWhatsAppToolkit, AbstractToolkit):
|
|||
super().__init__(timeout)
|
||||
self.api_task_id = api_task_id
|
||||
|
||||
@listen_toolkit(
|
||||
BaseWhatsAppToolkit.send_message,
|
||||
lambda _, to, message: f"send message to {to}: {message}",
|
||||
lambda result: f"message sent result: {result}",
|
||||
)
|
||||
def send_message(self, to: str, message: str) -> Dict[str, Any] | str:
|
||||
return super().send_message(to, message)
|
||||
|
||||
@listen_toolkit(
|
||||
BaseWhatsAppToolkit.get_message_templates,
|
||||
lambda _: "get message templates",
|
||||
lambda result: f"message templates: {result}",
|
||||
)
|
||||
def get_message_templates(self) -> List[Dict[str, Any]] | str:
|
||||
return super().get_message_templates()
|
||||
|
||||
@listen_toolkit(
|
||||
BaseWhatsAppToolkit.get_business_profile,
|
||||
lambda _: "get business profile",
|
||||
lambda result: f"business profile: {result}",
|
||||
)
|
||||
def get_business_profile(self) -> Dict[str, Any] | str:
|
||||
return super().get_business_profile()
|
||||
|
||||
@classmethod
|
||||
def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]:
|
||||
if env("WHATSAPP_ACCESS_TOKEN") and env("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
"""Conditional traceroot wrapper - only loads if .traceroot-config.yaml exists."""
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def _find_config() -> bool:
|
||||
"""Check if .traceroot-config.yaml exists in current or parent directories."""
|
||||
path = Path.cwd()
|
||||
for _ in range(5):
|
||||
if (path / ".traceroot-config.yaml").exists():
|
||||
return True
|
||||
if path == path.parent:
|
||||
break
|
||||
path = path.parent
|
||||
return False
|
||||
|
||||
|
||||
# Load traceroot only if config exists
|
||||
if _find_config():
|
||||
import traceroot
|
||||
trace = traceroot.trace
|
||||
get_logger = traceroot.get_logger
|
||||
else:
|
||||
# No-op implementations
|
||||
def trace():
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return func
|
||||
return decorator
|
||||
|
||||
class _NoOpLogger:
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
def get_logger(name: str):
|
||||
return _NoOpLogger()
|
||||
|
|
@ -9,7 +9,6 @@ from camel.societies.workforce.workforce import (
|
|||
from camel.societies.workforce.task_channel import TaskChannel
|
||||
from camel.societies.workforce.base import BaseNode
|
||||
from camel.societies.workforce.utils import TaskAssignResult
|
||||
from loguru import logger
|
||||
from camel.tasks.task import Task, TaskState, validate_task_content
|
||||
from app.component import code
|
||||
from app.exception.exception import UserException
|
||||
|
|
@ -18,30 +17,16 @@ from app.service.task import (
|
|||
Action,
|
||||
ActionAssignTaskData,
|
||||
ActionEndData,
|
||||
ActionNewTaskStateData,
|
||||
ActionTaskStateData,
|
||||
get_camel_task,
|
||||
get_task_lock,
|
||||
)
|
||||
from app.utils.single_agent_worker import SingleAgentWorker
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("workforce")
|
||||
|
||||
# === Debug sink === Write detailed dependency debug logs to file (logs/workforce_debug.log)
|
||||
# Create a new file every day, keep the logs for the last 7 days, and write asynchronously without blocking the main process
|
||||
logger.add(
|
||||
"logs/workforce_debug_{time:YYYY-MM-DD}.log",
|
||||
rotation="00:00",
|
||||
retention="7 days",
|
||||
enqueue=True,
|
||||
level="DEBUG",
|
||||
)
|
||||
# Independent sink: only collect the "[WF]" debug lines we insert to quickly view the dependency chain
|
||||
logger.add(
|
||||
"logs/wf_trace_{time:YYYY-MM-DD-HH}.log",
|
||||
rotation="00:00",
|
||||
retention="7 days",
|
||||
enqueue=True,
|
||||
level="DEBUG",
|
||||
filter=lambda record: record["message"].startswith("[WF]"),
|
||||
)
|
||||
|
||||
|
||||
class Workforce(BaseWorkforce):
|
||||
|
|
@ -69,8 +54,15 @@ class Workforce(BaseWorkforce):
|
|||
use_structured_output_handler=use_structured_output_handler,
|
||||
)
|
||||
|
||||
def eigent_make_sub_tasks(self, task: Task):
|
||||
"""split process_task method to eigent_make_sub_tasks and eigent_start method"""
|
||||
def eigent_make_sub_tasks(self, task: Task, coordinator_context: str = ""):
|
||||
"""
|
||||
Split process_task method to eigent_make_sub_tasks and eigent_start method.
|
||||
|
||||
Args:
|
||||
task: The main task to decompose
|
||||
coordinator_context: Optional context ONLY for coordinator agent during decomposition.
|
||||
This context will NOT be passed to subtasks or worker agents.
|
||||
"""
|
||||
|
||||
if not validate_task_content(task.content, task.id):
|
||||
task.state = TaskState.FAILED
|
||||
|
|
@ -85,10 +77,20 @@ class Workforce(BaseWorkforce):
|
|||
self.set_channel(TaskChannel())
|
||||
self._state = WorkforceState.RUNNING
|
||||
task.state = TaskState.OPEN
|
||||
self._pending_tasks.append(task)
|
||||
|
||||
# Decompose the task into subtasks first
|
||||
subtasks_result = self._decompose_task(task)
|
||||
if coordinator_context:
|
||||
original_content = task.content
|
||||
task_with_context = coordinator_context
|
||||
if coordinator_context:
|
||||
task_with_context += "\n=== CURRENT TASK ===\n"
|
||||
task_with_context += original_content
|
||||
task.content = task_with_context
|
||||
|
||||
subtasks_result = self._decompose_task(task)
|
||||
|
||||
task.content = original_content
|
||||
else:
|
||||
subtasks_result = self._decompose_task(task)
|
||||
|
||||
# Handle both streaming and non-streaming results
|
||||
if isinstance(subtasks_result, Generator):
|
||||
|
|
@ -119,6 +121,64 @@ class Workforce(BaseWorkforce):
|
|||
if self._state != WorkforceState.STOPPED:
|
||||
self._state = WorkforceState.IDLE
|
||||
|
||||
async def handle_decompose_append_task(
|
||||
self, task: Task, reset: bool = True, coordinator_context: str = ""
|
||||
) -> List[Task]:
|
||||
"""
|
||||
Override to support coordinator_context parameter.
|
||||
Handle task decomposition and validation, then append to pending tasks.
|
||||
|
||||
Args:
|
||||
task: The task to be processed
|
||||
reset: Should trigger workforce reset (Workforce must not be running)
|
||||
coordinator_context: Optional context ONLY for coordinator during decomposition
|
||||
|
||||
Returns:
|
||||
List[Task]: The decomposed subtasks or the original task
|
||||
"""
|
||||
if not validate_task_content(task.content, task.id):
|
||||
task.state = TaskState.FAILED
|
||||
task.result = "Task failed: Invalid or empty content provided"
|
||||
logger.warning(
|
||||
f"Task {task.id} rejected: Invalid or empty content. "
|
||||
f"Content preview: '{task.content}'"
|
||||
)
|
||||
return [task]
|
||||
|
||||
if reset and self._state != WorkforceState.RUNNING:
|
||||
self.reset()
|
||||
logger.info("Workforce reset before handling task.")
|
||||
|
||||
self._task = task
|
||||
task.state = TaskState.FAILED
|
||||
|
||||
if coordinator_context:
|
||||
original_content = task.content
|
||||
task_with_context = coordinator_context
|
||||
if coordinator_context:
|
||||
task_with_context += "\n=== CURRENT TASK ===\n"
|
||||
task_with_context += original_content
|
||||
task.content = task_with_context
|
||||
|
||||
subtasks_result = self._decompose_task(task)
|
||||
|
||||
task.content = original_content
|
||||
else:
|
||||
subtasks_result = self._decompose_task(task)
|
||||
|
||||
if isinstance(subtasks_result, Generator):
|
||||
subtasks = []
|
||||
for new_tasks in subtasks_result:
|
||||
subtasks.extend(new_tasks)
|
||||
else:
|
||||
subtasks = subtasks_result
|
||||
|
||||
if subtasks:
|
||||
self._pending_tasks.extendleft(reversed(subtasks))
|
||||
logger.info(f"Appended {len(subtasks)} subtasks to pending tasks")
|
||||
|
||||
return subtasks if subtasks else [task]
|
||||
|
||||
async def _find_assignee(self, tasks: List[Task]) -> TaskAssignResult:
|
||||
# Task assignment phase: send "waiting for execution" notification to the frontend, and send "start execution" notification when the task actually begins execution
|
||||
assigned = await super()._find_assignee(tasks)
|
||||
|
|
@ -133,7 +193,9 @@ class Workforce(BaseWorkforce):
|
|||
# Find task content
|
||||
task_obj = get_camel_task(item.task_id, tasks)
|
||||
if task_obj is None:
|
||||
logger.warning(f"[WF] WARN: Task {item.task_id} not found in tasks list during ASSIGN phase. This may indicate a task tree inconsistency.")
|
||||
logger.warning(
|
||||
f"[WF] WARN: Task {item.task_id} not found in tasks list during ASSIGN phase. This may indicate a task tree inconsistency."
|
||||
)
|
||||
content = ""
|
||||
else:
|
||||
content = task_obj.content
|
||||
|
|
@ -179,7 +241,11 @@ class Workforce(BaseWorkforce):
|
|||
await super()._post_task(task, assignee_id)
|
||||
|
||||
def add_single_agent_worker(
|
||||
self, description: str, worker: ListenChatAgent, pool_max_size: int = DEFAULT_WORKER_POOL_SIZE
|
||||
self,
|
||||
description: str,
|
||||
worker: ListenChatAgent,
|
||||
pool_max_size: int = DEFAULT_WORKER_POOL_SIZE,
|
||||
enable_workflow_memory: bool = False,
|
||||
) -> BaseWorkforce:
|
||||
if self._state == WorkforceState.RUNNING:
|
||||
raise RuntimeError("Cannot add workers while workforce is running. Pause the workforce first.")
|
||||
|
|
@ -195,6 +261,8 @@ class Workforce(BaseWorkforce):
|
|||
worker=worker,
|
||||
pool_max_size=pool_max_size,
|
||||
use_structured_output_handler=self.use_structured_output_handler,
|
||||
context_utility=None, # Will be set during save/load operations
|
||||
enable_workflow_memory=enable_workflow_memory,
|
||||
)
|
||||
self._children.append(worker_node)
|
||||
|
||||
|
|
@ -218,17 +286,33 @@ class Workforce(BaseWorkforce):
|
|||
logger.debug(f"[WF] DONE {task.id}")
|
||||
task_lock = get_task_lock(self.api_task_id)
|
||||
|
||||
await task_lock.put_queue(
|
||||
ActionTaskStateData(
|
||||
data={
|
||||
"task_id": task.id,
|
||||
"content": task.content,
|
||||
"state": task.state,
|
||||
"result": task.result or "",
|
||||
"failure_count": task.failure_count,
|
||||
},
|
||||
# Log task completion with result details
|
||||
is_main_task = self._task and task.id == self._task.id
|
||||
task_type = "MAIN TASK" if is_main_task else "SUB-TASK"
|
||||
logger.info(f"[TASK-RESULT] {task_type} COMPLETED: {task.id}")
|
||||
logger.info(f"[TASK-RESULT] Content: {task.content[:200]}..." if len(task.content) > 200 else f"[TASK-RESULT] Content: {task.content}")
|
||||
logger.info(f"[TASK-RESULT] Result: {task.result[:500]}..." if task.result and len(str(task.result)) > 500 else f"[TASK-RESULT] Result: {task.result}")
|
||||
|
||||
task_data = {
|
||||
"task_id": task.id,
|
||||
"content": task.content,
|
||||
"state": task.state,
|
||||
"result": task.result or "",
|
||||
"failure_count": task.failure_count,
|
||||
}
|
||||
|
||||
if self._task_is_new(task_data):
|
||||
await task_lock.put_queue(
|
||||
ActionNewTaskStateData(
|
||||
data=task_data
|
||||
)
|
||||
)
|
||||
else:
|
||||
await task_lock.put_queue(
|
||||
ActionTaskStateData(
|
||||
data=task_data
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return await super()._handle_completed_task(task)
|
||||
|
||||
|
|
@ -260,6 +344,36 @@ class Workforce(BaseWorkforce):
|
|||
|
||||
return result
|
||||
|
||||
def _task_is_new(self, item:dict) -> bool:
|
||||
# Validate the task state data object first
|
||||
assert isinstance(item, dict)
|
||||
task_id = item.get("task_id", "")
|
||||
state = item.get("state", "")
|
||||
result = item.get("result", "")
|
||||
failure_count = item.get("failure_count", 0)
|
||||
|
||||
# Validate required fields
|
||||
if not task_id:
|
||||
logger.error("Missing task_id in task_state data")
|
||||
return False
|
||||
elif not state:
|
||||
logger.error(f"Missing state in task_state data for task {task_id}")
|
||||
return False
|
||||
|
||||
# Ensure failure_count is an integer
|
||||
try:
|
||||
failure_count = int(failure_count)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"Invalid failure_count in task_state data for task {task_id}: {failure_count}")
|
||||
failure_count = 0 # Default to 0 if invalid
|
||||
|
||||
should_send_new_task_state = (
|
||||
state == "FAILED" or
|
||||
(failure_count == 0 and result.strip() == "")
|
||||
)
|
||||
|
||||
return should_send_new_task_state
|
||||
|
||||
def stop(self) -> None:
|
||||
super().stop()
|
||||
task_lock = get_task_lock(self.api_task_id)
|
||||
|
|
|
|||
|
|
@ -1,37 +1,58 @@
|
|||
import os
|
||||
import sys
|
||||
import pathlib
|
||||
import signal
|
||||
import asyncio
|
||||
import atexit
|
||||
|
||||
# Add project root to Python path to import shared utils
|
||||
_project_root = pathlib.Path(__file__).parent.parent
|
||||
if str(_project_root) not in sys.path:
|
||||
sys.path.insert(0, str(_project_root))
|
||||
|
||||
# 1) Load env and init traceroot BEFORE importing modules that get a logger
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
from app import api
|
||||
from loguru import logger
|
||||
from app.component.environment import auto_include_routers, env
|
||||
|
||||
# Only initialize traceroot if enabled
|
||||
if traceroot.is_enabled():
|
||||
from traceroot.integrations.fastapi import connect_fastapi
|
||||
connect_fastapi(api)
|
||||
|
||||
# 2) Now safe to import modules that use traceroot.get_logger() at import-time
|
||||
from app.component.environment import env
|
||||
from app.router import register_routers
|
||||
|
||||
|
||||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||
|
||||
app_logger = traceroot.get_logger("main")
|
||||
|
||||
# Log application startup
|
||||
logger.info("Starting Eigent Multi-Agent System API")
|
||||
logger.info(f"Python encoding: {os.environ.get('PYTHONIOENCODING')}")
|
||||
logger.info(f"Environment: {os.environ.get('ENVIRONMENT', 'development')}")
|
||||
app_logger.info("Starting Eigent Multi-Agent System API")
|
||||
app_logger.info(f"Python encoding: {os.environ.get('PYTHONIOENCODING')}")
|
||||
app_logger.info(f"Environment: {os.environ.get('ENVIRONMENT', 'development')}")
|
||||
|
||||
prefix = env("url_prefix", "")
|
||||
logger.info(f"Loading routers with prefix: '{prefix}'")
|
||||
auto_include_routers(api, prefix, "app/controller")
|
||||
logger.info("All routers loaded successfully")
|
||||
app_logger.info(f"Loading routers with prefix: '{prefix}'")
|
||||
register_routers(api, prefix)
|
||||
app_logger.info("All routers loaded successfully")
|
||||
|
||||
# Check if debug mode is enabled via environment variable
|
||||
if os.environ.get('ENABLE_PYTHON_DEBUG') == 'true':
|
||||
try:
|
||||
import debugpy
|
||||
DEBUG_PORT = int(os.environ.get('DEBUG_PORT', '5678'))
|
||||
app_logger.info(f"Debug mode enabled - Starting debugpy server on port {DEBUG_PORT}")
|
||||
debugpy.listen(("localhost", DEBUG_PORT))
|
||||
app_logger.info(f"Debugger ready for attachment on localhost:{DEBUG_PORT}")
|
||||
#📝 In VS Code: Run 'Debug Python Backend (Attach)' configuration
|
||||
# Don't wait for client automatically - let it attach when ready
|
||||
except ImportError:
|
||||
app_logger.warning("debugpy not available, install with: uv add debugpy")
|
||||
except Exception as e:
|
||||
app_logger.error(f"Failed to start debugpy: {e}")
|
||||
|
||||
# Configure Loguru
|
||||
log_path = os.path.expanduser("~/.eigent/runtime/log/app.log")
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
logger.add(
|
||||
log_path, # Log file
|
||||
rotation="10 MB", # Log rotation: 10MB per file
|
||||
retention="10 days", # Retain logs for the last 10 days
|
||||
level="DEBUG", # Log level
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(f"Loguru configured with log file: {log_path}")
|
||||
|
||||
dir = pathlib.Path(__file__).parent / "runtime"
|
||||
dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -44,12 +65,12 @@ async def write_pid_file():
|
|||
|
||||
async with aiofiles.open(dir / "run.pid", "w") as f:
|
||||
await f.write(str(os.getpid()))
|
||||
logger.info(f"PID file written: {os.getpid()}")
|
||||
app_logger.info(f"PID file written: {os.getpid()}")
|
||||
|
||||
|
||||
# Create task to write PID
|
||||
pid_task = asyncio.create_task(write_pid_file())
|
||||
logger.info("PID write task created")
|
||||
app_logger.info("PID write task created")
|
||||
|
||||
# Graceful shutdown handler
|
||||
shutdown_event = asyncio.Event()
|
||||
|
|
@ -57,8 +78,7 @@ shutdown_event = asyncio.Event()
|
|||
|
||||
async def cleanup_resources():
|
||||
r"""Cleanup all resources on shutdown"""
|
||||
logger.info("Starting graceful shutdown...")
|
||||
logger.info("Starting graceful shutdown process")
|
||||
app_logger.info("Starting graceful shutdown process")
|
||||
|
||||
from app.service.task import task_locks, _cleanup_task
|
||||
|
||||
|
|
@ -75,21 +95,19 @@ async def cleanup_resources():
|
|||
task_lock = task_locks[task_id]
|
||||
await task_lock.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up task {task_id}: {e}")
|
||||
app_logger.error(f"Error cleaning up task {task_id}: {e}")
|
||||
|
||||
# Remove PID file
|
||||
pid_file = dir / "run.pid"
|
||||
if pid_file.exists():
|
||||
pid_file.unlink()
|
||||
|
||||
logger.info("Graceful shutdown completed")
|
||||
logger.info("All resources cleaned up successfully")
|
||||
app_logger.info("All resources cleaned up successfully")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
r"""Handle shutdown signals"""
|
||||
logger.info(f"Received signal {signum}")
|
||||
logger.warning(f"Received shutdown signal: {signum}")
|
||||
app_logger.warning(f"Received shutdown signal: {signum}")
|
||||
asyncio.create_task(cleanup_resources())
|
||||
shutdown_event.set()
|
||||
|
||||
|
|
@ -97,8 +115,19 @@ def signal_handler(signum, frame):
|
|||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Register cleanup on exit
|
||||
atexit.register(lambda: asyncio.run(cleanup_resources()))
|
||||
# Register cleanup on exit with safe synchronous wrapper
|
||||
def sync_cleanup():
|
||||
"""Synchronous cleanup for atexit - handles PID file removal"""
|
||||
try:
|
||||
# Only perform synchronous cleanup tasks
|
||||
pid_file = dir / "run.pid"
|
||||
if pid_file.exists():
|
||||
pid_file.unlink()
|
||||
app_logger.info("PID file removed during shutdown")
|
||||
except Exception as e:
|
||||
app_logger.error(f"Error during atexit cleanup: {e}")
|
||||
|
||||
atexit.register(sync_cleanup)
|
||||
|
||||
# Log successful initialization
|
||||
logger.info("Application initialization completed successfully")
|
||||
app_logger.info("Application initialization completed successfully")
|
||||
|
|
|
|||
|
|
@ -5,21 +5,21 @@ description = "Add your description here"
|
|||
readme = "README.md"
|
||||
requires-python = "==3.10.16"
|
||||
dependencies = [
|
||||
"camel-ai[eigent]==0.2.76a13",
|
||||
"camel-ai[eigent]==0.2.78",
|
||||
"fastapi>=0.115.12",
|
||||
"fastapi-babel>=1.0.0",
|
||||
"uvicorn[standard]>=0.34.2",
|
||||
"pydantic-i18n>=0.4.5",
|
||||
"python-dotenv>=1.1.0",
|
||||
"httpx[socks]>=0.28.1",
|
||||
"loguru>=0.7.3",
|
||||
"pydash>=8.0.5",
|
||||
"inflection>=0.5.1",
|
||||
"aiofiles>=24.1.0",
|
||||
"openai>=1.99.3,<2",
|
||||
"traceroot>=0.0.5a2",
|
||||
"traceroot>=0.0.7",
|
||||
"nodejs-wheel>=22.18.0",
|
||||
"numpy>=1.23.0,<2.0.0",
|
||||
"debugpy>=1.8.17",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -145,14 +145,21 @@ def mock_model_backend():
|
|||
@pytest.fixture
|
||||
def mock_camel_agent():
|
||||
"""Mock CAMEL agent for testing."""
|
||||
agent = AsyncMock()
|
||||
agent = MagicMock() # Use MagicMock instead of AsyncMock
|
||||
agent.role_name = "test_agent"
|
||||
agent.agent_id = "test_agent_123"
|
||||
|
||||
# Make step method async and return proper structure
|
||||
agent.step = AsyncMock()
|
||||
agent.step.return_value.msgs = [MagicMock()]
|
||||
agent.step.return_value.msgs[0].content = "Test agent response"
|
||||
# Make step method return proper structure with both .msg and .msgs[0]
|
||||
mock_response = MagicMock()
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Test agent response"
|
||||
mock_message.parsed = None
|
||||
|
||||
mock_response.msg = mock_message
|
||||
mock_response.msgs = [mock_message] # msgs[0] should point to the same content
|
||||
mock_response.info = {"usage": {"total_tokens": 50}}
|
||||
|
||||
agent.step.return_value = mock_response
|
||||
|
||||
agent.astep = AsyncMock()
|
||||
agent.astep.return_value.msg.content = "Test async agent response"
|
||||
|
|
@ -288,6 +295,7 @@ def sample_chat_data():
|
|||
"""Sample chat data for testing."""
|
||||
return {
|
||||
"task_id": "test_task_123",
|
||||
"project_id": "test_project_456",
|
||||
"email": "test@example.com",
|
||||
"question": "Create a simple Python script",
|
||||
"attaches": [],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.service.chat_service import (
|
||||
step_solve,
|
||||
|
|
@ -12,14 +15,335 @@ from app.service.chat_service import (
|
|||
summary_task,
|
||||
construct_workforce,
|
||||
format_agent_description,
|
||||
new_agent_model
|
||||
new_agent_model,
|
||||
collect_previous_task_context,
|
||||
build_context_for_workforce
|
||||
)
|
||||
from app.model.chat import Chat, NewAgent
|
||||
from app.service.task import Action, ActionImproveData, ActionEndData, ActionInstallMcpData
|
||||
from app.service.task import Action, ActionImproveData, ActionEndData, ActionInstallMcpData, TaskLock
|
||||
from camel.tasks import Task
|
||||
from camel.tasks.task import TaskState
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCollectPreviousTaskContext:
|
||||
"""Test cases for collect_previous_task_context function."""
|
||||
|
||||
def test_collect_previous_task_context_basic(self, temp_dir):
|
||||
"""Test collect_previous_task_context with basic inputs."""
|
||||
working_directory = str(temp_dir)
|
||||
previous_task_content = "Create a Python script"
|
||||
previous_task_result = "Successfully created script.py"
|
||||
previous_summary = "Python Script Creation Task"
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content=previous_task_content,
|
||||
previous_task_result=previous_task_result,
|
||||
previous_summary=previous_summary
|
||||
)
|
||||
|
||||
# Check that all sections are included
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Previous Task:" in result
|
||||
assert "Create a Python script" in result
|
||||
assert "Previous Task Summary:" in result
|
||||
assert "Python Script Creation Task" in result
|
||||
assert "Previous Task Result:" in result
|
||||
assert "Successfully created script.py" in result
|
||||
assert "=== END OF PREVIOUS TASK CONTEXT ===" in result
|
||||
assert "=== NEW TASK ===" in result
|
||||
|
||||
def test_collect_previous_task_context_with_generated_files(self, temp_dir):
|
||||
"""Test collect_previous_task_context with generated files in working directory."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create some test files
|
||||
(temp_dir / "script.py").write_text("print('Hello World')")
|
||||
(temp_dir / "config.json").write_text('{"test": true}')
|
||||
(temp_dir / "README.md").write_text("# Test Project")
|
||||
|
||||
# Create a subdirectory with files
|
||||
sub_dir = temp_dir / "utils"
|
||||
sub_dir.mkdir()
|
||||
(sub_dir / "helper.py").write_text("def helper(): pass")
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Create project files",
|
||||
previous_task_result="Files created successfully",
|
||||
previous_summary=""
|
||||
)
|
||||
|
||||
# Check that generated files are listed
|
||||
assert "Generated Files from Previous Task:" in result
|
||||
assert "script.py" in result
|
||||
assert "config.json" in result
|
||||
assert "README.md" in result
|
||||
assert "utils/helper.py" in result or "utils\\helper.py" in result # Handle Windows paths
|
||||
|
||||
# Files should be sorted
|
||||
lines = result.split('\n')
|
||||
file_lines = [line.strip() for line in lines if line.strip().startswith('- ')]
|
||||
assert len(file_lines) == 4
|
||||
|
||||
def test_collect_previous_task_context_filters_hidden_files(self, temp_dir):
|
||||
"""Test that hidden files and directories are filtered out."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create regular files
|
||||
(temp_dir / "visible.py").write_text("# Visible file")
|
||||
|
||||
# Create hidden files and directories
|
||||
(temp_dir / ".hidden_file").write_text("hidden content")
|
||||
(temp_dir / ".env").write_text("SECRET=hidden")
|
||||
|
||||
hidden_dir = temp_dir / ".hidden_dir"
|
||||
hidden_dir.mkdir()
|
||||
(hidden_dir / "file.txt").write_text("in hidden dir")
|
||||
|
||||
# Create cache directories
|
||||
cache_dir = temp_dir / "__pycache__"
|
||||
cache_dir.mkdir()
|
||||
(cache_dir / "module.pyc").write_text("compiled")
|
||||
|
||||
node_modules = temp_dir / "node_modules"
|
||||
node_modules.mkdir()
|
||||
(node_modules / "package").mkdir()
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test filtering",
|
||||
previous_task_result="Files filtered",
|
||||
previous_summary=""
|
||||
)
|
||||
|
||||
# Should only include visible files
|
||||
assert "visible.py" in result
|
||||
assert ".hidden_file" not in result
|
||||
assert ".env" not in result
|
||||
assert "__pycache__" not in result
|
||||
assert "node_modules" not in result
|
||||
assert ".hidden_dir" not in result
|
||||
|
||||
def test_collect_previous_task_context_filters_temp_files(self, temp_dir):
|
||||
"""Test that temporary files are filtered out."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create regular files
|
||||
(temp_dir / "main.py").write_text("# Main file")
|
||||
|
||||
# Create temporary files
|
||||
(temp_dir / "temp.tmp").write_text("temporary")
|
||||
(temp_dir / "compiled.pyc").write_text("compiled python")
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test temp filtering",
|
||||
previous_task_result="Temp files filtered",
|
||||
previous_summary=""
|
||||
)
|
||||
|
||||
# Should only include regular files
|
||||
assert "main.py" in result
|
||||
assert "temp.tmp" not in result
|
||||
assert "compiled.pyc" not in result
|
||||
|
||||
def test_collect_previous_task_context_nonexistent_directory(self):
|
||||
"""Test collect_previous_task_context with non-existent working directory."""
|
||||
working_directory = "/nonexistent/directory"
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test task",
|
||||
previous_task_result="Test result",
|
||||
previous_summary="Test summary"
|
||||
)
|
||||
|
||||
# Should not crash and should not include file listing
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Test task" in result
|
||||
assert "Test result" in result
|
||||
assert "Test summary" in result
|
||||
assert "Generated Files from Previous Task:" not in result
|
||||
|
||||
def test_collect_previous_task_context_empty_inputs(self, temp_dir):
|
||||
"""Test collect_previous_task_context with empty string inputs."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="",
|
||||
previous_task_result="",
|
||||
previous_summary=""
|
||||
)
|
||||
|
||||
# Should still have the structural elements
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "=== END OF PREVIOUS TASK CONTEXT ===" in result
|
||||
assert "=== NEW TASK ===" in result
|
||||
|
||||
# Should not have content sections for empty inputs
|
||||
assert "Previous Task:" not in result
|
||||
assert "Previous Task Summary:" not in result
|
||||
assert "Previous Task Result:" not in result
|
||||
|
||||
def test_collect_previous_task_context_only_summary(self, temp_dir):
|
||||
"""Test collect_previous_task_context with only summary provided."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="",
|
||||
previous_task_result="",
|
||||
previous_summary="Only summary provided"
|
||||
)
|
||||
|
||||
# Should include summary section only
|
||||
assert "Previous Task Summary:" in result
|
||||
assert "Only summary provided" in result
|
||||
assert "Previous Task:" not in result
|
||||
assert "Previous Task Result:" not in result
|
||||
|
||||
@patch('app.service.chat_service.logger')
|
||||
def test_collect_previous_task_context_file_system_error(self, mock_logger, temp_dir):
|
||||
"""Test collect_previous_task_context handles file system errors gracefully."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Mock os.walk to raise an exception
|
||||
with patch('os.walk', side_effect=PermissionError("Access denied")):
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test task",
|
||||
previous_task_result="Test result",
|
||||
previous_summary="Test summary"
|
||||
)
|
||||
|
||||
# Should still return result without files
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Test task" in result
|
||||
assert "Generated Files from Previous Task:" not in result
|
||||
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_collect_previous_task_context_relative_paths(self, temp_dir):
|
||||
"""Test that file paths are correctly converted to relative paths."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create nested directory structure
|
||||
deep_dir = temp_dir / "level1" / "level2" / "level3"
|
||||
deep_dir.mkdir(parents=True)
|
||||
(deep_dir / "deep_file.txt").write_text("deep content")
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test relative paths",
|
||||
previous_task_result="Paths converted",
|
||||
previous_summary=""
|
||||
)
|
||||
|
||||
# Check that the path is relative to working directory
|
||||
expected_path = "level1/level2/level3/deep_file.txt"
|
||||
windows_path = "level1\\level2\\level3\\deep_file.txt"
|
||||
|
||||
# Should contain relative path (handle both Unix and Windows separators)
|
||||
assert expected_path in result or windows_path in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildContextForWorkforce:
|
||||
"""Test cases for build_context_for_workforce function."""
|
||||
|
||||
def test_build_context_for_workforce_basic(self, temp_dir):
|
||||
"""Test build_context_for_workforce with basic task lock and options."""
|
||||
# Create mock TaskLock
|
||||
task_lock = MagicMock(spec=TaskLock)
|
||||
task_lock.conversation_history = [
|
||||
{'role': 'user', 'content': 'Create a Python script'},
|
||||
{'role': 'assistant', 'content': 'I will create a Python script for you'}
|
||||
]
|
||||
task_lock.last_task_result = "Script created successfully"
|
||||
task_lock.last_task_summary = "Python Script Creation"
|
||||
|
||||
# Create mock Chat options
|
||||
options = MagicMock()
|
||||
options.file_save_path.return_value = str(temp_dir)
|
||||
|
||||
result = build_context_for_workforce(task_lock, options)
|
||||
|
||||
# Should include conversation history
|
||||
assert "=== CONVERSATION HISTORY ===" in result
|
||||
assert "user: Create a Python script" in result
|
||||
assert "assistant: I will create a Python script for you" in result
|
||||
|
||||
# Should include previous task context
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Script created successfully" in result
|
||||
|
||||
def test_build_context_for_workforce_empty_history(self, temp_dir):
|
||||
"""Test build_context_for_workforce with empty conversation history."""
|
||||
task_lock = MagicMock(spec=TaskLock)
|
||||
task_lock.conversation_history = []
|
||||
task_lock.last_task_result = ""
|
||||
task_lock.last_task_summary = ""
|
||||
|
||||
options = MagicMock()
|
||||
options.file_save_path.return_value = str(temp_dir)
|
||||
|
||||
result = build_context_for_workforce(task_lock, options)
|
||||
|
||||
# Should return empty string for no context
|
||||
assert result == ""
|
||||
|
||||
def test_build_context_for_workforce_task_result_role(self, temp_dir):
|
||||
"""Test build_context_for_workforce handles 'task_result' role specially."""
|
||||
task_lock = MagicMock(spec=TaskLock)
|
||||
task_lock.conversation_history = [
|
||||
{'role': 'user', 'content': 'First question'},
|
||||
{'role': 'task_result', 'content': 'Full task context from previous task'},
|
||||
{'role': 'user', 'content': 'Second question'}
|
||||
]
|
||||
task_lock.last_task_result = "Final result"
|
||||
task_lock.last_task_summary = "Task summary"
|
||||
|
||||
options = MagicMock()
|
||||
options.file_save_path.return_value = str(temp_dir)
|
||||
|
||||
result = build_context_for_workforce(task_lock, options)
|
||||
|
||||
# Should simplify task_result display
|
||||
assert "[Previous Task Completed]" in result
|
||||
assert "Full task context from previous task" not in result # Should not show full content
|
||||
assert "user: First question" in result
|
||||
assert "user: Second question" in result
|
||||
|
||||
def test_build_context_for_workforce_with_last_task_result(self, temp_dir):
|
||||
"""Test build_context_for_workforce includes last task result context."""
|
||||
# Create some files in temp directory
|
||||
(temp_dir / "output.txt").write_text("Task output")
|
||||
|
||||
task_lock = MagicMock(spec=TaskLock)
|
||||
task_lock.conversation_history = [
|
||||
{'role': 'user', 'content': 'Test question'}
|
||||
]
|
||||
task_lock.last_task_result = "Task completed with output.txt"
|
||||
task_lock.last_task_summary = "File creation task"
|
||||
|
||||
options = MagicMock()
|
||||
options.file_save_path.return_value = str(temp_dir)
|
||||
|
||||
result = build_context_for_workforce(task_lock, options)
|
||||
|
||||
# Should include conversation history and task context
|
||||
assert "=== CONVERSATION HISTORY ===" in result
|
||||
assert "user: Test question" in result
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Task completed with output.txt" in result
|
||||
assert "File creation task" in result
|
||||
assert "output.txt" in result # Generated file should be listed
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestChatServiceUtilities:
|
||||
"""Test cases for chat service utility functions."""
|
||||
|
|
@ -324,6 +648,130 @@ class TestChatServiceIntegration:
|
|||
"""Integration tests for chat service."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_solve_context_building_workflow(self, sample_chat_data, mock_request, temp_dir):
|
||||
"""Test step_solve builds context correctly using collect_previous_task_context."""
|
||||
options = Chat(**sample_chat_data)
|
||||
|
||||
# Create actual TaskLock with context data
|
||||
task_lock = TaskLock(
|
||||
id="test_task_123",
|
||||
queue=AsyncMock(),
|
||||
human_input={}
|
||||
)
|
||||
task_lock.conversation_history = [
|
||||
{'role': 'user', 'content': 'Create a Python script'},
|
||||
{'role': 'assistant', 'content': 'Script created successfully'}
|
||||
]
|
||||
task_lock.last_task_result = "def hello(): print('Hello World')"
|
||||
task_lock.last_task_summary = "Python Hello World Script"
|
||||
|
||||
# Create some files in working directory
|
||||
working_dir = temp_dir / "test_project"
|
||||
working_dir.mkdir()
|
||||
(working_dir / "script.py").write_text("def hello(): print('Hello World')")
|
||||
|
||||
# Mock file_save_path method to return our temp directory
|
||||
with patch.object(Chat, 'file_save_path', return_value=str(working_dir)):
|
||||
|
||||
# Test the context building directly
|
||||
context = build_context_for_workforce(task_lock, options)
|
||||
|
||||
# Verify context includes conversation history
|
||||
assert "=== CONVERSATION HISTORY ===" in context
|
||||
assert "user: Create a Python script" in context
|
||||
assert "assistant: Script created successfully" in context
|
||||
|
||||
# Verify context includes task context with files
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in context
|
||||
assert "def hello(): print('Hello World')" in context
|
||||
assert "Python Hello World Script" in context
|
||||
assert "script.py" in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_solve_new_task_state_context_collection(self, sample_chat_data, mock_request, temp_dir):
|
||||
"""Test step_solve correctly collects context in new_task_state action."""
|
||||
options = Chat(**sample_chat_data)
|
||||
working_dir = temp_dir / "project"
|
||||
working_dir.mkdir()
|
||||
|
||||
# Create files that should be included in context
|
||||
(working_dir / "main.py").write_text("print('main')")
|
||||
(working_dir / "config.json").write_text('{"version": "1.0"}')
|
||||
|
||||
# Mock file_save_path to return our temp directory
|
||||
with patch.object(Chat, 'file_save_path', return_value=str(working_dir)):
|
||||
|
||||
# Test collect_previous_task_context directly with the scenario
|
||||
result = collect_previous_task_context(
|
||||
working_directory=str(working_dir),
|
||||
previous_task_content="Create project structure",
|
||||
previous_task_result="Project files created successfully",
|
||||
previous_summary="Project Setup Task"
|
||||
)
|
||||
|
||||
# Verify all expected elements are present
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Previous Task:" in result
|
||||
assert "Create project structure" in result
|
||||
assert "Previous Task Summary:" in result
|
||||
assert "Project Setup Task" in result
|
||||
assert "Previous Task Result:" in result
|
||||
assert "Project files created successfully" in result
|
||||
assert "Generated Files from Previous Task:" in result
|
||||
assert "main.py" in result
|
||||
assert "config.json" in result
|
||||
assert "=== END OF PREVIOUS TASK CONTEXT ===" in result
|
||||
assert "=== NEW TASK ===" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_solve_end_action_context_collection(self, sample_chat_data, mock_request, temp_dir):
|
||||
"""Test step_solve correctly collects and saves context in end action."""
|
||||
options = Chat(**sample_chat_data)
|
||||
working_dir = temp_dir / "finished_project"
|
||||
working_dir.mkdir()
|
||||
|
||||
# Create output files
|
||||
(working_dir / "output.txt").write_text("Final output")
|
||||
(working_dir / "report.md").write_text("# Task Report")
|
||||
|
||||
# Create actual TaskLock
|
||||
task_lock = TaskLock(
|
||||
id="test_end_task",
|
||||
queue=AsyncMock(),
|
||||
human_input={}
|
||||
)
|
||||
task_lock.last_task_summary = "Final Task Summary"
|
||||
|
||||
# Mock file_save_path
|
||||
with patch.object(Chat, 'file_save_path', return_value=str(working_dir)):
|
||||
|
||||
# Test the context collection for end action scenario
|
||||
task_content = "Generate final report"
|
||||
task_result = "Report generated successfully with output files"
|
||||
|
||||
context = collect_previous_task_context(
|
||||
working_directory=str(working_dir),
|
||||
previous_task_content=task_content,
|
||||
previous_task_result=task_result,
|
||||
previous_summary=task_lock.last_task_summary
|
||||
)
|
||||
|
||||
# Verify context structure for end action
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in context
|
||||
assert "Generate final report" in context
|
||||
assert "Report generated successfully with output files" in context
|
||||
assert "Final Task Summary" in context
|
||||
assert "output.txt" in context
|
||||
assert "report.md" in context
|
||||
|
||||
# Test that context can be added to conversation history
|
||||
task_lock.add_conversation('task_result', context)
|
||||
assert len(task_lock.conversation_history) == 1
|
||||
assert task_lock.conversation_history[0]['role'] == 'task_result'
|
||||
assert task_lock.conversation_history[0]['content'] == context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Gets Stuck for some reason.")
|
||||
async def test_step_solve_basic_workflow(self, sample_chat_data, mock_request, mock_task_lock):
|
||||
"""Test step_solve basic workflow integration."""
|
||||
options = Chat(**sample_chat_data)
|
||||
|
|
@ -380,6 +828,7 @@ class TestChatServiceIntegration:
|
|||
# Note: Workforce might not be created/stopped if request is immediately disconnected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Gets Stuck for some reason.")
|
||||
async def test_step_solve_error_handling(self, sample_chat_data, mock_request, mock_task_lock):
|
||||
"""Test step_solve handles errors gracefully."""
|
||||
options = Chat(**sample_chat_data)
|
||||
|
|
@ -423,7 +872,195 @@ class TestChatServiceWithLLM:
|
|||
@pytest.mark.unit
|
||||
class TestChatServiceErrorCases:
|
||||
"""Test error cases and edge conditions for chat service."""
|
||||
|
||||
|
||||
def test_collect_previous_task_context_os_walk_exception(self, temp_dir):
|
||||
"""Test collect_previous_task_context handles os.walk exceptions."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
with patch('os.walk', side_effect=OSError("Permission denied")):
|
||||
with patch('app.service.chat_service.logger') as mock_logger:
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test task",
|
||||
previous_task_result="Test result",
|
||||
previous_summary="Test summary"
|
||||
)
|
||||
|
||||
# Should still include basic context
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
assert "Test task" in result
|
||||
assert "Test result" in result
|
||||
assert "Test summary" in result
|
||||
|
||||
# Should not include file listing
|
||||
assert "Generated Files from Previous Task:" not in result
|
||||
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_collect_previous_task_context_relpath_exception(self, temp_dir):
|
||||
"""Test collect_previous_task_context handles os.path.relpath exceptions."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create a test file
|
||||
(temp_dir / "test.txt").write_text("test content")
|
||||
|
||||
with patch('os.path.relpath', side_effect=ValueError("Invalid path")):
|
||||
with patch('app.service.chat_service.logger') as mock_logger:
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test task",
|
||||
previous_task_result="Test result",
|
||||
previous_summary="Test summary"
|
||||
)
|
||||
|
||||
# Should handle the exception gracefully
|
||||
assert "=== CONTEXT FROM PREVIOUS TASK ===" in result
|
||||
# Should log warning about file collection failure
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_build_context_for_workforce_missing_attributes(self, temp_dir):
|
||||
"""Test build_context_for_workforce handles missing attributes gracefully."""
|
||||
# Create task_lock without required attributes
|
||||
task_lock = MagicMock(spec=TaskLock)
|
||||
task_lock.conversation_history = None # Missing attribute
|
||||
task_lock.last_task_result = None # Missing attribute
|
||||
task_lock.last_task_summary = None # Missing attribute
|
||||
|
||||
options = MagicMock()
|
||||
options.file_save_path.return_value = str(temp_dir)
|
||||
|
||||
result = build_context_for_workforce(task_lock, options)
|
||||
|
||||
# Should handle missing attributes gracefully
|
||||
assert result == ""
|
||||
|
||||
def test_build_context_for_workforce_file_save_path_exception(self):
|
||||
"""Test build_context_for_workforce handles file_save_path exceptions."""
|
||||
task_lock = MagicMock(spec=TaskLock)
|
||||
task_lock.conversation_history = []
|
||||
task_lock.last_task_result = "Test result"
|
||||
task_lock.last_task_summary = "Test summary"
|
||||
|
||||
options = MagicMock()
|
||||
options.file_save_path.side_effect = Exception("Path error")
|
||||
|
||||
with patch('app.service.chat_service.logger') as mock_logger:
|
||||
# Should handle exception when getting file path
|
||||
with pytest.raises(Exception, match="Path error"):
|
||||
build_context_for_workforce(task_lock, options)
|
||||
|
||||
def test_collect_previous_task_context_unicode_handling(self, temp_dir):
|
||||
"""Test collect_previous_task_context handles unicode content correctly."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create files with unicode content
|
||||
(temp_dir / "unicode_file.txt").write_text("Unicode content: 🐍 Python ñáéíóú", encoding='utf-8')
|
||||
|
||||
unicode_task_content = "Create files with unicode: 🔥 emojis and ñáéíóú accents"
|
||||
unicode_result = "Files created successfully with unicode: ✅ done"
|
||||
unicode_summary = "Unicode Task: 📝 file creation"
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content=unicode_task_content,
|
||||
previous_task_result=unicode_result,
|
||||
previous_summary=unicode_summary
|
||||
)
|
||||
|
||||
# Should handle unicode correctly
|
||||
assert "🔥 emojis" in result
|
||||
assert "ñáéíóú accents" in result
|
||||
assert "✅ done" in result
|
||||
assert "📝 file creation" in result
|
||||
assert "unicode_file.txt" in result
|
||||
|
||||
def test_collect_previous_task_context_very_long_content(self, temp_dir):
|
||||
"""Test collect_previous_task_context handles very long content."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create very long content strings
|
||||
long_content = "Very long task content. " * 1000 # ~25KB
|
||||
long_result = "Very long task result. " * 1000 # ~23KB
|
||||
long_summary = "Very long summary. " * 100 # ~1.8KB
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content=long_content,
|
||||
previous_task_result=long_result,
|
||||
previous_summary=long_summary
|
||||
)
|
||||
|
||||
# Should handle long content without issues
|
||||
assert len(result) > 49000 # Should be quite long
|
||||
assert "Very long task content." in result
|
||||
assert "Very long task result." in result
|
||||
assert "Very long summary." in result
|
||||
|
||||
def test_collect_previous_task_context_many_files(self, temp_dir):
|
||||
"""Test collect_previous_task_context performance with many files."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create many files to test performance
|
||||
for i in range(100):
|
||||
(temp_dir / f"file_{i:03d}.txt").write_text(f"Content {i}")
|
||||
|
||||
# Create subdirectories with files
|
||||
for dir_i in range(10):
|
||||
sub_dir = temp_dir / f"subdir_{dir_i}"
|
||||
sub_dir.mkdir()
|
||||
for file_i in range(10):
|
||||
(sub_dir / f"subfile_{file_i}.txt").write_text(f"Sub content {dir_i}-{file_i}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test many files",
|
||||
previous_task_result="Many files processed",
|
||||
previous_summary="Performance test"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Should complete in reasonable time (less than 1 second for 200 files)
|
||||
assert execution_time < 1.0
|
||||
|
||||
# Should list all files
|
||||
assert "Generated Files from Previous Task:" in result
|
||||
# Count number of file entries
|
||||
file_lines = [line for line in result.split('\n') if ' - ' in line]
|
||||
assert len(file_lines) == 200 # 100 main files + 100 subfiles
|
||||
|
||||
def test_collect_previous_task_context_special_characters_in_filenames(self, temp_dir):
|
||||
"""Test collect_previous_task_context handles special characters in filenames."""
|
||||
working_directory = str(temp_dir)
|
||||
|
||||
# Create files with special characters (that are valid in filenames)
|
||||
try:
|
||||
(temp_dir / "file with spaces.txt").write_text("content")
|
||||
(temp_dir / "file-with-dashes.txt").write_text("content")
|
||||
(temp_dir / "file_with_underscores.txt").write_text("content")
|
||||
(temp_dir / "file.with.dots.txt").write_text("content")
|
||||
except OSError:
|
||||
# Skip if filesystem doesn't support these characters
|
||||
pytest.skip("Filesystem doesn't support special characters in filenames")
|
||||
|
||||
result = collect_previous_task_context(
|
||||
working_directory=working_directory,
|
||||
previous_task_content="Test special chars",
|
||||
previous_task_result="Files created",
|
||||
previous_summary=""
|
||||
)
|
||||
|
||||
# Should list files with special characters
|
||||
assert "file with spaces.txt" in result
|
||||
assert "file-with-dashes.txt" in result
|
||||
assert "file_with_underscores.txt" in result
|
||||
assert "file.with.dots.txt" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_confirm_agent_error(self, mock_camel_agent):
|
||||
"""Test question_confirm when agent raises error."""
|
||||
|
|
|
|||
100
backend/tests/unit/utils/test_terminal_toolkit.py
Normal file
100
backend/tests/unit/utils/test_terminal_toolkit.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from app.service.task import task_locks, TaskLock
|
||||
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTerminalToolkit:
|
||||
"""Test to verify the RuntimeError: no running event loop."""
|
||||
|
||||
def test_no_runtime_error_in_sync_context(self):
|
||||
"""Test no running event loop."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
# This should NOT raise RuntimeError: no running event loop
|
||||
# This simulates the exact scenario from the error traceback
|
||||
try:
|
||||
toolkit._write_to_log("/tmp/test.log", "Test output")
|
||||
time.sleep(0.1) # Give thread time to complete
|
||||
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" in str(e):
|
||||
pytest.fail("RuntimeError: no running event loop should not be raised - the fix is not working!")
|
||||
else:
|
||||
raise # Re-raise if it's a different RuntimeError
|
||||
|
||||
def test_multiple_calls_no_runtime_error(self):
|
||||
"""Test that multiple calls don't raise RuntimeError."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
# Make multiple calls - none should raise RuntimeError
|
||||
try:
|
||||
for i in range(5):
|
||||
toolkit._write_to_log(f"/tmp/test_{i}.log", f"Output {i}")
|
||||
time.sleep(0.2) # Give threads time to complete
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" in str(e):
|
||||
pytest.fail("RuntimeError: no running event loop should not be raised!")
|
||||
else:
|
||||
raise
|
||||
|
||||
def test_thread_safety_no_runtime_error(self):
|
||||
"""Test thread safety without RuntimeError."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
# Create multiple threads that call _write_to_log
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(
|
||||
target=toolkit._write_to_log,
|
||||
args=(f"/tmp/test_{i}.log", f"Thread {i} output")
|
||||
)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
time.sleep(0.2) # Give async operations time to complete
|
||||
|
||||
# Should not have raised any RuntimeError
|
||||
|
||||
def test_async_context_still_works(self):
|
||||
"""Test that async context still works without RuntimeError."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
async def test_async_context():
|
||||
toolkit._write_to_log("/tmp/async_test.log", "Async context test")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Should work in async context without RuntimeError
|
||||
try:
|
||||
asyncio.run(test_async_context())
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" in str(e):
|
||||
pytest.fail("RuntimeError: no running event loop should not be raised in async context!")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
|
||||
|
|
@ -370,13 +370,13 @@ class TestWorkforce:
|
|||
)
|
||||
|
||||
with patch('app.service.task.delete_task_lock', side_effect=Exception("Delete failed")), \
|
||||
patch('loguru.logger.error') as mock_log_error:
|
||||
patch('traceroot.get_logger') as mock_get_logger:
|
||||
|
||||
# Should not raise exception
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
# Should log the error
|
||||
mock_log_error.assert_called_once()
|
||||
mock_get_logger.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -623,13 +623,13 @@ class TestWorkforceErrorCases:
|
|||
)
|
||||
|
||||
with patch('app.service.task.delete_task_lock', side_effect=Exception("Task lock not found")), \
|
||||
patch('loguru.logger.error') as mock_log_error:
|
||||
patch('traceroot.get_logger') as mock_get_logger:
|
||||
|
||||
# Should handle missing task lock gracefully
|
||||
await workforce.cleanup()
|
||||
|
||||
|
||||
# Should log the error
|
||||
mock_log_error.assert_called_once()
|
||||
mock_get_logger.assert_called_once()
|
||||
|
||||
def test_workforce_inheritance(self):
|
||||
"""Test that Workforce properly inherits from BaseWorkforce."""
|
||||
|
|
|
|||
1661
backend/uv.lock
generated
1661
backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -10,6 +10,7 @@
|
|||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"iconLibrary": "lucide",
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils",
|
||||
|
|
@ -17,5 +18,7 @@
|
|||
"lib": "@/lib",
|
||||
"hooks": "@/hooks"
|
||||
},
|
||||
"iconLibrary": "lucide"
|
||||
}
|
||||
"registries": {
|
||||
"@animate-ui": "https://animate-ui.com/r/{name}.json"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
15
config/browser-profiles.json
Normal file
15
config/browser-profiles.json
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"profiles": {
|
||||
"userLogin": {
|
||||
"name": "profile_user_login",
|
||||
"partition": "user_login",
|
||||
"description": "Profile for user login browser"
|
||||
},
|
||||
"project": {
|
||||
"nameTemplate": "profile_{port}",
|
||||
"partitionTemplate": "project_{port}",
|
||||
"description": "Profile for project browser instances"
|
||||
}
|
||||
},
|
||||
"basePath": "~/.eigent/browser_profiles"
|
||||
}
|
||||
|
|
@ -12,6 +12,10 @@
|
|||
"from": "backend",
|
||||
"to": "backend",
|
||||
"filter": ["**/*", "!.venv/**/*"]
|
||||
},
|
||||
{
|
||||
"from": "utils",
|
||||
"to": "utils"
|
||||
}
|
||||
],
|
||||
"protocols": [
|
||||
|
|
|
|||
|
|
@ -9,6 +9,15 @@ import https from 'https'
|
|||
import http from 'http'
|
||||
import { URL } from 'url'
|
||||
|
||||
interface FileInfo {
|
||||
path: string;
|
||||
name: string;
|
||||
type: string;
|
||||
isFolder: boolean;
|
||||
relativePath: string;
|
||||
task_id?: string;
|
||||
project_id?: string;
|
||||
}
|
||||
|
||||
export class FileReader {
|
||||
private win: BrowserWindow | null = null
|
||||
|
|
@ -541,12 +550,54 @@ export class FileReader {
|
|||
}
|
||||
}
|
||||
|
||||
public getFileList(email: string, taskId: string): FileInfo[] {
|
||||
private findTaskInProjects(userDir: string, taskId: string): string | null {
|
||||
try {
|
||||
if (!fs.existsSync(userDir)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const entries = fs.readdirSync(userDir);
|
||||
|
||||
// Look for project directories
|
||||
for (const entry of entries) {
|
||||
if (entry.startsWith('project_')) {
|
||||
const projectDir = path.join(userDir, entry);
|
||||
const taskDir = path.join(projectDir, `task_${taskId}`);
|
||||
|
||||
if (fs.existsSync(taskDir)) {
|
||||
return taskDir;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (err) {
|
||||
console.error("Error finding task in projects:", err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public getFileList(email: string, taskId: string, projectId?: string): FileInfo[] {
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
|
||||
const userHome = app.getPath('home');
|
||||
const dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
|
||||
|
||||
let dirPath: string;
|
||||
|
||||
// Check if projectId is provided for new project-based structure
|
||||
if (projectId) {
|
||||
dirPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
|
||||
} else {
|
||||
// First try project-based structure (scan for existing projects)
|
||||
const userDir = path.join(userHome, "eigent", safeEmail);
|
||||
const projectBasedPath = this.findTaskInProjects(userDir, taskId);
|
||||
|
||||
if (projectBasedPath) {
|
||||
dirPath = projectBasedPath;
|
||||
} else {
|
||||
// Fallback to legacy direct task structure
|
||||
dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
if (!fs.existsSync(dirPath)) {
|
||||
|
|
@ -560,21 +611,53 @@ export class FileReader {
|
|||
}
|
||||
}
|
||||
|
||||
public deleteTaskFiles(email: string, taskId: string): {
|
||||
public deleteTaskFiles(email: string, taskId: string, projectId?: string): {
|
||||
success: boolean;
|
||||
path: { dirPath: string; logPath: string }
|
||||
}
|
||||
{
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
const userHome = app.getPath('home');
|
||||
const dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
|
||||
const logPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
|
||||
try {
|
||||
if (fs.existsSync(dirPath)&&fs.existsSync(logPath)) {
|
||||
fs.rmSync(dirPath, { recursive: true, force: true });
|
||||
fs.rmSync(logPath, { recursive: true, force: true });
|
||||
|
||||
let dirPath: string;
|
||||
let logPath: string;
|
||||
|
||||
// Check if projectId is provided for new project-based structure
|
||||
if (projectId) {
|
||||
dirPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
|
||||
logPath = path.join(userHome, ".eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
|
||||
} else {
|
||||
// First try project-based structure
|
||||
const userDir = path.join(userHome, "eigent", safeEmail);
|
||||
const projectBasedPath = this.findTaskInProjects(userDir, taskId);
|
||||
|
||||
if (projectBasedPath) {
|
||||
dirPath = projectBasedPath;
|
||||
// Extract project from path to construct log path
|
||||
const projectMatch = projectBasedPath.match(/project_([^\\\/]+)/);
|
||||
if (projectMatch) {
|
||||
logPath = path.join(userHome, ".eigent", safeEmail, projectMatch[0], `task_${taskId}`);
|
||||
} else {
|
||||
logPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
|
||||
}
|
||||
} else {
|
||||
// Fallback to legacy direct task structure
|
||||
dirPath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
|
||||
logPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
|
||||
}
|
||||
return { success: true, path: { dirPath, logPath } };
|
||||
}
|
||||
|
||||
try {
|
||||
let success = false;
|
||||
if (fs.existsSync(dirPath)) {
|
||||
fs.rmSync(dirPath, { recursive: true, force: true });
|
||||
success = true;
|
||||
}
|
||||
if (fs.existsSync(logPath)) {
|
||||
fs.rmSync(logPath, { recursive: true, force: true });
|
||||
success = true;
|
||||
}
|
||||
return { success, path: { dirPath, logPath } };
|
||||
} catch (err) {
|
||||
console.error("Delete task files failed:", dirPath, err);
|
||||
return { success: false, path: { dirPath, logPath } };
|
||||
|
|
@ -582,9 +665,7 @@ export class FileReader {
|
|||
}
|
||||
|
||||
public getLogFolder(email: string): string {
|
||||
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
|
||||
const userHome = app.getPath('home');
|
||||
const dirPath = path.join(userHome, "eigent", safeEmail);
|
||||
|
||||
|
|
@ -599,5 +680,205 @@ export class FileReader {
|
|||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
public createProjectStructure(email: string, projectId: string): { success: boolean; path: string } {
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
const userHome = app.getPath('home');
|
||||
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
|
||||
|
||||
try {
|
||||
if (!fs.existsSync(projectPath)) {
|
||||
fs.mkdirSync(projectPath, { recursive: true });
|
||||
}
|
||||
return { success: true, path: projectPath };
|
||||
} catch (err) {
|
||||
console.error("Create project structure failed:", err);
|
||||
return { success: false, path: projectPath };
|
||||
}
|
||||
}
|
||||
|
||||
public getProjectList(email: string): Array<{ id: string; name: string; path: string; taskCount: number; createdAt: Date }> {
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
const userHome = app.getPath('home');
|
||||
const userDir = path.join(userHome, "eigent", safeEmail);
|
||||
|
||||
try {
|
||||
if (!fs.existsSync(userDir)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const entries = fs.readdirSync(userDir);
|
||||
const projects: Array<{ id: string; name: string; path: string; taskCount: number; createdAt: Date }> = [];
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.startsWith('project_')) {
|
||||
const projectPath = path.join(userDir, entry);
|
||||
const stats = fs.statSync(projectPath);
|
||||
|
||||
if (stats.isDirectory()) {
|
||||
const projectId = entry.replace('project_', '');
|
||||
|
||||
// Count tasks in this project
|
||||
const taskCount = this.countTasksInProject(projectPath);
|
||||
|
||||
projects.push({
|
||||
id: projectId,
|
||||
name: `Project ${projectId}`,
|
||||
path: projectPath,
|
||||
taskCount,
|
||||
createdAt: stats.birthtime
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return projects.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
|
||||
} catch (err) {
|
||||
console.error("Get project list failed:", err);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
public getTasksInProject(email: string, projectId: string): Array<{ id: string; name: string; path: string; createdAt: Date }> {
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
const userHome = app.getPath('home');
|
||||
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
|
||||
|
||||
try {
|
||||
if (!fs.existsSync(projectPath)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const entries = fs.readdirSync(projectPath);
|
||||
const tasks: Array<{ id: string; name: string; path: string; createdAt: Date }> = [];
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.startsWith('task_')) {
|
||||
const taskPath = path.join(projectPath, entry);
|
||||
const stats = fs.statSync(taskPath);
|
||||
|
||||
if (stats.isDirectory()) {
|
||||
const taskId = entry.replace('task_', '');
|
||||
|
||||
tasks.push({
|
||||
id: taskId,
|
||||
name: `Task ${taskId}`,
|
||||
path: taskPath,
|
||||
createdAt: stats.birthtime
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tasks.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
|
||||
} catch (err) {
|
||||
console.error("Get tasks in project failed:", err);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
public moveTaskToProject(email: string, taskId: string, projectId: string): { success: boolean; message: string } {
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
const userHome = app.getPath('home');
|
||||
|
||||
// Source path (legacy structure)
|
||||
const sourcePath = path.join(userHome, "eigent", safeEmail, `task_${taskId}`);
|
||||
const sourceLogPath = path.join(userHome, ".eigent", safeEmail, `task_${taskId}`);
|
||||
|
||||
// Destination paths (project structure)
|
||||
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
|
||||
const destPath = path.join(projectPath, `task_${taskId}`);
|
||||
const destLogPath = path.join(userHome, ".eigent", safeEmail, `project_${projectId}`, `task_${taskId}`);
|
||||
|
||||
try {
|
||||
// Create project structure if it doesn't exist
|
||||
if (!fs.existsSync(projectPath)) {
|
||||
fs.mkdirSync(projectPath, { recursive: true });
|
||||
}
|
||||
|
||||
// Create destination log directory
|
||||
const destLogDir = path.dirname(destLogPath);
|
||||
if (!fs.existsSync(destLogDir)) {
|
||||
fs.mkdirSync(destLogDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Move task files
|
||||
if (fs.existsSync(sourcePath)) {
|
||||
fs.renameSync(sourcePath, destPath);
|
||||
}
|
||||
|
||||
// Move log files
|
||||
if (fs.existsSync(sourceLogPath)) {
|
||||
fs.renameSync(sourceLogPath, destLogPath);
|
||||
}
|
||||
|
||||
return { success: true, message: `Task ${taskId} moved to project ${projectId}` };
|
||||
} catch (err) {
|
||||
console.error("Move task to project failed:", err);
|
||||
return { success: false, message: `Failed to move task: ${err}` };
|
||||
}
|
||||
}
|
||||
|
||||
public getProjectFileList(email: string, projectId: string): FileInfo[] {
|
||||
const safeEmail = email.split('@')[0].replace(/[\\/*?:"<>|\s]/g, "_").replace(/^\.+|\.+$/g, "");
|
||||
const userHome = app.getPath('home');
|
||||
const projectPath = path.join(userHome, "eigent", safeEmail, `project_${projectId}`);
|
||||
|
||||
try {
|
||||
if (!fs.existsSync(projectPath)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const allFiles: FileInfo[] = [];
|
||||
const taskDirs = fs.readdirSync(projectPath);
|
||||
|
||||
for (const taskDir of taskDirs) {
|
||||
if (!taskDir.startsWith('task_')) continue;
|
||||
|
||||
const taskPath = path.join(projectPath, taskDir);
|
||||
const stats = fs.statSync(taskPath);
|
||||
|
||||
if (stats.isDirectory()) {
|
||||
const taskId = taskDir.replace('task_', '');
|
||||
const taskFiles = this.getFilesRecursive(taskPath, taskPath);
|
||||
|
||||
const enrichedFiles = taskFiles.map(file => {
|
||||
const fileDir = path.dirname(file.path);
|
||||
const relativeParentPath = path.relative(projectPath, fileDir);
|
||||
|
||||
return {
|
||||
...file,
|
||||
task_id: taskId,
|
||||
project_id: projectId,
|
||||
relativePath: relativeParentPath === '.' ? '' : relativeParentPath
|
||||
};
|
||||
});
|
||||
|
||||
allFiles.push(...enrichedFiles);
|
||||
}
|
||||
}
|
||||
|
||||
return allFiles.sort((a, b) => {
|
||||
// Sort by task_id first, then by file path
|
||||
if (a.task_id !== b.task_id) {
|
||||
return a.task_id!.localeCompare(b.task_id!);
|
||||
}
|
||||
return a.path.localeCompare(b.path);
|
||||
});
|
||||
} catch (err) {
|
||||
console.error("Get project file list failed:", err);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private countTasksInProject(projectPath: string): number {
|
||||
try {
|
||||
const entries = fs.readdirSync(projectPath);
|
||||
return entries.filter(entry => entry.startsWith('task_')).length;
|
||||
} catch (err) {
|
||||
console.error("Count tasks in project failed:", err);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,15 +40,45 @@ let python_process: ChildProcessWithoutNullStreams | null = null;
|
|||
let backendPort: number = 5001;
|
||||
let browser_port = 9222;
|
||||
|
||||
// Protocol URL queue for handling URLs before window is ready
|
||||
let protocolUrlQueue: string[] = [];
|
||||
let isWindowReady = false;
|
||||
|
||||
// ==================== path config ====================
|
||||
const preload = path.join(__dirname, '../preload/index.mjs');
|
||||
const indexHtml = path.join(RENDERER_DIST, 'index.html');
|
||||
const logPath = log.transports.file.getFile().path;
|
||||
|
||||
// Profile initialization promise
|
||||
let profileInitPromise: Promise<void>;
|
||||
|
||||
// Set remote debugging port
|
||||
findAvailablePort(browser_port).then(port => {
|
||||
// Storage strategy:
|
||||
// 1. Main window: partition 'persist:main_window' in app userData → Eigent account (persistent)
|
||||
// 2. WebView: partition 'persist:user_login' in app userData → will import cookies from tool_controller via session API
|
||||
// 3. tool_controller: ~/.eigent/browser_profiles/profile_user_login → source of truth for login cookies
|
||||
// 4. CDP browser: uses separate profile (doesn't share with main app)
|
||||
profileInitPromise = findAvailablePort(browser_port).then(async port => {
|
||||
browser_port = port;
|
||||
app.commandLine.appendSwitch('remote-debugging-port', port + '');
|
||||
|
||||
// Create isolated profile for CDP browser only
|
||||
const browserProfilesBase = path.join(os.homedir(), '.eigent', 'browser_profiles');
|
||||
const cdpProfile = path.join(browserProfilesBase, `cdp_profile_${port}`);
|
||||
|
||||
try {
|
||||
await fsp.mkdir(cdpProfile, { recursive: true });
|
||||
log.info(`[CDP BROWSER] Created CDP profile directory at ${cdpProfile}`);
|
||||
} catch (error) {
|
||||
log.error(`[CDP BROWSER] Failed to create directory: ${error}`);
|
||||
}
|
||||
|
||||
// Set user-data-dir for Chrome DevTools Protocol only
|
||||
app.commandLine.appendSwitch('user-data-dir', cdpProfile);
|
||||
|
||||
log.info(`[CDP BROWSER] Chrome DevTools Protocol enabled on port ${port}`);
|
||||
log.info(`[CDP BROWSER] CDP profile directory: ${cdpProfile}`);
|
||||
log.info(`[STORAGE] Main app userData: ${app.getPath('userData')}`);
|
||||
});
|
||||
|
||||
// Memory optimization settings
|
||||
|
|
@ -97,6 +127,19 @@ const setupProtocolHandlers = () => {
|
|||
// ==================== protocol url handle ====================
|
||||
function handleProtocolUrl(url: string) {
|
||||
log.info('enter handleProtocolUrl', url);
|
||||
|
||||
// If window is not ready, queue the URL
|
||||
if (!isWindowReady || !win || win.isDestroyed()) {
|
||||
log.info('Window not ready, queuing protocol URL:', url);
|
||||
protocolUrlQueue.push(url);
|
||||
return;
|
||||
}
|
||||
|
||||
processProtocolUrl(url);
|
||||
}
|
||||
|
||||
// Process a single protocol URL
|
||||
function processProtocolUrl(url: string) {
|
||||
const urlObj = new URL(url);
|
||||
const code = urlObj.searchParams.get('code');
|
||||
const share_token = urlObj.searchParams.get('share_token');
|
||||
|
|
@ -130,6 +173,26 @@ function handleProtocolUrl(url: string) {
|
|||
}
|
||||
}
|
||||
|
||||
// Process all queued protocol URLs
|
||||
function processQueuedProtocolUrls() {
|
||||
if (protocolUrlQueue.length > 0) {
|
||||
log.info('Processing queued protocol URLs:', protocolUrlQueue.length);
|
||||
|
||||
// Verify window is ready before processing
|
||||
if (!win || win.isDestroyed() || !isWindowReady) {
|
||||
log.warn('Window not ready for processing queued URLs, keeping URLs in queue');
|
||||
return;
|
||||
}
|
||||
|
||||
const urls = [...protocolUrlQueue];
|
||||
protocolUrlQueue = [];
|
||||
|
||||
urls.forEach(url => {
|
||||
processProtocolUrl(url);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== single instance lock ====================
|
||||
const setupSingleInstanceLock = () => {
|
||||
const gotLock = app.requestSingleInstanceLock();
|
||||
|
|
@ -207,11 +270,26 @@ const checkManagerInstance = (manager: any, name: string) => {
|
|||
function registerIpcHandlers() {
|
||||
// ==================== basic info handler ====================
|
||||
ipcMain.handle('get-browser-port', () => {
|
||||
log.info('Starting new task')
|
||||
log.info('Getting browser port')
|
||||
return browser_port
|
||||
});
|
||||
ipcMain.handle('get-app-version', () => app.getVersion());
|
||||
ipcMain.handle('get-backend-port', () => backendPort);
|
||||
|
||||
// ==================== restart app handler ====================
|
||||
ipcMain.handle('restart-app', async () => {
|
||||
log.info('[RESTART] Restarting app to apply user profile changes');
|
||||
|
||||
// Clean up Python process first
|
||||
await cleanupPythonProcess();
|
||||
|
||||
// Schedule relaunch after a short delay
|
||||
setTimeout(() => {
|
||||
app.relaunch();
|
||||
app.quit();
|
||||
}, 100);
|
||||
});
|
||||
|
||||
ipcMain.handle('restart-backend', async () => {
|
||||
try {
|
||||
if (backendPort) {
|
||||
|
|
@ -609,6 +687,13 @@ function registerIpcHandlers() {
|
|||
return { success: false, error: 'File does not exist' };
|
||||
}
|
||||
|
||||
// Check if it's a directory
|
||||
const stats = await fsp.stat(filePath);
|
||||
if (stats.isDirectory()) {
|
||||
log.error('Path is a directory, not a file:', filePath);
|
||||
return { success: false, error: 'Path is a directory, not a file' };
|
||||
}
|
||||
|
||||
// Read file content
|
||||
const fileContent = await fsp.readFile(filePath);
|
||||
log.info('File read successfully:', filePath);
|
||||
|
|
@ -712,6 +797,24 @@ function registerIpcHandlers() {
|
|||
let lines = content.split(/\r?\n/);
|
||||
lines = updateEnvBlock(lines, { [key]: value });
|
||||
fs.writeFileSync(ENV_PATH, lines.join('\n'), 'utf-8');
|
||||
|
||||
// Also write to global .env file for backend process to read
|
||||
const GLOBAL_ENV_PATH = path.join(os.homedir(), '.eigent', '.env');
|
||||
let globalContent = '';
|
||||
try {
|
||||
globalContent = fs.existsSync(GLOBAL_ENV_PATH) ? fs.readFileSync(GLOBAL_ENV_PATH, 'utf-8') : '';
|
||||
} catch (error) {
|
||||
log.error("global env-write read error:", error);
|
||||
}
|
||||
let globalLines = globalContent.split(/\r?\n/);
|
||||
globalLines = updateEnvBlock(globalLines, { [key]: value });
|
||||
try {
|
||||
fs.writeFileSync(GLOBAL_ENV_PATH, globalLines.join('\n'), 'utf-8');
|
||||
log.info(`env-write: wrote ${key} to both user and global .env files`);
|
||||
} catch (error) {
|
||||
log.error("global env-write error:", error);
|
||||
}
|
||||
|
||||
return { success: true };
|
||||
});
|
||||
|
||||
|
|
@ -728,6 +831,19 @@ function registerIpcHandlers() {
|
|||
lines = removeEnvKey(lines, key);
|
||||
fs.writeFileSync(ENV_PATH, lines.join('\n'), 'utf-8');
|
||||
log.info("env-remove success", ENV_PATH);
|
||||
|
||||
// Also remove from global .env file
|
||||
const GLOBAL_ENV_PATH = path.join(os.homedir(), '.eigent', '.env');
|
||||
try {
|
||||
let globalContent = fs.existsSync(GLOBAL_ENV_PATH) ? fs.readFileSync(GLOBAL_ENV_PATH, 'utf-8') : '';
|
||||
let globalLines = globalContent.split(/\r?\n/);
|
||||
globalLines = removeEnvKey(globalLines, key);
|
||||
fs.writeFileSync(GLOBAL_ENV_PATH, globalLines.join('\n'), 'utf-8');
|
||||
log.info(`env-remove: removed ${key} from both user and global .env files`);
|
||||
} catch (error) {
|
||||
log.error("global env-remove error:", error);
|
||||
}
|
||||
|
||||
return { success: true };
|
||||
});
|
||||
|
||||
|
|
@ -802,14 +918,40 @@ function registerIpcHandlers() {
|
|||
}
|
||||
});
|
||||
|
||||
ipcMain.handle('get-file-list', async (_, email: string, taskId: string) => {
|
||||
ipcMain.handle('get-file-list', async (_, email: string, taskId: string, projectId?: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.getFileList(email, taskId);
|
||||
return manager.getFileList(email, taskId, projectId);
|
||||
});
|
||||
|
||||
ipcMain.handle('delete-task-files', async (_, email: string, taskId: string) => {
|
||||
ipcMain.handle('delete-task-files', async (_, email: string, taskId: string, projectId?: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.deleteTaskFiles(email, taskId);
|
||||
return manager.deleteTaskFiles(email, taskId, projectId);
|
||||
});
|
||||
|
||||
// New project management handlers
|
||||
ipcMain.handle('create-project-structure', async (_, email: string, projectId: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.createProjectStructure(email, projectId);
|
||||
});
|
||||
|
||||
ipcMain.handle('get-project-list', async (_, email: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.getProjectList(email);
|
||||
});
|
||||
|
||||
ipcMain.handle('get-tasks-in-project', async (_, email: string, projectId: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.getTasksInProject(email, projectId);
|
||||
});
|
||||
|
||||
ipcMain.handle('move-task-to-project', async (_, email: string, taskId: string, projectId: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.moveTaskToProject(email, taskId, projectId);
|
||||
});
|
||||
|
||||
ipcMain.handle('get-project-file-list', async (_, email: string, projectId: string) => {
|
||||
const manager = checkManagerInstance(fileReader, 'FileReader');
|
||||
return manager.getProjectFileList(email, projectId);
|
||||
});
|
||||
|
||||
ipcMain.handle('get-log-folder', async (_, email: string) => {
|
||||
|
|
@ -905,6 +1047,10 @@ async function createWindow() {
|
|||
// Ensure .eigent directories exist before anything else
|
||||
ensureEigentDirectories();
|
||||
|
||||
log.info(`[PROJECT BROWSER WINDOW] Creating BrowserWindow which will start Chrome with CDP on port ${browser_port}`);
|
||||
log.info(`[PROJECT BROWSER WINDOW] Current user data path: ${app.getPath('userData')}`);
|
||||
log.info(`[PROJECT BROWSER WINDOW] Command line switch user-data-dir: ${app.commandLine.getSwitchValue('user-data-dir')}`);
|
||||
|
||||
win = new BrowserWindow({
|
||||
title: 'Eigent',
|
||||
width: 1200,
|
||||
|
|
@ -921,6 +1067,9 @@ async function createWindow() {
|
|||
icon: path.join(VITE_PUBLIC, 'favicon.ico'),
|
||||
roundedCorners: true,
|
||||
webPreferences: {
|
||||
// Use a dedicated partition for main window to isolate from webviews
|
||||
// This ensures main window's auth data (localStorage) is stored separately and persists across restarts
|
||||
partition: 'persist:main_window',
|
||||
webSecurity: false,
|
||||
preload,
|
||||
nodeIntegration: true,
|
||||
|
|
@ -930,14 +1079,58 @@ async function createWindow() {
|
|||
},
|
||||
});
|
||||
|
||||
// Main window now uses default userData directly with partition 'persist:main_window'
|
||||
// No migration needed - data is already persistent
|
||||
|
||||
// ==================== Import cookies from tool_controller to WebView BEFORE creating WebViews ====================
|
||||
// Copy partition data files before any session accesses them
|
||||
try {
|
||||
const browserProfilesBase = path.join(os.homedir(), '.eigent', 'browser_profiles');
|
||||
const toolControllerProfile = path.join(browserProfilesBase, 'profile_user_login');
|
||||
const toolControllerPartitionPath = path.join(toolControllerProfile, 'Partitions', 'user_login');
|
||||
|
||||
if (fs.existsSync(toolControllerPartitionPath)) {
|
||||
log.info('[COOKIE SYNC] Found tool_controller partition, copying to WebView partition...');
|
||||
|
||||
const targetPartitionPath = path.join(app.getPath('userData'), 'Partitions', 'user_login');
|
||||
log.info('[COOKIE SYNC] From:', toolControllerPartitionPath);
|
||||
log.info('[COOKIE SYNC] To:', targetPartitionPath);
|
||||
|
||||
// Ensure target directory exists
|
||||
if (!fs.existsSync(path.dirname(targetPartitionPath))) {
|
||||
fs.mkdirSync(path.dirname(targetPartitionPath), { recursive: true });
|
||||
}
|
||||
|
||||
// Copy the entire partition directory
|
||||
fs.cpSync(toolControllerPartitionPath, targetPartitionPath, {
|
||||
recursive: true,
|
||||
force: true
|
||||
});
|
||||
log.info('[COOKIE SYNC] Successfully copied partition data to WebView');
|
||||
|
||||
// Verify cookies were copied
|
||||
const targetCookies = path.join(targetPartitionPath, 'Cookies');
|
||||
if (fs.existsSync(targetCookies)) {
|
||||
const stats = fs.statSync(targetCookies);
|
||||
log.info(`[COOKIE SYNC] Cookies file size: ${stats.size} bytes`);
|
||||
}
|
||||
} else {
|
||||
log.info('[COOKIE SYNC] No tool_controller partition found, WebView will start fresh');
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('[COOKIE SYNC] Failed to sync partition data:', error);
|
||||
}
|
||||
|
||||
// ==================== initialize manager ====================
|
||||
fileReader = new FileReader(win);
|
||||
webViewManager = new WebViewManager(win);
|
||||
|
||||
// create initial webviews (reduced from 8 to 3)
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
// create multiple webviews
|
||||
log.info(`[PROJECT BROWSER] Creating WebViews with partition: persist:user_login`);
|
||||
for (let i = 1; i <= 8; i++) {
|
||||
webViewManager.createWebview(i === 1 ? undefined : i.toString());
|
||||
}
|
||||
log.info('[PROJECT BROWSER] WebViewManager initialized with webviews');
|
||||
|
||||
// ==================== set event listeners ====================
|
||||
setupWindowEventListeners();
|
||||
|
|
@ -990,7 +1183,9 @@ async function createWindow() {
|
|||
log.info('Installation needed - clearing auth storage to force carousel state');
|
||||
|
||||
// Clear the persisted auth storage file to force fresh initialization with carousel
|
||||
const localStoragePath = path.join(app.getPath('userData'), 'Local Storage');
|
||||
// Main window uses partition 'persist:main_window', so data is in Partitions/main_window
|
||||
const partitionPath = path.join(app.getPath('userData'), 'Partitions', 'main_window');
|
||||
const localStoragePath = path.join(partitionPath, 'Local Storage');
|
||||
const leveldbPath = path.join(localStoragePath, 'leveldb');
|
||||
|
||||
try {
|
||||
|
|
@ -1056,8 +1251,10 @@ async function createWindow() {
|
|||
(function() {
|
||||
try {
|
||||
const authStorage = localStorage.getItem('auth-storage');
|
||||
console.log('[ELECTRON DEBUG] Current auth-storage:', authStorage);
|
||||
if (authStorage) {
|
||||
const parsed = JSON.parse(authStorage);
|
||||
console.log('[ELECTRON DEBUG] Parsed state:', parsed.state);
|
||||
if (parsed.state && parsed.state.initState !== 'done') {
|
||||
console.log('[ELECTRON] Updating initState from', parsed.state.initState, 'to done');
|
||||
// Only update the initState field, preserve all other data
|
||||
|
|
@ -1071,7 +1268,11 @@ async function createWindow() {
|
|||
localStorage.setItem('auth-storage', JSON.stringify(updatedStorage));
|
||||
console.log('[ELECTRON] initState updated to done, reloading page...');
|
||||
return true; // Signal that we need to reload
|
||||
} else {
|
||||
console.log('[ELECTRON DEBUG] initState already done or state missing');
|
||||
}
|
||||
} else {
|
||||
console.log('[ELECTRON DEBUG] No auth-storage found in localStorage');
|
||||
}
|
||||
return false; // No reload needed
|
||||
} catch (e) {
|
||||
|
|
@ -1107,6 +1308,11 @@ async function createWindow() {
|
|||
});
|
||||
});
|
||||
|
||||
// Mark window as ready and process any queued protocol URLs
|
||||
isWindowReady = true;
|
||||
log.info('Window is ready, processing queued protocol URLs...');
|
||||
processQueuedProtocolUrls();
|
||||
|
||||
// Now check and install dependencies
|
||||
let res:PromiseReturnType = await checkAndInstallDepsOnUpdate({ win });
|
||||
if (!res.success) {
|
||||
|
|
@ -1282,7 +1488,15 @@ const handleBeforeClose = () => {
|
|||
}
|
||||
|
||||
// ==================== app event handle ====================
|
||||
app.whenReady().then(() => {
|
||||
app.whenReady().then(async () => {
|
||||
// Wait for profile initialization to complete
|
||||
log.info('[MAIN] Waiting for profile initialization...');
|
||||
try {
|
||||
await profileInitPromise;
|
||||
log.info('[MAIN] Profile initialization completed');
|
||||
} catch (error) {
|
||||
log.error('[MAIN] Profile initialization failed:', error);
|
||||
}
|
||||
|
||||
// ==================== download handle ====================
|
||||
session.defaultSession.on('will-download', (event, item, webContents) => {
|
||||
|
|
@ -1339,7 +1553,10 @@ app.on('window-all-closed', () => {
|
|||
webViewManager = null;
|
||||
}
|
||||
|
||||
// Reset window state
|
||||
win = null;
|
||||
isWindowReady = false;
|
||||
protocolUrlQueue = [];
|
||||
|
||||
if (process.platform !== 'darwin') {
|
||||
app.quit();
|
||||
|
|
@ -1363,35 +1580,42 @@ app.on('activate', () => {
|
|||
app.on('before-quit', async (event) => {
|
||||
log.info('before-quit');
|
||||
log.info('quit python_process.pid: ' + python_process?.pid);
|
||||
|
||||
|
||||
// Prevent default quit to ensure cleanup completes
|
||||
event.preventDefault();
|
||||
|
||||
|
||||
try {
|
||||
// NOTE: Profile sync removed - we now use app userData directly for all partitions
|
||||
// No need to sync between different profile directories
|
||||
|
||||
// Clean up resources
|
||||
if (webViewManager) {
|
||||
webViewManager.destroy();
|
||||
webViewManager = null;
|
||||
}
|
||||
|
||||
|
||||
if (win && !win.isDestroyed()) {
|
||||
win.destroy();
|
||||
win = null;
|
||||
}
|
||||
|
||||
|
||||
// Wait for Python process cleanup
|
||||
await cleanupPythonProcess();
|
||||
|
||||
|
||||
// Clean up file reader if exists
|
||||
if (fileReader) {
|
||||
fileReader = null;
|
||||
}
|
||||
|
||||
|
||||
// Clear any remaining timeouts/intervals
|
||||
if (global.gc) {
|
||||
global.gc();
|
||||
}
|
||||
|
||||
|
||||
// Reset protocol handling state
|
||||
isWindowReady = false;
|
||||
protocolUrlQueue = [];
|
||||
|
||||
log.info('All cleanup completed, exiting...');
|
||||
} catch (error) {
|
||||
log.error('Error during cleanup:', error);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import log from 'electron-log'
|
|||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
import * as net from "net";
|
||||
import * as http from "http";
|
||||
import { ipcMain, BrowserWindow, app } from 'electron'
|
||||
import { promisify } from 'util'
|
||||
import { detectInstallationLogs, PromiseReturnType } from "./install-deps";
|
||||
|
|
@ -195,21 +196,77 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
|
|||
|
||||
|
||||
let started = false;
|
||||
let healthCheckInterval: NodeJS.Timeout | null = null;
|
||||
const startTimeout = setTimeout(() => {
|
||||
if (!started) {
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
node_process.kill();
|
||||
reject(new Error('Backend failed to start within timeout'));
|
||||
}
|
||||
}, 30000); // 30 second timeout
|
||||
|
||||
// Helper function to poll health endpoint
|
||||
const pollHealthEndpoint = (): void => {
|
||||
let attempts = 0;
|
||||
const maxAttempts = 20; // 5 seconds total (20 * 250ms)
|
||||
const intervalMs = 250;
|
||||
|
||||
healthCheckInterval = setInterval(() => {
|
||||
attempts++;
|
||||
const healthUrl = `http://127.0.0.1:${port}/health`;
|
||||
|
||||
const req = http.get(healthUrl, { timeout: 1000 }, (res) => {
|
||||
if (res.statusCode === 200) {
|
||||
log.info(`Backend health check passed after ${attempts} attempts`);
|
||||
started = true;
|
||||
clearTimeout(startTimeout);
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
resolve(node_process);
|
||||
} else {
|
||||
// Non-200 status (e.g., 404), continue polling unless max attempts reached
|
||||
if (attempts >= maxAttempts) {
|
||||
log.error(`Backend health check failed after ${attempts} attempts with status ${res.statusCode}`);
|
||||
started = true;
|
||||
clearTimeout(startTimeout);
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
node_process.kill();
|
||||
reject(new Error(`Backend health check failed: HTTP ${res.statusCode}`));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
req.on('error', () => {
|
||||
// Connection error - backend might not be ready yet, continue polling
|
||||
if (attempts >= maxAttempts) {
|
||||
log.error(`Backend health check failed after ${attempts} attempts: unable to connect`);
|
||||
started = true;
|
||||
clearTimeout(startTimeout);
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
node_process.kill();
|
||||
reject(new Error('Backend health check failed: unable to connect'));
|
||||
}
|
||||
});
|
||||
|
||||
req.on('timeout', () => {
|
||||
req.destroy();
|
||||
if (attempts >= maxAttempts) {
|
||||
log.error(`Backend health check timed out after ${attempts} attempts`);
|
||||
started = true;
|
||||
clearTimeout(startTimeout);
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
node_process.kill();
|
||||
reject(new Error('Backend health check timed out'));
|
||||
}
|
||||
});
|
||||
}, intervalMs);
|
||||
};
|
||||
|
||||
node_process.stdout.on('data', (data) => {
|
||||
displayFilteredLogs(data);
|
||||
// check output content, judge if start success
|
||||
if (!started && data.toString().includes("Uvicorn running on")) {
|
||||
started = true;
|
||||
clearTimeout(startTimeout);
|
||||
resolve(node_process);
|
||||
log.info('Uvicorn startup detected, starting health check polling...');
|
||||
pollHealthEndpoint();
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -217,9 +274,8 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
|
|||
displayFilteredLogs(data);
|
||||
|
||||
if (!started && data.toString().includes("Uvicorn running on")) {
|
||||
started = true;
|
||||
clearTimeout(startTimeout);
|
||||
resolve(node_process);
|
||||
log.info('Uvicorn startup detected (stderr), starting health check polling...');
|
||||
pollHealthEndpoint();
|
||||
}
|
||||
|
||||
// Check for port binding errors
|
||||
|
|
@ -227,6 +283,7 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
|
|||
data.toString().includes("bind() failed")) {
|
||||
started = true; // Prevent multiple rejections
|
||||
clearTimeout(startTimeout);
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
node_process.kill();
|
||||
reject(new Error(`Port ${port} is already in use`));
|
||||
}
|
||||
|
|
@ -234,6 +291,7 @@ export async function startBackend(setPort?: (port: number) => void): Promise<an
|
|||
|
||||
node_process.on('close', (code) => {
|
||||
clearTimeout(startTimeout);
|
||||
if (healthCheckInterval) clearInterval(healthCheckInterval);
|
||||
if (!started) {
|
||||
reject(new Error(`fastapi exited with code ${code}`));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import fs from 'node:fs'
|
|||
import { getBackendPath, getBinaryPath, getCachePath, getVenvPath, cleanupOldVenvs, isBinaryExists, runInstallScript } from './utils/process'
|
||||
import { spawn } from 'child_process'
|
||||
import { safeMainWindowSend } from './utils/safeWebContentsSend'
|
||||
import os from 'node:os'
|
||||
|
||||
const userData = app.getPath('userData');
|
||||
const versionFile = path.join(userData, 'version.txt');
|
||||
|
|
@ -57,6 +58,13 @@ Promise<PromiseReturnType> => {
|
|||
|
||||
return new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
// Clean up cache in production environment BEFORE any checks
|
||||
// This ensures users always get fresh dependencies in production
|
||||
if (app.isPackaged) {
|
||||
log.info('[CACHE CLEANUP] Production environment detected, cleaning cache before dependency check...');
|
||||
cleanupCacheInProduction();
|
||||
}
|
||||
|
||||
const versionExists:boolean = checkInstallOperations.getSavedVersion();
|
||||
|
||||
// Check if command tools are installed
|
||||
|
|
@ -230,10 +238,8 @@ class InstallLogs {
|
|||
|
||||
/**Display filtered logs based on severity */
|
||||
displayFilteredLogs(data:String) {
|
||||
if (!data) return;
|
||||
if (!data) return;
|
||||
const msg = data.toString().trimEnd();
|
||||
//Detect if uv sync is run
|
||||
detectInstallationLogs(msg);
|
||||
if (msg.toLowerCase().includes("error") || msg.toLowerCase().includes("traceback")) {
|
||||
log.error(`BACKEND: [DEPS INSTALL] ${msg}`);
|
||||
safeMainWindowSend('install-dependencies-log', { type: 'stderr', data: data.toString() });
|
||||
|
|
@ -282,6 +288,34 @@ class InstallLogs {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up cache directory
|
||||
* This ensures users get fresh dependencies
|
||||
* Note: Only call this in production environment (caller should check app.isPackaged)
|
||||
*/
|
||||
function cleanupCacheInProduction(): void {
|
||||
try {
|
||||
const cacheBaseDir = path.join(os.homedir(), '.eigent', 'cache');
|
||||
|
||||
if (!fs.existsSync(cacheBaseDir)) {
|
||||
log.info('[CACHE CLEANUP] Cache directory does not exist, nothing to clean');
|
||||
return;
|
||||
}
|
||||
|
||||
log.info('[CACHE CLEANUP] Cleaning cache directory:', cacheBaseDir);
|
||||
|
||||
fs.rmSync(cacheBaseDir, { recursive: true, force: true });
|
||||
|
||||
log.info('[CACHE CLEANUP] Cache directory cleaned successfully');
|
||||
|
||||
fs.mkdirSync(cacheBaseDir, { recursive: true });
|
||||
log.info('[CACHE CLEANUP] Empty cache directory recreated');
|
||||
|
||||
} catch (error) {
|
||||
log.error('[CACHE CLEANUP] Failed to clean cache directory:', error);
|
||||
}
|
||||
}
|
||||
|
||||
const runInstall = (extraArgs: string[], version: string) => {
|
||||
const installLogs = new InstallLogs(extraArgs, version);
|
||||
return new Promise<PromiseReturnType>((resolveInner, rejectInner) => {
|
||||
|
|
@ -358,6 +392,29 @@ export async function installDependencies(version: string): Promise<PromiseRetur
|
|||
return true; // Not an error if the toolkit isn't installed
|
||||
}
|
||||
|
||||
// Check if npm dependencies are already installed
|
||||
const npmMarkerPath = path.join(toolkitPath, '.npm_dependencies_installed');
|
||||
const nodeModulesPath = path.join(toolkitPath, 'node_modules');
|
||||
const distPath = path.join(toolkitPath, 'dist');
|
||||
|
||||
// Check if marker exists and verify version
|
||||
if (fs.existsSync(npmMarkerPath) && fs.existsSync(nodeModulesPath) && fs.existsSync(distPath)) {
|
||||
try {
|
||||
const markerContent = JSON.parse(fs.readFileSync(npmMarkerPath, 'utf-8'));
|
||||
if (markerContent.version === version) {
|
||||
log.info('[DEPS INSTALL] hybrid_browser_toolkit npm dependencies already installed for current version, skipping...');
|
||||
return true;
|
||||
} else {
|
||||
log.info('[DEPS INSTALL] npm dependencies installed for different version, will reinstall...');
|
||||
// Clean up old installation
|
||||
fs.unlinkSync(npmMarkerPath);
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn('[DEPS INSTALL] Could not read npm marker file, will reinstall...', error);
|
||||
// If we can't read the marker, assume we need to reinstall
|
||||
}
|
||||
}
|
||||
|
||||
log.info('[DEPS INSTALL] Installing hybrid_browser_toolkit npm dependencies...');
|
||||
safeMainWindowSend('install-dependencies-log', {
|
||||
type: 'stdout',
|
||||
|
|
@ -515,6 +572,13 @@ export async function installDependencies(version: string): Promise<PromiseRetur
|
|||
// Non-critical, continue
|
||||
}
|
||||
|
||||
// Create marker file to indicate npm dependencies are installed
|
||||
fs.writeFileSync(npmMarkerPath, JSON.stringify({
|
||||
installedAt: new Date().toISOString(),
|
||||
version: version
|
||||
}));
|
||||
log.info('[DEPS INSTALL] Created npm dependencies marker file');
|
||||
|
||||
log.info('[DEPS INSTALL] hybrid_browser_toolkit dependencies installed successfully');
|
||||
return true;
|
||||
} catch (error) {
|
||||
|
|
@ -542,6 +606,32 @@ export async function installDependencies(version: string): Promise<PromiseRetur
|
|||
// Set Installing Lock Files
|
||||
InstallLogs.setLockPath();
|
||||
|
||||
// Clean up npm dependencies marker when reinstalling Python deps
|
||||
// This ensures npm deps are reinstalled when Python environment changes
|
||||
try {
|
||||
let sitePackagesPath: string | null = null;
|
||||
const libPath = path.join(venvPath, 'lib');
|
||||
|
||||
if (fs.existsSync(libPath)) {
|
||||
const libContents = fs.readdirSync(libPath);
|
||||
const pythonDir = libContents.find(name => name.startsWith('python'));
|
||||
if (pythonDir) {
|
||||
sitePackagesPath = path.join(libPath, pythonDir, 'site-packages');
|
||||
}
|
||||
}
|
||||
|
||||
if (sitePackagesPath) {
|
||||
const npmMarkerPath = path.join(sitePackagesPath, 'camel', 'toolkits', 'hybrid_browser_toolkit', 'ts', '.npm_dependencies_installed');
|
||||
if (fs.existsSync(npmMarkerPath)) {
|
||||
fs.unlinkSync(npmMarkerPath);
|
||||
log.info('[DEPS INSTALL] Removed npm dependencies marker for fresh installation');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn('[DEPS INSTALL] Could not clean npm marker file:', error);
|
||||
// Non-critical, continue
|
||||
}
|
||||
|
||||
// try default install
|
||||
const installSuccess = await runInstall([], version)
|
||||
if (installSuccess.success) {
|
||||
|
|
@ -592,6 +682,24 @@ export async function installDependencies(version: string): Promise<PromiseRetur
|
|||
let dependencyInstallationDetected = false;
|
||||
let installationNotificationSent = false;
|
||||
export function detectInstallationLogs(msg:string) {
|
||||
// CRITICAL FIX: Use file system to check if installation is complete
|
||||
// Don't rely on module variables as they can be reset during hot reload
|
||||
|
||||
// Check if dependencies are already installed
|
||||
const isAlreadyInstalled = fs.existsSync(installedLockPath);
|
||||
|
||||
// If installed lock file exists, dependencies are already installed
|
||||
// Skip all detection to avoid false positives
|
||||
if (isAlreadyInstalled) {
|
||||
// Dependencies are already installed, skip detection entirely
|
||||
return;
|
||||
}
|
||||
|
||||
// Also skip if notification was already sent (in current session)
|
||||
if (installationNotificationSent) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for UV dependency installation patterns
|
||||
const installPatterns = [
|
||||
"Resolved", // UV resolving dependencies
|
||||
|
|
@ -605,18 +713,18 @@ export function detectInstallationLogs(msg:string) {
|
|||
"× No solution found when resolving dependencies", // Dependency resolution issues
|
||||
"Audited" // UV auditing dependencies
|
||||
];
|
||||
|
||||
|
||||
// Detect if UV is installing dependencies
|
||||
if (!dependencyInstallationDetected && installPatterns.some(pattern =>
|
||||
if (!dependencyInstallationDetected && installPatterns.some(pattern =>
|
||||
msg.includes(pattern) && !msg.includes("Uvicorn running on")
|
||||
)) {
|
||||
dependencyInstallationDetected = true;
|
||||
log.info('[BACKEND STARTUP] UV dependency installation detected during uvicorn startup');
|
||||
|
||||
|
||||
// Create installing lock file to maintain consistency with install-deps.ts
|
||||
InstallLogs.setLockPath();
|
||||
log.info('[BACKEND STARTUP] Created uv_installing.lock file');
|
||||
|
||||
|
||||
// Notify frontend that installation has started (only once)
|
||||
if (!installationNotificationSent) {
|
||||
installationNotificationSent = true;
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ export class WebViewManager {
|
|||
}
|
||||
const view = new WebContentsView({
|
||||
webPreferences: {
|
||||
// Use a separate session partition for webviews to isolate storage from main window
|
||||
// This ensures clearing webview storage won't affect main window's auth data
|
||||
partition: 'persist:user_login',
|
||||
nodeIntegration: false,
|
||||
contextIsolation: true,
|
||||
backgroundThrottling: true,
|
||||
|
|
@ -269,10 +272,11 @@ export class WebViewManager {
|
|||
|
||||
if (!webViewInfo.view.webContents.isDestroyed()) {
|
||||
webViewInfo.view.webContents.removeAllListeners()
|
||||
// DO NOT clear storage data here!
|
||||
// Multiple webviews share the same partition 'persist:user_login'
|
||||
// Clearing storage would affect ALL webviews and remove login cookies
|
||||
// Only clear cache which is per-webContents
|
||||
webViewInfo.view.webContents.session.clearCache()
|
||||
webViewInfo.view.webContents.session.clearStorageData({
|
||||
storages: ['cookies', 'localstorage', 'websql', 'indexdb', 'serviceworkers', 'cachestorage']
|
||||
})
|
||||
}
|
||||
|
||||
// remove webview from parent container
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
|||
ipcRenderer.removeAllListeners(channel);
|
||||
},
|
||||
getEmailFolderPath: (email: string) => ipcRenderer.invoke('get-email-folder-path', email),
|
||||
restartApp: () => ipcRenderer.invoke('restart-app'),
|
||||
});
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@
|
|||
"type": "module",
|
||||
"scripts": {
|
||||
"compile-babel": "cd backend && uv run pybabel compile -d lang",
|
||||
"dev": "npm run compile-babel && vite",
|
||||
"clean-cache": "rimraf node_modules/.vite",
|
||||
"dev": "npm run clean-cache && npm run compile-babel && vite",
|
||||
"build": "npm run compile-babel && tsc && vite build && electron-builder -- --publish always",
|
||||
"build:mac": "npm run compile-babel && tsc && vite build && electron-builder --mac",
|
||||
"build:win": "npm run compile-babel && tsc && vite build && electron-builder --win",
|
||||
|
|
@ -30,6 +31,7 @@
|
|||
"dependencies": {
|
||||
"@electron/notarize": "^2.5.0",
|
||||
"@fontsource/inter": "^5.2.5",
|
||||
"@gsap/react": "^2.1.2",
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@monaco-editor/loader": "^1.5.0",
|
||||
"@monaco-editor/react": "^4.7.0",
|
||||
|
|
@ -71,8 +73,10 @@
|
|||
"lucide-react": "^0.509.0",
|
||||
"mammoth": "^1.9.1",
|
||||
"monaco-editor": "^0.52.2",
|
||||
"motion": "^12.23.24",
|
||||
"next-themes": "^0.4.6",
|
||||
"papaparse": "^5.5.3",
|
||||
"postprocessing": "^6.37.8",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-resizable-panels": "^3.0.4",
|
||||
"react-router-dom": "^7.6.0",
|
||||
|
|
@ -81,6 +85,7 @@
|
|||
"tailwind-merge": "^3.3.0",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"tar": "^7.4.3",
|
||||
"three": "^0.180.0",
|
||||
"tree-kill": "^1.2.2",
|
||||
"tw-animate-css": "^1.2.9",
|
||||
"unzipper": "^0.12.3",
|
||||
|
|
@ -112,6 +117,7 @@
|
|||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-i18next": "^15.7.3",
|
||||
"rimraf": "^6.0.1",
|
||||
"tailwindcss": "^3.4.15",
|
||||
"typescript": "^5.4.2",
|
||||
"vite": "^5.4.11",
|
||||
|
|
|
|||
|
|
@ -5,3 +5,5 @@ database_url=postgresql://postgres:postgres@localhost:5432/postgres
|
|||
# Chat Share Secret Key
|
||||
CHAT_SHARE_SECRET_KEY=put-your-secret-key-here
|
||||
CHAT_SHARE_SALT=put-your-encode-salt-here
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Use a Python image with uv pre-installed
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install the project into `/app`
|
||||
WORKDIR /app
|
||||
|
|
@ -15,9 +15,13 @@ ENV UV_PYTHON_INSTALL_MIRROR=https://registry.npmmirror.com/-/binary/python-buil
|
|||
ARG database_url
|
||||
ENV database_url=$database_url
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy dependency files first
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY server/pyproject.toml server/uv.lock ./
|
||||
|
||||
# Install the project's dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
|
@ -25,7 +29,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
|
||||
# Then, add the rest of the project source code and install it
|
||||
# Installing separately from its dependencies allows optimal layer caching
|
||||
COPY . /app
|
||||
COPY server/ /app
|
||||
|
||||
# Copy the utils directory from the parent project
|
||||
COPY utils /app/utils
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --no-dev
|
||||
|
||||
|
|
@ -41,7 +49,7 @@ RUN apt-get update && apt-get install -y curl netcat-openbsd && rm -rf /var/lib/
|
|||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Copy and make the start script executable
|
||||
COPY start.sh /app/start.sh
|
||||
COPY server/start.sh /app/start.sh
|
||||
RUN sed -i 's/\r$//' /app/start.sh && chmod +x /app/start.sh
|
||||
|
||||
# Reset the entrypoint, don't invoke `uv`
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from logging.config import fileConfig
|
||||
import sys
|
||||
import pathlib
|
||||
|
||||
# Add project root to Python path to import shared utils
|
||||
_project_root = pathlib.Path(__file__).parent.parent.parent
|
||||
if str(_project_root) not in sys.path:
|
||||
sys.path.insert(0, str(_project_root))
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from alembic import context
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
"""modify_chat_history_add_project_id
|
||||
|
||||
Revision ID: eec7242b3a9b
|
||||
Revises: d74ab2a44600
|
||||
Create Date: 2025-10-15 14:46:47.904254
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel.sql.sqltypes
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "eec7242b3a9b"
|
||||
down_revision: Union[str, None] = "d74ab2a44600"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("chat_history", sa.Column("project_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
op.create_index(op.f("ix_chat_history_project_id"), "chat_history", ["project_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_chat_history_project_id"), table_name="chat_history")
|
||||
op.drop_column("chat_history", "project_id")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi_pagination import add_pagination
|
||||
|
||||
|
||||
api = FastAPI(swagger_ui_parameters={"persistAuthorization": True})
|
||||
add_pagination(api)
|
||||
|
|
|
|||
|
|
@ -3,50 +3,110 @@ from fastapi_pagination import Page
|
|||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from app.model.chat.chat_history import ChatHistoryOut, ChatHistoryIn, ChatHistory, ChatHistoryUpdate
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session, select, desc
|
||||
from sqlmodel import Session, select, desc, case
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_chat_history")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat History"])
|
||||
|
||||
|
||||
@router.post("/history", name="save chat history", response_model=ChatHistoryOut)
|
||||
@traceroot.trace()
|
||||
def create_chat_history(data: ChatHistoryIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
data.user_id = auth.user.id
|
||||
chat_history = ChatHistory(**data.model_dump())
|
||||
session.add(chat_history)
|
||||
session.commit()
|
||||
session.refresh(chat_history)
|
||||
return chat_history
|
||||
"""Save new chat history."""
|
||||
user_id = auth.user.id
|
||||
|
||||
try:
|
||||
data.user_id = user_id
|
||||
chat_history = ChatHistory(**data.model_dump())
|
||||
session.add(chat_history)
|
||||
session.commit()
|
||||
session.refresh(chat_history)
|
||||
logger.info("Chat history created", extra={"user_id": user_id, "history_id": chat_history.id, "task_id": data.task_id})
|
||||
return chat_history
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Chat history creation failed", extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/histories", name="get chat history")
|
||||
@traceroot.trace()
|
||||
def list_chat_history(session: Session = Depends(session), auth: Auth = Depends(auth_must)) -> Page[ChatHistoryOut]:
|
||||
stmt = select(ChatHistory).where(ChatHistory.user_id == auth.user.id).order_by(desc(ChatHistory.created_at))
|
||||
return paginate(session, stmt)
|
||||
"""List chat histories for current user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Order by created_at descending, but fallback to id descending for old records without timestamps
|
||||
# This ensures newer records with timestamps come first, followed by old records ordered by id
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == user_id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), # Non-null created_at first
|
||||
desc(ChatHistory.created_at), # Then by created_at descending
|
||||
desc(ChatHistory.id) # Finally by id descending for records with same/null created_at
|
||||
)
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, 'total') else 0
|
||||
logger.debug("Chat histories listed", extra={"user_id": user_id, "total": total})
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/history/{history_id}", name="delete chat history")
|
||||
def delete_chat_history(history_id: str, session: Session = Depends(session)):
|
||||
@traceroot.trace()
|
||||
def delete_chat_history(history_id: str, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete chat history."""
|
||||
user_id = auth.user.id
|
||||
history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
|
||||
|
||||
if not history:
|
||||
raise HTTPException(status_code=404, detail="Caht History not found")
|
||||
session.delete(history)
|
||||
session.commit()
|
||||
return Response(status_code=204)
|
||||
logger.warning("Chat history not found for deletion", extra={"user_id": user_id, "history_id": history_id})
|
||||
raise HTTPException(status_code=404, detail="Chat History not found")
|
||||
|
||||
if history.user_id != user_id:
|
||||
logger.warning("Unauthorized deletion attempt", extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id})
|
||||
raise HTTPException(status_code=403, detail="You are not allowed to delete this chat history")
|
||||
|
||||
try:
|
||||
session.delete(history)
|
||||
session.commit()
|
||||
logger.info("Chat history deleted", extra={"user_id": user_id, "history_id": history_id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Chat history deletion failed", extra={"user_id": user_id, "history_id": history_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut)
|
||||
@traceroot.trace()
|
||||
def update_chat_history(
|
||||
history_id: int, data: ChatHistoryUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update chat history."""
|
||||
user_id = auth.user.id
|
||||
history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
|
||||
|
||||
if not history:
|
||||
logger.warning("Chat history not found for update", extra={"user_id": user_id, "history_id": history_id})
|
||||
raise HTTPException(status_code=404, detail="Chat History not found")
|
||||
if history.user_id != auth.user.id:
|
||||
|
||||
if history.user_id != user_id:
|
||||
logger.warning("Unauthorized update attempt", extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id})
|
||||
raise HTTPException(status_code=403, detail="You are not allowed to update this chat history")
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
history.update_fields(update_data)
|
||||
history.save(session)
|
||||
session.refresh(history)
|
||||
return history
|
||||
|
||||
try:
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
history.update_fields(update_data)
|
||||
history.save(session)
|
||||
session.refresh(history)
|
||||
logger.info("Chat history updated", extra={"user_id": user_id, "history_id": history_id, "fields_updated": list(update_data.keys())})
|
||||
return history
|
||||
except Exception as e:
|
||||
logger.error("Chat history update failed", extra={"user_id": user_id, "history_id": history_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,78 +1,107 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from sqlmodel import Session, asc, select
|
||||
from app.component.database import session
|
||||
import json
|
||||
import asyncio
|
||||
from itsdangerous import SignatureExpired, BadTimeSignature
|
||||
from starlette.responses import StreamingResponse
|
||||
from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn
|
||||
from app.model.chat.chat_step import ChatStep
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Share"])
|
||||
|
||||
|
||||
@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut)
|
||||
def get_share_info(token: str, session: Session = Depends(session)):
|
||||
"""
|
||||
Get shared chat history info by token, excluding sensitive data.
|
||||
"""
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except (SignatureExpired, BadTimeSignature):
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
|
||||
|
||||
stmt = select(ChatHistory).where(ChatHistory.task_id == task_id)
|
||||
history = session.exec(stmt).one_or_none()
|
||||
|
||||
if not history:
|
||||
raise HTTPException(status_code=404, detail="Chat history not found.")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
@router.get("/share/playback/{token}", name="Playback shared chat via SSE")
|
||||
async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0):
|
||||
"""
|
||||
Playbacks the chat history via a sharing token (SSE).
|
||||
delay_time: control sse interval, max 5 seconds
|
||||
"""
|
||||
if delay_time > 5:
|
||||
delay_time = 5
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Share link has expired.")
|
||||
except BadTimeSignature:
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid.")
|
||||
|
||||
async def event_generator():
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
|
||||
steps = session.exec(stmt).all()
|
||||
|
||||
if not steps:
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
|
||||
for step in steps:
|
||||
step_data = {
|
||||
"id": step.id,
|
||||
"task_id": step.task_id,
|
||||
"step": step.step,
|
||||
"data": step.data,
|
||||
"created_at": step.created_at.isoformat() if step.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
if delay_time > 0 and step.step != "create_agent":
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
|
||||
def create_share_link(data: ChatShareIn):
|
||||
"""
|
||||
Generates a sharing token with an expiration time for the specified task_id.
|
||||
"""
|
||||
share_token = ChatShare.generate_token(data.task_id)
|
||||
return {"share_token": share_token}
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from sqlmodel import Session, asc, select
|
||||
from app.component.database import session
|
||||
import json
|
||||
import asyncio
|
||||
from itsdangerous import SignatureExpired, BadTimeSignature
|
||||
from starlette.responses import StreamingResponse
|
||||
from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn
|
||||
from app.model.chat.chat_step import ChatStep
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_chat_share")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Share"])
|
||||
|
||||
|
||||
@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut)
|
||||
@traceroot.trace()
|
||||
def get_share_info(token: str, session: Session = Depends(session)):
|
||||
"""
|
||||
Get shared chat history info by token, excluding sensitive data.
|
||||
"""
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except SignatureExpired:
|
||||
logger.warning("Shared chat access failed: token expired", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
|
||||
except BadTimeSignature:
|
||||
logger.warning("Shared chat access failed: invalid token", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
|
||||
|
||||
stmt = select(ChatHistory).where(ChatHistory.task_id == task_id)
|
||||
history = session.exec(stmt).one_or_none()
|
||||
|
||||
if not history:
|
||||
logger.warning("Shared chat not found", extra={"task_id": task_id})
|
||||
raise HTTPException(status_code=404, detail="Chat history not found.")
|
||||
|
||||
logger.info("Shared chat info accessed", extra={"task_id": task_id})
|
||||
return history
|
||||
|
||||
|
||||
@router.get("/share/playback/{token}", name="Playback shared chat via SSE")
|
||||
@traceroot.trace()
|
||||
async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0):
|
||||
"""
|
||||
Playbacks the chat history via a sharing token (SSE).
|
||||
delay_time: control sse interval, max 5 seconds
|
||||
"""
|
||||
if delay_time > 5:
|
||||
logger.debug("Delay time capped", extra={"requested": delay_time, "capped": 5})
|
||||
delay_time = 5
|
||||
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except SignatureExpired:
|
||||
logger.warning("Shared chat playback failed: token expired", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link has expired.")
|
||||
except BadTimeSignature:
|
||||
logger.warning("Shared chat playback failed: invalid token", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid.")
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
|
||||
steps = session.exec(stmt).all()
|
||||
|
||||
if not steps:
|
||||
logger.warning("No steps found for playback", extra={"task_id": task_id})
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
|
||||
logger.info("Shared chat playback started", extra={"task_id": task_id, "step_count": len(steps), "delay_time": delay_time})
|
||||
|
||||
for idx, step in enumerate(steps, start=1):
|
||||
step_data = {
|
||||
"id": step.id,
|
||||
"task_id": step.task_id,
|
||||
"step": step.step,
|
||||
"data": step.data,
|
||||
"created_at": step.created_at.isoformat() if step.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
|
||||
if delay_time > 0 and step.step != "create_agent":
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
logger.info("Shared chat playback completed", extra={"task_id": task_id, "step_count": len(steps)})
|
||||
except Exception as e:
|
||||
logger.error("Shared chat playback error", extra={"task_id": task_id, "error": str(e)}, exc_info=True)
|
||||
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
|
||||
@traceroot.trace()
|
||||
def create_share_link(data: ChatShareIn):
|
||||
"""Generate sharing token with 1-day expiration for task."""
|
||||
try:
|
||||
share_token = ChatShare.generate_token(data.task_id)
|
||||
logger.info("Share link created", extra={"task_id": data.task_id, "token_prefix": share_token[:10]})
|
||||
return {"share_token": share_token}
|
||||
except Exception as e:
|
||||
logger.error("Share link creation failed", extra={"task_id": data.task_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,81 +1,138 @@
|
|||
from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Response, APIRouter
|
||||
from sqlmodel import Session, select
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"])
|
||||
|
||||
|
||||
@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot])
|
||||
async def list_chat_snapshots(
|
||||
api_task_id: Optional[str] = None,
|
||||
camel_task_id: Optional[str] = None,
|
||||
browser_url: Optional[str] = None,
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
query = select(ChatSnapshot)
|
||||
if api_task_id is not None:
|
||||
query = query.where(ChatSnapshot.api_task_id == api_task_id)
|
||||
if camel_task_id is not None:
|
||||
query = query.where(ChatSnapshot.camel_task_id == camel_task_id)
|
||||
if browser_url is not None:
|
||||
query = query.where(ChatSnapshot.browser_url == browser_url)
|
||||
snapshots = session.exec(query).all()
|
||||
return snapshots
|
||||
|
||||
|
||||
@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot)
|
||||
async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot)
|
||||
async def create_chat_snapshot(
|
||||
snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session)
|
||||
):
|
||||
image_path = ChatSnapshotIn.save_image(auth.user.id, snapshot.api_task_id, snapshot.image_base64)
|
||||
chat_snapshot = ChatSnapshot(
|
||||
user_id=auth.user.id,
|
||||
api_task_id=snapshot.api_task_id,
|
||||
camel_task_id=snapshot.camel_task_id,
|
||||
browser_url=snapshot.browser_url,
|
||||
image_path=image_path,
|
||||
)
|
||||
session.add(chat_snapshot)
|
||||
session.commit()
|
||||
session.refresh(chat_snapshot)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot)
|
||||
async def update_chat_snapshot(
|
||||
snapshot_id: int,
|
||||
snapshot_update: ChatSnapshot,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
db_snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
if not db_snapshot:
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
for key, value in snapshot_update.dict(exclude_unset=True).items():
|
||||
setattr(db_snapshot, key, value)
|
||||
session.add(db_snapshot)
|
||||
session.commit()
|
||||
session.refresh(db_snapshot)
|
||||
return db_snapshot
|
||||
|
||||
|
||||
@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot")
|
||||
async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
db_snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
if not db_snapshot:
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
session.delete(db_snapshot)
|
||||
session.commit()
|
||||
return Response(status_code=204)
|
||||
from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Response, APIRouter
|
||||
from sqlmodel import Session, select
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_chat_snapshot")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"])
|
||||
|
||||
|
||||
@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot])
|
||||
@traceroot.trace()
|
||||
async def list_chat_snapshots(
|
||||
api_task_id: Optional[str] = None,
|
||||
camel_task_id: Optional[str] = None,
|
||||
browser_url: Optional[str] = None,
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
"""List chat snapshots with optional filtering."""
|
||||
query = select(ChatSnapshot)
|
||||
if api_task_id is not None:
|
||||
query = query.where(ChatSnapshot.api_task_id == api_task_id)
|
||||
if camel_task_id is not None:
|
||||
query = query.where(ChatSnapshot.camel_task_id == camel_task_id)
|
||||
if browser_url is not None:
|
||||
query = query.where(ChatSnapshot.browser_url == browser_url)
|
||||
|
||||
snapshots = session.exec(query).all()
|
||||
logger.debug("Snapshots listed", extra={"api_task_id": api_task_id, "camel_task_id": camel_task_id, "count": len(snapshots)})
|
||||
return snapshots
|
||||
|
||||
|
||||
@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot)
|
||||
@traceroot.trace()
|
||||
async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get specific chat snapshot."""
|
||||
user_id = auth.user.id
|
||||
snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
|
||||
if not snapshot:
|
||||
logger.warning("Snapshot not found", extra={"user_id": user_id, "snapshot_id": snapshot_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
|
||||
logger.debug("Snapshot retrieved", extra={"user_id": user_id, "snapshot_id": snapshot_id, "api_task_id": snapshot.api_task_id})
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot)
|
||||
@traceroot.trace()
|
||||
async def create_chat_snapshot(
|
||||
snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session)
|
||||
):
|
||||
"""Create new chat snapshot from image."""
|
||||
user_id = auth.user.id
|
||||
|
||||
try:
|
||||
image_path = ChatSnapshotIn.save_image(user_id, snapshot.api_task_id, snapshot.image_base64)
|
||||
chat_snapshot = ChatSnapshot(
|
||||
user_id=user_id,
|
||||
api_task_id=snapshot.api_task_id,
|
||||
camel_task_id=snapshot.camel_task_id,
|
||||
browser_url=snapshot.browser_url,
|
||||
image_path=image_path,
|
||||
)
|
||||
session.add(chat_snapshot)
|
||||
session.commit()
|
||||
session.refresh(chat_snapshot)
|
||||
logger.info("Snapshot created", extra={"user_id": user_id, "snapshot_id": chat_snapshot.id, "api_task_id": snapshot.api_task_id, "image_path": image_path})
|
||||
return chat_snapshot
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Snapshot creation failed", extra={"user_id": user_id, "api_task_id": snapshot.api_task_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot)
|
||||
@traceroot.trace()
|
||||
async def update_chat_snapshot(
|
||||
snapshot_id: int,
|
||||
snapshot_update: ChatSnapshot,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""Update chat snapshot."""
|
||||
user_id = auth.user.id
|
||||
db_snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
|
||||
if not db_snapshot:
|
||||
logger.warning("Snapshot not found for update", extra={"user_id": user_id, "snapshot_id": snapshot_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
|
||||
if db_snapshot.user_id != user_id:
|
||||
logger.warning("Unauthorized snapshot update", extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id})
|
||||
raise HTTPException(status_code=403, detail=_("You are not allowed to update this snapshot"))
|
||||
|
||||
try:
|
||||
update_data = snapshot_update.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_snapshot, key, value)
|
||||
session.add(db_snapshot)
|
||||
session.commit()
|
||||
session.refresh(db_snapshot)
|
||||
logger.info("Snapshot updated", extra={"user_id": user_id, "snapshot_id": snapshot_id, "fields_updated": list(update_data.keys())})
|
||||
return db_snapshot
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Snapshot update failed", extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot")
|
||||
@traceroot.trace()
|
||||
async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete chat snapshot."""
|
||||
user_id = auth.user.id
|
||||
db_snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
|
||||
if not db_snapshot:
|
||||
logger.warning("Snapshot not found for deletion", extra={"user_id": user_id, "snapshot_id": snapshot_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
|
||||
if db_snapshot.user_id != user_id:
|
||||
logger.warning("Unauthorized snapshot deletion", extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id})
|
||||
raise HTTPException(status_code=403, detail=_("You are not allowed to delete this snapshot"))
|
||||
|
||||
try:
|
||||
session.delete(db_snapshot)
|
||||
session.commit()
|
||||
logger.info("Snapshot deleted", extra={"user_id": user_id, "snapshot_id": snapshot_id, "image_path": db_snapshot.image_path})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Snapshot deletion failed", extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,105 +1,163 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlmodel import Session, asc, select
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Step Management"])
|
||||
|
||||
|
||||
@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut])
|
||||
async def list_chat_steps(
|
||||
task_id: str, step: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
query = select(ChatStep)
|
||||
if task_id is not None:
|
||||
query = query.where(ChatStep.task_id == task_id)
|
||||
if step is not None:
|
||||
query = query.where(ChatStep.step == step)
|
||||
chat_steps = session.exec(query).all()
|
||||
return chat_steps
|
||||
|
||||
|
||||
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
|
||||
async def share_playback(
|
||||
task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""
|
||||
Playbacks the chat steps (SSE).
|
||||
"""
|
||||
if delay_time > 5:
|
||||
delay_time = 5
|
||||
|
||||
async def event_generator():
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
|
||||
steps = session.exec(stmt).all()
|
||||
|
||||
if not steps:
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
|
||||
for step in steps:
|
||||
step_data = {
|
||||
"id": step.id,
|
||||
"task_id": step.task_id,
|
||||
"step": step.step,
|
||||
"data": step.data,
|
||||
"created_at": step.created_at.isoformat() if step.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
if delay_time > 0:
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
|
||||
async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
chat_step = session.get(ChatStep, step_id)
|
||||
if not chat_step:
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
return chat_step
|
||||
|
||||
|
||||
@router.post("/steps", name="create chat step")
|
||||
# TODO Limit request sources
|
||||
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
|
||||
chat_step = ChatStep(
|
||||
task_id=step.task_id,
|
||||
step=step.step,
|
||||
data=step.data,
|
||||
)
|
||||
session.add(chat_step)
|
||||
session.commit()
|
||||
session.refresh(chat_step)
|
||||
return {"code": 200, "msg": "success"}
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
|
||||
async def update_chat_step(
|
||||
step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
db_chat_step = session.get(ChatStep, step_id)
|
||||
if not db_chat_step:
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
for key, value in chat_step_update.dict(exclude_unset=True).items():
|
||||
setattr(db_chat_step, key, value)
|
||||
session.add(db_chat_step)
|
||||
session.commit()
|
||||
session.refresh(db_chat_step)
|
||||
return db_chat_step
|
||||
|
||||
|
||||
@router.delete("/steps/{step_id}", name="delete chat step")
|
||||
async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
db_chat_step = session.get(ChatStep, step_id)
|
||||
if not db_chat_step:
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
session.delete(db_chat_step)
|
||||
session.commit()
|
||||
return Response(status_code=204)
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlmodel import Session, asc, select
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_chat_step")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Step Management"])
|
||||
|
||||
|
||||
@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut])
|
||||
@traceroot.trace()
|
||||
async def list_chat_steps(
|
||||
task_id: str, step: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""List chat steps for a task with optional step type filtering."""
|
||||
user_id = auth.user.id
|
||||
query = select(ChatStep)
|
||||
if task_id is not None:
|
||||
query = query.where(ChatStep.task_id == task_id)
|
||||
if step is not None:
|
||||
query = query.where(ChatStep.step == step)
|
||||
|
||||
chat_steps = session.exec(query).all()
|
||||
logger.debug("Chat steps listed", extra={"user_id": user_id, "task_id": task_id, "step_type": step, "count": len(chat_steps)})
|
||||
return chat_steps
|
||||
|
||||
|
||||
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
|
||||
@traceroot.trace()
|
||||
async def share_playback(
|
||||
task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Playback chat steps via SSE stream."""
|
||||
user_id = auth.user.id
|
||||
if delay_time > 5:
|
||||
logger.debug("Delay time capped", extra={"user_id": user_id, "task_id": task_id, "requested": delay_time, "capped": 5})
|
||||
delay_time = 5
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
|
||||
steps = session.exec(stmt).all()
|
||||
|
||||
if not steps:
|
||||
logger.warning("No steps found for playback", extra={"user_id": user_id, "task_id": task_id})
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
|
||||
logger.info("Chat step playback started", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps), "delay_time": delay_time})
|
||||
|
||||
for step in steps:
|
||||
step_data = {
|
||||
"id": step.id,
|
||||
"task_id": step.task_id,
|
||||
"step": step.step,
|
||||
"data": step.data,
|
||||
"created_at": step.created_at.isoformat() if step.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
if delay_time > 0:
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
|
||||
logger.info("Chat step playback completed", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps)})
|
||||
except Exception as e:
|
||||
logger.error("Chat step playback error", extra={"user_id": user_id, "task_id": task_id, "error": str(e)}, exc_info=True)
|
||||
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
|
||||
@traceroot.trace()
|
||||
async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get specific chat step."""
|
||||
user_id = auth.user.id
|
||||
chat_step = session.get(ChatStep, step_id)
|
||||
|
||||
if not chat_step:
|
||||
logger.warning("Chat step not found", extra={"user_id": user_id, "step_id": step_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
|
||||
logger.debug("Chat step retrieved", extra={"user_id": user_id, "step_id": step_id, "task_id": chat_step.task_id})
|
||||
return chat_step
|
||||
|
||||
|
||||
@router.post("/steps", name="create chat step")
|
||||
@traceroot.trace()
|
||||
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
|
||||
"""Create new chat step. TODO: Implement request source validation."""
|
||||
try:
|
||||
chat_step = ChatStep(
|
||||
task_id=step.task_id,
|
||||
step=step.step,
|
||||
data=step.data,
|
||||
)
|
||||
session.add(chat_step)
|
||||
session.commit()
|
||||
session.refresh(chat_step)
|
||||
logger.info("Chat step created", extra={"step_id": chat_step.id, "task_id": step.task_id, "step_type": step.step})
|
||||
return {"code": 200, "msg": "success"}
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Chat step creation failed", extra={"task_id": step.task_id, "step_type": step.step, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
|
||||
@traceroot.trace()
|
||||
async def update_chat_step(
|
||||
step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update chat step."""
|
||||
user_id = auth.user.id
|
||||
db_chat_step = session.get(ChatStep, step_id)
|
||||
|
||||
if not db_chat_step:
|
||||
logger.warning("Chat step not found for update", extra={"user_id": user_id, "step_id": step_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
|
||||
try:
|
||||
update_data = chat_step_update.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_chat_step, key, value)
|
||||
session.add(db_chat_step)
|
||||
session.commit()
|
||||
session.refresh(db_chat_step)
|
||||
logger.info("Chat step updated", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id, "fields_updated": list(update_data.keys())})
|
||||
return db_chat_step
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Chat step update failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/steps/{step_id}", name="delete chat step")
|
||||
@traceroot.trace()
|
||||
async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete chat step."""
|
||||
user_id = auth.user.id
|
||||
db_chat_step = session.get(ChatStep, step_id)
|
||||
|
||||
if not db_chat_step:
|
||||
logger.warning("Chat step not found for deletion", extra={"user_id": user_id, "step_id": step_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
|
||||
try:
|
||||
session.delete(db_chat_step)
|
||||
session.commit()
|
||||
logger.info("Chat step deleted", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Chat step deletion failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,121 +1,172 @@
|
|||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from sqlmodel import Session, select, or_
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from app.model.config.config import Config, ConfigCreate, ConfigUpdate, ConfigInfo, ConfigOut
|
||||
|
||||
router = APIRouter(tags=["Config Management"])
|
||||
|
||||
|
||||
@router.get("/configs", name="list configs", response_model=list[ConfigOut])
|
||||
async def list_configs(
|
||||
config_group: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
query = select(Config)
|
||||
user_id = auth.user.id
|
||||
if user_id is not None:
|
||||
query = query.where(Config.user_id == user_id)
|
||||
if config_group is not None:
|
||||
query = query.where(Config.config_group == config_group)
|
||||
configs = session.exec(query).all()
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut)
|
||||
async def get_config(
|
||||
config_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
query = select(Config).where(Config.user_id == auth.user.id)
|
||||
|
||||
if config_id is not None:
|
||||
query = query.where(Config.id == config_id)
|
||||
|
||||
config = session.exec(query).first()
|
||||
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
return config
|
||||
|
||||
|
||||
@router.post("/configs", name="create config", response_model=ConfigOut)
|
||||
async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name):
|
||||
raise HTTPException(status_code=400, detail=_("Config Name is valid"))
|
||||
|
||||
# Check if configuration already exists
|
||||
existing_config = session.exec(
|
||||
select(Config).where(Config.user_id == auth.user.id, Config.config_name == config.config_name)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
|
||||
|
||||
db_config = Config(
|
||||
user_id=auth.user.id,
|
||||
config_name=config.config_name,
|
||||
config_value=config.config_value,
|
||||
config_group=config.config_group,
|
||||
)
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
|
||||
@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut)
|
||||
async def update_config(
|
||||
config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == auth.user.id)).first()
|
||||
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
# Check if configuration group is valid
|
||||
if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name):
|
||||
raise HTTPException(status_code=400, detail=_("Invalid configuration group"))
|
||||
|
||||
# Check for conflicts with other configurations
|
||||
existing_config = session.exec(
|
||||
select(Config).where(
|
||||
Config.user_id == auth.user.id,
|
||||
Config.config_name == config_update.config_name,
|
||||
Config.id != config_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
|
||||
|
||||
db_config.config_name = config_update.config_name
|
||||
db_config.config_value = config_update.config_value
|
||||
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
session.refresh(db_config)
|
||||
return db_config
|
||||
|
||||
|
||||
@router.delete("/configs/{config_id}", name="delete config")
|
||||
async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == auth.user.id)).first()
|
||||
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
session.delete(db_config)
|
||||
session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.get("/config/info", name="get config info")
|
||||
async def get_config_info(
|
||||
show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"),
|
||||
):
|
||||
configs = ConfigInfo.getinfo()
|
||||
if show_all:
|
||||
return configs
|
||||
return {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0}
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from sqlmodel import Session, select, or_
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from app.model.config.config import Config, ConfigCreate, ConfigUpdate, ConfigInfo, ConfigOut
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_config_controller")
|
||||
|
||||
router = APIRouter(tags=["Config Management"])
|
||||
|
||||
|
||||
@router.get("/configs", name="list configs", response_model=list[ConfigOut])
|
||||
@traceroot.trace()
|
||||
async def list_configs(
|
||||
config_group: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""List user's configurations with optional group filtering."""
|
||||
user_id = auth.user.id
|
||||
query = select(Config).where(Config.user_id == user_id)
|
||||
|
||||
if config_group is not None:
|
||||
query = query.where(Config.config_group == config_group)
|
||||
|
||||
configs = session.exec(query).all()
|
||||
logger.debug("Configs listed", extra={"user_id": user_id, "config_group": config_group, "count": len(configs)})
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut)
|
||||
@traceroot.trace()
|
||||
async def get_config(
|
||||
config_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
query = select(Config).where(Config.user_id == auth.user.id)
|
||||
|
||||
if config_id is not None:
|
||||
query = query.where(Config.id == config_id)
|
||||
|
||||
config = session.exec(query).first()
|
||||
|
||||
if not config:
|
||||
logger.warning("Config not found")
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
logger.debug("Config retrieved")
|
||||
return config
|
||||
|
||||
|
||||
@router.post("/configs", name="create config", response_model=ConfigOut)
|
||||
@traceroot.trace()
|
||||
async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Create new configuration."""
|
||||
user_id = auth.user.id
|
||||
|
||||
if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name):
|
||||
logger.warning("Config validation failed", extra={"user_id": user_id, "config_group": config.config_group, "config_name": config.config_name})
|
||||
raise HTTPException(status_code=400, detail=_("Invalid config name or group"))
|
||||
|
||||
# Check if configuration already exists
|
||||
existing_config = session.exec(
|
||||
select(Config).where(Config.user_id == user_id, Config.config_name == config.config_name)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning("Config creation failed: already exists", extra={"user_id": user_id, "config_name": config.config_name})
|
||||
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
|
||||
|
||||
try:
|
||||
db_config = Config(
|
||||
user_id=user_id,
|
||||
config_name=config.config_name,
|
||||
config_value=config.config_value,
|
||||
config_group=config.config_group,
|
||||
)
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
session.refresh(db_config)
|
||||
logger.info("Config created", extra={"user_id": user_id, "config_id": db_config.id, "config_group": config.config_group, "config_name": config.config_name})
|
||||
return db_config
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Config creation failed", extra={"user_id": user_id, "config_name": config.config_name, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut)
|
||||
@traceroot.trace()
|
||||
async def update_config(
|
||||
config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update configuration."""
|
||||
user_id = auth.user.id
|
||||
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first()
|
||||
|
||||
if not db_config:
|
||||
logger.warning("Config not found for update", extra={"user_id": user_id, "config_id": config_id})
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
# Check if configuration group is valid
|
||||
if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name):
|
||||
logger.warning("Config update validation failed", extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group})
|
||||
raise HTTPException(status_code=400, detail=_("Invalid configuration group"))
|
||||
|
||||
# Check for conflicts with other configurations
|
||||
existing_config = session.exec(
|
||||
select(Config).where(
|
||||
Config.user_id == user_id,
|
||||
Config.config_name == config_update.config_name,
|
||||
Config.id != config_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning("Config update failed: duplicate name", extra={"user_id": user_id, "config_id": config_id, "config_name": config_update.config_name})
|
||||
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
|
||||
|
||||
try:
|
||||
db_config.config_name = config_update.config_name
|
||||
db_config.config_value = config_update.config_value
|
||||
db_config.config_group = config_update.config_group
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
session.refresh(db_config)
|
||||
logger.info("Config updated", extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group})
|
||||
return db_config
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Config update failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/configs/{config_id}", name="delete config")
|
||||
@traceroot.trace()
|
||||
async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete configuration."""
|
||||
user_id = auth.user.id
|
||||
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first()
|
||||
|
||||
if not db_config:
|
||||
logger.warning("Config not found for deletion", extra={"user_id": user_id, "config_id": config_id})
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
try:
|
||||
session.delete(db_config)
|
||||
session.commit()
|
||||
logger.info("Config deleted", extra={"user_id": user_id, "config_id": config_id, "config_name": db_config.config_name})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Config deletion failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/config/info", name="get config info")
|
||||
@traceroot.trace()
|
||||
async def get_config_info(
|
||||
show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"),
|
||||
):
|
||||
"""Get available configuration templates and info."""
|
||||
configs = ConfigInfo.getinfo()
|
||||
if show_all:
|
||||
logger.debug("Config info retrieved", extra={"show_all": True, "count": len(configs)})
|
||||
return configs
|
||||
|
||||
filtered = {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0}
|
||||
logger.debug("Config info retrieved", extra={"show_all": False, "total_count": len(configs), "filtered_count": len(filtered)})
|
||||
return filtered
|
||||
15
server/app/controller/health_controller.py
Normal file
15
server/app/controller/health_controller.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
service: str
|
||||
|
||||
|
||||
@router.get("/health", name="health check", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint for monitoring and container orchestration."""
|
||||
return HealthResponse(status="ok", service="eigent-server")
|
||||
|
|
@ -1,214 +1,262 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
from fastapi import Depends, HTTPException, APIRouter
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy.orm import selectinload, with_loader_criteria
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.mcp.mcp import Mcp, McpOut, McpType
|
||||
from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus
|
||||
from app.model.mcp.mcp_user import McpImportType, McpUser, Status
|
||||
from loguru import logger
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from app.component.environment import env
|
||||
|
||||
from app.component.validator.McpServer import (
|
||||
McpRemoteServer,
|
||||
McpServerItem,
|
||||
validate_mcp_remote_servers,
|
||||
validate_mcp_servers,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Mcp Servers"])
|
||||
|
||||
|
||||
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
|
||||
"""
|
||||
Pre-instantiate MCP toolkit to complete authentication process
|
||||
|
||||
Args:
|
||||
config_dict: MCP server configuration dictionary
|
||||
|
||||
Returns:
|
||||
bool: Whether successfully instantiated and connected
|
||||
"""
|
||||
try:
|
||||
# Ensure unified auth directory for all mcp servers
|
||||
for server_config in config_dict.get("mcpServers", {}).values():
|
||||
if "env" not in server_config:
|
||||
server_config["env"] = {}
|
||||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR",
|
||||
os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
# Create MCP toolkit and attempt to connect
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
# Get tools list to ensure connection is successful
|
||||
tools = mcp_toolkit.get_tools()
|
||||
logger.info(f"Successfully pre-instantiated MCP toolkit with {len(tools)} tools")
|
||||
|
||||
# Disconnect, authentication info is already saved
|
||||
await mcp_toolkit.disconnect()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-instantiate MCP toolkit: {e!r}")
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/mcps", name="mcp list")
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
category_id: int | None = None,
|
||||
mine: int | None = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
) -> Page[McpOut]:
|
||||
stmt = (
|
||||
select(Mcp)
|
||||
.where(Mcp.no_delete())
|
||||
.options(
|
||||
selectinload(Mcp.category),
|
||||
selectinload(Mcp.envs),
|
||||
with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use),
|
||||
)
|
||||
# .order_by(col(Mcp.sort).desc())
|
||||
)
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%"))
|
||||
if category_id:
|
||||
stmt = stmt.where(Mcp.category_id == category_id)
|
||||
if mine and auth:
|
||||
stmt = (
|
||||
stmt.join(McpUser)
|
||||
.where(McpUser.user_id == auth.user.id)
|
||||
.options(
|
||||
selectinload(Mcp.mcp_user),
|
||||
with_loader_criteria(McpUser, col(McpUser.user_id) == auth.user.id),
|
||||
)
|
||||
)
|
||||
return paginate(session, stmt)
|
||||
|
||||
|
||||
@router.get("/mcp", name="mcp detail", response_model=McpOut)
|
||||
async def get(id: int, session: Session = Depends(session)):
|
||||
stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs))
|
||||
model = session.exec(stmt).one()
|
||||
return model
|
||||
|
||||
|
||||
@router.post("/mcp/install", name="mcp install")
|
||||
async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
mcp = session.get_one(Mcp, mcp_id)
|
||||
if not mcp:
|
||||
raise HTTPException(status_code=404, detail=_("Mcp not found"))
|
||||
exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == auth.user.id)).first()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
|
||||
install_command: dict = mcp.install_command
|
||||
|
||||
# Pre-instantiate MCP toolkit for authentication
|
||||
config_dict = {
|
||||
"mcpServers": {
|
||||
mcp.key: install_command
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning(f"Pre-instantiation failed for MCP {mcp.key}, but continuing with installation")
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception during pre-instantiation for MCP {mcp.key}: {e}")
|
||||
|
||||
mcp_user = McpUser(
|
||||
mcp_id=mcp.id,
|
||||
user_id=auth.user.id,
|
||||
mcp_name=mcp.name,
|
||||
mcp_key=mcp.key,
|
||||
mcp_desc=mcp.description,
|
||||
type=mcp.type,
|
||||
status=Status.enable,
|
||||
command=install_command["command"],
|
||||
args=install_command["args"],
|
||||
env=install_command["env"],
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
return mcp_user
|
||||
|
||||
|
||||
@router.post("/mcp/import/{mcp_type}", name="mcp import")
|
||||
async def import_mcp(
|
||||
mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
logger.debug(mcp_type, mcp_type.value)
|
||||
|
||||
if mcp_type == McpImportType.Local:
|
||||
is_valid, res = validate_mcp_servers(mcp_data)
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=res)
|
||||
mcp_data: Dict[str, McpServerItem] = res.mcpServers
|
||||
|
||||
for name, data in mcp_data.items():
|
||||
# Pre-instantiate MCP toolkit for authentication
|
||||
config_dict = {
|
||||
"mcpServers": {
|
||||
name: {
|
||||
"command": data.command,
|
||||
"args": data.args,
|
||||
"env": data.env or {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning(f"Pre-instantiation failed for local MCP {name}, but continuing with installation")
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception during pre-instantiation for local MCP {name}: {e}")
|
||||
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=auth.user.id,
|
||||
mcp_name=name,
|
||||
mcp_key=name,
|
||||
mcp_desc=name,
|
||||
type=McpType.Local,
|
||||
status=Status.enable,
|
||||
command=data.command,
|
||||
args=data.args,
|
||||
env=data.env,
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
return {"message": "Local MCP servers imported successfully", "count": len(mcp_data)}
|
||||
elif mcp_type == McpImportType.Remote:
|
||||
is_valid, res = validate_mcp_remote_servers(mcp_data)
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=res)
|
||||
data: McpRemoteServer = res
|
||||
|
||||
# For remote servers, we don't need to pre-instantiate as they typically don't require authentication
|
||||
# but we can still try to validate the connection if needed
|
||||
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=auth.user.id,
|
||||
type=McpType.Remote,
|
||||
status=Status.enable,
|
||||
mcp_name=data.server_name,
|
||||
server_url=data.server_url,
|
||||
)
|
||||
mcp_user.save()
|
||||
return mcp_user
|
||||
import os
|
||||
from typing import Dict
|
||||
from fastapi import Depends, HTTPException, APIRouter
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy.orm import selectinload, with_loader_criteria
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.mcp.mcp import Mcp, McpOut, McpType
|
||||
from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus
|
||||
from app.model.mcp.mcp_user import McpImportType, McpUser, Status
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from app.component.environment import env
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_mcp_controller")
|
||||
|
||||
from app.component.validator.McpServer import (
|
||||
McpRemoteServer,
|
||||
McpServerItem,
|
||||
validate_mcp_remote_servers,
|
||||
validate_mcp_servers,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Mcp Servers"])
|
||||
|
||||
|
||||
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
|
||||
"""
|
||||
Pre-instantiate MCP toolkit to complete authentication process
|
||||
|
||||
Args:
|
||||
config_dict: MCP server configuration dictionary
|
||||
|
||||
Returns:
|
||||
bool: Whether successfully instantiated and connected
|
||||
"""
|
||||
try:
|
||||
# Ensure unified auth directory for all mcp servers
|
||||
for server_config in config_dict.get("mcpServers", {}).values():
|
||||
if "env" not in server_config:
|
||||
server_config["env"] = {}
|
||||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR",
|
||||
os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
# Create MCP toolkit and attempt to connect
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
# Get tools list to ensure connection is successful
|
||||
tools = mcp_toolkit.get_tools()
|
||||
logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)})
|
||||
|
||||
# Disconnect, authentication info is already saved
|
||||
await mcp_toolkit.disconnect()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/mcps", name="mcp list")
|
||||
@traceroot.trace()
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
category_id: int | None = None,
|
||||
mine: int | None = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
) -> Page[McpOut]:
|
||||
"""List MCP servers with optional filtering."""
|
||||
user_id = auth.user.id
|
||||
stmt = (
|
||||
select(Mcp)
|
||||
.where(Mcp.no_delete())
|
||||
.options(
|
||||
selectinload(Mcp.category),
|
||||
selectinload(Mcp.envs),
|
||||
with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use),
|
||||
)
|
||||
)
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%"))
|
||||
if category_id:
|
||||
stmt = stmt.where(Mcp.category_id == category_id)
|
||||
if mine and auth:
|
||||
stmt = (
|
||||
stmt.join(McpUser)
|
||||
.where(McpUser.user_id == user_id)
|
||||
.options(
|
||||
selectinload(Mcp.mcp_user),
|
||||
with_loader_criteria(McpUser, col(McpUser.user_id) == user_id),
|
||||
)
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, 'total') else 0
|
||||
logger.debug("MCP list retrieved", extra={"user_id": user_id, "keyword": keyword, "category_id": category_id, "mine": mine, "total": total})
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/mcp", name="mcp detail", response_model=McpOut)
|
||||
@traceroot.trace()
|
||||
async def get(id: int, session: Session = Depends(session)):
|
||||
"""Get MCP server details."""
|
||||
try:
|
||||
stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs))
|
||||
model = session.exec(stmt).one()
|
||||
logger.debug("MCP detail retrieved", extra={"mcp_id": id, "mcp_key": model.key})
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.warning("MCP not found", extra={"mcp_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp not found"))
|
||||
|
||||
|
||||
@router.post("/mcp/install", name="mcp install")
|
||||
@traceroot.trace()
|
||||
async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Install MCP server for user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
mcp = session.get_one(Mcp, mcp_id)
|
||||
if not mcp:
|
||||
logger.warning("MCP install failed: MCP not found", extra={"user_id": user_id, "mcp_id": mcp_id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp not found"))
|
||||
|
||||
exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == user_id)).first()
|
||||
if exists:
|
||||
logger.warning("MCP install failed: already installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
|
||||
install_command: dict = mcp.install_command
|
||||
|
||||
# Pre-instantiate MCP toolkit for authentication
|
||||
config_dict = {
|
||||
"mcpServers": {
|
||||
mcp.key: install_command
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning("MCP pre-instantiation failed, continuing with installation", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
|
||||
else:
|
||||
logger.debug("MCP toolkit pre-instantiated", extra={"mcp_key": mcp.key})
|
||||
except Exception as e:
|
||||
logger.warning("MCP pre-instantiation exception", extra={"user_id": user_id, "mcp_key": mcp.key, "error": str(e)}, exc_info=True)
|
||||
|
||||
try:
|
||||
mcp_user = McpUser(
|
||||
mcp_id=mcp.id,
|
||||
user_id=user_id,
|
||||
mcp_name=mcp.name,
|
||||
mcp_key=mcp.key,
|
||||
mcp_desc=mcp.description,
|
||||
type=mcp.type,
|
||||
status=Status.enable,
|
||||
command=install_command["command"],
|
||||
args=install_command["args"],
|
||||
env=install_command["env"],
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
logger.info("MCP installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
|
||||
return mcp_user
|
||||
except Exception as e:
|
||||
logger.error("MCP installation failed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/mcp/import/{mcp_type}", name="mcp import")
|
||||
@traceroot.trace()
|
||||
async def import_mcp(
|
||||
mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Import MCP servers (local or remote)."""
|
||||
user_id = auth.user.id
|
||||
|
||||
if mcp_type == McpImportType.Local:
|
||||
logger.info("Importing local MCP servers", extra={"user_id": user_id})
|
||||
is_valid, res = validate_mcp_servers(mcp_data)
|
||||
if not is_valid:
|
||||
logger.warning("Local MCP import validation failed", extra={"user_id": user_id, "error": res})
|
||||
raise HTTPException(status_code=400, detail=res)
|
||||
|
||||
mcp_data: Dict[str, McpServerItem] = res.mcpServers
|
||||
imported_count = 0
|
||||
|
||||
for name, data in mcp_data.items():
|
||||
config_dict = {
|
||||
"mcpServers": {
|
||||
name: {
|
||||
"command": data.command,
|
||||
"args": data.args,
|
||||
"env": data.env or {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning("Local MCP pre-instantiation failed, continuing", extra={"user_id": user_id, "mcp_name": name})
|
||||
except Exception as e:
|
||||
logger.warning("Local MCP pre-instantiation exception", extra={"user_id": user_id, "mcp_name": name, "error": str(e)})
|
||||
|
||||
try:
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=user_id,
|
||||
mcp_name=name,
|
||||
mcp_key=name,
|
||||
mcp_desc=name,
|
||||
type=McpType.Local,
|
||||
status=Status.enable,
|
||||
command=data.command,
|
||||
args=data.args,
|
||||
env=data.env,
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
imported_count += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to import local MCP", extra={"user_id": user_id, "mcp_name": name, "error": str(e)}, exc_info=True)
|
||||
|
||||
logger.info("Local MCPs imported", extra={"user_id": user_id, "count": imported_count})
|
||||
return {"message": "Local MCP servers imported successfully", "count": imported_count}
|
||||
|
||||
elif mcp_type == McpImportType.Remote:
|
||||
logger.info("Importing remote MCP server", extra={"user_id": user_id})
|
||||
is_valid, res = validate_mcp_remote_servers(mcp_data)
|
||||
if not is_valid:
|
||||
logger.warning("Remote MCP import validation failed", extra={"user_id": user_id, "error": res})
|
||||
raise HTTPException(status_code=400, detail=res)
|
||||
|
||||
data: McpRemoteServer = res
|
||||
|
||||
try:
|
||||
# For remote servers, we don't need to pre-instantiate as they typically don't require authentication
|
||||
# but we can still try to validate the connection if needed
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=user_id,
|
||||
type=McpType.Remote,
|
||||
status=Status.enable,
|
||||
mcp_name=data.server_name,
|
||||
server_url=data.server_url,
|
||||
)
|
||||
mcp_user.save()
|
||||
logger.info("Remote MCP imported", extra={"user_id": user_id, "server_name": data.server_name, "server_url": data.server_url})
|
||||
return mcp_user
|
||||
except Exception as e:
|
||||
logger.error("Remote MCP import failed", extra={"user_id": user_id, "server_name": data.server_name, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,173 +1,196 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from exa_py import Exa
|
||||
from loguru import logger
|
||||
from app.component.auth import key_must
|
||||
from app.component.environment import env_not_empty
|
||||
from app.model.mcp.proxy import ExaSearch
|
||||
from typing import Any, cast
|
||||
import requests
|
||||
|
||||
from app.model.user.key import Key
|
||||
|
||||
|
||||
router = APIRouter(prefix="/proxy", tags=["Mcp Servers"])
|
||||
|
||||
|
||||
@router.post("/exa")
|
||||
def exa_search(search: ExaSearch, key: Key = Depends(key_must)):
|
||||
EXA_API_KEY = env_not_empty("EXA_API_KEY")
|
||||
try:
|
||||
exa = Exa(EXA_API_KEY)
|
||||
|
||||
if search.num_results is not None and not 0 < search.num_results <= 100:
|
||||
raise ValueError("num_results must be between 1 and 100")
|
||||
|
||||
if search.include_text is not None:
|
||||
if len(search.include_text) > 1:
|
||||
raise ValueError("include_text can only contain 1 string")
|
||||
if len(search.include_text[0].split()) > 5:
|
||||
raise ValueError("include_text string cannot be longer than 5 words")
|
||||
|
||||
if search.exclude_text is not None:
|
||||
if len(search.exclude_text) > 1:
|
||||
raise ValueError("exclude_text can only contain 1 string")
|
||||
if len(search.exclude_text[0].split()) > 5:
|
||||
raise ValueError("exclude_text string cannot be longer than 5 words")
|
||||
|
||||
# Call Exa API with direct parameters
|
||||
if search.text:
|
||||
results = cast(
|
||||
dict[str, Any],
|
||||
exa.search_and_contents(
|
||||
query=search.query,
|
||||
type=search.search_type,
|
||||
category=search.category,
|
||||
num_results=search.num_results,
|
||||
include_text=search.include_text,
|
||||
exclude_text=search.exclude_text,
|
||||
use_autoprompt=search.use_autoprompt,
|
||||
text=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
results = cast(
|
||||
dict[str, Any],
|
||||
exa.search(
|
||||
query=search.query,
|
||||
type=search.search_type,
|
||||
category=search.category,
|
||||
num_results=search.num_results,
|
||||
include_text=search.include_text,
|
||||
exclude_text=search.exclude_text,
|
||||
use_autoprompt=search.use_autoprompt,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Exa search failed: {e!s}"}
|
||||
|
||||
|
||||
@router.get("/google")
|
||||
def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)):
|
||||
# https://developers.google.com/custom-search/v1/overview
|
||||
GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY")
|
||||
# https://cse.google.com/cse/all
|
||||
SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID")
|
||||
|
||||
# Using the first page
|
||||
start_page_idx = 1
|
||||
# Different language may get different result
|
||||
search_language = "en"
|
||||
# How many pages to return
|
||||
num_result_pages = 10
|
||||
# Constructing the URL
|
||||
# Doc: https://developers.google.com/custom-search/v1/using_rest
|
||||
base_url = (
|
||||
f"https://www.googleapis.com/customsearch/v1?"
|
||||
f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start="
|
||||
f"{start_page_idx}&lr={search_language}&num={num_result_pages}"
|
||||
)
|
||||
|
||||
if search_type == "image":
|
||||
url = base_url + "&searchType=image"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
responses = []
|
||||
# Fetch the results given the URL
|
||||
try:
|
||||
# Make the get
|
||||
result = requests.get(url)
|
||||
data = result.json()
|
||||
|
||||
# Get the result items
|
||||
if "items" in data:
|
||||
search_items = data.get("items")
|
||||
|
||||
# Iterate over results found
|
||||
for i, search_item in enumerate(search_items, start=1):
|
||||
if search_type == "image":
|
||||
# Process image search results
|
||||
title = search_item.get("title")
|
||||
image_url = search_item.get("link")
|
||||
display_link = search_item.get("displayLink")
|
||||
|
||||
# Get context URL (page containing the image)
|
||||
image_info = search_item.get("image", {})
|
||||
context_url = image_info.get("contextLink", "")
|
||||
|
||||
# Get image dimensions if available
|
||||
width = image_info.get("width")
|
||||
height = image_info.get("height")
|
||||
|
||||
response = {
|
||||
"result_id": i,
|
||||
"title": title,
|
||||
"image_url": image_url,
|
||||
"display_link": display_link,
|
||||
"context_url": context_url,
|
||||
}
|
||||
|
||||
# Add dimensions if available
|
||||
if width:
|
||||
response["width"] = int(width)
|
||||
if height:
|
||||
response["height"] = int(height)
|
||||
|
||||
responses.append(response)
|
||||
else:
|
||||
# Process web search results (existing logic)
|
||||
# Check metatags are present
|
||||
if "pagemap" not in search_item:
|
||||
continue
|
||||
if "metatags" not in search_item["pagemap"]:
|
||||
continue
|
||||
if "og:description" in search_item["pagemap"]["metatags"][0]:
|
||||
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
||||
else:
|
||||
long_description = "N/A"
|
||||
# Get the page title
|
||||
title = search_item.get("title")
|
||||
# Page snippet
|
||||
snippet = search_item.get("snippet")
|
||||
|
||||
# Extract the page url
|
||||
link = search_item.get("link")
|
||||
response = {
|
||||
"result_id": i,
|
||||
"title": title,
|
||||
"description": snippet,
|
||||
"long_description": long_description,
|
||||
"url": link,
|
||||
}
|
||||
responses.append(response)
|
||||
else:
|
||||
error_info = data.get("error", {})
|
||||
logger.error(f"Google search failed - API response: {error_info}")
|
||||
responses.append({"error": f"Google search failed - API response: {error_info}"})
|
||||
|
||||
except Exception as e:
|
||||
responses.append({"error": f"google search failed: {e!s}"})
|
||||
return responses
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from exa_py import Exa
|
||||
from app.component.auth import key_must
|
||||
from app.component.environment import env_not_empty
|
||||
from app.model.mcp.proxy import ExaSearch
|
||||
from typing import Any, cast
|
||||
import requests
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_proxy_controller")
|
||||
|
||||
from app.model.user.key import Key
|
||||
|
||||
|
||||
router = APIRouter(prefix="/proxy", tags=["Mcp Servers"])
|
||||
|
||||
|
||||
@router.post("/exa")
|
||||
@traceroot.trace()
|
||||
def exa_search(search: ExaSearch, key: Key = Depends(key_must)):
|
||||
"""Search using Exa API."""
|
||||
EXA_API_KEY = env_not_empty("EXA_API_KEY")
|
||||
try:
|
||||
# Validate input parameters
|
||||
if search.num_results is not None and not 0 < search.num_results <= 100:
|
||||
logger.warning("Invalid exa search parameter", extra={"param": "num_results", "value": search.num_results})
|
||||
raise ValueError("num_results must be between 1 and 100")
|
||||
|
||||
if search.include_text is not None and len(search.include_text) > 0:
|
||||
if len(search.include_text) > 1:
|
||||
logger.warning("Invalid exa search parameter", extra={"param": "include_text", "reason": "more than 1 string"})
|
||||
raise ValueError("include_text can only contain 1 string")
|
||||
if len(search.include_text[0].split()) > 5:
|
||||
logger.warning("Invalid exa search parameter", extra={"param": "include_text", "reason": "exceeds 5 words"})
|
||||
raise ValueError("include_text string cannot be longer than 5 words")
|
||||
|
||||
if search.exclude_text is not None and len(search.exclude_text) > 0:
|
||||
if len(search.exclude_text) > 1:
|
||||
logger.warning("Invalid exa search parameter", extra={"param": "exclude_text", "reason": "more than 1 string"})
|
||||
raise ValueError("exclude_text can only contain 1 string")
|
||||
if len(search.exclude_text[0].split()) > 5:
|
||||
logger.warning("Invalid exa search parameter", extra={"param": "exclude_text", "reason": "exceeds 5 words"})
|
||||
raise ValueError("exclude_text string cannot be longer than 5 words")
|
||||
|
||||
exa = Exa(EXA_API_KEY)
|
||||
|
||||
# Call Exa API with direct parameters
|
||||
if search.text:
|
||||
results = cast(
|
||||
dict[str, Any],
|
||||
exa.search_and_contents(
|
||||
query=search.query,
|
||||
type=search.search_type,
|
||||
category=search.category,
|
||||
num_results=search.num_results,
|
||||
include_text=search.include_text,
|
||||
exclude_text=search.exclude_text,
|
||||
use_autoprompt=search.use_autoprompt,
|
||||
text=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
results = cast(
|
||||
dict[str, Any],
|
||||
exa.search(
|
||||
query=search.query,
|
||||
type=search.search_type,
|
||||
category=search.category,
|
||||
num_results=search.num_results,
|
||||
include_text=search.include_text,
|
||||
exclude_text=search.exclude_text,
|
||||
use_autoprompt=search.use_autoprompt,
|
||||
),
|
||||
)
|
||||
|
||||
result_count = len(results.get("results", [])) if "results" in results else 0
|
||||
logger.info("Exa search completed", extra={"query": search.query, "search_type": search.search_type, "result_count": result_count})
|
||||
return results
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning("Exa search validation error", extra={"error": str(e)})
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
except Exception as e:
|
||||
logger.error("Exa search failed", extra={"query": search.query, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/google")
|
||||
@traceroot.trace()
|
||||
def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)):
|
||||
"""Search using Google Custom Search API."""
|
||||
# https://developers.google.com/custom-search/v1/overview
|
||||
GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY")
|
||||
# https://cse.google.com/cse/all
|
||||
SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID")
|
||||
|
||||
# Using the first page
|
||||
start_page_idx = 1
|
||||
# Different language may get different result
|
||||
search_language = "en"
|
||||
# How many pages to return
|
||||
num_result_pages = 10
|
||||
|
||||
# Constructing the URL
|
||||
# Doc: https://developers.google.com/custom-search/v1/using_rest
|
||||
base_url = (
|
||||
f"https://www.googleapis.com/customsearch/v1?"
|
||||
f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start="
|
||||
f"{start_page_idx}&lr={search_language}&num={num_result_pages}"
|
||||
)
|
||||
|
||||
if search_type == "image":
|
||||
url = base_url + "&searchType=image"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
responses = []
|
||||
|
||||
try:
|
||||
# Make the GET request
|
||||
result = requests.get(url)
|
||||
data = result.json()
|
||||
|
||||
# Get the result items
|
||||
if "items" in data:
|
||||
search_items = data.get("items")
|
||||
|
||||
# Iterate over results found
|
||||
for i, search_item in enumerate(search_items, start=1):
|
||||
if search_type == "image":
|
||||
# Process image search results
|
||||
title = search_item.get("title")
|
||||
image_url = search_item.get("link")
|
||||
display_link = search_item.get("displayLink")
|
||||
|
||||
# Get context URL (page containing the image)
|
||||
image_info = search_item.get("image", {})
|
||||
context_url = image_info.get("contextLink", "")
|
||||
|
||||
# Get image dimensions if available
|
||||
width = image_info.get("width")
|
||||
height = image_info.get("height")
|
||||
|
||||
response = {
|
||||
"result_id": i,
|
||||
"title": title,
|
||||
"image_url": image_url,
|
||||
"display_link": display_link,
|
||||
"context_url": context_url,
|
||||
}
|
||||
|
||||
# Add dimensions if available
|
||||
if width:
|
||||
response["width"] = int(width)
|
||||
if height:
|
||||
response["height"] = int(height)
|
||||
|
||||
responses.append(response)
|
||||
else:
|
||||
# Process web search results
|
||||
# Check metatags are present
|
||||
if "pagemap" not in search_item:
|
||||
continue
|
||||
if "metatags" not in search_item["pagemap"]:
|
||||
continue
|
||||
if "og:description" in search_item["pagemap"]["metatags"][0]:
|
||||
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
||||
else:
|
||||
long_description = "N/A"
|
||||
# Get the page title
|
||||
title = search_item.get("title")
|
||||
# Page snippet
|
||||
snippet = search_item.get("snippet")
|
||||
|
||||
# Extract the page url
|
||||
link = search_item.get("link")
|
||||
response = {
|
||||
"result_id": i,
|
||||
"title": title,
|
||||
"description": snippet,
|
||||
"long_description": long_description,
|
||||
"url": link,
|
||||
}
|
||||
responses.append(response)
|
||||
|
||||
logger.info("Google search completed", extra={"query": query, "search_type": search_type, "result_count": len(responses)})
|
||||
else:
|
||||
error_info = data.get("error", {})
|
||||
logger.error("Google search API error", extra={"query": query, "api_error": error_info})
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Google search failed", extra={"query": query, "search_type": search_type, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
return responses
|
||||
|
|
@ -1,139 +1,181 @@
|
|||
import os
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from sqlmodel import Session, select
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate, Status
|
||||
from app.model.mcp.mcp import Mcp
|
||||
from loguru import logger
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from app.component.environment import env
|
||||
|
||||
router = APIRouter(tags=["McpUser Management"])
|
||||
|
||||
|
||||
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
|
||||
"""
|
||||
Pre-instantiate MCP toolkit to complete authentication process
|
||||
|
||||
Args:
|
||||
config_dict: MCP server configuration dictionary
|
||||
|
||||
Returns:
|
||||
bool: Whether successfully instantiated and connected
|
||||
"""
|
||||
try:
|
||||
# Ensure unified auth directory for all mcp servers
|
||||
for server_config in config_dict.get("mcpServers", {}).values():
|
||||
if "env" not in server_config:
|
||||
server_config["env"] = {}
|
||||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR",
|
||||
os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
# Create MCP toolkit and attempt to connect
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
# Get tools list to ensure connection is successful
|
||||
tools = mcp_toolkit.get_tools()
|
||||
logger.info(f"Successfully pre-instantiated MCP toolkit with {len(tools)} tools")
|
||||
|
||||
# Disconnect, authentication info is already saved
|
||||
await mcp_toolkit.disconnect()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-instantiate MCP toolkit: {e!r}")
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut])
|
||||
async def list_mcp_users(
|
||||
mcp_id: Optional[int] = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
user_id = auth.user.id
|
||||
query = select(McpUser)
|
||||
if mcp_id is not None:
|
||||
query = query.where(McpUser.mcp_id == mcp_id)
|
||||
if user_id is not None:
|
||||
query = query.where(McpUser.user_id == user_id)
|
||||
mcp_users = session.exec(query).all()
|
||||
return mcp_users
|
||||
|
||||
|
||||
@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut)
|
||||
async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
query = select(McpUser).where(McpUser.id == mcp_user_id)
|
||||
mcp_user = session.exec(query).first()
|
||||
if not mcp_user:
|
||||
raise HTTPException(status_code=404, detail=_("McpUser not found"))
|
||||
return mcp_user
|
||||
|
||||
|
||||
@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut)
|
||||
async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
exists = session.exec(
|
||||
select(McpUser).where(McpUser.mcp_id == mcp_user.mcp_id, McpUser.user_id == auth.user.id)
|
||||
).first()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
|
||||
# Get MCP configuration from the main Mcp table
|
||||
mcp = session.get(Mcp, mcp_user.mcp_id)
|
||||
if mcp and mcp.install_command:
|
||||
# Pre-instantiate MCP toolkit for authentication
|
||||
config_dict = {
|
||||
"mcpServers": {
|
||||
mcp.key: mcp.install_command
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning(f"Pre-instantiation failed for MCP {mcp.key}, but continuing with user creation")
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception during pre-instantiation for MCP {mcp.key}: {e}")
|
||||
|
||||
db_mcp_user = McpUser(mcp_id=mcp_user.mcp_id, user_id=auth.user.id, env=mcp_user.env)
|
||||
session.add(db_mcp_user)
|
||||
session.commit()
|
||||
session.refresh(db_mcp_user)
|
||||
return db_mcp_user
|
||||
|
||||
|
||||
@router.put("/mcp/users/{id}", name="update mcp user")
|
||||
async def update_mcp_user(
|
||||
id: int,
|
||||
update_item: McpUserUpdate,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
model = session.get(McpUser, id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
|
||||
if model.user_id != auth.user.id:
|
||||
raise HTTPException(status_code=400, detail=_("current user have no permission to modify"))
|
||||
update_data = update_item.model_dump(exclude_unset=True)
|
||||
model.update_fields(update_data)
|
||||
model.save(session)
|
||||
session.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user")
|
||||
async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
db_mcp_user = session.get(McpUser, mcp_user_id)
|
||||
if not db_mcp_user:
|
||||
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
|
||||
session.delete(db_mcp_user)
|
||||
session.commit()
|
||||
return Response(status_code=204)
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from sqlmodel import Session, select
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from fastapi_babel import _
|
||||
from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate, Status
|
||||
from app.model.mcp.mcp import Mcp
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from app.component.environment import env
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_mcp_user_controller")
|
||||
|
||||
router = APIRouter(tags=["McpUser Management"])
|
||||
|
||||
|
||||
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
|
||||
"""
|
||||
Pre-instantiate MCP toolkit to complete authentication process
|
||||
|
||||
Args:
|
||||
config_dict: MCP server configuration dictionary
|
||||
|
||||
Returns:
|
||||
bool: Whether successfully instantiated and connected
|
||||
"""
|
||||
try:
|
||||
# Ensure unified auth directory for all mcp servers
|
||||
for server_config in config_dict.get("mcpServers", {}).values():
|
||||
if "env" not in server_config:
|
||||
server_config["env"] = {}
|
||||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR",
|
||||
os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
# Create MCP toolkit and attempt to connect
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
# Get tools list to ensure connection is successful
|
||||
tools = mcp_toolkit.get_tools()
|
||||
logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)})
|
||||
|
||||
# Disconnect, authentication info is already saved
|
||||
await mcp_toolkit.disconnect()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut])
|
||||
@traceroot.trace()
|
||||
async def list_mcp_users(
|
||||
mcp_id: Optional[int] = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""List MCP users for current user."""
|
||||
user_id = auth.user.id
|
||||
query = select(McpUser)
|
||||
if mcp_id is not None:
|
||||
query = query.where(McpUser.mcp_id == mcp_id)
|
||||
if user_id is not None:
|
||||
query = query.where(McpUser.user_id == user_id)
|
||||
mcp_users = session.exec(query).all()
|
||||
logger.debug("MCP users listed", extra={"user_id": user_id, "mcp_id": mcp_id, "count": len(mcp_users)})
|
||||
return mcp_users
|
||||
|
||||
|
||||
@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut)
|
||||
@traceroot.trace()
|
||||
async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get MCP user details."""
|
||||
query = select(McpUser).where(McpUser.id == mcp_user_id)
|
||||
mcp_user = session.exec(query).first()
|
||||
if not mcp_user:
|
||||
logger.warning("MCP user not found", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id})
|
||||
raise HTTPException(status_code=404, detail=_("McpUser not found"))
|
||||
logger.debug("MCP user retrieved", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id, "mcp_id": mcp_user.mcp_id})
|
||||
return mcp_user
|
||||
|
||||
|
||||
@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut)
|
||||
@traceroot.trace()
|
||||
async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Create MCP user installation."""
|
||||
user_id = auth.user.id
|
||||
mcp_id = mcp_user.mcp_id
|
||||
|
||||
exists = session.exec(
|
||||
select(McpUser).where(McpUser.mcp_id == mcp_id, McpUser.user_id == user_id)
|
||||
).first()
|
||||
if exists:
|
||||
logger.warning("MCP already installed", extra={"user_id": user_id, "mcp_id": mcp_id})
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
|
||||
# Get MCP configuration from the main Mcp table
|
||||
mcp = session.get(Mcp, mcp_id)
|
||||
if mcp and mcp.install_command:
|
||||
config_dict = {
|
||||
"mcpServers": {
|
||||
mcp.key: mcp.install_command
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning("MCP pre-instantiation failed, continuing", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
|
||||
except Exception as e:
|
||||
logger.warning("MCP pre-instantiation exception", extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, exc_info=True)
|
||||
|
||||
try:
|
||||
db_mcp_user = McpUser(mcp_id=mcp_id, user_id=user_id, env=mcp_user.env)
|
||||
session.add(db_mcp_user)
|
||||
session.commit()
|
||||
session.refresh(db_mcp_user)
|
||||
logger.info("MCP user created", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_user_id": db_mcp_user.id})
|
||||
return db_mcp_user
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("MCP user creation failed", extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/mcp/users/{id}", name="update mcp user")
|
||||
@traceroot.trace()
|
||||
async def update_mcp_user(
|
||||
id: int,
|
||||
update_item: McpUserUpdate,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""Update MCP user settings."""
|
||||
user_id = auth.user.id
|
||||
model = session.get(McpUser, id)
|
||||
if not model:
|
||||
logger.warning("MCP user not found for update", extra={"user_id": user_id, "mcp_user_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
|
||||
if model.user_id != user_id:
|
||||
logger.warning("Unauthorized MCP user update", extra={"user_id": user_id, "mcp_user_id": id, "owner_id": model.user_id})
|
||||
raise HTTPException(status_code=400, detail=_("current user have no permission to modify"))
|
||||
|
||||
try:
|
||||
update_data = update_item.model_dump(exclude_unset=True)
|
||||
model.update_fields(update_data)
|
||||
model.save(session)
|
||||
session.refresh(model)
|
||||
logger.info("MCP user updated", extra={"user_id": user_id, "mcp_user_id": id})
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error("MCP user update failed", extra={"user_id": user_id, "mcp_user_id": id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user")
|
||||
@traceroot.trace()
|
||||
async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete MCP user installation."""
|
||||
user_id = auth.user.id
|
||||
db_mcp_user = session.get(McpUser, mcp_user_id)
|
||||
if not db_mcp_user:
|
||||
logger.warning("MCP user not found for deletion", extra={"user_id": user_id, "mcp_user_id": mcp_user_id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
|
||||
|
||||
try:
|
||||
session.delete(db_mcp_user)
|
||||
session.commit()
|
||||
logger.info("MCP user deleted", extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "mcp_id": db_mcp_user.mcp_id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("MCP user deletion failed", extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,58 +1,81 @@
|
|||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
|
||||
from app.component.environment import env
|
||||
from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter(prefix="/oauth", tags=["Oauth Servers"])
|
||||
|
||||
|
||||
@router.get("/{app}/login", name="OAuth Login Redirect")
|
||||
def oauth_login(app: str, request: Request, state: Optional[str] = None):
|
||||
try:
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://") :]
|
||||
adapter = get_oauth_adapter(app, callback_url)
|
||||
url = adapter.get_authorize_url(state)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Failed to generate authorization URL")
|
||||
return RedirectResponse(str(url))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{app}/callback", name="OAuth Callback")
|
||||
def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None):
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing code parameter")
|
||||
redirect_url = f"eigent://callback/oauth?provider={app}&code={code}&state={state}"
|
||||
html_content = f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>OAuth Callback</title>
|
||||
</head>
|
||||
<body>
|
||||
<script type='text/javascript'>
|
||||
window.location.href = '{redirect_url}';
|
||||
</script>
|
||||
<p>Redirecting, please wait...</p>
|
||||
<button onclick='window.close()'>Close this window</button>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
@router.post("/{app}/token", name="OAuth Fetch Token")
|
||||
def fetch_token(app: str, request: Request, data: OauthCallbackPayload):
|
||||
try:
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://") :]
|
||||
|
||||
adapter = get_oauth_adapter(app, callback_url)
|
||||
token_data = adapter.fetch_token(data.code)
|
||||
return JSONResponse(token_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
|
||||
from app.component.environment import env
|
||||
from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter
|
||||
from typing import Optional
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_oauth_controller")
|
||||
|
||||
router = APIRouter(prefix="/oauth", tags=["Oauth Servers"])
|
||||
|
||||
|
||||
@router.get("/{app}/login", name="OAuth Login Redirect")
|
||||
@traceroot.trace()
|
||||
def oauth_login(app: str, request: Request, state: Optional[str] = None):
|
||||
"""Redirect user to OAuth provider's authorization endpoint."""
|
||||
try:
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://") :]
|
||||
|
||||
adapter = get_oauth_adapter(app, callback_url)
|
||||
url = adapter.get_authorize_url(state)
|
||||
|
||||
if not url:
|
||||
logger.error("Failed to generate authorization URL", extra={"provider": app, "callback_url": callback_url})
|
||||
raise HTTPException(status_code=400, detail="Failed to generate authorization URL")
|
||||
|
||||
logger.info("OAuth login initiated", extra={"provider": app})
|
||||
return RedirectResponse(str(url))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OAuth login failed", extra={"provider": app, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="OAuth login failed")
|
||||
|
||||
|
||||
@router.get("/{app}/callback", name="OAuth Callback")
|
||||
@traceroot.trace()
|
||||
def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None):
|
||||
"""Handle OAuth provider callback and redirect to client app."""
|
||||
if not code:
|
||||
logger.warning("OAuth callback missing code", extra={"provider": app})
|
||||
raise HTTPException(status_code=400, detail="Missing code parameter")
|
||||
|
||||
logger.info("OAuth callback received", extra={"provider": app, "has_state": state is not None})
|
||||
|
||||
redirect_url = f"eigent://callback/oauth?provider={app}&code={code}&state={state}"
|
||||
html_content = f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>OAuth Callback</title>
|
||||
</head>
|
||||
<body>
|
||||
<script type='text/javascript'>
|
||||
window.location.href = '{redirect_url}';
|
||||
</script>
|
||||
<p>Redirecting, please wait...</p>
|
||||
<button onclick='window.close()'>Close this window</button>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
@router.post("/{app}/token", name="OAuth Fetch Token")
|
||||
@traceroot.trace()
|
||||
def fetch_token(app: str, request: Request, data: OauthCallbackPayload):
|
||||
"""Exchange authorization code for access token."""
|
||||
try:
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://") :]
|
||||
|
||||
adapter = get_oauth_adapter(app, callback_url)
|
||||
token_data = adapter.fetch_token(data.code)
|
||||
logger.info("OAuth token fetched", extra={"provider": app})
|
||||
return JSONResponse(token_data)
|
||||
except Exception as e:
|
||||
logger.error("OAuth token fetch failed", extra={"provider": app, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,100 +1,140 @@
|
|||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, select, col
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn
|
||||
|
||||
|
||||
router = APIRouter(tags=["Provider Management"])
|
||||
|
||||
|
||||
@router.get("/providers", name="list providers", response_model=Page[ProviderOut])
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
prefer: Optional[bool] = Query(None, description="Filter by prefer status"),
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
) -> Page[ProviderOut]:
|
||||
user_id = auth.user.id
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete())
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%"))
|
||||
if prefer is not None:
|
||||
stmt = stmt.where(Provider.prefer == prefer)
|
||||
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc()) # Added for consistent pagination
|
||||
return paginate(session, stmt)
|
||||
|
||||
|
||||
@router.get("/provider", name="get provider detail", response_model=ProviderOut)
|
||||
async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
return model
|
||||
|
||||
|
||||
@router.post("/provider", name="create provider", response_model=ProviderOut)
|
||||
async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
model = Provider(**data.model_dump(), user_id=user_id)
|
||||
model.save(session)
|
||||
return model
|
||||
|
||||
|
||||
@router.put("/provider/{id}", name="update provider", response_model=ProviderOut)
|
||||
async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
model = session.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
).one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
model.model_type = data.model_type
|
||||
model.provider_name = data.provider_name
|
||||
model.api_key = data.api_key
|
||||
model.endpoint_url = data.endpoint_url
|
||||
model.encrypted_config = data.encrypted_config
|
||||
model.is_vaild = data.is_vaild
|
||||
model.save(session)
|
||||
session.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/provider/{id}", name="delete provider")
|
||||
async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
model = session.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
).one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
model.delete(session)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/provider/prefer", name="set provider prefer")
|
||||
async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
try:
|
||||
# 1. current user's all provider prefer set to false
|
||||
session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False))
|
||||
# 2. set the prefer of the specified provider_id to true
|
||||
session.exec(
|
||||
update(Provider)
|
||||
.where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == data.provider_id)
|
||||
.values(prefer=True)
|
||||
)
|
||||
session.commit()
|
||||
return {"success": True}
|
||||
except SQLAlchemyError as e:
|
||||
session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
from typing import List, Optional
|
||||
from fastapi import Depends, HTTPException, Query, Response, APIRouter
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, select, col
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.component.database import session
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_provider_controller")
|
||||
|
||||
router = APIRouter(tags=["Provider Management"])
|
||||
|
||||
|
||||
@router.get("/providers", name="list providers", response_model=Page[ProviderOut])
|
||||
@traceroot.trace()
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
prefer: Optional[bool] = Query(None, description="Filter by prefer status"),
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
) -> Page[ProviderOut]:
|
||||
"""List user's providers with optional filtering."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete())
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%"))
|
||||
if prefer is not None:
|
||||
stmt = stmt.where(Provider.prefer == prefer)
|
||||
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc())
|
||||
logger.debug("Providers listed", extra={"user_id": user_id, "keyword": keyword, "prefer_filter": prefer})
|
||||
return paginate(session, stmt)
|
||||
|
||||
|
||||
@router.get("/provider", name="get provider detail", response_model=ProviderOut)
|
||||
@traceroot.trace()
|
||||
async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get provider details."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
if not model:
|
||||
logger.warning("Provider not found", extra={"user_id": user_id, "provider_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
logger.debug("Provider retrieved", extra={"user_id": user_id, "provider_id": id})
|
||||
return model
|
||||
|
||||
|
||||
@router.post("/provider", name="create provider", response_model=ProviderOut)
|
||||
@traceroot.trace()
|
||||
async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Create a new provider."""
|
||||
user_id = auth.user.id
|
||||
try:
|
||||
model = Provider(**data.model_dump(), user_id=user_id)
|
||||
model.save(session)
|
||||
logger.info("Provider created", extra={"user_id": user_id, "provider_id": model.id, "provider_name": data.provider_name})
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error("Provider creation failed", extra={"user_id": user_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/provider/{id}", name="update provider", response_model=ProviderOut)
|
||||
@traceroot.trace()
|
||||
async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update provider details."""
|
||||
user_id = auth.user.id
|
||||
model = session.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
).one_or_none()
|
||||
if not model:
|
||||
logger.warning("Provider not found for update", extra={"user_id": user_id, "provider_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
|
||||
try:
|
||||
model.model_type = data.model_type
|
||||
model.provider_name = data.provider_name
|
||||
model.api_key = data.api_key
|
||||
model.endpoint_url = data.endpoint_url
|
||||
model.encrypted_config = data.encrypted_config
|
||||
model.is_vaild = data.is_vaild
|
||||
model.save(session)
|
||||
session.refresh(model)
|
||||
logger.info("Provider updated", extra={"user_id": user_id, "provider_id": id, "provider_name": data.provider_name})
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error("Provider update failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/provider/{id}", name="delete provider")
|
||||
@traceroot.trace()
|
||||
async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete a provider."""
|
||||
user_id = auth.user.id
|
||||
model = session.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
).one_or_none()
|
||||
if not model:
|
||||
logger.warning("Provider not found for deletion", extra={"user_id": user_id, "provider_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
|
||||
try:
|
||||
model.delete(session)
|
||||
logger.info("Provider deleted", extra={"user_id": user_id, "provider_id": id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.error("Provider deletion failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/provider/prefer", name="set provider prefer")
|
||||
@traceroot.trace()
|
||||
async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Set preferred provider for user."""
|
||||
user_id = auth.user.id
|
||||
provider_id = data.provider_id
|
||||
|
||||
try:
|
||||
# 1. Set all current user's providers prefer to false
|
||||
session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False))
|
||||
# 2. Set the prefer of the specified provider_id to true
|
||||
session.exec(
|
||||
update(Provider)
|
||||
.where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
.values(prefer=True)
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Preferred provider set", extra={"user_id": user_id, "provider_id": provider_id})
|
||||
return {"success": True}
|
||||
except SQLAlchemyError as e:
|
||||
session.rollback()
|
||||
logger.error("Failed to set preferred provider", extra={"user_id": user_id, "provider_id": provider_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,90 +1,114 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session
|
||||
from app.component import code
|
||||
from app.component.auth import Auth
|
||||
from app.component.database import session
|
||||
from app.component.encrypt import password_verify
|
||||
from app.component.stack_auth import StackAuth
|
||||
from app.exception.exception import UserException
|
||||
from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User, RegisterIn
|
||||
from loguru import logger
|
||||
from app.component.environment import env
|
||||
|
||||
|
||||
router = APIRouter(tags=["Login/Registration"])
|
||||
|
||||
|
||||
@router.post("/login", name="login by email or password")
|
||||
async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse:
|
||||
"""
|
||||
User login with email and password
|
||||
"""
|
||||
user = User.by(User.email == data.email, s=session).one_or_none()
|
||||
if not user or not password_verify(data.password, user.password):
|
||||
raise UserException(code.password, _("Account or password error"))
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/login-by_stack", name="login by stack")
|
||||
async def by_stack_auth(
|
||||
token: str,
|
||||
type: str = "signup",
|
||||
invite_code: str | None = None,
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
try:
|
||||
stack_id = await StackAuth.user_id(token)
|
||||
info = await StackAuth.user_info(token)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(500, detail=_(f"{e}"))
|
||||
user = User.by(User.stack_id == stack_id, s=session).one_or_none()
|
||||
|
||||
if not user:
|
||||
# Only signup can create user
|
||||
if type != "signup":
|
||||
raise UserException(code.error, _("User not found"))
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
username=info["username"] if "username" in info else None,
|
||||
nickname=info["display_name"],
|
||||
email=info["primary_email"],
|
||||
avatar=info["profile_image_url"],
|
||||
stack_id=stack_id,
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
session.refresh(user)
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error(f"Failed to register: {e}")
|
||||
raise UserException(code.error, _("Failed to register"))
|
||||
else:
|
||||
if user.status == Status.Block:
|
||||
raise UserException(code.error, _("Your account has been blocked."))
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/register", name="register by email/password")
|
||||
async def register(data: RegisterIn, session: Session = Depends(session)):
|
||||
# Check if email is already registered
|
||||
if User.by(User.email == data.email, s=session).one_or_none():
|
||||
raise UserException(code.error, _("Email already registered"))
|
||||
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
email=data.email,
|
||||
password=data.password,
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error(f"Failed to register: {e}")
|
||||
raise UserException(code.error, _("Failed to register"))
|
||||
return {"status": "success"}
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session
|
||||
from app.component import code
|
||||
from app.component.auth import Auth
|
||||
from app.component.database import session
|
||||
from app.component.encrypt import password_verify
|
||||
from app.component.stack_auth import StackAuth
|
||||
from app.exception.exception import UserException
|
||||
from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User, RegisterIn
|
||||
from app.component.environment import env
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_login_controller")
|
||||
|
||||
|
||||
router = APIRouter(tags=["Login/Registration"])
|
||||
|
||||
|
||||
@router.post("/login", name="login by email or password")
|
||||
@traceroot.trace()
|
||||
async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse:
|
||||
"""
|
||||
User login with email and password
|
||||
"""
|
||||
email = data.email
|
||||
user = User.by(User.email == email, s=session).one_or_none()
|
||||
|
||||
if not user:
|
||||
logger.warning("Login failed: user not found", extra={"email": email})
|
||||
raise UserException(code.password, _("Account or password error"))
|
||||
|
||||
if not password_verify(data.password, user.password):
|
||||
logger.warning("Login failed: invalid password", extra={"user_id": user.id, "email": email})
|
||||
raise UserException(code.password, _("Account or password error"))
|
||||
|
||||
logger.info("User login successful", extra={"user_id": user.id, "email": email})
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/login-by_stack", name="login by stack")
|
||||
@traceroot.trace()
|
||||
async def by_stack_auth(
|
||||
token: str,
|
||||
type: str = "signup",
|
||||
invite_code: str | None = None,
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
try:
|
||||
stack_id = await StackAuth.user_id(token)
|
||||
info = await StackAuth.user_info(token)
|
||||
except Exception as e:
|
||||
logger.error("Stack auth failed", extra={"type": type, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(500, detail=_("Authentication failed"))
|
||||
|
||||
user = User.by(User.stack_id == stack_id, s=session).one_or_none()
|
||||
|
||||
if not user:
|
||||
if type != "signup":
|
||||
logger.warning("Stack auth signup blocked: user not found", extra={"stack_id": stack_id, "type": type})
|
||||
raise UserException(code.error, _("User not found"))
|
||||
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
username=info["username"] if "username" in info else None,
|
||||
nickname=info["display_name"],
|
||||
email=info["primary_email"],
|
||||
avatar=info["profile_image_url"],
|
||||
stack_id=stack_id,
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
logger.info("New user registered via stack", extra={"user_id": user.id, "email": user.email, "stack_id": stack_id})
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error("Stack auth registration failed", extra={"stack_id": stack_id, "error": str(e)}, exc_info=True)
|
||||
raise UserException(code.error, _("Failed to register"))
|
||||
else:
|
||||
if user.status == Status.Block:
|
||||
logger.warning("Blocked user login attempt", extra={"user_id": user.id, "stack_id": stack_id})
|
||||
raise UserException(code.error, _("Your account has been blocked."))
|
||||
|
||||
logger.info("User login via stack successful", extra={"user_id": user.id, "email": user.email, "stack_id": stack_id})
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/register", name="register by email/password")
|
||||
@traceroot.trace()
|
||||
async def register(data: RegisterIn, session: Session = Depends(session)):
|
||||
email = data.email
|
||||
|
||||
if User.by(User.email == email, s=session).one_or_none():
|
||||
logger.warning("Registration failed: email already exists", extra={"email": email})
|
||||
raise UserException(code.error, _("Email already registered"))
|
||||
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
email=email,
|
||||
password=data.password,
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
logger.info("User registered successfully", extra={"user_id": user.id, "email": email})
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error("User registration failed", extra={"email": email, "error": str(e)}, exc_info=True)
|
||||
raise UserException(code.error, _("Failed to register"))
|
||||
|
||||
return {"status": "success"}
|
||||
|
|
@ -1,115 +1,151 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import Session, select
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.user.privacy import UserPrivacy, UserPrivacySettings
|
||||
from app.model.user.user import User, UserIn, UserOut, UserProfile
|
||||
from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from app.model.mcp.mcp_user import McpUser
|
||||
from app.model.config.config import Config
|
||||
from app.model.chat.chat_snpshot import ChatSnapshot
|
||||
from app.model.user.user_credits_record import UserCreditsRecord
|
||||
|
||||
|
||||
router = APIRouter(tags=["User"])
|
||||
|
||||
|
||||
@router.get("/user", name="user info", response_model=UserOut)
|
||||
def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
# 获取用户信息时触发积分刷新
|
||||
user: User = auth.user
|
||||
user.refresh_credits_on_active(session)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/user", name="update user info", response_model=UserOut)
|
||||
def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
model = auth.user
|
||||
model.username = data.username
|
||||
model.save(session)
|
||||
return model
|
||||
|
||||
|
||||
@router.put("/user/profile", name="update user profile", response_model=UserProfile)
|
||||
def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
model = auth.user
|
||||
model.nickname = data.nickname
|
||||
model.fullname = data.fullname
|
||||
model.work_desc = data.work_desc
|
||||
model.save(session)
|
||||
return model
|
||||
|
||||
|
||||
@router.get("/user/privacy", name="get user privacy")
|
||||
def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
|
||||
if not model:
|
||||
return UserPrivacySettings.default_settings()
|
||||
return model.pricacy_setting
|
||||
|
||||
|
||||
@router.put("/user/privacy", name="update user privacy")
|
||||
def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
user_id = auth.user.id
|
||||
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
default_settings = UserPrivacySettings.default_settings()
|
||||
|
||||
if model:
|
||||
model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()}
|
||||
model.save(session)
|
||||
else:
|
||||
model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()})
|
||||
model.save(session)
|
||||
|
||||
return model.pricacy_setting
|
||||
|
||||
|
||||
@router.get("/user/current_credits", name="get user current credits")
|
||||
def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
user = auth.user
|
||||
user.refresh_credits_on_active(session)
|
||||
credits = user.credits
|
||||
daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id)
|
||||
current_daily_credits = 0
|
||||
if daily_credits:
|
||||
current_daily_credits = daily_credits.amount - daily_credits.balance
|
||||
credits += current_daily_credits if current_daily_credits > 0 else 0
|
||||
return {"credits": credits, "daily_credits": current_daily_credits}
|
||||
|
||||
|
||||
@router.get("/user/stat", name="get user stat", response_model=UserStatOut)
|
||||
def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get current user's operation statistics."""
|
||||
stat = session.exec(select(UserStat).where(UserStat.user_id == auth.user.id)).first()
|
||||
data = UserStatOut()
|
||||
if stat:
|
||||
data = UserStatOut(**stat.model_dump())
|
||||
else:
|
||||
data = UserStatOut(user_id=auth.user.id)
|
||||
data.task_queries = ChatHistory.count(ChatHistory.user_id == auth.user.id, s=session)
|
||||
mcp = McpUser.count(McpUser.user_id == auth.user.id, s=session)
|
||||
tool: list = session.exec(
|
||||
select(func.count("*")).where(Config.user_id == auth.user.id).group_by(Config.config_group)
|
||||
).all()
|
||||
tool = tool.__len__()
|
||||
data.mcp_install_count = mcp + tool
|
||||
data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(auth.user.id))
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/user/stat", name="record user stat")
|
||||
def record_user_stat(
|
||||
data: UserStatActionIn,
|
||||
auth: Auth = Depends(auth_must),
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
"""Record or update current user's operation statistics."""
|
||||
data.user_id = auth.user.id
|
||||
stat = UserStat.record_action(session, data)
|
||||
return stat
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import Session, select
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.user.privacy import UserPrivacy, UserPrivacySettings
|
||||
from app.model.user.user import User, UserIn, UserOut, UserProfile
|
||||
from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from app.model.mcp.mcp_user import McpUser
|
||||
from app.model.config.config import Config
|
||||
from app.model.chat.chat_snpshot import ChatSnapshot
|
||||
from app.model.user.user_credits_record import UserCreditsRecord
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_user_controller")
|
||||
|
||||
router = APIRouter(tags=["User"])
|
||||
|
||||
|
||||
@router.get("/user", name="user info", response_model=UserOut)
|
||||
@traceroot.trace()
|
||||
def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get current user information and refresh credits."""
|
||||
user: User = auth.user
|
||||
user.refresh_credits_on_active(session)
|
||||
logger.debug("User info retrieved", extra={"user_id": user.id})
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/user", name="update user info", response_model=UserOut)
|
||||
@traceroot.trace()
|
||||
def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update user basic information."""
|
||||
model = auth.user
|
||||
model.username = data.username
|
||||
model.save(session)
|
||||
logger.info("User info updated", extra={"user_id": model.id, "username": data.username})
|
||||
return model
|
||||
|
||||
|
||||
@router.put("/user/profile", name="update user profile", response_model=UserProfile)
|
||||
@traceroot.trace()
|
||||
def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update user profile details."""
|
||||
model = auth.user
|
||||
model.nickname = data.nickname
|
||||
model.fullname = data.fullname
|
||||
model.work_desc = data.work_desc
|
||||
model.save(session)
|
||||
logger.info("User profile updated", extra={"user_id": model.id, "nickname": data.nickname})
|
||||
return model
|
||||
|
||||
|
||||
@router.get("/user/privacy", name="get user privacy")
|
||||
@traceroot.trace()
|
||||
def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get user privacy settings."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
|
||||
if not model:
|
||||
logger.debug("Privacy settings not found, returning defaults", extra={"user_id": user_id})
|
||||
return UserPrivacySettings.default_settings()
|
||||
|
||||
logger.debug("Privacy settings retrieved", extra={"user_id": user_id})
|
||||
return model.pricacy_setting
|
||||
|
||||
|
||||
@router.put("/user/privacy", name="update user privacy")
|
||||
@traceroot.trace()
|
||||
def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update user privacy settings."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
default_settings = UserPrivacySettings.default_settings()
|
||||
|
||||
if model:
|
||||
model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()}
|
||||
model.save(session)
|
||||
logger.info("Privacy settings updated", extra={"user_id": user_id})
|
||||
else:
|
||||
model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()})
|
||||
model.save(session)
|
||||
logger.info("Privacy settings created", extra={"user_id": user_id})
|
||||
|
||||
return model.pricacy_setting
|
||||
|
||||
|
||||
@router.get("/user/current_credits", name="get user current credits")
|
||||
@traceroot.trace()
|
||||
def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get user's current credit balance."""
|
||||
user = auth.user
|
||||
user.refresh_credits_on_active(session)
|
||||
credits = user.credits
|
||||
daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id)
|
||||
current_daily_credits = 0
|
||||
if daily_credits:
|
||||
current_daily_credits = daily_credits.amount - daily_credits.balance
|
||||
credits += current_daily_credits if current_daily_credits > 0 else 0
|
||||
|
||||
logger.debug("Credits retrieved", extra={"user_id": user.id, "total_credits": credits, "daily_credits": current_daily_credits})
|
||||
return {"credits": credits, "daily_credits": current_daily_credits}
|
||||
|
||||
|
||||
@router.get("/user/stat", name="get user stat", response_model=UserStatOut)
|
||||
@traceroot.trace()
|
||||
def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get current user's operation statistics."""
|
||||
user_id = auth.user.id
|
||||
stat = session.exec(select(UserStat).where(UserStat.user_id == user_id)).first()
|
||||
data = UserStatOut()
|
||||
|
||||
if stat:
|
||||
data = UserStatOut(**stat.model_dump())
|
||||
else:
|
||||
data = UserStatOut(user_id=user_id)
|
||||
|
||||
data.task_queries = ChatHistory.count(ChatHistory.user_id == user_id, s=session)
|
||||
mcp = McpUser.count(McpUser.user_id == user_id, s=session)
|
||||
tool: list = session.exec(
|
||||
select(func.count("*")).where(Config.user_id == user_id).group_by(Config.config_group)
|
||||
).all()
|
||||
tool = tool.__len__()
|
||||
data.mcp_install_count = mcp + tool
|
||||
data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(user_id))
|
||||
|
||||
logger.debug("User stats retrieved", extra={
|
||||
"user_id": user_id,
|
||||
"task_queries": data.task_queries,
|
||||
"mcp_install_count": data.mcp_install_count,
|
||||
"storage_used": data.storage_used
|
||||
})
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/user/stat", name="record user stat")
|
||||
@traceroot.trace()
|
||||
def record_user_stat(
|
||||
data: UserStatActionIn,
|
||||
auth: Auth = Depends(auth_must),
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
"""Record or update current user's operation statistics."""
|
||||
data.user_id = auth.user.id
|
||||
stat = UserStat.record_action(session, data)
|
||||
logger.info("User stat recorded", extra={"user_id": data.user_id, "action": data.action if hasattr(data, 'action') else "unknown"})
|
||||
return stat
|
||||
|
|
@ -1,24 +1,36 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.component import code
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.encrypt import password_hash, password_verify
|
||||
from app.exception.exception import UserException
|
||||
from app.model.user.user import UpdatePassword, UserOut
|
||||
from fastapi_babel import _
|
||||
|
||||
router = APIRouter(tags=["User"])
|
||||
|
||||
|
||||
@router.put("/user/update-password", name="update password", response_model=UserOut)
|
||||
def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
model = auth.user
|
||||
if not password_verify(data.password, model.password):
|
||||
raise UserException(code.error, _("Password is incorrect"))
|
||||
if data.new_password != data.re_new_password:
|
||||
raise UserException(code.error, _("The two passwords do not match"))
|
||||
model.password = password_hash(data.new_password)
|
||||
model.save(session)
|
||||
return model
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.component import code
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.encrypt import password_hash, password_verify
|
||||
from app.exception.exception import UserException
|
||||
from app.model.user.user import UpdatePassword, UserOut
|
||||
from fastapi_babel import _
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("server_password_controller")
|
||||
|
||||
router = APIRouter(tags=["User"])
|
||||
|
||||
|
||||
@router.put("/user/update-password", name="update password", response_model=UserOut)
|
||||
@traceroot.trace()
|
||||
def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Update user password after verifying current password."""
|
||||
user_id = auth.user.id
|
||||
model = auth.user
|
||||
|
||||
if not password_verify(data.password, model.password):
|
||||
logger.warning("Password update failed: incorrect current password", extra={"user_id": user_id})
|
||||
raise UserException(code.error, _("Password is incorrect"))
|
||||
|
||||
if data.new_password != data.re_new_password:
|
||||
logger.warning("Password update failed: new passwords do not match", extra={"user_id": user_id})
|
||||
raise UserException(code.error, _("The two passwords do not match"))
|
||||
|
||||
model.password = password_hash(data.new_password)
|
||||
model.save(session)
|
||||
logger.info("Password updated successfully", extra={"user_id": user_id})
|
||||
return model
|
||||
|
|
@ -2,9 +2,10 @@ from sqlalchemy import Float, Integer
|
|||
from sqlmodel import Field, SmallInteger, Column, JSON, String
|
||||
from typing import Optional
|
||||
from enum import IntEnum
|
||||
from datetime import datetime
|
||||
from sqlalchemy_utils import ChoiceType
|
||||
from app.model.abstract.model import AbstractModel, DefaultTimes
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class ChatStatus(IntEnum):
|
||||
|
|
@ -13,9 +14,20 @@ class ChatStatus(IntEnum):
|
|||
|
||||
|
||||
class ChatHistory(AbstractModel, DefaultTimes, table=True):
|
||||
"""
|
||||
Chat history model with timestamp tracking.
|
||||
|
||||
Inherits from DefaultTimes which provides:
|
||||
- created_at: timestamp when record is created (auto-populated)
|
||||
- updated_at: timestamp when record is last modified (auto-updated)
|
||||
- deleted_at: timestamp for soft deletion (nullable)
|
||||
|
||||
For legacy records without timestamps, sorting falls back to id ordering.
|
||||
"""
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(index=True)
|
||||
task_id: str = Field(index=True, unique=True)
|
||||
project_id: str = Field(index=True, unique=False, nullable=True)
|
||||
question: str
|
||||
language: str
|
||||
model_platform: str
|
||||
|
|
@ -34,6 +46,7 @@ class ChatHistory(AbstractModel, DefaultTimes, table=True):
|
|||
|
||||
class ChatHistoryIn(BaseModel):
|
||||
task_id: str
|
||||
project_id: str | None = None
|
||||
user_id: int | None = None
|
||||
question: str
|
||||
language: str
|
||||
|
|
@ -54,6 +67,7 @@ class ChatHistoryIn(BaseModel):
|
|||
class ChatHistoryOut(BaseModel):
|
||||
id: int
|
||||
task_id: str
|
||||
project_id: str | None = None
|
||||
question: str
|
||||
language: str
|
||||
model_platform: str
|
||||
|
|
@ -67,6 +81,22 @@ class ChatHistoryOut(BaseModel):
|
|||
summary: str | None = None
|
||||
tokens: int
|
||||
status: int
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def fill_project_id_from_task_id(self):
|
||||
"""Fill project_id from task_id when project_id is None"""
|
||||
if self.project_id is None:
|
||||
self.project_id = self.task_id
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def handle_legacy_timestamps(self):
|
||||
"""Handle legacy records that might not have timestamp fields"""
|
||||
# For old records without timestamps, we rely on database-level defaults
|
||||
# The sorting in the controller will handle ordering appropriately
|
||||
return self
|
||||
|
||||
|
||||
class ChatHistoryUpdate(BaseModel):
|
||||
|
|
@ -74,3 +104,4 @@ class ChatHistoryUpdate(BaseModel):
|
|||
summary: str | None = None
|
||||
tokens: int | None = None
|
||||
status: int | None = None
|
||||
project_id: str | None = None
|
||||
|
|
|
|||
|
|
@ -124,10 +124,10 @@ class ConfigInfo:
|
|||
"env_vars": [],
|
||||
"toolkit": "google_drive_mcp_toolkit",
|
||||
},
|
||||
# ConfigGroup.GOOGLE_GMAIL_MCP.value: {
|
||||
# "env_vars": [],
|
||||
# "toolkit": "google_gmail_mcp_toolkit",
|
||||
# },
|
||||
ConfigGroup.GOOGLE_GMAIL_MCP.value: {
|
||||
"env_vars": ["GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET", "GOOGLE_REFRESH_TOKEN"],
|
||||
"toolkit": "google_gmail_native_toolkit",
|
||||
},
|
||||
ConfigGroup.IMAGE_ANALYSIS.value: {
|
||||
"env_vars": [],
|
||||
"toolkit": "image_analysis_toolkit",
|
||||
|
|
|
|||
|
|
@ -1,381 +1,383 @@
|
|||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Relationship, SQLModel, Field, Column, col, select, Session
|
||||
from sqlalchemy_utils import ChoiceType
|
||||
from sqlalchemy import Boolean, SmallInteger, text
|
||||
from app.model.abstract.model import AbstractModel, DefaultTimes
|
||||
from datetime import date, datetime, timedelta
|
||||
from app.model.user.key import ModelType
|
||||
from app.component.database import session_make
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CreditsChannel(IntEnum):
|
||||
register = 1 # 注册赠送
|
||||
invite = 2 # 邀请赠送
|
||||
daily = 3 # 每日刷新积分
|
||||
monthly = 4 # 每月刷新积分
|
||||
paid = 5 # 付费积分
|
||||
addon = 6 # 加量包
|
||||
consume = 7 # 任务消费
|
||||
|
||||
|
||||
class CreditsPriority(IntEnum):
|
||||
daily = 1 # 每日刷新积分
|
||||
monthly = 2 # 每月刷新积分
|
||||
paid = 3 # 付费积分
|
||||
addon = 4 # 加量包
|
||||
|
||||
|
||||
class CreditsPoint(IntEnum):
|
||||
register = 1000
|
||||
invite = 500
|
||||
special_register = 1500 # 1000 register + 500 invite credit
|
||||
|
||||
|
||||
class UserCreditsRecord(AbstractModel, DefaultTimes, table=True):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(index=True, foreign_key="user.id")
|
||||
invite_by: int = Field(default=None, nullable=True, description="invite by user id")
|
||||
invite_code: str = Field(default="", max_length=255)
|
||||
amount: int = Field(default=0)
|
||||
balance: int = Field(default=0)
|
||||
channel: CreditsChannel = Field(
|
||||
default=CreditsChannel.register.value, sa_column=Column(ChoiceType(CreditsChannel, SmallInteger()))
|
||||
)
|
||||
source_id: int = Field(default=0, description="source id")
|
||||
remark: str = Field(default="", max_length=255)
|
||||
expire_at: datetime = Field(default=None, nullable=True, description="Expiration time")
|
||||
used: bool = Field(
|
||||
default=False,
|
||||
sa_column=Column(Boolean, server_default=text("false")),
|
||||
description="Is this record used/expired",
|
||||
)
|
||||
used_at: datetime = Field(default=None, nullable=True, description="Time when this record was used/expired")
|
||||
|
||||
@classmethod
|
||||
def get_permanent_credits(cls, user_id: int) -> int:
|
||||
"""
|
||||
获取可用的token总量,直接用SQL聚合sum
|
||||
Returns:
|
||||
int: 可用的token总量
|
||||
"""
|
||||
session = session_make()
|
||||
from sqlalchemy import func
|
||||
|
||||
statement = (
|
||||
select(func.sum(UserCreditsRecord.amount))
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(
|
||||
UserCreditsRecord.channel.in_(
|
||||
[
|
||||
CreditsChannel.register,
|
||||
CreditsChannel.invite,
|
||||
CreditsChannel.paid,
|
||||
CreditsChannel.addon,
|
||||
CreditsChannel.monthly,
|
||||
]
|
||||
)
|
||||
)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > datetime.now()))
|
||||
)
|
||||
result = session.exec(statement).first()
|
||||
return result or 0
|
||||
|
||||
@classmethod
|
||||
def get_temp_credits(cls, user_id: int) -> tuple[int, date]:
|
||||
"""
|
||||
1. 获取可用的临时token总量,需要通过credits 然后根据model_type来计算
|
||||
2. 每天只允许赠送一次临时的量
|
||||
|
||||
Returns:
|
||||
int: 可用的临时token总量
|
||||
"""
|
||||
session = session_make()
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where(UserCreditsRecord.expire_at.is_not(None))
|
||||
.where(col(UserCreditsRecord.expire_at) > datetime.now())
|
||||
)
|
||||
record: UserCreditsRecord = session.exec(statement).first()
|
||||
if record is None:
|
||||
return 0, None
|
||||
return record.amount - record.balance, record.expire_at
|
||||
|
||||
@classmethod
|
||||
def consume_credits(cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""):
|
||||
"""
|
||||
消耗积分,优先消耗每日积分(daily),再消耗monthly、paid、addon等。
|
||||
消耗时更新UserCreditsRecord的balance字段,记录已消耗积分数。
|
||||
同时生成积分消耗记录,更新用户积分credits字段(不包括每日积分)。
|
||||
避免重复生成积分消耗记录和重复扣减积分。
|
||||
"""
|
||||
|
||||
# 检查是否已有积分消耗记录
|
||||
existing_consume_record = None
|
||||
if source_id > 0:
|
||||
existing_consume_record = session.exec(
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.consume)
|
||||
.where(UserCreditsRecord.source_id == source_id)
|
||||
).first()
|
||||
|
||||
if existing_consume_record:
|
||||
# 如果新amount更大,需要额外消耗积分
|
||||
if amount > 0:
|
||||
existing_consume_record.amount -= amount
|
||||
session.add(existing_consume_record)
|
||||
# 直接处理额外的积分消耗,不生成新的消耗记录
|
||||
cls._consume_credits_internal_update(user_id, amount, session, source_id, remark)
|
||||
# 如果新amount更小,需要退还积分(这里可以根据业务需求决定是否实现)
|
||||
else:
|
||||
# 暂时不实现退还逻辑,可以根据需要添加
|
||||
pass
|
||||
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 没有现有记录,执行正常的积分消耗流程
|
||||
cls._consume_credits_internal(user_id, amount, session, source_id, remark)
|
||||
|
||||
@classmethod
|
||||
def _consume_credits_internal(
|
||||
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
|
||||
):
|
||||
"""
|
||||
内部积分消耗逻辑,处理实际的积分扣减
|
||||
"""
|
||||
from app.model.user.user import User
|
||||
|
||||
remain = amount
|
||||
now = datetime.now()
|
||||
consumed_from_daily = 0
|
||||
consumed_from_other = 0
|
||||
|
||||
# 优先消耗daily
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where(UserCreditsRecord.expire_at.is_not(None))
|
||||
.where(col(UserCreditsRecord.expire_at) > now)
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
daily_records = session.exec(statement).first()
|
||||
if daily_records:
|
||||
can_consume = daily_records.amount - daily_records.balance
|
||||
use = min(remain, can_consume)
|
||||
daily_records.balance += use
|
||||
session.add(daily_records)
|
||||
remain -= use
|
||||
consumed_from_daily = use
|
||||
if remain == 0:
|
||||
# 生成积分消耗记录
|
||||
consume_record = UserCreditsRecord(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
channel=CreditsChannel.consume,
|
||||
source_id=source_id,
|
||||
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily})",
|
||||
)
|
||||
session.add(consume_record)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 若daily不够,继续消耗monthly/paid/addon
|
||||
if remain > 0:
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(
|
||||
UserCreditsRecord.channel.in_(
|
||||
[
|
||||
CreditsChannel.monthly,
|
||||
CreditsChannel.paid,
|
||||
CreditsChannel.addon,
|
||||
CreditsChannel.register,
|
||||
CreditsChannel.invite,
|
||||
]
|
||||
)
|
||||
)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
other_records = session.exec(statement).all()
|
||||
for record in other_records:
|
||||
can_consume = record.amount - record.balance
|
||||
if can_consume <= 0:
|
||||
continue
|
||||
use = min(remain, can_consume)
|
||||
record.balance += use
|
||||
session.add(record)
|
||||
remain -= use
|
||||
consumed_from_other += use
|
||||
if remain == 0:
|
||||
break
|
||||
|
||||
# 更新用户积分字段(只扣除非每日积分消耗的部分)
|
||||
if consumed_from_other > 0:
|
||||
user = session.exec(select(User).where(User.id == user_id)).first()
|
||||
if user:
|
||||
user.credits -= consumed_from_other
|
||||
session.add(user)
|
||||
|
||||
# 生成积分消耗记录
|
||||
consume_record = UserCreditsRecord(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
channel=CreditsChannel.consume,
|
||||
source_id=source_id,
|
||||
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily}, other: {consumed_from_other})",
|
||||
)
|
||||
session.add(consume_record)
|
||||
session.commit()
|
||||
|
||||
if remain > 0:
|
||||
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
|
||||
|
||||
@classmethod
|
||||
def _consume_credits_internal_update(
|
||||
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
|
||||
):
|
||||
"""
|
||||
内部积分消耗逻辑(更新模式),处理实际的积分扣减但不生成新的消耗记录
|
||||
用于更新现有消耗记录时的额外积分消耗
|
||||
"""
|
||||
from app.model.user.user import User
|
||||
|
||||
remain = amount
|
||||
now = datetime.now()
|
||||
consumed_from_daily = 0
|
||||
consumed_from_other = 0
|
||||
|
||||
# 优先消耗daily
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where(UserCreditsRecord.expire_at.is_not(None))
|
||||
.where(col(UserCreditsRecord.expire_at) > now)
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
daily_records = session.exec(statement).first()
|
||||
if daily_records:
|
||||
can_consume = daily_records.amount - daily_records.balance
|
||||
use = min(remain, can_consume)
|
||||
daily_records.balance += use
|
||||
session.add(daily_records)
|
||||
remain -= use
|
||||
consumed_from_daily = use
|
||||
if remain == 0:
|
||||
# 不生成新的消耗记录,只更新现有记录
|
||||
return
|
||||
|
||||
# 若daily不够,继续消耗monthly/paid/addon
|
||||
if remain > 0:
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(
|
||||
UserCreditsRecord.channel.in_(
|
||||
[
|
||||
CreditsChannel.monthly,
|
||||
CreditsChannel.paid,
|
||||
CreditsChannel.addon,
|
||||
CreditsChannel.register,
|
||||
CreditsChannel.invite,
|
||||
]
|
||||
)
|
||||
)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
other_records = session.exec(statement).all()
|
||||
for record in other_records:
|
||||
can_consume = record.amount - record.balance
|
||||
if can_consume <= 0:
|
||||
continue
|
||||
use = min(remain, can_consume)
|
||||
record.balance += use
|
||||
session.add(record)
|
||||
remain -= use
|
||||
consumed_from_other += use
|
||||
if remain == 0:
|
||||
break
|
||||
logger.info(f"consumed_from_other: {consumed_from_other}")
|
||||
# 更新用户积分字段(只扣除非每日积分消耗的部分)
|
||||
if consumed_from_other > 0:
|
||||
user = session.exec(select(User).where(User.id == user_id)).first()
|
||||
if user:
|
||||
user.credits -= consumed_from_other
|
||||
session.add(user)
|
||||
|
||||
# 不生成新的消耗记录,因为现有记录已经在主函数中更新了
|
||||
|
||||
if remain > 0:
|
||||
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
|
||||
|
||||
@classmethod
|
||||
def get_daily_balance_sum(cls, user_id: int) -> int:
|
||||
"""
|
||||
获取用户所有每日积分(daily channel)的balance字段之和
|
||||
"""
|
||||
session = session_make()
|
||||
statement = (
|
||||
select(UserCreditsRecord.balance)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
)
|
||||
balances = session.exec(statement).all()
|
||||
return sum(balances) if balances else 0
|
||||
|
||||
@classmethod
|
||||
def get_daily_balance(cls, user_id: int) -> int:
|
||||
"""
|
||||
获取用户当前的每日积分数据
|
||||
"""
|
||||
session = session_make()
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
return record
|
||||
|
||||
|
||||
class UserCreditsRecordWithChatOut(BaseModel):
|
||||
"""扩展的积分记录输出模型,包含聊天历史信息"""
|
||||
|
||||
amount: int
|
||||
balance: int
|
||||
channel: CreditsChannel
|
||||
source_id: int
|
||||
expire_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
# 聊天历史相关字段(当channel为consume且source_id有效时)
|
||||
chat_project_name: Optional[str] = None
|
||||
chat_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class UserCreditsRecordOut(BaseModel):
|
||||
amount: int
|
||||
balance: int
|
||||
channel: CreditsChannel
|
||||
source_id: int
|
||||
remark: str
|
||||
expire_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Relationship, SQLModel, Field, Column, col, select, Session
|
||||
from sqlalchemy_utils import ChoiceType
|
||||
from sqlalchemy import Boolean, SmallInteger, text
|
||||
from app.model.abstract.model import AbstractModel, DefaultTimes
|
||||
from datetime import date, datetime, timedelta
|
||||
from app.model.user.key import ModelType
|
||||
from app.component.database import session_make
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
|
||||
logger = traceroot.get_logger("user_credits_record")
|
||||
|
||||
|
||||
class CreditsChannel(IntEnum):
|
||||
register = 1 # 注册赠送
|
||||
invite = 2 # 邀请赠送
|
||||
daily = 3 # 每日刷新积分
|
||||
monthly = 4 # 每月刷新积分
|
||||
paid = 5 # 付费积分
|
||||
addon = 6 # 加量包
|
||||
consume = 7 # 任务消费
|
||||
|
||||
|
||||
class CreditsPriority(IntEnum):
|
||||
daily = 1 # 每日刷新积分
|
||||
monthly = 2 # 每月刷新积分
|
||||
paid = 3 # 付费积分
|
||||
addon = 4 # 加量包
|
||||
|
||||
|
||||
class CreditsPoint(IntEnum):
|
||||
register = 1000
|
||||
invite = 500
|
||||
special_register = 1500 # 1000 register + 500 invite credit
|
||||
|
||||
|
||||
class UserCreditsRecord(AbstractModel, DefaultTimes, table=True):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(index=True, foreign_key="user.id")
|
||||
invite_by: int = Field(default=None, nullable=True, description="invite by user id")
|
||||
invite_code: str = Field(default="", max_length=255)
|
||||
amount: int = Field(default=0)
|
||||
balance: int = Field(default=0)
|
||||
channel: CreditsChannel = Field(
|
||||
default=CreditsChannel.register.value, sa_column=Column(ChoiceType(CreditsChannel, SmallInteger()))
|
||||
)
|
||||
source_id: int = Field(default=0, description="source id")
|
||||
remark: str = Field(default="", max_length=255)
|
||||
expire_at: datetime = Field(default=None, nullable=True, description="Expiration time")
|
||||
used: bool = Field(
|
||||
default=False,
|
||||
sa_column=Column(Boolean, server_default=text("false")),
|
||||
description="Is this record used/expired",
|
||||
)
|
||||
used_at: datetime = Field(default=None, nullable=True, description="Time when this record was used/expired")
|
||||
|
||||
@classmethod
|
||||
def get_permanent_credits(cls, user_id: int) -> int:
|
||||
"""
|
||||
获取可用的token总量,直接用SQL聚合sum
|
||||
Returns:
|
||||
int: 可用的token总量
|
||||
"""
|
||||
session = session_make()
|
||||
from sqlalchemy import func
|
||||
|
||||
statement = (
|
||||
select(func.sum(UserCreditsRecord.amount))
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(
|
||||
UserCreditsRecord.channel.in_(
|
||||
[
|
||||
CreditsChannel.register,
|
||||
CreditsChannel.invite,
|
||||
CreditsChannel.paid,
|
||||
CreditsChannel.addon,
|
||||
CreditsChannel.monthly,
|
||||
]
|
||||
)
|
||||
)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > datetime.now()))
|
||||
)
|
||||
result = session.exec(statement).first()
|
||||
return result or 0
|
||||
|
||||
@classmethod
|
||||
def get_temp_credits(cls, user_id: int) -> tuple[int, date]:
|
||||
"""
|
||||
1. 获取可用的临时token总量,需要通过credits 然后根据model_type来计算
|
||||
2. 每天只允许赠送一次临时的量
|
||||
|
||||
Returns:
|
||||
int: 可用的临时token总量
|
||||
"""
|
||||
session = session_make()
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where(UserCreditsRecord.expire_at.is_not(None))
|
||||
.where(col(UserCreditsRecord.expire_at) > datetime.now())
|
||||
)
|
||||
record: UserCreditsRecord = session.exec(statement).first()
|
||||
if record is None:
|
||||
return 0, None
|
||||
return record.amount - record.balance, record.expire_at
|
||||
|
||||
@classmethod
|
||||
def consume_credits(cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""):
|
||||
"""
|
||||
消耗积分,优先消耗每日积分(daily),再消耗monthly、paid、addon等。
|
||||
消耗时更新UserCreditsRecord的balance字段,记录已消耗积分数。
|
||||
同时生成积分消耗记录,更新用户积分credits字段(不包括每日积分)。
|
||||
避免重复生成积分消耗记录和重复扣减积分。
|
||||
"""
|
||||
|
||||
# 检查是否已有积分消耗记录
|
||||
existing_consume_record = None
|
||||
if source_id > 0:
|
||||
existing_consume_record = session.exec(
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.consume)
|
||||
.where(UserCreditsRecord.source_id == source_id)
|
||||
).first()
|
||||
|
||||
if existing_consume_record:
|
||||
# 如果新amount更大,需要额外消耗积分
|
||||
if amount > 0:
|
||||
existing_consume_record.amount -= amount
|
||||
session.add(existing_consume_record)
|
||||
# 直接处理额外的积分消耗,不生成新的消耗记录
|
||||
cls._consume_credits_internal_update(user_id, amount, session, source_id, remark)
|
||||
# 如果新amount更小,需要退还积分(这里可以根据业务需求决定是否实现)
|
||||
else:
|
||||
# 暂时不实现退还逻辑,可以根据需要添加
|
||||
pass
|
||||
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 没有现有记录,执行正常的积分消耗流程
|
||||
cls._consume_credits_internal(user_id, amount, session, source_id, remark)
|
||||
|
||||
@classmethod
|
||||
def _consume_credits_internal(
|
||||
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
|
||||
):
|
||||
"""
|
||||
内部积分消耗逻辑,处理实际的积分扣减
|
||||
"""
|
||||
from app.model.user.user import User
|
||||
|
||||
remain = amount
|
||||
now = datetime.now()
|
||||
consumed_from_daily = 0
|
||||
consumed_from_other = 0
|
||||
|
||||
# 优先消耗daily
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where(UserCreditsRecord.expire_at.is_not(None))
|
||||
.where(col(UserCreditsRecord.expire_at) > now)
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
daily_records = session.exec(statement).first()
|
||||
if daily_records:
|
||||
can_consume = daily_records.amount - daily_records.balance
|
||||
use = min(remain, can_consume)
|
||||
daily_records.balance += use
|
||||
session.add(daily_records)
|
||||
remain -= use
|
||||
consumed_from_daily = use
|
||||
if remain == 0:
|
||||
# 生成积分消耗记录
|
||||
consume_record = UserCreditsRecord(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
channel=CreditsChannel.consume,
|
||||
source_id=source_id,
|
||||
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily})",
|
||||
)
|
||||
session.add(consume_record)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 若daily不够,继续消耗monthly/paid/addon
|
||||
if remain > 0:
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(
|
||||
UserCreditsRecord.channel.in_(
|
||||
[
|
||||
CreditsChannel.monthly,
|
||||
CreditsChannel.paid,
|
||||
CreditsChannel.addon,
|
||||
CreditsChannel.register,
|
||||
CreditsChannel.invite,
|
||||
]
|
||||
)
|
||||
)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
other_records = session.exec(statement).all()
|
||||
for record in other_records:
|
||||
can_consume = record.amount - record.balance
|
||||
if can_consume <= 0:
|
||||
continue
|
||||
use = min(remain, can_consume)
|
||||
record.balance += use
|
||||
session.add(record)
|
||||
remain -= use
|
||||
consumed_from_other += use
|
||||
if remain == 0:
|
||||
break
|
||||
|
||||
# 更新用户积分字段(只扣除非每日积分消耗的部分)
|
||||
if consumed_from_other > 0:
|
||||
user = session.exec(select(User).where(User.id == user_id)).first()
|
||||
if user:
|
||||
user.credits -= consumed_from_other
|
||||
session.add(user)
|
||||
|
||||
# 生成积分消耗记录
|
||||
consume_record = UserCreditsRecord(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
channel=CreditsChannel.consume,
|
||||
source_id=source_id,
|
||||
remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily}, other: {consumed_from_other})",
|
||||
)
|
||||
session.add(consume_record)
|
||||
session.commit()
|
||||
|
||||
if remain > 0:
|
||||
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
|
||||
|
||||
@classmethod
|
||||
def _consume_credits_internal_update(
|
||||
cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""
|
||||
):
|
||||
"""
|
||||
内部积分消耗逻辑(更新模式),处理实际的积分扣减但不生成新的消耗记录
|
||||
用于更新现有消耗记录时的额外积分消耗
|
||||
"""
|
||||
from app.model.user.user import User
|
||||
|
||||
remain = amount
|
||||
now = datetime.now()
|
||||
consumed_from_daily = 0
|
||||
consumed_from_other = 0
|
||||
|
||||
# 优先消耗daily
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where(UserCreditsRecord.expire_at.is_not(None))
|
||||
.where(col(UserCreditsRecord.expire_at) > now)
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
daily_records = session.exec(statement).first()
|
||||
if daily_records:
|
||||
can_consume = daily_records.amount - daily_records.balance
|
||||
use = min(remain, can_consume)
|
||||
daily_records.balance += use
|
||||
session.add(daily_records)
|
||||
remain -= use
|
||||
consumed_from_daily = use
|
||||
if remain == 0:
|
||||
# 不生成新的消耗记录,只更新现有记录
|
||||
return
|
||||
|
||||
# 若daily不够,继续消耗monthly/paid/addon
|
||||
if remain > 0:
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(
|
||||
UserCreditsRecord.channel.in_(
|
||||
[
|
||||
CreditsChannel.monthly,
|
||||
CreditsChannel.paid,
|
||||
CreditsChannel.addon,
|
||||
CreditsChannel.register,
|
||||
CreditsChannel.invite,
|
||||
]
|
||||
)
|
||||
)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
.where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now))
|
||||
.order_by(UserCreditsRecord.expire_at)
|
||||
)
|
||||
other_records = session.exec(statement).all()
|
||||
for record in other_records:
|
||||
can_consume = record.amount - record.balance
|
||||
if can_consume <= 0:
|
||||
continue
|
||||
use = min(remain, can_consume)
|
||||
record.balance += use
|
||||
session.add(record)
|
||||
remain -= use
|
||||
consumed_from_other += use
|
||||
if remain == 0:
|
||||
break
|
||||
logger.info(f"consumed_from_other: {consumed_from_other}")
|
||||
# 更新用户积分字段(只扣除非每日积分消耗的部分)
|
||||
if consumed_from_other > 0:
|
||||
user = session.exec(select(User).where(User.id == user_id)).first()
|
||||
if user:
|
||||
user.credits -= consumed_from_other
|
||||
session.add(user)
|
||||
|
||||
# 不生成新的消耗记录,因为现有记录已经在主函数中更新了
|
||||
|
||||
if remain > 0:
|
||||
raise Exception(f"Insufficient credits: need {amount}, remain {remain}")
|
||||
|
||||
@classmethod
|
||||
def get_daily_balance_sum(cls, user_id: int) -> int:
|
||||
"""
|
||||
获取用户所有每日积分(daily channel)的balance字段之和
|
||||
"""
|
||||
session = session_make()
|
||||
statement = (
|
||||
select(UserCreditsRecord.balance)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
)
|
||||
balances = session.exec(statement).all()
|
||||
return sum(balances) if balances else 0
|
||||
|
||||
@classmethod
|
||||
def get_daily_balance(cls, user_id: int) -> int:
|
||||
"""
|
||||
获取用户当前的每日积分数据
|
||||
"""
|
||||
session = session_make()
|
||||
statement = (
|
||||
select(UserCreditsRecord)
|
||||
.where(UserCreditsRecord.user_id == user_id)
|
||||
.where(UserCreditsRecord.channel == CreditsChannel.daily)
|
||||
.where(UserCreditsRecord.used == False)
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
return record
|
||||
|
||||
|
||||
class UserCreditsRecordWithChatOut(BaseModel):
|
||||
"""扩展的积分记录输出模型,包含聊天历史信息"""
|
||||
|
||||
amount: int
|
||||
balance: int
|
||||
channel: CreditsChannel
|
||||
source_id: int
|
||||
expire_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
# 聊天历史相关字段(当channel为consume且source_id有效时)
|
||||
chat_project_name: Optional[str] = None
|
||||
chat_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class UserCreditsRecordOut(BaseModel):
|
||||
amount: int
|
||||
balance: int
|
||||
channel: CreditsChannel
|
||||
source_id: int
|
||||
remark: str
|
||||
expire_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class ConfigGroup(str, Enum):
|
|||
GITHUB = "Github"
|
||||
GOOGLE_CALENDAR = "Google Calendar"
|
||||
GOOGLE_DRIVE_MCP = "Google Drive MCP"
|
||||
GOOGLE_GMAIL_MCP = "Google Gmail MCP"
|
||||
GOOGLE_GMAIL_MCP = "Google Gmail"
|
||||
IMAGE_ANALYSIS = "Image Analysis"
|
||||
MCP_SEARCH = "MCP Search"
|
||||
PPTX = "PPTX"
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ services:
|
|||
# FastAPI Application
|
||||
api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
context: ..
|
||||
dockerfile: server/Dockerfile
|
||||
args:
|
||||
database_url: postgresql://postgres:123456@postgres:5432/eigent
|
||||
container_name: eigent_api
|
||||
|
|
|
|||
|
|
@ -1,30 +1,36 @@
|
|||
from app import api
|
||||
from app.component.environment import auto_include_routers, env
|
||||
from loguru import logger
|
||||
import os
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
prefix = env("url_prefix", "")
|
||||
auto_include_routers(api, prefix, "app/controller")
|
||||
public_dir = os.environ.get("PUBLIC_DIR") or os.path.join(os.path.dirname(__file__), "app", "public")
|
||||
# Ensure static directory exists or gracefully skip mounting
|
||||
if not os.path.isdir(public_dir):
|
||||
try:
|
||||
os.makedirs(public_dir, exist_ok=True)
|
||||
logger.warning(f"Public directory did not exist. Created: {public_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Public directory missing and could not be created: {public_dir}. Error: {e}")
|
||||
public_dir = None
|
||||
|
||||
if public_dir and os.path.isdir(public_dir):
|
||||
api.mount("/public", StaticFiles(directory=public_dir), name="public")
|
||||
else:
|
||||
logger.warning("Skipping /public mount because public directory is unavailable")
|
||||
|
||||
logger.add(
|
||||
"runtime/log/app.log",
|
||||
rotation="10 MB",
|
||||
retention="10 days",
|
||||
level="DEBUG",
|
||||
enqueue=True,
|
||||
)
|
||||
import os
|
||||
import sys
|
||||
import pathlib
|
||||
|
||||
# Add project root to Python path to import shared utils
|
||||
_project_root = pathlib.Path(__file__).parent.parent
|
||||
if str(_project_root) not in sys.path:
|
||||
sys.path.insert(0, str(_project_root))
|
||||
|
||||
from utils import traceroot_wrapper as traceroot
|
||||
from app import api
|
||||
from app.component.environment import auto_include_routers, env
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Only initialize traceroot if enabled
|
||||
if traceroot.is_enabled():
|
||||
from traceroot.integrations.fastapi import connect_fastapi
|
||||
connect_fastapi(api)
|
||||
|
||||
logger = traceroot.get_logger("server_main")
|
||||
|
||||
prefix = env("url_prefix", "")
|
||||
auto_include_routers(api, prefix, "app/controller")
|
||||
public_dir = os.environ.get("PUBLIC_DIR") or os.path.join(os.path.dirname(__file__), "app", "public")
|
||||
if not os.path.isdir(public_dir):
|
||||
try:
|
||||
os.makedirs(public_dir, exist_ok=True)
|
||||
logger.warning(f"Public directory did not exist. Created: {public_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Public directory missing and could not be created: {public_dir}. Error: {e}")
|
||||
public_dir = None
|
||||
|
||||
if public_dir and os.path.isdir(public_dir):
|
||||
api.mount("/public", StaticFiles(directory=public_dir), name="public")
|
||||
else:
|
||||
logger.warning("Skipping /public mount because public directory is unavailable")
|
||||
|
|
|
|||
|
|
@ -1,40 +1,41 @@
|
|||
[project]
|
||||
name = "Eigent"
|
||||
version = "0.1.0"
|
||||
description = "Eigent"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"alembic>=1.15.2",
|
||||
"click>=8.1.8",
|
||||
"fastapi>=0.115.12",
|
||||
"fastapi-babel>=1.0.0",
|
||||
"fastapi-pagination>=0.12.34",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"bcrypt==4.0.1",
|
||||
"pydantic-i18n>=0.4.5",
|
||||
"pydantic[email]>=2.11.1",
|
||||
"pyjwt>=2.10.1",
|
||||
"python-dotenv>=1.1.0",
|
||||
"sqlalchemy-utils>=0.41.2",
|
||||
"sqlmodel>=0.0.24",
|
||||
"pandas>=2.2.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"pandas>=2.2.3",
|
||||
"arrow>=1.3.0",
|
||||
"fastapi-filter>=2.0.1",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"convert-case>=1.2.3",
|
||||
"python-multipart>=0.0.20",
|
||||
"loguru>=0.7.3",
|
||||
"httpx>=0.28.1",
|
||||
"pydash>=8.0.5",
|
||||
"requests>=2.32.4",
|
||||
"itsdangerous>=2.2.0",
|
||||
"cryptography>=45.0.4",
|
||||
"sqids>=0.5.2",
|
||||
"exa-py>=1.14.16",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
[project]
|
||||
name = "Eigent"
|
||||
version = "0.1.0"
|
||||
description = "Eigent"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
"alembic>=1.15.2",
|
||||
"openai>=1.99.3,<2",
|
||||
"camel-ai==0.2.76a13",
|
||||
"pydantic[email]>=2.11.1",
|
||||
"click>=8.1.8",
|
||||
"fastapi>=0.115.12",
|
||||
"fastapi-babel>=1.0.0",
|
||||
"fastapi-pagination>=0.12.34",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"bcrypt==4.0.1",
|
||||
"pydantic-i18n>=0.4.5",
|
||||
"pyjwt>=2.10.1",
|
||||
"python-dotenv>=1.1.0",
|
||||
"sqlalchemy-utils>=0.41.2",
|
||||
"sqlmodel>=0.0.24",
|
||||
"pandas>=2.2.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"arrow>=1.3.0",
|
||||
"fastapi-filter>=2.0.1",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"convert-case>=1.2.3",
|
||||
"python-multipart>=0.0.20",
|
||||
"httpx>=0.28.1",
|
||||
"pydash>=8.0.5",
|
||||
"requests>=2.32.4",
|
||||
"itsdangerous>=2.2.0",
|
||||
"cryptography>=45.0.4",
|
||||
"sqids>=0.5.2",
|
||||
"exa-py>=1.14.16",
|
||||
"traceroot>=0.0.7",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
|
|
|||
1353
server/uv.lock
generated
1353
server/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -170,7 +170,6 @@ async function proxyFetchRequest(
|
|||
...customHeaders,
|
||||
}
|
||||
|
||||
console.debug('url', url, token)
|
||||
if (!url.includes('http://') && !url.includes('https://') && token) {
|
||||
headers['Authorization'] = `Bearer ${token}`
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue