From 712f20a8fa2c3ff0b8fdfa7f36863915cffa6258 Mon Sep 17 00:00:00 2001 From: Tong Chen Date: Tue, 24 Mar 2026 18:05:52 +0800 Subject: [PATCH] Feat: Server refactor v1 (#1509) --- .github/workflows/build-view.yml | 51 +- backend/app/controller/tool_controller.py | 9 +- backend/app/router.py | 5 + server/alembic/env.py | 28 +- ...26_02_06_0440-9464b9d89de7_feat_trigger.py | 12 +- server/app/__init__.py | 32 +- server/app/api/__init__.py | 15 + server/app/api/demo_controller.py | 27 + .../redirect_controller.py | 31 +- server/app/component/auth.py | 107 --- server/app/component/permission.py | 89 --- .../app/component/service/trigger/__init__.py | 54 -- .../service/trigger/trigger_service.py | 391 ---------- server/app/component/stack_auth.py | 74 -- server/app/component/time_friendly.py | 38 - .../app/controller/chat/history_controller.py | 454 ----------- .../app/controller/chat/share_controller.py | 136 ---- .../controller/chat/snapshot_controller.py | 195 ----- server/app/controller/chat/step_controller.py | 207 ----- .../controller/config/config_controller.py | 221 ------ server/app/controller/mcp/mcp_controller.py | 294 ------- server/app/controller/mcp/user_controller.py | 210 ----- .../app/controller/oauth/oauth_controller.py | 121 --- .../provider/provider_controller.py | 165 ---- .../controller/trigger/slack_controller.py | 135 ---- .../controller/trigger/trigger_controller.py | 731 ------------------ .../app/controller/user/login_controller.py | 226 ------ server/app/controller/user/user_controller.py | 168 ---- server/app/{component => core}/babel.py | 24 +- server/app/{component => core}/celery.py | 8 +- server/app/{component => core}/code.py | 24 +- server/app/{component => core}/database.py | 26 +- server/app/{component => core}/encrypt.py | 24 +- server/app/{component => core}/environment.py | 24 +- .../app/{component => core}/oauth_adapter.py | 2 +- .../app/{component => core}/pydantic/i18n.py | 26 +- .../pydantic/translations/en_US.json | 0 .../pydantic/translations/zh_CN.json | 0 server/app/{component => core}/redis_utils.py | 0 server/app/{component => core}/sqids.py | 24 +- .../app/{component => core}/trigger_utils.py | 2 +- .../validator/McpServer.py | 24 +- server/app/domains/__init__.py | 15 + server/app/domains/chat/__init__.py | 15 + .../trigger => domains/chat/api}/__init__.py | 0 .../domains/chat/api/history_controller.py | 158 ++++ .../app/domains/chat/api/share_controller.py | 91 +++ .../domains/chat/api/snapshot_controller.py | 122 +++ .../app/domains/chat/api/step_controller.py | 162 ++++ server/app/domains/chat/schema/__init__.py | 27 + .../chat/schema/schemas.py} | 21 +- .../chat/service}/__init__.py | 31 +- .../app/domains/chat/service/chat_service.py | 242 ++++++ server/app/domains/config/__init__.py | 15 + .../config/api}/__init__.py | 1 + .../domains/config/api/config_controller.py | 93 +++ server/app/domains/config/schema/__init__.py | 14 + server/app/domains/config/service/__init__.py | 17 + .../domains/config/service/config_service.py | 120 +++ server/app/domains/mcp/__init__.py | 15 + server/app/domains/mcp/api/__init__.py | 14 + .../mcp/api}/category_controller.py | 30 +- server/app/domains/mcp/api/mcp_controller.py | 101 +++ .../mcp/api}/proxy_controller.py | 120 +-- server/app/domains/mcp/api/user_controller.py | 99 +++ server/app/domains/mcp/schema/__init__.py | 14 + server/app/domains/mcp/service/__init__.py | 17 + .../domains/mcp/service/mcp_user_service.py | 175 +++++ server/app/domains/model_provider/__init__.py | 15 + .../domains/model_provider/api/__init__.py | 14 + .../model_provider/api/provider_controller.py | 91 +++ .../domains/model_provider/schema/__init__.py | 14 + .../model_provider/service/__init__.py | 17 + .../service/provider_service.py | 121 +++ server/app/domains/oauth/__init__.py | 14 + server/app/domains/oauth/api/__init__.py | 14 + .../app/domains/oauth/api/oauth_controller.py | 97 +++ server/app/domains/oauth/schema/__init__.py | 27 + server/app/domains/oauth/schema/schemas.py | 36 + server/app/domains/oauth/service/__init__.py | 17 + .../domains/oauth/service/oauth_service.py | 55 ++ server/app/domains/trigger/__init__.py | 15 + server/app/domains/trigger/api/__init__.py | 14 + .../domains/trigger/api/slack_controller.py | 57 ++ .../domains/trigger/api/trigger_controller.py | 225 ++++++ .../api}/trigger_execution_controller.py | 448 +++-------- .../trigger/api}/webhook_controller.py | 41 +- server/app/domains/trigger/schema/__init__.py | 14 + .../trigger/service}/__init__.py | 20 +- .../trigger/service}/app_handler_service.py | 6 +- .../domains/trigger/service/slack_service.py | 83 ++ .../trigger/service/trigger_crud_service.py | 594 ++++++++++++++ .../service}/trigger_schedule_service.py | 8 +- .../trigger/service}/trigger_schedule_task.py | 93 +-- .../trigger/service}/trigger_service.py | 11 +- server/app/domains/user/__init__.py | 15 + server/app/domains/user/api/__init__.py | 15 + .../app/domains/user/api/login_controller.py | 129 ++++ .../app/domains/user/api/logout_controller.py | 51 ++ .../user/api/password_controller.py} | 29 +- .../app/domains/user/api/user_controller.py | 116 +++ server/app/domains/user/schema/__init__.py | 33 + server/app/domains/user/schema/schemas.py | 52 ++ server/app/domains/user/service/__init__.py | 18 + .../domains/user/service/user_auth_service.py | 118 +++ server/app/exception/handler.py | 65 -- server/app/model/abstract/model.py | 30 +- server/app/model/chat/chat_snpshot.py | 34 +- server/app/model/chat/chat_step.py | 30 +- server/app/model/config/config.py | 2 +- server/app/model/mcp/mcp.py | 26 +- .../model/trigger/app_configs/base_config.py | 2 +- .../trigger/app_configs/config_registry.py | 2 +- .../model/trigger/app_configs/slack_config.py | 2 +- server/app/model/trigger/trigger.py | 2 +- server/app/model/trigger/trigger_execution.py | 2 +- server/app/model/user/key.py | 2 +- server/app/model/user/user.py | 2 +- server/app/model/user/user_credits_record.py | 2 +- .../service/trigger/app_handler_service.py | 448 ----------- .../trigger/trigger_schedule_service.py | 430 ----------- server/app/shared/__init__.py | 20 + server/app/shared/auth/__init__.py | 34 + server/app/shared/auth/admin_auth.py | 124 +++ server/app/shared/auth/ownership.py | 37 + server/app/shared/auth/token_blacklist.py | 63 ++ server/app/shared/auth/user_auth.py | 177 +++++ server/app/shared/context.py | 59 ++ .../exception/__init__.py} | 37 +- server/app/shared/exception/handlers.py | 65 ++ server/app/shared/http/__init__.py | 25 + server/app/shared/http/client.py | 92 +++ server/app/shared/logging/__init__.py | 29 + server/app/shared/logging/logging_utils.py | 56 ++ server/app/shared/middleware/__init__.py | 31 + server/app/shared/middleware/cors.py | 52 ++ server/app/shared/middleware/rate_limit.py | 43 ++ server/app/shared/middleware/trace.py | 49 ++ server/app/shared/redis_publish.py | 30 + server/app/shared/redis_sync.py | 37 + server/app/shared/types/__init__.py | 14 + .../{type => shared/types}/config_group.py | 0 server/app/{type => shared/types}/pydantic.py | 24 +- .../{type => shared/types}/trigger_types.py | 0 server/celery/beat/start | 2 +- server/celery/worker/start | 2 +- server/cli.py | 26 +- server/main.py | 74 +- server/pyproject.toml | 1 + server/start_server.sh | 33 +- .../app/{component => core}/test_encrypt.py | 2 +- server/tests/test_auth.py | 34 +- src/components/AddWorker/ToolSelect.tsx | 46 +- src/components/ChatBox/index.tsx | 10 +- src/components/Folder/index.tsx | 2 +- src/components/GroupedHistoryView/index.tsx | 4 +- src/components/HistorySidebar/index.tsx | 6 +- src/components/IntegrationList/index.tsx | 14 +- src/components/TopBar/index.tsx | 4 +- .../Trigger/DynamicTriggerConfig.tsx | 8 +- src/hooks/useExecutionSubscription.ts | 2 +- src/hooks/useIntegrationManagement.ts | 19 +- src/hooks/useTriggerTaskExecutor.ts | 2 +- src/lib/oauth.ts | 2 +- src/lib/share.ts | 2 +- src/pages/Agents/Models.tsx | 28 +- src/pages/Connectors/MCP.tsx | 70 +- src/pages/Connectors/Search.tsx | 16 +- src/pages/Connectors/components/MCPMarket.tsx | 40 +- src/pages/Login.tsx | 12 +- src/pages/Projects/Project.tsx | 10 +- src/pages/Setting/API.tsx | 6 +- src/pages/Setting/Privacy.tsx | 6 +- src/pages/SignUp.tsx | 12 +- src/routers/index.tsx | 2 +- src/service/historyApi.ts | 8 +- src/service/triggerApi.ts | 24 +- src/stack/client.ts | 2 +- src/store/chatStore.ts | 46 +- 179 files changed, 5593 insertions(+), 6063 deletions(-) create mode 100644 server/app/api/__init__.py create mode 100644 server/app/api/demo_controller.py rename server/app/{controller => api}/redirect_controller.py (91%) delete mode 100644 server/app/component/auth.py delete mode 100644 server/app/component/permission.py delete mode 100644 server/app/component/service/trigger/__init__.py delete mode 100644 server/app/component/service/trigger/trigger_service.py delete mode 100644 server/app/component/stack_auth.py delete mode 100644 server/app/component/time_friendly.py delete mode 100644 server/app/controller/chat/history_controller.py delete mode 100644 server/app/controller/chat/share_controller.py delete mode 100644 server/app/controller/chat/snapshot_controller.py delete mode 100644 server/app/controller/chat/step_controller.py delete mode 100644 server/app/controller/config/config_controller.py delete mode 100644 server/app/controller/mcp/mcp_controller.py delete mode 100644 server/app/controller/mcp/user_controller.py delete mode 100644 server/app/controller/oauth/oauth_controller.py delete mode 100644 server/app/controller/provider/provider_controller.py delete mode 100644 server/app/controller/trigger/slack_controller.py delete mode 100644 server/app/controller/trigger/trigger_controller.py delete mode 100644 server/app/controller/user/login_controller.py delete mode 100644 server/app/controller/user/user_controller.py rename server/app/{component => core}/babel.py (98%) rename server/app/{component => core}/celery.py (85%) rename server/app/{component => core}/code.py (98%) rename server/app/{component => core}/database.py (96%) rename server/app/{component => core}/encrypt.py (98%) rename server/app/{component => core}/environment.py (99%) rename server/app/{component => core}/oauth_adapter.py (99%) rename server/app/{component => core}/pydantic/i18n.py (97%) rename server/app/{component => core}/pydantic/translations/en_US.json (100%) rename server/app/{component => core}/pydantic/translations/zh_CN.json (100%) rename server/app/{component => core}/redis_utils.py (100%) rename server/app/{component => core}/sqids.py (98%) rename server/app/{component => core}/trigger_utils.py (98%) rename server/app/{component => core}/validator/McpServer.py (99%) create mode 100644 server/app/domains/__init__.py create mode 100644 server/app/domains/chat/__init__.py rename server/app/{controller/trigger => domains/chat/api}/__init__.py (100%) create mode 100644 server/app/domains/chat/api/history_controller.py create mode 100644 server/app/domains/chat/api/share_controller.py create mode 100644 server/app/domains/chat/api/snapshot_controller.py create mode 100644 server/app/domains/chat/api/step_controller.py create mode 100644 server/app/domains/chat/schema/__init__.py rename server/app/{controller/health_controller.py => domains/chat/schema/schemas.py} (66%) rename server/app/{middleware => domains/chat/service}/__init__.py (79%) create mode 100644 server/app/domains/chat/service/chat_service.py create mode 100644 server/app/domains/config/__init__.py rename server/app/{controller => domains/config/api}/__init__.py (99%) create mode 100644 server/app/domains/config/api/config_controller.py create mode 100644 server/app/domains/config/schema/__init__.py create mode 100644 server/app/domains/config/service/__init__.py create mode 100644 server/app/domains/config/service/config_service.py create mode 100644 server/app/domains/mcp/__init__.py create mode 100644 server/app/domains/mcp/api/__init__.py rename server/app/{controller/mcp => domains/mcp/api}/category_controller.py (89%) create mode 100644 server/app/domains/mcp/api/mcp_controller.py rename server/app/{controller/mcp => domains/mcp/api}/proxy_controller.py (61%) create mode 100644 server/app/domains/mcp/api/user_controller.py create mode 100644 server/app/domains/mcp/schema/__init__.py create mode 100644 server/app/domains/mcp/service/__init__.py create mode 100644 server/app/domains/mcp/service/mcp_user_service.py create mode 100644 server/app/domains/model_provider/__init__.py create mode 100644 server/app/domains/model_provider/api/__init__.py create mode 100644 server/app/domains/model_provider/api/provider_controller.py create mode 100644 server/app/domains/model_provider/schema/__init__.py create mode 100644 server/app/domains/model_provider/service/__init__.py create mode 100644 server/app/domains/model_provider/service/provider_service.py create mode 100644 server/app/domains/oauth/__init__.py create mode 100644 server/app/domains/oauth/api/__init__.py create mode 100644 server/app/domains/oauth/api/oauth_controller.py create mode 100644 server/app/domains/oauth/schema/__init__.py create mode 100644 server/app/domains/oauth/schema/schemas.py create mode 100644 server/app/domains/oauth/service/__init__.py create mode 100644 server/app/domains/oauth/service/oauth_service.py create mode 100644 server/app/domains/trigger/__init__.py create mode 100644 server/app/domains/trigger/api/__init__.py create mode 100644 server/app/domains/trigger/api/slack_controller.py create mode 100644 server/app/domains/trigger/api/trigger_controller.py rename server/app/{controller/trigger => domains/trigger/api}/trigger_execution_controller.py (55%) rename server/app/{controller/trigger => domains/trigger/api}/webhook_controller.py (92%) create mode 100644 server/app/domains/trigger/schema/__init__.py rename server/app/{service/trigger => domains/trigger/service}/__init__.py (69%) rename server/app/{component/service/trigger => domains/trigger/service}/app_handler_service.py (98%) create mode 100644 server/app/domains/trigger/service/slack_service.py create mode 100644 server/app/domains/trigger/service/trigger_crud_service.py rename server/app/{component/service/trigger => domains/trigger/service}/trigger_schedule_service.py (98%) rename server/app/{schedule => domains/trigger/service}/trigger_schedule_task.py (70%) rename server/app/{service/trigger => domains/trigger/service}/trigger_service.py (97%) create mode 100644 server/app/domains/user/__init__.py create mode 100644 server/app/domains/user/api/__init__.py create mode 100644 server/app/domains/user/api/login_controller.py create mode 100644 server/app/domains/user/api/logout_controller.py rename server/app/{controller/user/user_password_controller.py => domains/user/api/password_controller.py} (61%) create mode 100644 server/app/domains/user/api/user_controller.py create mode 100644 server/app/domains/user/schema/__init__.py create mode 100644 server/app/domains/user/schema/schemas.py create mode 100644 server/app/domains/user/service/__init__.py create mode 100644 server/app/domains/user/service/user_auth_service.py delete mode 100644 server/app/exception/handler.py delete mode 100644 server/app/service/trigger/app_handler_service.py delete mode 100644 server/app/service/trigger/trigger_schedule_service.py create mode 100644 server/app/shared/__init__.py create mode 100644 server/app/shared/auth/__init__.py create mode 100644 server/app/shared/auth/admin_auth.py create mode 100644 server/app/shared/auth/ownership.py create mode 100644 server/app/shared/auth/token_blacklist.py create mode 100644 server/app/shared/auth/user_auth.py create mode 100644 server/app/shared/context.py rename server/app/{exception/exception.py => shared/exception/__init__.py} (81%) create mode 100644 server/app/shared/exception/handlers.py create mode 100644 server/app/shared/http/__init__.py create mode 100644 server/app/shared/http/client.py create mode 100644 server/app/shared/logging/__init__.py create mode 100644 server/app/shared/logging/logging_utils.py create mode 100644 server/app/shared/middleware/__init__.py create mode 100644 server/app/shared/middleware/cors.py create mode 100644 server/app/shared/middleware/rate_limit.py create mode 100644 server/app/shared/middleware/trace.py create mode 100644 server/app/shared/redis_publish.py create mode 100644 server/app/shared/redis_sync.py create mode 100644 server/app/shared/types/__init__.py rename server/app/{type => shared/types}/config_group.py (100%) rename server/app/{type => shared/types}/pydantic.py (98%) rename server/app/{type => shared/types}/trigger_types.py (100%) rename server/tests/app/{component => core}/test_encrypt.py (97%) diff --git a/.github/workflows/build-view.yml b/.github/workflows/build-view.yml index f6b44cd1..001a4f52 100644 --- a/.github/workflows/build-view.yml +++ b/.github/workflows/build-view.yml @@ -12,6 +12,7 @@ jobs: timeout-minutes: 120 strategy: + fail-fast: false matrix: include: - os: macos-latest @@ -39,6 +40,22 @@ jobs: run: | rm -rf node_modules rm -rf release + rm -rf dist out build .vite + rm -rf node_modules/.cache || true + + # Clean build outputs on GitHub-hosted runners to avoid stale artifacts in current job + - name: Clean build outputs (non-Windows) + if: "!contains(matrix.os, 'self-hosted') && runner.os != 'Windows'" + run: | + rm -rf release dist out .vite + rm -rf node_modules/.cache || true + + - name: Clean build outputs (Windows) + if: "!contains(matrix.os, 'self-hosted') && runner.os == 'Windows'" + shell: pwsh + run: | + Remove-Item -Recurse -Force release, dist, out, .vite -ErrorAction SilentlyContinue + Remove-Item -Recurse -Force node_modules/.cache -ErrorAction SilentlyContinue - name: Setup Node.js if: "!contains(matrix.os, 'self-hosted')" @@ -92,6 +109,20 @@ jobs: echo "LLVM_DIR=$(brew --prefix llvm@20)/lib/cmake/llvm" >> $GITHUB_ENV echo "CMAKE_PREFIX_PATH=$(brew --prefix llvm@20)/lib/cmake/llvm" >> $GITHUB_ENV + # Prebuild separately on macOS so signing/package issues are isolated + - name: Build Release Files (macOS prebuild) + if: runner.os == 'macOS' + timeout-minutes: 45 + run: | + npm run prebuild + env: + VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }} + VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }} + VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }} + VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }} + VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }} + USE_NPM_INSTALL_BUN: 'true' + # Step for macOS builds with signing - name: Build Release Files (macOS with signing) if: runner.os == 'macOS' @@ -104,8 +135,20 @@ jobs: fi ulimit -n 65536 2>/dev/null || ulimit -n 10240 2>/dev/null || true echo "File descriptor limit: $(ulimit -n) (hard: $(ulimit -Hn 2>/dev/null || echo 'N/A'))" - npm run prebuild - npx electron-builder --mac --${{ matrix.arch }} --publish never + + set +e + npx electron-builder --mac dmg --${{ matrix.arch }} --publish never + BUILD_EXIT=$? + + if [ $BUILD_EXIT -ne 0 ]; then + echo "First attempt failed with exit code $BUILD_EXIT" + echo "Retrying once in 5 seconds..." + sleep 5 + npx electron-builder --mac dmg --${{ matrix.arch }} --publish never + BUILD_EXIT=$? + fi + + exit $BUILD_EXIT env: CSC_LINK: ${{ secrets.CERT_P12 }} CSC_KEY_PASSWORD: ${{ secrets.CERT_PASSWORD }} @@ -113,6 +156,7 @@ jobs: APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }} APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }} + VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }} VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }} VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }} VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }} @@ -127,6 +171,7 @@ jobs: npx electron-builder --win --${{ matrix.arch }} --publish never env: VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }} + VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }} VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }} VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }} VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }} @@ -141,6 +186,7 @@ jobs: npx electron-builder --linux --${{ matrix.arch }} --publish never env: VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }} + VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }} VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }} VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }} VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }} @@ -195,6 +241,7 @@ jobs: path: | release/*.AppImage retention-days: 5 + merge-release: needs: build runs-on: ubuntu-latest diff --git a/backend/app/controller/tool_controller.py b/backend/app/controller/tool_controller.py index f4370627..89fbac87 100644 --- a/backend/app/controller/tool_controller.py +++ b/backend/app/controller/tool_controller.py @@ -15,6 +15,7 @@ import logging import os import shutil +import threading import time from fastapi import APIRouter, HTTPException @@ -779,15 +780,13 @@ async def open_browser_login(): bufsize=1, # Line buffered ) - # Create async task to log Electron output - async def log_electron_output(): + def log_electron_output(): for line in iter(process.stdout.readline, ""): if line: logger.info(f"[ELECTRON OUTPUT] {line.strip()}") - import asyncio - - asyncio.create_task(log_electron_output()) + log_thread = threading.Thread(target=log_electron_output, daemon=True) + log_thread.start() # Wait a bit for Electron to start import asyncio diff --git a/backend/app/router.py b/backend/app/router.py index f257dde0..2909eaae 100644 --- a/backend/app/router.py +++ b/backend/app/router.py @@ -73,6 +73,11 @@ def register_routers(app: FastAPI, prefix: str = "") -> None: }, ] + app.include_router(health_controller.router, tags=["Health"]) + logger.info( + "Registered Health router at root level for Docker health checks" + ) + for config in routers_config: app.include_router( config["router"], prefix=prefix, tags=config["tags"] diff --git a/server/alembic/env.py b/server/alembic/env.py index 4c7ae10b..1157f975 100644 --- a/server/alembic/env.py +++ b/server/alembic/env.py @@ -1,16 +1,16 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import pathlib import sys @@ -25,7 +25,7 @@ from sqlalchemy import engine_from_config, pool from sqlmodel import SQLModel from alembic import context -from app.component.environment import auto_import, env_not_empty +from app.core.environment import auto_import, env_not_empty # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/server/alembic/versions/2026_02_06_0440-9464b9d89de7_feat_trigger.py b/server/alembic/versions/2026_02_06_0440-9464b9d89de7_feat_trigger.py index ed1d0d87..f65b36ec 100644 --- a/server/alembic/versions/2026_02_06_0440-9464b9d89de7_feat_trigger.py +++ b/server/alembic/versions/2026_02_06_0440-9464b9d89de7_feat_trigger.py @@ -24,12 +24,12 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes -from app.type.trigger_types import ExecutionStatus -from app.type.trigger_types import ExecutionType -from app.type.trigger_types import ListenerType -from app.type.trigger_types import RequestType -from app.type.trigger_types import TriggerStatus -from app.type.trigger_types import TriggerType +from app.shared.types.trigger_types import ExecutionStatus +from app.shared.types.trigger_types import ExecutionType +from app.shared.types.trigger_types import ListenerType +from app.shared.types.trigger_types import RequestType +from app.shared.types.trigger_types import TriggerStatus +from app.shared.types.trigger_types import TriggerType from sqlalchemy_utils.types import ChoiceType # revision identifiers, used by Alembic. diff --git a/server/app/__init__.py b/server/app/__init__.py index 0cda6dd9..2a2efae7 100644 --- a/server/app/__init__.py +++ b/server/app/__init__.py @@ -1,22 +1,22 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from contextlib import asynccontextmanager -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from fastapi_pagination import add_pagination from fastapi_limiter import FastAPILimiter -from app.component.environment import env_or_fail +from app.core.environment import env_or_fail from redis import asyncio as aioredis import logging @@ -43,3 +43,5 @@ api = FastAPI( lifespan=lifespan ) add_pagination(api) + +router = APIRouter() diff --git a/server/app/api/__init__.py b/server/app/api/__init__.py new file mode 100644 index 00000000..a030087a --- /dev/null +++ b/server/app/api/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Non-domain v1 API endpoints.""" diff --git a/server/app/api/demo_controller.py b/server/app/api/demo_controller.py new file mode 100644 index 00000000..e8001d59 --- /dev/null +++ b/server/app/api/demo_controller.py @@ -0,0 +1,27 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Demo endpoint - uses v1 auth.""" + +from fastapi import APIRouter, Depends +from fastapi_babel import _ + +from app.shared.auth import auth_optional + +router = APIRouter(tags=["demo"]) + + +@router.get("/demo") +def get(user=Depends(auth_optional)): + return {"message": user.id if user else _("no auth"), "content": _("hello")} diff --git a/server/app/controller/redirect_controller.py b/server/app/api/redirect_controller.py similarity index 91% rename from server/app/controller/redirect_controller.py rename to server/app/api/redirect_controller.py index f11e740a..7655ed5c 100644 --- a/server/app/controller/redirect_controller.py +++ b/server/app/api/redirect_controller.py @@ -1,16 +1,18 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Redirect controller - H11 XSS fix with json.dumps + encodeURIComponent.""" import json @@ -24,6 +26,7 @@ router = APIRouter(tags=["Redirect"]) def redirect_callback(code: str, request: Request): cookies = request.cookies cookies_json = json.dumps(cookies) + safe_code = json.dumps(code) html_content = f""" @@ -72,10 +75,10 @@ def redirect_callback(code: str, request: Request): diff --git a/server/app/component/auth.py b/server/app/component/auth.py deleted file mode 100644 index 9a91792e..00000000 --- a/server/app/component/auth.py +++ /dev/null @@ -1,107 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from datetime import datetime, timedelta - -import jwt -from fastapi import Depends, Header -from fastapi.security import OAuth2PasswordBearer -from fastapi_babel import _ -from jwt.exceptions import InvalidTokenError -from sqlmodel import Session, select - -from app.component import code -from app.component.database import session -from app.component.environment import env, env_not_empty -from app.exception.exception import ( - NoPermissionException, - TokenException, -) -from app.model.mcp.proxy import ApiKey -from app.model.user.key import Key -from app.model.user.user import User - - -class Auth: - SECRET_KEY = env_not_empty("secret_key") - - def __init__(self, id: int, expired_at: datetime): - self.id = id - self.expired_at = expired_at - self._user: User | None = None - - @property - def user(self): - if self._user is None: - raise NoPermissionException("未查询到登录用户") - return self._user - - @classmethod - def decode_token(cls, token: str): - try: - payload = jwt.decode(token, Auth.SECRET_KEY, algorithms=["HS256"]) - id = payload["id"] - if payload["exp"] < int(datetime.now().timestamp()): - raise TokenException(code.token_expired, _("Validate credentials expired")) - except InvalidTokenError: - raise TokenException(code.token_invalid, _("Could not validate credentials")) - return Auth(id, payload["exp"]) - - @classmethod - def create_access_token(cls, user_id: int, expires_delta: timedelta | None = None): - to_encode: dict = {"id": user_id} - if expires_delta: - expire = datetime.now() + expires_delta - else: - expire = datetime.now() + timedelta(days=30) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, Auth.SECRET_KEY, algorithm="HS256") - return encoded_jwt - - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{env('url_prefix', '')}/dev_login", auto_error=False) - - -async def auth( - token: str | None = Depends(oauth2_scheme), - session: Session = Depends(session), -) -> Auth | None: - if token is None: - return None - try: - model = Auth.decode_token(token) - user = session.get(User, model.id) - model._user = user - return model - except Exception: - return None - - -async def auth_must( - token: str | None = Depends(oauth2_scheme), - session: Session = Depends(session), -) -> Auth: - if token is None: - raise TokenException(code.token_invalid, _("Authentication required")) - model = Auth.decode_token(token) - user = session.get(User, model.id) - model._user = user - return model - - -async def key_must(headers: ApiKey = Header(), session: Session = Depends(session)): - model = session.exec(select(Key).where(Key.value == headers.api_key)).one_or_none() - if model is None: - raise TokenException(code.token_invalid, _(f"Could not validate key credentials: {headers.api_key}")) - return model diff --git a/server/app/component/permission.py b/server/app/component/permission.py deleted file mode 100644 index 001095fc..00000000 --- a/server/app/component/permission.py +++ /dev/null @@ -1,89 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from fastapi_babel import _ - -""" -权限定义: -当存在子权限的时候,父权限则不生效,应该全部放至子权限中定义处理 -""" - - -def permissions(): - return [ - { - "name": _("User"), - "description": _("User manager"), - "children": [ - { - "identity": "user:view", - "name": _("User Manage"), - "description": _("View users"), - }, - { - "identity": "user:edit", - "name": _("User Edit"), - "description": _("Manage users"), # 修改用户信息,邀请用户(限本组织下) - }, - ], - }, - { - "name": _("Admin"), - "description": _("Admin manager"), - "children": [ - { - "identity": "admin:view", - "name": _("Admin View"), - "description": _("View admins"), # 修改项目,工作区,角色,用户 - }, - { - "identity": "admin:edit", - "name": _("Admin Edit"), - "description": _("Edit admins"), - }, - ], - }, - { - "name": _("Role"), - "description": _("Role manager"), - "children": [ - { - "identity": "role:view", - "name": _("Role View"), - "description": _("View roles"), # 修改项目和工作区中的角色,创建新的角色 - }, - { - "identity": "role:edit", - "name": _("Role Edit"), - "description": _("Edit roles"), # 修改角色 - }, - ], - }, - { - "name": _("Mcp"), - "description": _("Mcp manager"), - "children": [ - { - "identity": "mcp:edit", - "name": _("Mcp Edit"), - "description": _("Edit mcp service"), - }, - { - "identity": "mcp-category:edit", - "name": _("Mcp Category Edit"), - "description": _("Edit mcp category"), - }, - ], - }, - ] diff --git a/server/app/component/service/trigger/__init__.py b/server/app/component/service/trigger/__init__.py deleted file mode 100644 index f8bbaca2..00000000 --- a/server/app/component/service/trigger/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -""" -Trigger Service Package - -Contains services for managing triggers including: -- TriggerService: Main service for trigger operations -- TriggerScheduleService: Service for scheduled trigger operations -- App Handlers: Handlers for different trigger types (Slack, Webhook, Schedule) -""" - -from app.service.trigger.trigger_service import TriggerService, get_trigger_service -from app.service.trigger.trigger_schedule_service import TriggerScheduleService -from app.service.trigger.app_handler_service import ( - BaseAppHandler, - SlackAppHandler, - DefaultWebhookHandler, - ScheduleAppHandler, - AppHandlerResult, - get_app_handler, - get_schedule_handler, - register_app_handler, - get_supported_trigger_types, -) - -__all__ = [ - # Services - "TriggerService", - "get_trigger_service", - "TriggerScheduleService", - # Handlers - "BaseAppHandler", - "SlackAppHandler", - "DefaultWebhookHandler", - "ScheduleAppHandler", - "AppHandlerResult", - # Handler functions - "get_app_handler", - "get_schedule_handler", - "register_app_handler", - "get_supported_trigger_types", -] \ No newline at end of file diff --git a/server/app/component/service/trigger/trigger_service.py b/server/app/component/service/trigger/trigger_service.py deleted file mode 100644 index 8a6559ff..00000000 --- a/server/app/component/service/trigger/trigger_service.py +++ /dev/null @@ -1,391 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from datetime import datetime, timedelta, timezone -from typing import Optional, List, Dict, Any -from sqlmodel import select, and_, or_ -from uuid import uuid4 -import logging - -from app.model.trigger.trigger import Trigger -from app.model.trigger.trigger_execution import TriggerExecution -from app.type.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus -from app.component.database import session_make -from app.service.trigger.trigger_schedule_service import TriggerScheduleService -from app.component.trigger_utils import SCHEDULED_FETCH_BATCH_SIZE, check_rate_limits -from app.model.trigger.app_configs import ScheduleTriggerConfig, WebhookTriggerConfig -from app.model.trigger.app_configs.base_config import BaseTriggerConfig - - - -class TriggerService: - """Service for managing trigger operations and scheduling.""" - - def __init__(self, session=None): - self.session = session or session_make() - self.schedule_service = TriggerScheduleService(self.session) - - def create_execution( - self, - trigger: Trigger, - execution_type: ExecutionType, - input_data: Optional[Dict[str, Any]] = None - ) -> TriggerExecution: - """Create a new trigger execution.""" - execution_id = str(uuid4()) - - execution = TriggerExecution( - trigger_id=trigger.id, - execution_id=execution_id, - execution_type=execution_type, - status=ExecutionStatus.pending, - input_data=input_data or {}, - started_at=datetime.now(timezone.utc) - ) - - self.session.add(execution) - self.session.commit() - self.session.refresh(execution) - - # Update trigger statistics - trigger.last_executed_at = datetime.now(timezone.utc) - trigger.last_execution_status = "pending" - self.session.add(trigger) - self.session.commit() - - logger.info("Execution created", extra={ - "trigger_id": trigger.id, - "execution_id": execution_id, - "execution_type": execution_type.value - }) - - return execution - - def update_execution_status( - self, - execution: TriggerExecution, - status: ExecutionStatus, - output_data: Optional[Dict[str, Any]] = None, - error_message: Optional[str] = None, - tokens_used: Optional[int] = None, - tools_executed: Optional[Dict[str, Any]] = None - ) -> TriggerExecution: - """Update execution status and metadata.""" - execution.status = status - - # Set completed_at and duration for terminal statuses - if status in [ExecutionStatus.completed, ExecutionStatus.failed, ExecutionStatus.cancelled, ExecutionStatus.missed]: - execution.completed_at = datetime.now(timezone.utc) - if execution.started_at: - # Ensure started_at is timezone-aware for subtraction - started_at = execution.started_at - if started_at.tzinfo is None: - started_at = started_at.replace(tzinfo=timezone.utc) - execution.duration_seconds = (execution.completed_at - started_at).total_seconds() - - if output_data: - execution.output_data = output_data - - if error_message: - execution.error_message = error_message - - if tokens_used: - execution.tokens_used = tokens_used - - if tools_executed: - execution.tools_executed = tools_executed - - self.session.add(execution) - self.session.commit() - - # Update trigger status and handle auto-disable logic - trigger = self.session.get(Trigger, execution.trigger_id) - if trigger: - if status == ExecutionStatus.failed: - trigger.last_execution_status = "failed" - trigger.consecutive_failures += 1 - - # Check for auto-disable based on max_failure_count in config - self._check_auto_disable(trigger) - - elif status == ExecutionStatus.completed: - trigger.last_execution_status = "completed" - # Reset consecutive failures on success - trigger.consecutive_failures = 0 - elif status == ExecutionStatus.cancelled: - trigger.last_execution_status = "cancelled" - elif status == ExecutionStatus.missed: - trigger.last_execution_status = "missed" - - self.session.add(trigger) - self.session.commit() - - logger.info("Execution status updated", extra={ - "execution_id": execution.execution_id, - "status": status.name, - "duration": execution.duration_seconds - }) - - return execution - - def _check_auto_disable(self, trigger: Trigger) -> bool: - """ - Check if trigger should be auto-disabled based on consecutive failures. - - Args: - trigger: The trigger to check - - Returns: - True if trigger was auto-disabled, False otherwise - """ - if not trigger.config: - return False - - try: - # Get the appropriate config class based on trigger type - config: BaseTriggerConfig - if trigger.trigger_type == TriggerType.schedule: - config = ScheduleTriggerConfig(**trigger.config) - elif trigger.trigger_type == TriggerType.webhook: - config = WebhookTriggerConfig(**trigger.config) - else: - # For other trigger types, use base config - config = BaseTriggerConfig(**trigger.config) - - # Check if auto-disable should happen - if config.should_auto_disable(trigger.consecutive_failures): - trigger.status = TriggerStatus.inactive - trigger.auto_disabled_at = datetime.now(timezone.utc) - - logger.warning( - "Trigger auto-disabled due to max failures", - extra={ - "trigger_id": trigger.id, - "trigger_name": trigger.name, - "consecutive_failures": trigger.consecutive_failures, - "max_failure_count": config.max_failure_count - } - ) - return True - - except Exception as e: - logger.error( - "Failed to check auto-disable for trigger", - extra={ - "trigger_id": trigger.id, - "error": str(e) - } - ) - - return False - - def get_pending_executions(self) -> List[TriggerExecution]: - """Get all pending executions that need to be processed.""" - executions = self.session.exec( - select(TriggerExecution).where( - TriggerExecution.status == ExecutionStatus.pending - ).order_by(TriggerExecution.created_at) - ).all() - - return list(executions) - - def get_failed_executions_for_retry(self) -> List[TriggerExecution]: - """Get failed executions that can be retried.""" - executions = self.session.exec( - select(TriggerExecution).where( - and_( - TriggerExecution.status == ExecutionStatus.failed, - TriggerExecution.attempts < TriggerExecution.max_retries - ) - ).order_by(TriggerExecution.created_at) - ).all() - - return list(executions) - - def get_due_scheduled_triggers(self, limit: Optional[int] = None) -> List[Trigger]: - """ - Fetch scheduled triggers that are due for execution. - - Args: - limit: Maximum number of triggers to fetch (defaults to SCHEDULED_FETCH_BATCH_SIZE) - - Returns: - List of triggers that are due for execution - """ - current_time = datetime.now(timezone.utc) - limit = limit or SCHEDULED_FETCH_BATCH_SIZE - - # Query triggers that: - # 1. Are scheduled type - # 2. Are active - # 3. Have a cron expression - # 4. next_run_at is null (never run) or next_run_at <= now - triggers = self.session.exec( - select(Trigger) - .where( - and_( - Trigger.trigger_type == TriggerType.schedule, - Trigger.status == TriggerStatus.active, - Trigger.custom_cron_expression.is_not(None), - or_( - Trigger.next_run_at.is_(None), - Trigger.next_run_at <= current_time - ) - ) - ) - .limit(limit) - ).all() - - return list(triggers) - - def execute_scheduled_triggers(self) -> int: - """ - Execute all due scheduled triggers. - Uses TriggerScheduleService for the actual execution logic. - """ - due_triggers = self.get_due_scheduled_triggers() - - if not due_triggers: - return 0 - - dispatched_count, rate_limited_count = self.schedule_service.process_schedules(due_triggers) - - logger.info( - "Scheduled triggers execution completed", - extra={ - "dispatched": dispatched_count, - "rate_limited": rate_limited_count - } - ) - - return dispatched_count - - def process_slack_trigger( - self, - trigger: Trigger, - slack_data: Dict[str, Any] - ) -> Optional[TriggerExecution]: - """Process a Slack trigger event.""" - if trigger.trigger_type != TriggerType.slack_trigger: - raise ValueError("Trigger is not a Slack trigger") - - if trigger.status != TriggerStatus.active: - logger.warning("Slack trigger is not active", extra={ - "trigger_id": trigger.id - }) - return None - - if not check_rate_limits(self.session, trigger): - logger.warning("Slack trigger execution skipped due to rate limits", extra={ - "trigger_id": trigger.id - }) - return None - - try: - execution = self.create_execution( - trigger=trigger, - execution_type=ExecutionType.slack, - input_data=slack_data - ) - - # TODO: Queue the actual task execution - - logger.info("Slack trigger executed", extra={ - "trigger_id": trigger.id, - "execution_id": execution.execution_id - }) - - return execution - - except Exception as e: - logger.error("Slack trigger execution failed", extra={ - "trigger_id": trigger.id, - "error": str(e) - }, exc_info=True) - return None - - def cleanup_old_executions(self, days_to_keep: int = 30) -> int: - """Clean up old execution records.""" - cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep) - - old_executions = self.session.exec( - select(TriggerExecution).where( - and_( - TriggerExecution.created_at < cutoff_date, - TriggerExecution.status.in_([ - ExecutionStatus.completed, - ExecutionStatus.failed, - ExecutionStatus.cancelled - ]) - ) - ) - ).all() - - count = len(old_executions) - - for execution in old_executions: - self.session.delete(execution) - - self.session.commit() - - logger.info("Old executions cleaned up", extra={ - "count": count, - "days_to_keep": days_to_keep - }) - - return count - - def get_trigger_statistics(self, trigger_id: int) -> Dict[str, Any]: - """Get statistics for a specific trigger.""" - trigger = self.session.get(Trigger, trigger_id) - if not trigger: - raise ValueError("Trigger not found") - - # Get execution counts by status - executions = self.session.exec( - select(TriggerExecution).where( - TriggerExecution.trigger_id == trigger_id - ) - ).all() - - stats = { - "trigger_id": trigger_id, - "name": trigger.name, - "trigger_type": trigger.trigger_type.value, - "status": trigger.status.name, - "total_executions": len(executions), - "successful_executions": len([e for e in executions if e.status == ExecutionStatus.completed]), - "failed_executions": len([e for e in executions if e.status == ExecutionStatus.failed]), - "pending_executions": len([e for e in executions if e.status == ExecutionStatus.pending]), - "cancelled_executions": len([e for e in executions if e.status == ExecutionStatus.cancelled]), - "last_executed_at": trigger.last_executed_at.isoformat() if trigger.last_executed_at else None, - "created_at": trigger.created_at.isoformat() if trigger.created_at else None - } - - # Calculate average execution time for completed executions - completed_executions = [e for e in executions if e.status == ExecutionStatus.completed and e.duration_seconds] - if completed_executions: - avg_duration = sum(e.duration_seconds for e in completed_executions) / len(completed_executions) - stats["average_execution_time_seconds"] = round(avg_duration, 2) - - # Calculate total tokens used - total_tokens = sum(e.tokens_used for e in executions if e.tokens_used) - if total_tokens: - stats["total_tokens_used"] = total_tokens - - return stats - -def get_trigger_service(session=None) -> TriggerService: - """Factory function to create a TriggerService instance with a fresh session.""" - return TriggerService(session) \ No newline at end of file diff --git a/server/app/component/stack_auth.py b/server/app/component/stack_auth.py deleted file mode 100644 index 60d87caf..00000000 --- a/server/app/component/stack_auth.py +++ /dev/null @@ -1,74 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import asyncio - -import httpx -import jwt - -from app.component import code -from app.component.environment import env_not_empty -from app.exception.exception import UserException - - -class StackAuth: - _signing_key_cache = {} - - @staticmethod - async def user_id(token: str): - header = jwt.get_unverified_header(token) - kid = header.get("kid") - if not kid: - raise jwt.InvalidTokenError("Token is missing 'kid' in header") - - signed = await StackAuth.stack_signing_key(kid) - payload = jwt.decode( - token, - signed.key, - algorithms=["ES256"], - audience=env_not_empty("stack_project_id"), - # issuer="https://access-token.jwt-signature.stack-auth.com", - ) - return payload["sub"] - - @staticmethod - async def user_info(token: str): - headers = { - "X-Stack-Access-Type": "server", - "X-Stack-Project-Id": env_not_empty("stack_project_id"), - "X-Stack-Secret-Server-Key": env_not_empty("stack_secret_server_key"), - "X-Stack-Access-Token": token, - } - url = "https://api.stack-auth.com/api/v1/users/me" - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=headers) - return response.json() - - @staticmethod - async def stack_signing_key(kid: str): - if kid in StackAuth._signing_key_cache: - return StackAuth._signing_key_cache[kid] - - jwks_endpoint = ( - f"https://api.stack-auth.com/api/v1/projects/{env_not_empty('stack_project_id')}/.well-known/jwks.json" - ) - loop = asyncio.get_running_loop() - jwks_client = jwt.PyJWKClient(jwks_endpoint) - - try: - signing_key = await loop.run_in_executor(None, jwks_client.get_signing_key, kid) - StackAuth._signing_key_cache[kid] = signing_key - return signing_key - except jwt.exceptions.PyJWKClientError as e: - raise UserException(code.token_invalid, str(e)) diff --git a/server/app/component/time_friendly.py b/server/app/component/time_friendly.py deleted file mode 100644 index b88d41a4..00000000 --- a/server/app/component/time_friendly.py +++ /dev/null @@ -1,38 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from datetime import datetime, timedelta - -import arrow - - -def to_date(time: str, format: str | None = None): - try: - if format: - return arrow.get(time, format).date() - else: - return arrow.get(time).date() - except Exception: - return None - - -def monday_start_time() -> datetime: - # 获取当前时间 - now = datetime.now() - # 计算今天是本周的第几天(星期一是0,星期天是6) - weekday = now.weekday() - # 计算本周一的日期 - monday = now - timedelta(days=weekday) - # 设置时间为 0 点 - return monday.replace(hour=0, minute=0, second=0, microsecond=0) diff --git a/server/app/controller/chat/history_controller.py b/server/app/controller/chat/history_controller.py deleted file mode 100644 index 84254a86..00000000 --- a/server/app/controller/chat/history_controller.py +++ /dev/null @@ -1,454 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging -from collections import defaultdict -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Query, Response -from fastapi_pagination import Page -from fastapi_pagination.ext.sqlmodel import paginate -from app.model.trigger.trigger import Trigger -from app.model.trigger.trigger_execution import TriggerExecution -from sqlmodel import Session, case, desc, select, func, delete - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.chat.chat_history import ( - ChatHistory, - ChatHistoryIn, - ChatHistoryOut, - ChatHistoryUpdate, - ChatStatus, -) -from app.model.chat.chat_history_grouped import ( - GroupedHistoryResponse, - ProjectGroup, -) - -logger = logging.getLogger("server_chat_history") - -router = APIRouter(prefix="/chat", tags=["Chat History"]) - -def is_real_task(history: ChatHistory) -> bool: - """ - Check if a task is a real task vs a placeholder/trigger-created task. - Excludes placeholder tasks created during trigger creation. - """ - # Has actual token usage - if history.tokens and history.tokens > 0: - return True - - # Has real model configuration (not placeholder "none" values) - if (history.model_platform and history.model_platform != "none" and - history.model_type and history.model_type != "none" and - history.installed_mcp and history.installed_mcp != "none"): - return True - - # Check if question starts with trigger placeholder prefix - if history.question and history.question.startswith("Project created via trigger:"): - return False - - # Default to real task if no placeholder indicators - return True - -@router.post("/history", name="save chat history", response_model=ChatHistoryOut) -def create_chat_history(data: ChatHistoryIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Save new chat history.""" - user_id = auth.user.id - - try: - data.user_id = user_id - chat_history = ChatHistory(**data.model_dump()) - session.add(chat_history) - session.commit() - session.refresh(chat_history) - logger.info( - "Chat history created", extra={"user_id": user_id, "history_id": chat_history.id, "task_id": data.task_id} - ) - return chat_history - except Exception as e: - session.rollback() - logger.error( - "Chat history creation failed", - extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.get("/histories", name="get chat history") -def list_chat_history(session: Session = Depends(session), auth: Auth = Depends(auth_must)) -> Page[ChatHistoryOut]: - """List chat histories for current user.""" - user_id = auth.user.id - - # Order by created_at descending, but fallback to id descending for old records without timestamps - # This ensures newer records with timestamps come first, followed by old records ordered by id - stmt = ( - select(ChatHistory) - .where(ChatHistory.user_id == user_id) - .order_by( - desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), # Non-null created_at first - desc(ChatHistory.created_at), # Then by created_at descending - desc(ChatHistory.id), # Finally by id descending for records with same/null created_at - ) - ) - - result = paginate(session, stmt) - total = result.total if hasattr(result, "total") else 0 - logger.debug("Chat histories listed", extra={"user_id": user_id, "total": total}) - return result - - -@router.get("/histories/grouped", name="get grouped chat history") -def list_grouped_chat_history( - include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in groups"), - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -) -> GroupedHistoryResponse: - """List chat histories grouped by project_id for current user.""" - user_id = auth.user.id - - # Get all histories for the user, ordered by creation time - stmt = ( - select(ChatHistory) - .where(ChatHistory.user_id == user_id) - .order_by( - desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), # Non-null created_at first - desc(ChatHistory.created_at), # Then by created_at descending - desc(ChatHistory.id) # Finally by id descending for records with same/null created_at - ) - ) - - histories = session.exec(stmt).all() - - # Get trigger counts per project - trigger_count_stmt = ( - select(Trigger.project_id, func.count(Trigger.id).label('count')) - .where(Trigger.user_id == str(user_id)) - .group_by(Trigger.project_id) - ) - trigger_counts = session.exec(trigger_count_stmt).all() - trigger_count_map = {project_id: count for project_id, count in trigger_counts} - - # Group histories by project_id - project_map = defaultdict(lambda: { - 'project_id': '', - 'project_name': None, - 'total_tokens': 0, - 'task_count': 0, - 'latest_task_date': '', - 'last_prompt': None, - 'tasks': [], - 'total_completed_tasks': 0, - 'total_ongoing_tasks': 0, - 'average_tokens_per_task': 0, - 'total_triggers': 0 - }) - - for history in histories: - # Use project_id if available, fallback to task_id - project_id = history.project_id if history.project_id else history.task_id - project_data = project_map[project_id] - - # Initialize project data - if not project_data['project_id']: - project_data['project_id'] = project_id - project_data['project_name'] = history.project_name or f"Project {project_id}" - project_data['latest_task_date'] = history.created_at.isoformat() if history.created_at else '' - project_data['last_prompt'] = history.question # Set the most recent question - - # Convert to ChatHistoryOut format - history_out = ChatHistoryOut(**history.model_dump()) - - # Add task to project if requested (only real tasks) - if include_tasks and is_real_task(history): - project_data['tasks'].append(history_out) - - # Update project statistics (only for real tasks) - if is_real_task(history): - project_data['task_count'] += 1 - project_data['total_tokens'] += history.tokens or 0 - - if history.status == ChatStatus.done: - project_data['total_completed_tasks'] += 1 - elif history.status == ChatStatus.ongoing: - project_data['total_ongoing_tasks'] += 1 - - # Update latest task date and last prompt - if history.created_at: - task_date = history.created_at.isoformat() - if not project_data['latest_task_date'] or task_date > project_data['latest_task_date']: - project_data['latest_task_date'] = task_date - project_data['last_prompt'] = history.question - - # Convert to ProjectGroup objects and sort - projects = [] - for project_data in project_map.values(): - # Sort tasks within each project by creation date (oldest first) - if include_tasks: - project_data['tasks'].sort(key=lambda x: (x.created_at is None, x.created_at or ''), reverse=False) - - # Set trigger count from trigger_count_map - project_id = project_data['project_id'] - project_data['total_triggers'] = trigger_count_map.get(project_id, 0) - - project_group = ProjectGroup(**project_data) - projects.append(project_group) - - # Sort projects by latest task date (newest first) - projects.sort(key=lambda x: x.latest_task_date, reverse=True) - - response = GroupedHistoryResponse(projects=projects) - - logger.debug("Grouped chat histories listed", extra={ - "user_id": user_id, - "total_projects": response.total_projects, - "total_tasks": response.total_tasks, - "include_tasks": include_tasks - }) - - return response - - -@router.get("/histories/grouped/{project_id}", name="get single grouped project") -def get_grouped_project( - project_id: str, - include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in the project"), - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -) -> ProjectGroup: - """Get a single project group by project_id for current user.""" - user_id = auth.user.id - - # Get all histories for the specific project - stmt = ( - select(ChatHistory) - .where(ChatHistory.user_id == user_id) - .where(ChatHistory.project_id == project_id) - .order_by( - desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), - desc(ChatHistory.created_at), - desc(ChatHistory.id) - ) - ) - - histories = session.exec(stmt).all() - - if not histories: - raise HTTPException(status_code=404, detail="Project not found") - - # Get trigger count for this project - trigger_count_stmt = ( - select(func.count(Trigger.id)) - .where(Trigger.user_id == str(user_id)) - .where(Trigger.project_id == project_id) - ) - trigger_count = session.exec(trigger_count_stmt).first() or 0 - - # Build project data - project_data = { - 'project_id': project_id, - 'project_name': None, - 'total_tokens': 0, - 'task_count': 0, - 'latest_task_date': '', - 'last_prompt': None, - 'tasks': [], - 'total_completed_tasks': 0, - 'total_ongoing_tasks': 0, - 'average_tokens_per_task': 0, - 'total_triggers': trigger_count - } - - for history in histories: - # Initialize project name from first history - if not project_data['project_name']: - project_data['project_name'] = history.project_name or f"Project {project_id}" - project_data['latest_task_date'] = history.created_at.isoformat() if history.created_at else '' - project_data['last_prompt'] = history.question - - # Convert to ChatHistoryOut format - history_out = ChatHistoryOut(**history.model_dump()) - - # Add task to project if requested (only real tasks) - if include_tasks and is_real_task(history): - project_data['tasks'].append(history_out) - - # Update project statistics (only for real tasks) - if is_real_task(history): - project_data['task_count'] += 1 - project_data['total_tokens'] += history.tokens or 0 - - if history.status == ChatStatus.done: - project_data['total_completed_tasks'] += 1 - elif history.status == ChatStatus.ongoing: - project_data['total_ongoing_tasks'] += 1 - - # Update latest task date and last prompt - if history.created_at: - task_date = history.created_at.isoformat() - if not project_data['latest_task_date'] or task_date > project_data['latest_task_date']: - project_data['latest_task_date'] = task_date - project_data['last_prompt'] = history.question - - # Sort tasks within the project by creation date (oldest first) - if include_tasks: - project_data['tasks'].sort(key=lambda x: (x.created_at is None, x.created_at or ''), reverse=False) - - project_group = ProjectGroup(**project_data) - - logger.debug("Single grouped project retrieved", extra={ - "user_id": user_id, - "project_id": project_id, - "task_count": project_group.task_count, - "include_tasks": include_tasks - }) - - return project_group - - -@router.delete("/history/{history_id}", name="delete chat history") -def delete_chat_history(history_id: str, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Delete chat history.""" - user_id = auth.user.id - history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first() - - if not history: - logger.warning("Chat history not found for deletion", extra={"user_id": user_id, "history_id": history_id}) - raise HTTPException(status_code=404, detail="Chat History not found") - - if history.user_id != user_id: - logger.warning( - "Unauthorized deletion attempt", - extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id}, - ) - raise HTTPException(status_code=403, detail="You are not allowed to delete this chat history") - - try: - # Determine the project this history belongs to - project_id = history.project_id if history.project_id else history.task_id - - # Check if this is the last history in the project - sibling_count = ( - session.exec( - select(func.count(ChatHistory.id)).where( - ChatHistory.id != history_id, - ChatHistory.project_id == project_id if history.project_id else ChatHistory.task_id == project_id, - ) - ).first() - or 0 - ) - - session.delete(history) - - if sibling_count == 0: - # Last history in the project — delete all related triggers - triggers = session.exec(select(Trigger).where(Trigger.project_id == project_id)).all() - for trigger in triggers: - session.exec(delete(TriggerExecution).where(TriggerExecution.trigger_id == trigger.id)) - session.delete(trigger) - logger.info( - "Deleted triggers for removed project", extra={"project_id": project_id, "trigger_count": len(triggers)} - ) - - session.commit() - logger.info("Chat history deleted", extra={"user_id": user_id, "history_id": history_id}) - return Response(status_code=204) - except Exception as e: - session.rollback() - logger.error( - "Chat history deletion failed", - extra={"user_id": user_id, "history_id": history_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut) -def update_chat_history( - history_id: int, data: ChatHistoryUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """Update chat history.""" - user_id = auth.user.id - history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first() - - if not history: - logger.warning("Chat history not found for update", extra={"user_id": user_id, "history_id": history_id}) - raise HTTPException(status_code=404, detail="Chat History not found") - - if history.user_id != user_id: - logger.warning( - "Unauthorized update attempt", - extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id}, - ) - raise HTTPException(status_code=403, detail="You are not allowed to update this chat history") - - try: - update_data = data.model_dump(exclude_unset=True) - history.update_fields(update_data) - history.save(session) - session.refresh(history) - logger.info( - "Chat history updated", - extra={"user_id": user_id, "history_id": history_id, "fields_updated": list(update_data.keys())}, - ) - return history - except Exception as e: - logger.error( - "Chat history update failed", - extra={"user_id": user_id, "history_id": history_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/project/{project_id}/name", name="update project name") -def update_project_name( - project_id: str, new_name: str, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """Update project name for all tasks in a project.""" - user_id = auth.user.id - - # Get all histories for this project - stmt = select(ChatHistory).where(ChatHistory.project_id == project_id).where(ChatHistory.user_id == user_id) - - histories = session.exec(stmt).all() - - if not histories: - logger.warning("No histories found for project", extra={"user_id": user_id, "project_id": project_id}) - raise HTTPException(status_code=404, detail="Project not found or access denied") - - try: - # Update all histories for this project - for history in histories: - history.project_name = new_name - session.add(history) - - session.commit() - - logger.info( - "Project name updated", - extra={"user_id": user_id, "project_id": project_id, "new_name": new_name, "updated_count": len(histories)}, - ) - - return Response(status_code=200) - except Exception as e: - session.rollback() - logger.error( - "Project name update failed", - extra={"user_id": user_id, "project_id": project_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/chat/share_controller.py b/server/app/controller/chat/share_controller.py deleted file mode 100644 index 4aee4489..00000000 --- a/server/app/controller/chat/share_controller.py +++ /dev/null @@ -1,136 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import asyncio -import json -import logging - -from fastapi import APIRouter, Depends, HTTPException -from itsdangerous import BadTimeSignature, SignatureExpired -from sqlmodel import Session, asc, select -from starlette.responses import StreamingResponse - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.chat.chat_history import ChatHistory -from app.model.chat.chat_share import ( - ChatHistoryShareOut, - ChatShare, - ChatShareIn, -) -from app.model.chat.chat_step import ChatStep - -logger = logging.getLogger("server_chat_share") - -router = APIRouter(prefix="/chat", tags=["Chat Share"]) - - -@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut) -def get_share_info(token: str, session: Session = Depends(session)): - """ - Get shared chat history info by token, excluding sensitive data. - """ - try: - task_id = ChatShare.verify_token(token, False) - except SignatureExpired: - logger.warning("Shared chat access failed: token expired", extra={"token_prefix": token[:10]}) - raise HTTPException(status_code=400, detail="Share link is invalid or has expired.") - except BadTimeSignature: - logger.warning("Shared chat access failed: invalid token", extra={"token_prefix": token[:10]}) - raise HTTPException(status_code=400, detail="Share link is invalid or has expired.") - - stmt = select(ChatHistory).where(ChatHistory.task_id == task_id) - history = session.exec(stmt).one_or_none() - - if not history: - logger.warning("Shared chat not found", extra={"task_id": task_id}) - raise HTTPException(status_code=404, detail="Chat history not found.") - - logger.info("Shared chat info accessed", extra={"task_id": task_id}) - return history - - -@router.get("/share/playback/{token}", name="Playback shared chat via SSE") -async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0): - """ - Playbacks the chat history via a sharing token (SSE). - delay_time: control sse interval, max 5 seconds - """ - if delay_time > 5: - logger.debug("Delay time capped", extra={"requested": delay_time, "capped": 5}) - delay_time = 5 - - try: - task_id = ChatShare.verify_token(token, False) - except SignatureExpired: - logger.warning("Shared chat playback failed: token expired", extra={"token_prefix": token[:10]}) - raise HTTPException(status_code=400, detail="Share link has expired.") - except BadTimeSignature: - logger.warning("Shared chat playback failed: invalid token", extra={"token_prefix": token[:10]}) - raise HTTPException(status_code=400, detail="Share link is invalid.") - - async def event_generator(): - try: - stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id)) - steps = session.exec(stmt).all() - - if not steps: - logger.warning("No steps found for playback", extra={"task_id": task_id}) - yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" - return - - logger.info( - "Shared chat playback started", - extra={"task_id": task_id, "step_count": len(steps), "delay_time": delay_time}, - ) - - for idx, step in enumerate(steps, start=1): - step_data = { - "id": step.id, - "task_id": step.task_id, - "step": step.step, - "data": step.data, - "created_at": step.created_at.isoformat() if step.created_at else None, - } - yield f"data: {json.dumps(step_data)}\n\n" - - if delay_time > 0 and step.step != "create_agent": - await asyncio.sleep(delay_time) - - logger.info("Shared chat playback completed", extra={"task_id": task_id, "step_count": len(steps)}) - except Exception as e: - logger.error("Shared chat playback error", extra={"task_id": task_id, "error": str(e)}, exc_info=True) - yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n" - - return StreamingResponse(event_generator(), media_type="text/event-stream") - - -@router.post("/share", name="Generate sharable link for a task(1 day expiration)") -def create_share_link(data: ChatShareIn, auth: Auth = Depends(auth_must)): - """Generate sharing token with 1-day expiration for task.""" - user_id = auth.user.id - try: - share_token = ChatShare.generate_token(data.task_id) - logger.info( - "Share link created", - extra={"user_id": user_id, "task_id": data.task_id, "token_prefix": share_token[:10]}, - ) - return {"share_token": share_token} - except Exception as e: - logger.error( - "Share link creation failed", - extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/chat/snapshot_controller.py b/server/app/controller/chat/snapshot_controller.py deleted file mode 100644 index 3f52c7c5..00000000 --- a/server/app/controller/chat/snapshot_controller.py +++ /dev/null @@ -1,195 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging - -from fastapi import APIRouter, Depends, HTTPException, Response -from fastapi_babel import _ -from sqlmodel import Session, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn - -logger = logging.getLogger("server_chat_snapshot") - -router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"]) - - -@router.get("/snapshots", name="list chat snapshots", response_model=list[ChatSnapshot]) -async def list_chat_snapshots( - api_task_id: str | None = None, - camel_task_id: str | None = None, - browser_url: str | None = None, - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -): - """List chat snapshots with optional filtering.""" - user_id = auth.user.id - query = select(ChatSnapshot).where(ChatSnapshot.user_id == user_id) - if api_task_id is not None: - query = query.where(ChatSnapshot.api_task_id == api_task_id) - if camel_task_id is not None: - query = query.where(ChatSnapshot.camel_task_id == camel_task_id) - if browser_url is not None: - query = query.where(ChatSnapshot.browser_url == browser_url) - - snapshots = session.exec(query).all() - logger.debug( - "Snapshots listed", - extra={"user_id": user_id, "api_task_id": api_task_id, "camel_task_id": camel_task_id, "count": len(snapshots)}, - ) - return snapshots - - -@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot) -async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Get specific chat snapshot.""" - user_id = auth.user.id - snapshot = session.get(ChatSnapshot, snapshot_id) - - if not snapshot: - logger.warning("Snapshot not found", extra={"user_id": user_id, "snapshot_id": snapshot_id}) - raise HTTPException(status_code=404, detail=_("Chat snapshot not found")) - - if snapshot.user_id != user_id: - logger.warning( - "Unauthorized snapshot access", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": snapshot.user_id}, - ) - raise HTTPException(status_code=403, detail=_("You are not allowed to view this snapshot")) - - logger.debug( - "Snapshot retrieved", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "api_task_id": snapshot.api_task_id}, - ) - return snapshot - - -@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot) -async def create_chat_snapshot( - snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session) -): - """Create new chat snapshot from image.""" - user_id = auth.user.id - - try: - image_path = ChatSnapshotIn.save_image(user_id, snapshot.api_task_id, snapshot.image_base64) - chat_snapshot = ChatSnapshot( - user_id=user_id, - api_task_id=snapshot.api_task_id, - camel_task_id=snapshot.camel_task_id, - browser_url=snapshot.browser_url, - image_path=image_path, - ) - session.add(chat_snapshot) - session.commit() - session.refresh(chat_snapshot) - logger.info( - "Snapshot created", - extra={ - "user_id": user_id, - "snapshot_id": chat_snapshot.id, - "api_task_id": snapshot.api_task_id, - "image_path": image_path, - }, - ) - return chat_snapshot - except Exception as e: - session.rollback() - logger.error( - "Snapshot creation failed", - extra={"user_id": user_id, "api_task_id": snapshot.api_task_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot) -async def update_chat_snapshot( - snapshot_id: int, - snapshot_update: ChatSnapshot, - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -): - """Update chat snapshot.""" - user_id = auth.user.id - db_snapshot = session.get(ChatSnapshot, snapshot_id) - - if not db_snapshot: - logger.warning("Snapshot not found for update", extra={"user_id": user_id, "snapshot_id": snapshot_id}) - raise HTTPException(status_code=404, detail=_("Chat snapshot not found")) - - if db_snapshot.user_id != user_id: - logger.warning( - "Unauthorized snapshot update", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id}, - ) - raise HTTPException(status_code=403, detail=_("You are not allowed to update this snapshot")) - - try: - update_data = snapshot_update.dict(exclude_unset=True) - for key, value in update_data.items(): - setattr(db_snapshot, key, value) - session.add(db_snapshot) - session.commit() - session.refresh(db_snapshot) - logger.info( - "Snapshot updated", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "fields_updated": list(update_data.keys())}, - ) - return db_snapshot - except Exception as e: - session.rollback() - logger.error( - "Snapshot update failed", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot") -async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Delete chat snapshot.""" - user_id = auth.user.id - db_snapshot = session.get(ChatSnapshot, snapshot_id) - - if not db_snapshot: - logger.warning("Snapshot not found for deletion", extra={"user_id": user_id, "snapshot_id": snapshot_id}) - raise HTTPException(status_code=404, detail=_("Chat snapshot not found")) - - if db_snapshot.user_id != user_id: - logger.warning( - "Unauthorized snapshot deletion", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id}, - ) - raise HTTPException(status_code=403, detail=_("You are not allowed to delete this snapshot")) - - try: - session.delete(db_snapshot) - session.commit() - logger.info( - "Snapshot deleted", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "image_path": db_snapshot.image_path}, - ) - return Response(status_code=204) - except Exception as e: - session.rollback() - logger.error( - "Snapshot deletion failed", - extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/chat/step_controller.py b/server/app/controller/chat/step_controller.py deleted file mode 100644 index 60c1bda1..00000000 --- a/server/app/controller/chat/step_controller.py +++ /dev/null @@ -1,207 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import asyncio -import json -import logging - -from fastapi import APIRouter, Depends, HTTPException, Response -from fastapi.responses import StreamingResponse -from fastapi_babel import _ -from sqlalchemy.sql.expression import case -from sqlmodel import Session, asc, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.chat.chat_step import ChatStep, ChatStepIn, ChatStepOut - -logger = logging.getLogger("server_chat_step") - -router = APIRouter(prefix="/chat", tags=["Chat Step Management"]) - - -@router.get("/steps", name="list chat steps", response_model=list[ChatStepOut]) -async def list_chat_steps( - task_id: str, step: str | None = None, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """List chat steps for a task with optional step type filtering.""" - user_id = auth.user.id - query = select(ChatStep) - if task_id is not None: - query = query.where(ChatStep.task_id == task_id) - if step is not None: - query = query.where(ChatStep.step == step) - - chat_steps = session.exec(query).all() - logger.debug( - "Chat steps listed", extra={"user_id": user_id, "task_id": task_id, "step_type": step, "count": len(chat_steps)} - ) - return chat_steps - - -@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE") -async def share_playback( - task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """Playback chat steps via SSE stream.""" - user_id = auth.user.id - if delay_time > 5: - logger.debug( - "Delay time capped", extra={"user_id": user_id, "task_id": task_id, "requested": delay_time, "capped": 5} - ) - delay_time = 5 - - async def event_generator(): - try: - stmt = ( - select(ChatStep) - .where(ChatStep.task_id == task_id) - .order_by( - asc(case((ChatStep.timestamp.is_(None), 1), else_=0)), asc(ChatStep.timestamp), asc(ChatStep.id) - ) - ) - steps = session.exec(stmt).all() - - if not steps: - logger.warning("No steps found for playback", extra={"user_id": user_id, "task_id": task_id}) - yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" - return - - logger.info( - "Chat step playback started", - extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps), "delay_time": delay_time}, - ) - - for step in steps: - step_data = { - "id": step.id, - "task_id": step.task_id, - "step": step.step, - "data": step.data, - "created_at": step.created_at.isoformat() if step.created_at else None, - } - yield f"data: {json.dumps(step_data)}\n\n" - if delay_time > 0: - await asyncio.sleep(delay_time) - - logger.info( - "Chat step playback completed", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps)} - ) - except Exception as e: - logger.error( - "Chat step playback error", - extra={"user_id": user_id, "task_id": task_id, "error": str(e)}, - exc_info=True, - ) - yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n" - - return StreamingResponse(event_generator(), media_type="text/event-stream") - - -@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut) -async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Get specific chat step.""" - user_id = auth.user.id - chat_step = session.get(ChatStep, step_id) - - if not chat_step: - logger.warning("Chat step not found", extra={"user_id": user_id, "step_id": step_id}) - raise HTTPException(status_code=404, detail=_("Chat step not found")) - - logger.debug("Chat step retrieved", extra={"user_id": user_id, "step_id": step_id, "task_id": chat_step.task_id}) - return chat_step - - -@router.post("/steps", name="create chat step") -async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)): - """Create new chat step. TODO: Implement request source validation.""" - try: - chat_step = ChatStep(task_id=step.task_id, step=step.step, data=step.data, timestamp=step.timestamp) - session.add(chat_step) - session.commit() - session.refresh(chat_step) - logger.info( - "Chat step created", extra={"step_id": chat_step.id, "task_id": step.task_id, "step_type": step.step} - ) - return {"code": 200, "msg": "success"} - except Exception as e: - session.rollback() - logger.error( - "Chat step creation failed", - extra={"task_id": step.task_id, "step_type": step.step, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut) -async def update_chat_step( - step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """Update chat step.""" - user_id = auth.user.id - db_chat_step = session.get(ChatStep, step_id) - - if not db_chat_step: - logger.warning("Chat step not found for update", extra={"user_id": user_id, "step_id": step_id}) - raise HTTPException(status_code=404, detail=_("Chat step not found")) - - try: - update_data = chat_step_update.dict(exclude_unset=True) - for key, value in update_data.items(): - setattr(db_chat_step, key, value) - session.add(db_chat_step) - session.commit() - session.refresh(db_chat_step) - logger.info( - "Chat step updated", - extra={ - "user_id": user_id, - "step_id": step_id, - "task_id": db_chat_step.task_id, - "fields_updated": list(update_data.keys()), - }, - ) - return db_chat_step - except Exception as e: - session.rollback() - logger.error( - "Chat step update failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.delete("/steps/{step_id}", name="delete chat step") -async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Delete chat step.""" - user_id = auth.user.id - db_chat_step = session.get(ChatStep, step_id) - - if not db_chat_step: - logger.warning("Chat step not found for deletion", extra={"user_id": user_id, "step_id": step_id}) - raise HTTPException(status_code=404, detail=_("Chat step not found")) - - try: - session.delete(db_chat_step) - session.commit() - logger.info( - "Chat step deleted", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id} - ) - return Response(status_code=204) - except Exception as e: - session.rollback() - logger.error( - "Chat step deletion failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/config/config_controller.py b/server/app/controller/config/config_controller.py deleted file mode 100644 index 6812b2d8..00000000 --- a/server/app/controller/config/config_controller.py +++ /dev/null @@ -1,221 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging - -from fastapi import APIRouter, Depends, HTTPException, Query, Response -from fastapi_babel import _ -from sqlmodel import Session, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.config.config import ( - Config, - ConfigCreate, - ConfigInfo, - ConfigOut, - ConfigUpdate, -) - -logger = logging.getLogger("server_config_controller") - -router = APIRouter(tags=["Config Management"]) - - -@router.get("/configs", name="list configs", response_model=list[ConfigOut]) -async def list_configs( - config_group: str | None = None, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """List user's configurations with optional group filtering.""" - user_id = auth.user.id - query = select(Config).where(Config.user_id == user_id) - - if config_group is not None: - query = query.where(Config.config_group == config_group) - - configs = session.exec(query).all() - logger.debug("Configs listed", extra={"user_id": user_id, "config_group": config_group, "count": len(configs)}) - return configs - - -@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut) -async def get_config( - config_id: int, - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -): - query = select(Config).where(Config.user_id == auth.user.id) - - if config_id is not None: - query = query.where(Config.id == config_id) - - config = session.exec(query).first() - - if not config: - logger.warning("Config not found") - raise HTTPException(status_code=404, detail=_("Configuration not found")) - - logger.debug("Config retrieved") - return config - - -@router.post("/configs", name="create config", response_model=ConfigOut) -async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Create new configuration.""" - user_id = auth.user.id - - if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name): - logger.warning( - "Config validation failed", - extra={"user_id": user_id, "config_group": config.config_group, "config_name": config.config_name}, - ) - raise HTTPException(status_code=400, detail=_("Invalid config name or group")) - - # Check if configuration already exists - existing_config = session.exec( - select(Config).where(Config.user_id == user_id, Config.config_name == config.config_name) - ).first() - - if existing_config: - logger.warning( - "Config creation failed: already exists", extra={"user_id": user_id, "config_name": config.config_name} - ) - raise HTTPException(status_code=400, detail=_("Configuration already exists for this user")) - - try: - db_config = Config( - user_id=user_id, - config_name=config.config_name, - config_value=config.config_value, - config_group=config.config_group, - ) - session.add(db_config) - session.commit() - session.refresh(db_config) - logger.info( - "Config created", - extra={ - "user_id": user_id, - "config_id": db_config.id, - "config_group": config.config_group, - "config_name": config.config_name, - }, - ) - return db_config - except Exception as e: - session.rollback() - logger.error( - "Config creation failed", - extra={"user_id": user_id, "config_name": config.config_name, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut) -async def update_config( - config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """Update configuration.""" - user_id = auth.user.id - db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first() - - if not db_config: - logger.warning("Config not found for update", extra={"user_id": user_id, "config_id": config_id}) - raise HTTPException(status_code=404, detail=_("Configuration not found")) - - # Check if configuration group is valid - if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name): - logger.warning( - "Config update validation failed", - extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group}, - ) - raise HTTPException(status_code=400, detail=_("Invalid configuration group")) - - # Check for conflicts with other configurations - existing_config = session.exec( - select(Config).where( - Config.user_id == user_id, - Config.config_name == config_update.config_name, - Config.id != config_id, - ) - ).first() - - if existing_config: - logger.warning( - "Config update failed: duplicate name", - extra={"user_id": user_id, "config_id": config_id, "config_name": config_update.config_name}, - ) - raise HTTPException(status_code=400, detail=_("Configuration already exists for this user")) - - try: - db_config.config_name = config_update.config_name - db_config.config_value = config_update.config_value - db_config.config_group = config_update.config_group - session.add(db_config) - session.commit() - session.refresh(db_config) - logger.info( - "Config updated", - extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group}, - ) - return db_config - except Exception as e: - session.rollback() - logger.error( - "Config update failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.delete("/configs/{config_id}", name="delete config") -async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Delete configuration.""" - user_id = auth.user.id - db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first() - - if not db_config: - logger.warning("Config not found for deletion", extra={"user_id": user_id, "config_id": config_id}) - raise HTTPException(status_code=404, detail=_("Configuration not found")) - - try: - session.delete(db_config) - session.commit() - logger.info( - "Config deleted", extra={"user_id": user_id, "config_id": config_id, "config_name": db_config.config_name} - ) - return Response(status_code=204) - except Exception as e: - session.rollback() - logger.error( - "Config deletion failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.get("/config/info", name="get config info") -async def get_config_info( - show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"), -): - """Get available configuration templates and info.""" - configs = ConfigInfo.getinfo() - if show_all: - logger.debug("Config info retrieved", extra={"show_all": True, "count": len(configs)}) - return configs - - filtered = {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0} - logger.debug( - "Config info retrieved", extra={"show_all": False, "total_count": len(configs), "filtered_count": len(filtered)} - ) - return filtered diff --git a/server/app/controller/mcp/mcp_controller.py b/server/app/controller/mcp/mcp_controller.py deleted file mode 100644 index 17e92dc2..00000000 --- a/server/app/controller/mcp/mcp_controller.py +++ /dev/null @@ -1,294 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging -import os - -from camel.toolkits.mcp_toolkit import MCPToolkit -from fastapi import APIRouter, Depends, HTTPException -from fastapi_babel import _ -from fastapi_pagination import Page -from fastapi_pagination.ext.sqlmodel import paginate -from sqlalchemy.orm import selectinload, with_loader_criteria -from sqlmodel import Session, col, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.component.environment import env -from app.model.mcp.mcp import Mcp, McpOut, McpType -from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus -from app.model.mcp.mcp_user import McpImportType, McpUser, Status - -logger = logging.getLogger("server_mcp_controller") - -from app.component.validator.McpServer import ( - McpRemoteServer, - McpServerItem, - validate_mcp_remote_servers, - validate_mcp_servers, -) - -router = APIRouter(tags=["Mcp Servers"]) - - -async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool: - """ - Pre-instantiate MCP toolkit to complete authentication process - - Args: - config_dict: MCP server configuration dictionary - - Returns: - bool: Whether successfully instantiated and connected - """ - try: - # Ensure unified auth directory for all mcp servers - for server_config in config_dict.get("mcpServers", {}).values(): - if "env" not in server_config: - server_config["env"] = {} - # Set global auth directory to persist authentication across tasks - if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]: - server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env( - "MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth") - ) - - # Create MCP toolkit and attempt to connect - mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30) - await mcp_toolkit.connect() - - # Get tools list to ensure connection is successful - tools = mcp_toolkit.get_tools() - logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)}) - - # Disconnect, authentication info is already saved - await mcp_toolkit.disconnect() - return True - - except Exception as e: - logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True) - return False - - -@router.get("/mcps", name="mcp list") -async def gets( - keyword: str | None = None, - category_id: int | None = None, - mine: int | None = None, - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -) -> Page[McpOut]: - """List MCP servers with optional filtering.""" - user_id = auth.user.id - stmt = ( - select(Mcp) - .where(Mcp.no_delete()) - .options( - selectinload(Mcp.category), - selectinload(Mcp.envs), - with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use), - ) - ) - if keyword: - stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%")) - if category_id: - stmt = stmt.where(Mcp.category_id == category_id) - if mine and auth: - stmt = ( - stmt.join(McpUser) - .where(McpUser.user_id == user_id) - .options( - selectinload(Mcp.mcp_user), - with_loader_criteria(McpUser, col(McpUser.user_id) == user_id), - ) - ) - - result = paginate(session, stmt) - total = result.total if hasattr(result, "total") else 0 - logger.debug( - "MCP list retrieved", - extra={"user_id": user_id, "keyword": keyword, "category_id": category_id, "mine": mine, "total": total}, - ) - return result - - -@router.get("/mcp", name="mcp detail", response_model=McpOut) -async def get(id: int, session: Session = Depends(session)): - """Get MCP server details.""" - try: - stmt = ( - select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs)) - ) - model = session.exec(stmt).one() - logger.debug("MCP detail retrieved", extra={"mcp_id": id, "mcp_key": model.key}) - return model - except Exception: - logger.warning("MCP not found", extra={"mcp_id": id}) - raise HTTPException(status_code=404, detail=_("Mcp not found")) - - -@router.post("/mcp/install", name="mcp install") -async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Install MCP server for user.""" - user_id = auth.user.id - - mcp = session.get_one(Mcp, mcp_id) - if not mcp: - logger.warning("MCP install failed: MCP not found", extra={"user_id": user_id, "mcp_id": mcp_id}) - raise HTTPException(status_code=404, detail=_("Mcp not found")) - - exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == user_id)).first() - if exists: - logger.warning( - "MCP install failed: already installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key} - ) - raise HTTPException(status_code=400, detail=_("mcp is installed")) - - install_command: dict = mcp.install_command - - # Pre-instantiate MCP toolkit for authentication - config_dict = {"mcpServers": {mcp.key: install_command}} - - try: - success = await pre_instantiate_mcp_toolkit(config_dict) - if not success: - logger.warning( - "MCP pre-instantiation failed, continuing with installation", - extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key}, - ) - else: - logger.debug("MCP toolkit pre-instantiated", extra={"mcp_key": mcp.key}) - except Exception as e: - logger.warning( - "MCP pre-instantiation exception", - extra={"user_id": user_id, "mcp_key": mcp.key, "error": str(e)}, - exc_info=True, - ) - - try: - mcp_user = McpUser( - mcp_id=mcp.id, - user_id=user_id, - mcp_name=mcp.name, - mcp_key=mcp.key, - mcp_desc=mcp.description, - type=mcp.type, - status=Status.enable, - command=install_command["command"], - args=install_command["args"], - env=install_command["env"], - server_url=None, - ) - mcp_user.save() - logger.info("MCP installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key}) - return mcp_user - except Exception as e: - logger.error( - "MCP installation failed", - extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.post("/mcp/import/{mcp_type}", name="mcp import") -async def import_mcp( - mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must) -): - """Import MCP servers (local or remote).""" - user_id = auth.user.id - - if mcp_type == McpImportType.Local: - logger.info("Importing local MCP servers", extra={"user_id": user_id}) - is_valid, res = validate_mcp_servers(mcp_data) - if not is_valid: - logger.warning("Local MCP import validation failed", extra={"user_id": user_id, "error": res}) - raise HTTPException(status_code=400, detail=res) - - mcp_data: dict[str, McpServerItem] = res.mcpServers - imported_count = 0 - - for name, data in mcp_data.items(): - config_dict = {"mcpServers": {name: {"command": data.command, "args": data.args, "env": data.env or {}}}} - - try: - success = await pre_instantiate_mcp_toolkit(config_dict) - if not success: - logger.warning( - "Local MCP pre-instantiation failed, continuing", extra={"user_id": user_id, "mcp_name": name} - ) - except Exception as e: - logger.warning( - "Local MCP pre-instantiation exception", - extra={"user_id": user_id, "mcp_name": name, "error": str(e)}, - ) - - try: - mcp_user = McpUser( - mcp_id=0, - user_id=user_id, - mcp_name=name, - mcp_key=name, - mcp_desc=name, - type=McpType.Local, - status=Status.enable, - command=data.command, - args=data.args, - env=data.env, - server_url=None, - ) - mcp_user.save() - imported_count += 1 - except Exception as e: - logger.error( - "Failed to import local MCP", - extra={"user_id": user_id, "mcp_name": name, "error": str(e)}, - exc_info=True, - ) - - logger.info("Local MCPs imported", extra={"user_id": user_id, "count": imported_count}) - return {"message": "Local MCP servers imported successfully", "count": imported_count} - - elif mcp_type == McpImportType.Remote: - logger.info("Importing remote MCP server", extra={"user_id": user_id}) - is_valid, res = validate_mcp_remote_servers(mcp_data) - if not is_valid: - logger.warning("Remote MCP import validation failed", extra={"user_id": user_id, "error": res}) - raise HTTPException(status_code=400, detail=res) - - data: McpRemoteServer = res - - try: - # For remote servers, we don't need to pre-instantiate as they typically don't require authentication - # but we can still try to validate the connection if needed - mcp_user = McpUser( - mcp_id=0, - user_id=user_id, - type=McpType.Remote, - status=Status.enable, - mcp_name=data.server_name, - server_url=data.server_url, - ) - mcp_user.save() - logger.info( - "Remote MCP imported", - extra={"user_id": user_id, "server_name": data.server_name, "server_url": data.server_url}, - ) - return mcp_user - except Exception as e: - logger.error( - "Remote MCP import failed", - extra={"user_id": user_id, "server_name": data.server_name, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/mcp/user_controller.py b/server/app/controller/mcp/user_controller.py deleted file mode 100644 index 0d8ca33b..00000000 --- a/server/app/controller/mcp/user_controller.py +++ /dev/null @@ -1,210 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging -import os - -from camel.toolkits.mcp_toolkit import MCPToolkit -from fastapi import APIRouter, Depends, HTTPException, Response -from fastapi_babel import _ -from sqlmodel import Session, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.component.environment import env -from app.model.mcp.mcp import Mcp -from app.model.mcp.mcp_user import ( - McpUser, - McpUserIn, - McpUserOut, - McpUserUpdate, -) - -logger = logging.getLogger("server_mcp_user_controller") - -router = APIRouter(tags=["McpUser Management"]) - - -async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool: - """ - Pre-instantiate MCP toolkit to complete authentication process - - Args: - config_dict: MCP server configuration dictionary - - Returns: - bool: Whether successfully instantiated and connected - """ - try: - # Ensure unified auth directory for all mcp servers - for server_config in config_dict.get("mcpServers", {}).values(): - if "env" not in server_config: - server_config["env"] = {} - # Set global auth directory to persist authentication across tasks - if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]: - server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env( - "MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth") - ) - - # Create MCP toolkit and attempt to connect - mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30) - await mcp_toolkit.connect() - - # Get tools list to ensure connection is successful - tools = mcp_toolkit.get_tools() - logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)}) - - # Disconnect, authentication info is already saved - await mcp_toolkit.disconnect() - return True - - except Exception as e: - logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True) - return False - - -@router.get("/mcp/users", name="list mcp users", response_model=list[McpUserOut]) -async def list_mcp_users( - mcp_id: int | None = None, - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -): - """List MCP users for current user.""" - user_id = auth.user.id - query = select(McpUser) - if mcp_id is not None: - query = query.where(McpUser.mcp_id == mcp_id) - if user_id is not None: - query = query.where(McpUser.user_id == user_id) - mcp_users = session.exec(query).all() - logger.debug("MCP users listed", extra={"user_id": user_id, "mcp_id": mcp_id, "count": len(mcp_users)}) - return mcp_users - - -@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut) -async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Get MCP user details.""" - query = select(McpUser).where(McpUser.id == mcp_user_id) - mcp_user = session.exec(query).first() - if not mcp_user: - logger.warning("MCP user not found", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id}) - raise HTTPException(status_code=404, detail=_("McpUser not found")) - logger.debug( - "MCP user retrieved", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id, "mcp_id": mcp_user.mcp_id} - ) - return mcp_user - - -@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut) -async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Create MCP user installation.""" - user_id = auth.user.id - mcp_id = mcp_user.mcp_id - - exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp_id, McpUser.user_id == user_id)).first() - if exists: - logger.warning("MCP already installed", extra={"user_id": user_id, "mcp_id": mcp_id}) - raise HTTPException(status_code=400, detail=_("mcp is installed")) - - # Get MCP configuration from the main Mcp table - mcp = session.get(Mcp, mcp_id) - if mcp and mcp.install_command: - config_dict = {"mcpServers": {mcp.key: mcp.install_command}} - - try: - success = await pre_instantiate_mcp_toolkit(config_dict) - if not success: - logger.warning( - "MCP pre-instantiation failed, continuing", - extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key}, - ) - except Exception as e: - logger.warning( - "MCP pre-instantiation exception", - extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, - exc_info=True, - ) - - try: - db_mcp_user = McpUser(mcp_id=mcp_id, user_id=user_id, env=mcp_user.env) - session.add(db_mcp_user) - session.commit() - session.refresh(db_mcp_user) - logger.info("MCP user created", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_user_id": db_mcp_user.id}) - return db_mcp_user - except Exception as e: - session.rollback() - logger.error( - "MCP user creation failed", extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/mcp/users/{id}", name="update mcp user") -async def update_mcp_user( - id: int, - update_item: McpUserUpdate, - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -): - """Update MCP user settings.""" - user_id = auth.user.id - model = session.get(McpUser, id) - if not model: - logger.warning("MCP user not found for update", extra={"user_id": user_id, "mcp_user_id": id}) - raise HTTPException(status_code=404, detail=_("Mcp Info not found")) - if model.user_id != user_id: - logger.warning( - "Unauthorized MCP user update", extra={"user_id": user_id, "mcp_user_id": id, "owner_id": model.user_id} - ) - raise HTTPException(status_code=400, detail=_("current user have no permission to modify")) - - try: - update_data = update_item.model_dump(exclude_unset=True) - model.update_fields(update_data) - model.save(session) - session.refresh(model) - logger.info("MCP user updated", extra={"user_id": user_id, "mcp_user_id": id}) - return model - except Exception as e: - logger.error( - "MCP user update failed", extra={"user_id": user_id, "mcp_user_id": id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user") -async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Delete MCP user installation.""" - user_id = auth.user.id - db_mcp_user = session.get(McpUser, mcp_user_id) - if not db_mcp_user: - logger.warning("MCP user not found for deletion", extra={"user_id": user_id, "mcp_user_id": mcp_user_id}) - raise HTTPException(status_code=404, detail=_("Mcp Info not found")) - - try: - session.delete(db_mcp_user) - session.commit() - logger.info( - "MCP user deleted", extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "mcp_id": db_mcp_user.mcp_id} - ) - return Response(status_code=204) - except Exception as e: - session.rollback() - logger.error( - "MCP user deletion failed", - extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/oauth/oauth_controller.py b/server/app/controller/oauth/oauth_controller.py deleted file mode 100644 index a93ca5fb..00000000 --- a/server/app/controller/oauth/oauth_controller.py +++ /dev/null @@ -1,121 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging -import re -from urllib.parse import quote, urlencode - -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse - -from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter - -# Allowed OAuth provider names (alphanumeric and hyphens only) -ALLOWED_PROVIDER_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") - -logger = logging.getLogger("server_oauth_controller") - -router = APIRouter(prefix="/oauth", tags=["Oauth Servers"]) - - -@router.get("/{app}/login", name="OAuth Login Redirect") -def oauth_login(app: str, request: Request, state: str | None = None): - """Redirect user to OAuth provider's authorization endpoint.""" - try: - callback_url = str(request.url_for("OAuth Callback", app=app)) - if callback_url.startswith("http://"): - callback_url = "https://" + callback_url[len("http://") :] - - adapter = get_oauth_adapter(app, callback_url) - url = adapter.get_authorize_url(state) - - if not url: - logger.error("Failed to generate authorization URL", extra={"provider": app, "callback_url": callback_url}) - raise HTTPException(status_code=400, detail="Failed to generate authorization URL") - - logger.info("OAuth login initiated", extra={"provider": app}) - return RedirectResponse(str(url)) - except HTTPException: - raise - except Exception as e: - logger.error("OAuth login failed", extra={"provider": app, "error": str(e)}, exc_info=True) - raise HTTPException(status_code=400, detail="OAuth login failed") - - -@router.get("/{app}/callback", name="OAuth Callback") -def oauth_callback( - app: str, - request: Request, - code: str | None = None, - state: str | None = None, -): - """Handle OAuth provider callback and redirect to client app.""" - # Security: Validate provider name to prevent injection - if not ALLOWED_PROVIDER_PATTERN.match(app): - logger.warning( - "OAuth callback invalid provider name", - extra={"provider": app[:50]}, # Truncate for logging - ) - raise HTTPException(status_code=400, detail="Invalid provider") - - if not code: - logger.warning("OAuth callback missing code", extra={"provider": app}) - raise HTTPException(status_code=400, detail="Missing code parameter") - - logger.info( - "OAuth callback received", - extra={"provider": app, "has_state": state is not None}, - ) - - # Security: URL-encode all parameters to prevent XSS - params = {"provider": app, "code": code} - if state is not None: - params["state"] = state - redirect_url = f"eigent://callback/oauth?{urlencode(params)}" - - # Security: Use a safe redirect approach without embedding in JavaScript - # The redirect URL uses a custom protocol, so we encode it safely - safe_redirect_url = quote(redirect_url, safe=":/&?=") - - html_content = f""" - - - - OAuth Callback - - - -

Redirecting, please wait...

-

If you are not redirected, click here.

- - -""" - return HTMLResponse(content=html_content) - - -@router.post("/{app}/token", name="OAuth Fetch Token") -def fetch_token(app: str, request: Request, data: OauthCallbackPayload): - """Exchange authorization code for access token.""" - try: - callback_url = str(request.url_for("OAuth Callback", app=app)) - if callback_url.startswith("http://"): - callback_url = "https://" + callback_url[len("http://") :] - - adapter = get_oauth_adapter(app, callback_url) - token_data = adapter.fetch_token(data.code) - logger.info("OAuth token fetched", extra={"provider": app}) - return JSONResponse(token_data) - except Exception as e: - logger.error("OAuth token fetch failed", extra={"provider": app, "error": str(e)}, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/provider/provider_controller.py b/server/app/controller/provider/provider_controller.py deleted file mode 100644 index 6ed32729..00000000 --- a/server/app/controller/provider/provider_controller.py +++ /dev/null @@ -1,165 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging - -from fastapi import APIRouter, Depends, HTTPException, Query, Response -from fastapi_babel import _ -from fastapi_pagination import Page -from fastapi_pagination.ext.sqlmodel import paginate -from sqlalchemy import update -from sqlalchemy.exc import SQLAlchemyError -from sqlmodel import Session, col, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.provider.provider import ( - Provider, - ProviderIn, - ProviderOut, - ProviderPreferIn, -) - -logger = logging.getLogger("server_provider_controller") - -router = APIRouter(tags=["Provider Management"]) - - -@router.get("/providers", name="list providers", response_model=Page[ProviderOut]) -async def gets( - keyword: str | None = None, - prefer: bool | None = Query(None, description="Filter by prefer status"), - session: Session = Depends(session), - auth: Auth = Depends(auth_must), -) -> Page[ProviderOut]: - """List user's providers with optional filtering.""" - user_id = auth.user.id - stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete()) - if keyword: - stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%")) - if prefer is not None: - stmt = stmt.where(Provider.prefer == prefer) - stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc()) - logger.debug("Providers listed", extra={"user_id": user_id, "keyword": keyword, "prefer_filter": prefer}) - return paginate(session, stmt) - - -@router.get("/provider", name="get provider detail", response_model=ProviderOut) -async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Get provider details.""" - user_id = auth.user.id - stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id) - model = session.exec(stmt).one_or_none() - if not model: - logger.warning("Provider not found", extra={"user_id": user_id, "provider_id": id}) - raise HTTPException(status_code=404, detail=_("Provider not found")) - logger.debug("Provider retrieved", extra={"user_id": user_id, "provider_id": id}) - return model - - -@router.post("/provider", name="create provider", response_model=ProviderOut) -async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Create a new provider.""" - user_id = auth.user.id - try: - model = Provider(**data.model_dump(), user_id=user_id) - model.save(session) - logger.info( - "Provider created", extra={"user_id": user_id, "provider_id": model.id, "provider_name": data.provider_name} - ) - return model - except Exception as e: - logger.error("Provider creation failed", extra={"user_id": user_id, "error": str(e)}, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.put("/provider/{id}", name="update provider", response_model=ProviderOut) -async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Update provider details.""" - user_id = auth.user.id - model = session.exec( - select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id) - ).one_or_none() - if not model: - logger.warning("Provider not found for update", extra={"user_id": user_id, "provider_id": id}) - raise HTTPException(status_code=404, detail=_("Provider not found")) - - try: - model.model_type = data.model_type - model.provider_name = data.provider_name - model.api_key = data.api_key - model.endpoint_url = data.endpoint_url - model.encrypted_config = data.encrypted_config - model.is_vaild = data.is_vaild - model.save(session) - session.refresh(model) - logger.info( - "Provider updated", extra={"user_id": user_id, "provider_id": id, "provider_name": data.provider_name} - ) - return model - except Exception as e: - logger.error( - "Provider update failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.delete("/provider/{id}", name="delete provider") -async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Delete a provider.""" - user_id = auth.user.id - model = session.exec( - select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id) - ).one_or_none() - if not model: - logger.warning("Provider not found for deletion", extra={"user_id": user_id, "provider_id": id}) - raise HTTPException(status_code=404, detail=_("Provider not found")) - - try: - model.delete(session) - logger.info("Provider deleted", extra={"user_id": user_id, "provider_id": id}) - return Response(status_code=204) - except Exception as e: - logger.error( - "Provider deletion failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.post("/provider/prefer", name="set provider prefer") -async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Set preferred provider for user.""" - user_id = auth.user.id - provider_id = data.provider_id - - try: - # 1. Set all current user's providers prefer to false - session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False)) - # 2. Set the prefer of the specified provider_id to true - session.exec( - update(Provider) - .where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id) - .values(prefer=True) - ) - session.commit() - logger.info("Preferred provider set", extra={"user_id": user_id, "provider_id": provider_id}) - return {"success": True} - except SQLAlchemyError as e: - session.rollback() - logger.error( - "Failed to set preferred provider", - extra={"user_id": user_id, "provider_id": provider_id, "error": str(e)}, - exc_info=True, - ) - raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/controller/trigger/slack_controller.py b/server/app/controller/trigger/slack_controller.py deleted file mode 100644 index 8981c44f..00000000 --- a/server/app/controller/trigger/slack_controller.py +++ /dev/null @@ -1,135 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from fastapi import APIRouter, Depends, HTTPException -from sqlmodel import Session, select, and_ -from typing import Optional, List -import logging -from pydantic import BaseModel - -from app.model.config.config import Config -from app.type.config_group import ConfigGroup -from app.component.auth import Auth, auth_must -from app.component.database import session - -logger = logging.getLogger("server_slack_controller") - - -class SlackChannelOut(BaseModel): - """Output model for Slack channels.""" - id: str - name: str - is_private: bool = False - is_member: bool = False - num_members: Optional[int] = None - - -class SlackChannelsResponse(BaseModel): - """Response model for Slack channels list.""" - channels: List[SlackChannelOut] - has_credentials: bool - - -router = APIRouter(prefix="/trigger/slack", tags=["Slack Integration"]) - - -@router.get("/channels", name="get slack channels") -def get_slack_channels( - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -) -> SlackChannelsResponse: - """ - Get list of Slack channels for the authenticated user. - - This endpoint fetches channels from the user's Slack workspace using their - stored credentials. Requires SLACK_BOT_TOKEN to be configured in user configs. - """ - user_id = auth.user.id - - # Get Slack credentials from config - configs = session.exec( - select(Config).where( - and_( - Config.user_id == int(user_id), - Config.config_group == ConfigGroup.SLACK.value - ) - ) - ).all() - - credentials = {config.config_name: config.config_value for config in configs} - bot_token = credentials.get("SLACK_BOT_TOKEN") - - if not bot_token: - logger.warning("Slack credentials not found", extra={"user_id": user_id}) - return SlackChannelsResponse(channels=[], has_credentials=False) - - try: - from slack_sdk import WebClient - from slack_sdk.errors import SlackApiError - - client = WebClient(token=bot_token) - - # Fetch all channels (public and private the bot has access to) - channels = [] - cursor = None - - while True: - response = client.conversations_list( - types="public_channel,private_channel", - cursor=cursor, - limit=200 - ) - - for channel in response.get("channels", []): - channels.append(SlackChannelOut( - id=channel.get("id"), - name=channel.get("name"), - is_private=channel.get("is_private", False), - is_member=channel.get("is_member", False), - num_members=channel.get("num_members") - )) - - # Check for pagination - cursor = response.get("response_metadata", {}).get("next_cursor") - if not cursor: - break - - logger.info("Slack channels fetched", extra={ - "user_id": user_id, - "channel_count": len(channels) - }) - - return SlackChannelsResponse(channels=channels, has_credentials=True) - - except ImportError: - logger.error("slack_sdk not installed") - raise HTTPException( - status_code=500, - detail="Slack SDK not installed on server" - ) - except SlackApiError as e: - logger.error("Slack API error", extra={ - "user_id": user_id, - "error": str(e) - }) - raise HTTPException( - status_code=400, - detail=f"Slack API error: {e.response.get('error', 'Unknown error')}" - ) - except Exception as e: - logger.error("Error fetching Slack channels", extra={ - "user_id": user_id, - "error": str(e) - }, exc_info=True) - raise HTTPException(status_code=500, detail="Failed to fetch Slack channels") diff --git a/server/app/controller/trigger/trigger_controller.py b/server/app/controller/trigger/trigger_controller.py deleted file mode 100644 index 36e02cac..00000000 --- a/server/app/controller/trigger/trigger_controller.py +++ /dev/null @@ -1,731 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from fastapi import APIRouter, Depends, HTTPException, Response, Query -from fastapi_pagination import Page -from fastapi_pagination.ext.sqlmodel import paginate -from sqlmodel import Session, select, desc, and_, delete -from typing import Optional -from uuid import uuid4 -import logging -from pydantic import ValidationError - -from app.model.trigger.trigger import Trigger, TriggerIn, TriggerOut, TriggerUpdate, TriggerConfigSchemaOut -from app.model.trigger.trigger_execution import TriggerExecution, TriggerExecutionOut -from app.model.trigger.app_configs import ( - get_config_schema, - validate_config, - has_config, - validate_activation, - ActivationError, -) -from app.model.trigger.app_configs.config_registry import requires_authentication -from app.model.chat.chat_history import ChatHistory -from app.type.trigger_types import TriggerType, TriggerStatus -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.component.redis_utils import get_redis_manager -from app.service.trigger.trigger_schedule_service import TriggerScheduleService -from fastapi_babel import _ -from sqlalchemy import func - -logger = logging.getLogger("server_trigger_controller") - - -ACTIVE_STATUSES = (TriggerStatus.active, TriggerStatus.pending_verification) -MAX_ACTIVE_PER_USER = 25 -MAX_ACTIVE_PER_PROJECT = 5 - - -def get_active_trigger_counts(session: Session, user_id: str, project_id: str | None = None) -> tuple[int, int]: - """Return (user_active_count, project_active_count) for active/pending triggers.""" - user_count = session.exec( - select(func.count(Trigger.id)).where( - and_( - Trigger.user_id == user_id, - Trigger.status.in_(ACTIVE_STATUSES), # type: ignore[attr-defined] - ) - ) - ).one() - - project_count = 0 - if project_id: - project_count = session.exec( - select(func.count(Trigger.id)).where( - and_( - Trigger.user_id == user_id, - Trigger.project_id == project_id, - Trigger.status.in_(ACTIVE_STATUSES), # type: ignore[attr-defined] - ) - ) - ).one() - - return user_count, project_count - - -def get_execution_counts(session: Session, trigger_ids: list[int]) -> dict[int, int]: - """Get execution counts for multiple triggers in a single query.""" - if not trigger_ids: - return {} - - result = session.exec( - select(TriggerExecution.trigger_id, func.count(TriggerExecution.id)) - .where(TriggerExecution.trigger_id.in_(trigger_ids)) - .group_by(TriggerExecution.trigger_id) - ).all() - - return {trigger_id: count for trigger_id, count in result} - - -def trigger_to_out(trigger: Trigger, execution_count: int = 0) -> TriggerOut: - """Convert Trigger model to TriggerOut with execution count.""" - return TriggerOut( - id=trigger.id, - user_id=trigger.user_id, - project_id=trigger.project_id, - name=trigger.name, - description=trigger.description, - trigger_type=trigger.trigger_type, - status=trigger.status, - execution_count=execution_count, - webhook_url=trigger.webhook_url, - webhook_method=trigger.webhook_method, - custom_cron_expression=trigger.custom_cron_expression, - listener_type=trigger.listener_type, - agent_model=trigger.agent_model, - task_prompt=trigger.task_prompt, - config=trigger.config, - max_executions_per_hour=trigger.max_executions_per_hour, - max_executions_per_day=trigger.max_executions_per_day, - is_single_execution=trigger.is_single_execution, - last_executed_at=trigger.last_executed_at, - next_run_at=trigger.next_run_at, - last_execution_status=trigger.last_execution_status, - created_at=trigger.created_at, - updated_at=trigger.updated_at, - ) - - -router = APIRouter(prefix="/trigger", tags=["Triggers"]) - -@router.post("/", name="create trigger", response_model=TriggerOut) -def create_trigger( - data: TriggerIn, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -): - """Create a new trigger.""" - user_id = auth.user.id - - try: - # Check if project_id exists in chat_history, if not create one - if data.project_id: - existing_chat = session.exec( - select(ChatHistory).where(ChatHistory.project_id == data.project_id) - ).first() - - if not existing_chat: - # Create a new chat_history for this project - chat_history = ChatHistory( - user_id=user_id, - task_id=data.project_id, # Using project_id as task_id - project_id=data.project_id, - question=f"Project created via trigger: {data.name}", - language="en", - model_platform=data.agent_model or "none", - model_type=data.agent_model or "none", - installed_mcp="none", #Expects String - api_key="", - api_url="", - max_retries=3, - project_name=data.name, - summary=data.description or "", - tokens=0, - spend=0, - status=2 # completed status - ) - session.add(chat_history) - session.commit() - session.refresh(chat_history) - - logger.info("Chat history created for new project", extra={ - "user_id": user_id, - "project_id": data.project_id, - "chat_history_id": chat_history.id - }) - - # Send WebSocket notification about new project - try: - redis_manager = get_redis_manager() - redis_manager.publish_execution_event({ - "type": "project_created", - "project_id": data.project_id, - "project_name": data.name, - "chat_history_id": chat_history.id, - "trigger_name": data.name, - "user_id": str(user_id), - "created_at": chat_history.created_at.isoformat() if chat_history.created_at else None - }) - logger.debug("WebSocket notification sent for new project", extra={ - "user_id": user_id, - "project_id": data.project_id - }) - except Exception as e: - logger.warning("Failed to send WebSocket notification for new project", extra={ - "user_id": user_id, - "project_id": data.project_id, - "error": str(e) - }) - - # Generate webhook URL for webhook-based triggers - webhook_url = None - if data.trigger_type in (TriggerType.webhook, TriggerType.slack_trigger): - webhook_url = f"/webhook/trigger/{uuid4()}" - - # Validate trigger-type specific config - if data.config and has_config(data.trigger_type): - try: - validate_config(data.trigger_type, data.config) - except ValidationError as e: - logger.warning("Invalid trigger config", extra={ - "user_id": user_id, - "trigger_type": data.trigger_type.value, - "errors": e.errors() - }) - raise HTTPException( - status_code=400, - detail=f"Invalid config for {data.trigger_type.value}: {e.errors()}" - ) - - # Create trigger instance - trigger_data = data.model_dump() - trigger_data["user_id"] = str(user_id) - trigger_data["webhook_url"] = webhook_url - - # Determine desired initial status based on auth requirements - if has_config(data.trigger_type) and data.config and requires_authentication(data.trigger_type, data.config): - desired_status = TriggerStatus.pending_verification - else: - desired_status = TriggerStatus.active - - # Check concurrent active-trigger limits before auto-activating - user_active, project_active = get_active_trigger_counts( - session, str(user_id), data.project_id - ) - if user_active >= MAX_ACTIVE_PER_USER or ( - data.project_id and project_active >= MAX_ACTIVE_PER_PROJECT - ): - logger.info( - "Active trigger limit reached — new trigger created as inactive", - extra={ - "user_id": user_id, - "project_id": data.project_id, - "user_active": user_active, - "project_active": project_active, - }, - ) - trigger_data["status"] = TriggerStatus.inactive - else: - trigger_data["status"] = desired_status - - trigger = Trigger(**trigger_data) - session.add(trigger) - session.commit() - session.refresh(trigger) - - # Calculate next_run_at for scheduled triggers - if trigger.trigger_type == TriggerType.schedule and trigger.custom_cron_expression: - schedule_service = TriggerScheduleService(session) - trigger.next_run_at = schedule_service.calculate_next_run_at(trigger) - session.add(trigger) - session.commit() - session.refresh(trigger) - - logger.info("Trigger created", extra={ - "user_id": user_id, - "trigger_id": trigger.id, - "trigger_type": data.trigger_type.value, - "next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None - }) - - return trigger_to_out(trigger, 0) # New trigger has 0 executions - - except Exception as e: - session.rollback() - logger.error("Trigger creation failed", extra={ - "user_id": user_id, - "error": str(e) - }, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.get("/", name="list triggers") -def list_triggers( - trigger_type: Optional[TriggerType] = Query(None, description="Filter by trigger type"), - status: Optional[TriggerStatus] = Query(None, description="Filter by status"), - project_id: Optional[str] = Query(None, description="Filter by project ID"), - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -) -> Page[TriggerOut]: - """List triggers for current user.""" - user_id = auth.user.id - - # Build query with filters - conditions = [Trigger.user_id == str(user_id)] - - if trigger_type: - conditions.append(Trigger.trigger_type == trigger_type) - - if status is not None: - conditions.append(Trigger.status == status) - - if project_id: - conditions.append(Trigger.project_id == project_id) - - stmt = ( - select(Trigger) - .where(and_(*conditions)) - .order_by(desc(Trigger.created_at)) - ) - - result = paginate(session, stmt) - total = result.total if hasattr(result, 'total') else 0 - - # Get execution counts for all triggers in the result - trigger_ids = [t.id for t in result.items] - counts = get_execution_counts(session, trigger_ids) - - # Convert triggers to TriggerOut with execution counts - result.items = [trigger_to_out(t, counts.get(t.id, 0)) for t in result.items] - - logger.debug("Triggers listed", extra={ - "user_id": user_id, - "total": total, - "filters": { - "trigger_type": trigger_type.value if trigger_type else None, - "status": status.value if status is not None else None, - "project_id": project_id - } - }) - - return result - -@router.get("/{trigger_id}", name="get trigger", response_model=TriggerOut) -def get_trigger( - trigger_id: int, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -): - """Get a specific trigger by ID.""" - user_id = auth.user.id - - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - - # Get execution count - counts = get_execution_counts(session, [trigger_id]) - execution_count = counts.get(trigger_id, 0) - - logger.debug("Trigger retrieved", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - - return trigger_to_out(trigger, execution_count) - - -@router.put("/{trigger_id}", name="update trigger", response_model=TriggerOut) -def update_trigger( - trigger_id: int, - data: TriggerUpdate, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -): - """Update a trigger.""" - user_id = auth.user.id - - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found for update", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - - try: - update_data = data.model_dump(exclude_unset=True) - - # Validate config if being updated - if "config" in update_data and update_data["config"] is not None: - if has_config(trigger.trigger_type): - try: - validate_config(trigger.trigger_type, update_data["config"]) - except ValidationError as e: - logger.warning("Invalid trigger config on update", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "trigger_type": trigger.trigger_type.value, - "errors": e.errors() - }) - raise HTTPException( - status_code=400, - detail=f"Invalid config for {trigger.trigger_type.value}: {e.errors()}" - ) - - for key, value in update_data.items(): - setattr(trigger, key, value) - - # Recalculate next_run_at if cron expression or status changed for scheduled triggers - if trigger.trigger_type == TriggerType.schedule: - if "custom_cron_expression" in update_data or "status" in update_data: - if trigger.status == TriggerStatus.active and trigger.custom_cron_expression: - schedule_service = TriggerScheduleService(session) - trigger.next_run_at = schedule_service.calculate_next_run_at(trigger) - - session.add(trigger) - session.commit() - session.refresh(trigger) - - # Get execution count - counts = get_execution_counts(session, [trigger_id]) - execution_count = counts.get(trigger_id, 0) - - logger.info("Trigger updated", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "fields_updated": list(update_data.keys()), - "next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None - }) - - return trigger_to_out(trigger, execution_count) - - except Exception as e: - session.rollback() - logger.error("Trigger update failed", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "error": str(e) - }, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.delete("/{trigger_id}", name="delete trigger") -def delete_trigger( - trigger_id: int, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -): - """Delete a trigger.""" - user_id = auth.user.id - - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found for deletion", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - - try: - # Delete execution logs first (bulk delete) - session.exec( - delete(TriggerExecution).where( - TriggerExecution.trigger_id == trigger_id - ) - ) - - # Then delete the trigger - session.delete(trigger) - - session.commit() - - logger.info("Trigger deleted", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - - return Response(status_code=204) - - except Exception as e: - session.rollback() - logger.error("Trigger deletion failed", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "error": str(e) - }, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.post("/{trigger_id}/activate", name="activate trigger", response_model=TriggerOut) -def activate_trigger( - trigger_id: int, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -): - """Activate a trigger.""" - user_id = auth.user.id - - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found for activation", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - - try: - # --- Concurrent active-trigger limits --- - user_active, project_active = get_active_trigger_counts( - session, str(user_id), trigger.project_id - ) - if user_active >= MAX_ACTIVE_PER_USER: - logger.warning("User active trigger limit reached", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "current_active": user_active, - "limit": MAX_ACTIVE_PER_USER, - }) - raise HTTPException( - status_code=400, - detail=f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_USER}) reached for this user" - ) - - if trigger.project_id and project_active >= MAX_ACTIVE_PER_PROJECT: - logger.warning("Project active trigger limit reached", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "project_id": trigger.project_id, - "current_active": project_active, - "limit": MAX_ACTIVE_PER_PROJECT, - }) - raise HTTPException( - status_code=400, - detail=f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_PROJECT}) reached for this project" - ) - - # Check if authentication is required first — auth-required triggers - # go straight to pending_verification (credentials are provided via - # the auth flow, so missing-credential errors are expected). - if has_config(trigger.trigger_type) and requires_authentication(trigger.trigger_type, trigger.config): - trigger.status = TriggerStatus.pending_verification - logger.info("Trigger set to pending verification (authentication required)", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "trigger_type": trigger.trigger_type.value - }) - # Save the status change before raising the exception - session.add(trigger) - session.commit() - session.refresh(trigger) - raise HTTPException( - status_code=401, - detail={ - "message": "Authentication required for this trigger type", - "missing_requirements": ["authentication"], - "trigger_type": trigger.trigger_type.value - } - ) - - # For non-auth triggers, validate activation requirements - if has_config(trigger.trigger_type): - try: - validate_activation( - trigger_type=trigger.trigger_type, - config_data=trigger.config, - user_id=int(user_id), - session=session - ) - except ActivationError as e: - logger.warning("Trigger activation requirements not met", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "trigger_type": trigger.trigger_type.value, - "missing_requirements": e.missing_requirements - }) - raise HTTPException( - status_code=400, - detail={ - "message": e.message, - "missing_requirements": e.missing_requirements, - "trigger_type": trigger.trigger_type.value - } - ) - - trigger.status = TriggerStatus.active - session.add(trigger) - session.commit() - session.refresh(trigger) - - # Get execution count - counts = get_execution_counts(session, [trigger_id]) - execution_count = counts.get(trigger_id, 0) - - logger.info("Trigger status updated", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "status": trigger.status.value - }) - - return trigger_to_out(trigger, execution_count) - - except HTTPException: - raise - except Exception as e: - session.rollback() - logger.error("Trigger activation failed", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "error": str(e) - }, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.post("/{trigger_id}/deactivate", name="deactivate trigger", response_model=TriggerOut) -def deactivate_trigger( - trigger_id: int, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -): - """Deactivate a trigger.""" - user_id = auth.user.id - - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found for deactivation", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - - try: - trigger.status = TriggerStatus.inactive - session.add(trigger) - session.commit() - session.refresh(trigger) - - # Get execution count - counts = get_execution_counts(session, [trigger_id]) - execution_count = counts.get(trigger_id, 0) - - logger.info("Trigger deactivated", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - - return trigger_to_out(trigger, execution_count) - - except Exception as e: - session.rollback() - logger.error("Trigger deactivation failed", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "error": str(e) - }, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") - - -@router.get("/{trigger_id}/executions", name="list trigger executions") -def list_trigger_executions( - trigger_id: int, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) -) -> Page[TriggerExecutionOut]: - """List executions for a specific trigger.""" - user_id = auth.user.id - - # First verify the trigger belongs to the user - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found for executions list", extra={ - "user_id": user_id, - "trigger_id": trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - - # Get executions for this trigger - stmt = ( - select(TriggerExecution) - .where(TriggerExecution.trigger_id == trigger_id) - .order_by(desc(TriggerExecution.created_at)) - ) - - result = paginate(session, stmt) - total = result.total if hasattr(result, 'total') else 0 - - logger.debug("Trigger executions listed", extra={ - "user_id": user_id, - "trigger_id": trigger_id, - "total": total - }) - - return result - - -# ============================================================================ -# Trigger Config Endpoints -# ============================================================================ - -@router.get("/{trigger_type}/config", name="get trigger type config schema") -def get_trigger_type_config( - trigger_type: TriggerType, - auth: Auth = Depends(auth_must) -) -> TriggerConfigSchemaOut: - """ - Get the configuration schema for a specific trigger type. - - This endpoint returns the JSON schema for the trigger type's config field, - which can be used by the frontend to dynamically render configuration forms. - """ - schema = get_config_schema(trigger_type) - - return TriggerConfigSchemaOut( - trigger_type=trigger_type.value, - has_config=has_config(trigger_type), - schema_=schema - ) \ No newline at end of file diff --git a/server/app/controller/user/login_controller.py b/server/app/controller/user/login_controller.py deleted file mode 100644 index 0d23b413..00000000 --- a/server/app/controller/user/login_controller.py +++ /dev/null @@ -1,226 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging - -from fastapi import APIRouter, Depends, Form, HTTPException -from fastapi_babel import _ -from sqlmodel import Session - -from app.component import code -from app.component.auth import Auth -from app.component.database import session -from app.component.encrypt import password_verify -from app.component.stack_auth import StackAuth -from app.exception.exception import UserException -from app.model.user.user import ( - LoginByPasswordIn, - LoginResponse, - RegisterIn, - Status, - User, -) - -logger = logging.getLogger("server_login_controller") - -router = APIRouter(tags=["Login/Registration"]) - - -@router.post("/login", name="login by email or password") -async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse: - """ - User login with email and password - """ - email = data.email - user = User.by(User.email == email, s=session).one_or_none() - - if not user: - logger.warning("Login failed: user not found", extra={"email": email}) - raise UserException(code.password, _("Account or password error")) - - if not password_verify(data.password, user.password): - logger.warning("Login failed: invalid password", extra={"user_id": user.id, "email": email}) - raise UserException(code.password, _("Account or password error")) - - logger.info("User login successful", extra={"user_id": user.id, "email": email}) - return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) - - -@router.post("/dev_login", name="OAuth2 password flow login (for Swagger UI)") -async def dev_login( - username: str = Form(...), # OAuth2 uses 'username' but we accept email - password: str = Form(...), - session: Session = Depends(session), -) -> dict: - """ - OAuth2 password flow compatible login endpoint for Swagger UI. - This endpoint accepts form data (username/password) and returns an access token. - """ - # Use username as email (OAuth2 standard uses 'username' field) - email = username - user = User.by(User.email == email, s=session).one_or_none() - - if not user: - logger.warning("OAuth2 login failed: user not found", extra={"email": email}) - raise HTTPException(status_code=401, detail="Incorrect username or password") - - if not password_verify(password, user.password): - logger.warning( - "OAuth2 login failed: invalid password", - extra={"user_id": user.id, "email": email}, - ) - raise HTTPException(status_code=401, detail="Incorrect username or password") - - token = Auth.create_access_token(user.id) - logger.info("OAuth2 login successful", extra={"user_id": user.id, "email": email}) - - # Return OAuth2 compatible response - return {"access_token": token, "token_type": "bearer"} - - -@router.post("/login-by_stack", name="login by stack") -async def by_stack_auth( - token: str, - type: str = "signup", - invite_code: str | None = None, - session: Session = Depends(session), -): - try: - stack_id = await StackAuth.user_id(token) - info = await StackAuth.user_info(token) - except Exception as e: - logger.error("Stack auth failed", extra={"type": type, "error": str(e)}, exc_info=True) - raise HTTPException(500, detail=_("Authentication failed")) - - user = User.by(User.stack_id == stack_id, s=session).one_or_none() - - if not user: - if type != "signup": - logger.warning( - "Stack auth signup blocked: user not found", - extra={"stack_id": stack_id, "type": type}, - ) - raise UserException(code.error, _("User not found")) - - with session as s: - try: - user = User( - username=info["username"] if "username" in info else None, - nickname=info["display_name"], - email=info["primary_email"], - avatar=info["profile_image_url"], - stack_id=stack_id, - ) - s.add(user) - s.commit() - s.refresh(user) - logger.info( - "New user registered via stack", - extra={ - "user_id": user.id, - "email": user.email, - "stack_id": stack_id, - }, - ) - return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) - except Exception as e: - s.rollback() - logger.error( - "Stack auth registration failed", - extra={"stack_id": stack_id, "error": str(e)}, - exc_info=True, - ) - raise UserException(code.error, _("Failed to register")) - else: - if user.status == Status.Block: - logger.warning( - "Blocked user login attempt", - extra={"user_id": user.id, "stack_id": stack_id}, - ) - raise UserException(code.error, _("Your account has been blocked.")) - - logger.info( - "User login via stack successful", - extra={"user_id": user.id, "email": user.email, "stack_id": stack_id}, - ) - return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) - - -@router.post("/auto-login", name="auto login for local mode") -async def auto_login(session: Session = Depends(session)) -> LoginResponse: - """ - Auto login for fully local mode (VITE_USE_LOCAL_PROXY=true). - Returns the most recently active user, or creates a default admin user if none exists. - """ - # Find the most recently active user - user = User.by( - User.status == Status.Normal, - order_by=User.updated_at.desc(), - limit=1, - s=session, - ).one_or_none() - - if not user: - # Create default admin user - with session as s: - try: - user = User( - email="admin@eigent.local", - username="admin", - nickname="Admin", - ) - s.add(user) - s.commit() - s.refresh(user) - logger.info("Default admin user created", extra={"user_id": user.id}) - except Exception as e: - s.rollback() - logger.error("Failed to create default admin user", extra={"error": str(e)}, exc_info=True) - raise UserException(code.error, _("Failed to create default user")) - - logger.info("Auto login successful", extra={"user_id": user.id, "email": user.email}) - return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) - - -@router.post("/register", name="register by email/password") -async def register(data: RegisterIn, session: Session = Depends(session)): - email = data.email - - if User.by(User.email == email, s=session).one_or_none(): - logger.warning("Registration failed: email already exists", extra={"email": email}) - raise UserException(code.error, _("Email already registered")) - - with session as s: - try: - user = User( - email=email, - password=data.password, - ) - s.add(user) - s.commit() - s.refresh(user) - logger.info( - "User registered successfully", - extra={"user_id": user.id, "email": email}, - ) - except Exception as e: - s.rollback() - logger.error( - "User registration failed", - extra={"email": email, "error": str(e)}, - exc_info=True, - ) - raise UserException(code.error, _("Failed to register")) - - return {"status": "success"} diff --git a/server/app/controller/user/user_controller.py b/server/app/controller/user/user_controller.py deleted file mode 100644 index 9f75d74b..00000000 --- a/server/app/controller/user/user_controller.py +++ /dev/null @@ -1,168 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging - -from fastapi import APIRouter, Depends -from sqlalchemy import func -from sqlmodel import Session, select - -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.model.chat.chat_history import ChatHistory -from app.model.chat.chat_snpshot import ChatSnapshot -from app.model.config.config import Config -from app.model.mcp.mcp_user import McpUser -from app.model.user.privacy import UserPrivacy, UserPrivacySettings -from app.model.user.user import User, UserIn, UserOut, UserProfile -from app.model.user.user_credits_record import UserCreditsRecord -from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut - -logger = logging.getLogger("server_user_controller") - -router = APIRouter(tags=["User"]) - - -@router.get("/user", name="user info", response_model=UserOut) -def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)): - """Get current user information and refresh credits.""" - user: User = auth.user - user.refresh_credits_on_active(session) - logger.debug("User info retrieved", extra={"user_id": user.id}) - return user - - -@router.put("/user", name="update user info", response_model=UserOut) -def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Update user basic information.""" - model = auth.user - model.username = data.username - model.save(session) - logger.info("User info updated", extra={"user_id": model.id, "username": data.username}) - return model - - -@router.put("/user/profile", name="update user profile", response_model=UserProfile) -def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Update user profile details.""" - model = auth.user - model.nickname = data.nickname - model.fullname = data.fullname - model.work_desc = data.work_desc - model.save(session) - logger.info("User profile updated", extra={"user_id": model.id, "nickname": data.nickname}) - return model - - -@router.get("/user/privacy", name="get user privacy") -def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Get user privacy settings.""" - user_id = auth.user.id - stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id) - model = session.exec(stmt).one_or_none() - - if not model: - logger.debug("Privacy settings not found, returning defaults", extra={"user_id": user_id}) - return UserPrivacySettings.default_settings() - - logger.debug("Privacy settings retrieved", extra={"user_id": user_id}) - return UserPrivacySettings(**model.pricacy_setting).to_response() - - -@router.put("/user/privacy", name="update user privacy") -def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)): - """Update user privacy settings.""" - user_id = auth.user.id - stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id) - model = session.exec(stmt).one_or_none() - default_settings = UserPrivacySettings.default_settings() - - if model: - model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()} - model.save(session) - logger.info("Privacy settings updated", extra={"user_id": user_id}) - else: - model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()}) - model.save(session) - logger.info("Privacy settings created", extra={"user_id": user_id}) - - return UserPrivacySettings(**model.pricacy_setting).to_response() - - -@router.get("/user/current_credits", name="get user current credits") -def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)): - """Get user's current credit balance.""" - user = auth.user - user.refresh_credits_on_active(session) - credits = user.credits - daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id) - current_daily_credits = 0 - if daily_credits: - current_daily_credits = daily_credits.amount - daily_credits.balance - credits += current_daily_credits if current_daily_credits > 0 else 0 - - logger.debug( - "Credits retrieved", - extra={"user_id": user.id, "total_credits": credits, "daily_credits": current_daily_credits}, - ) - return {"credits": credits, "daily_credits": current_daily_credits} - - -@router.get("/user/stat", name="get user stat", response_model=UserStatOut) -def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)): - """Get current user's operation statistics.""" - user_id = auth.user.id - stat = session.exec(select(UserStat).where(UserStat.user_id == user_id)).first() - data = UserStatOut() - - if stat: - data = UserStatOut(**stat.model_dump()) - else: - data = UserStatOut(user_id=user_id) - - data.task_queries = ChatHistory.count(ChatHistory.user_id == user_id, s=session) - mcp = McpUser.count(McpUser.user_id == user_id, s=session) - tool: list = session.exec( - select(func.count("*")).where(Config.user_id == user_id).group_by(Config.config_group) - ).all() - tool = tool.__len__() - data.mcp_install_count = mcp + tool - data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(user_id)) - - logger.debug( - "User stats retrieved", - extra={ - "user_id": user_id, - "task_queries": data.task_queries, - "mcp_install_count": data.mcp_install_count, - "storage_used": data.storage_used, - }, - ) - return data - - -@router.post("/user/stat", name="record user stat") -def record_user_stat( - data: UserStatActionIn, - auth: Auth = Depends(auth_must), - session: Session = Depends(session), -): - """Record or update current user's operation statistics.""" - data.user_id = auth.user.id - stat = UserStat.record_action(session, data) - logger.info( - "User stat recorded", - extra={"user_id": data.user_id, "action": data.action if hasattr(data, "action") else "unknown"}, - ) - return stat diff --git a/server/app/component/babel.py b/server/app/core/babel.py similarity index 98% rename from server/app/component/babel.py rename to server/app/core/babel.py index 0a14de79..6dea770c 100644 --- a/server/app/component/babel.py +++ b/server/app/core/babel.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from pathlib import Path diff --git a/server/app/component/celery.py b/server/app/core/celery.py similarity index 85% rename from server/app/component/celery.py rename to server/app/core/celery.py index c5c28d4f..1bf1fd8e 100644 --- a/server/app/component/celery.py +++ b/server/app/core/celery.py @@ -13,7 +13,7 @@ # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from celery import Celery -from app.component.environment import env_or_fail, env +from app.core.environment import env_or_fail, env celery = Celery( __name__, @@ -23,7 +23,7 @@ celery = Celery( # Configure Celery to autodiscover tasks celery.conf.imports = [ - "app.schedule.trigger_schedule_task", + "app.domains.trigger.service.trigger_schedule_task", ] # Configure Celery Beat schedule @@ -37,14 +37,14 @@ celery.conf.beat_schedule = {} if ENABLE_TRIGGER_SCHEDULE_POLLER: celery.conf.beat_schedule["poll-trigger-schedules"] = { - "task": "app.schedule.trigger_schedule_task.poll_trigger_schedules", + "task": "app.domains.trigger.service.trigger_schedule_task.poll_trigger_schedules", "schedule": TRIGGER_SCHEDULE_POLLER_INTERVAL * 60.0, # Convert minutes to seconds "options": {"queue": "poll_trigger_schedules"}, } if ENABLE_EXECUTION_TIMEOUT_CHECKER: celery.conf.beat_schedule["check-execution-timeouts"] = { - "task": "app.schedule.trigger_schedule_task.check_execution_timeouts", + "task": "app.domains.trigger.service.trigger_schedule_task.check_execution_timeouts", "schedule": EXECUTION_TIMEOUT_CHECKER_INTERVAL * 60.0, # Convert minutes to seconds "options": {"queue": "check_execution_timeouts"}, } diff --git a/server/app/component/code.py b/server/app/core/code.py similarity index 98% rename from server/app/component/code.py rename to server/app/core/code.py index ff54e19b..caa38c96 100644 --- a/server/app/component/code.py +++ b/server/app/core/code.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= success = 0 # success response code diff --git a/server/app/component/database.py b/server/app/core/database.py similarity index 96% rename from server/app/component/database.py rename to server/app/core/database.py index 6c68abfd..4eaf078d 100644 --- a/server/app/component/database.py +++ b/server/app/core/database.py @@ -1,22 +1,22 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import logging from sqlmodel import Session, create_engine -from app.component.environment import env, env_or_fail +from app.core.environment import env, env_or_fail logger = logging.getLogger("database") diff --git a/server/app/component/encrypt.py b/server/app/core/encrypt.py similarity index 98% rename from server/app/component/encrypt.py rename to server/app/core/encrypt.py index 6e7ea536..cf82f386 100644 --- a/server/app/component/encrypt.py +++ b/server/app/core/encrypt.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from passlib.context import CryptContext diff --git a/server/app/component/environment.py b/server/app/core/environment.py similarity index 99% rename from server/app/component/environment.py rename to server/app/core/environment.py index 5f03a54c..7fef56ce 100644 --- a/server/app/component/environment.py +++ b/server/app/core/environment.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import importlib diff --git a/server/app/component/oauth_adapter.py b/server/app/core/oauth_adapter.py similarity index 99% rename from server/app/component/oauth_adapter.py rename to server/app/core/oauth_adapter.py index 9a5a946c..ab568816 100644 --- a/server/app/component/oauth_adapter.py +++ b/server/app/core/oauth_adapter.py @@ -20,7 +20,7 @@ from urllib.parse import urlencode import httpx from pydantic import BaseModel -from app.component.environment import env +from app.core.environment import env class OAuthAdapter(ABC): diff --git a/server/app/component/pydantic/i18n.py b/server/app/core/pydantic/i18n.py similarity index 97% rename from server/app/component/pydantic/i18n.py rename to server/app/core/pydantic/i18n.py index a59b1380..384a9664 100644 --- a/server/app/component/pydantic/i18n.py +++ b/server/app/core/pydantic/i18n.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import os @@ -19,7 +19,7 @@ from pathlib import Path from fastapi_babel.middleware import LANGUAGES_PATTERN from pydantic_i18n import JsonLoader, PydanticI18n -from app.component.babel import babel, babel_configs +from app.core.babel import babel, babel_configs def get_language(lang_code: str | None = None): diff --git a/server/app/component/pydantic/translations/en_US.json b/server/app/core/pydantic/translations/en_US.json similarity index 100% rename from server/app/component/pydantic/translations/en_US.json rename to server/app/core/pydantic/translations/en_US.json diff --git a/server/app/component/pydantic/translations/zh_CN.json b/server/app/core/pydantic/translations/zh_CN.json similarity index 100% rename from server/app/component/pydantic/translations/zh_CN.json rename to server/app/core/pydantic/translations/zh_CN.json diff --git a/server/app/component/redis_utils.py b/server/app/core/redis_utils.py similarity index 100% rename from server/app/component/redis_utils.py rename to server/app/core/redis_utils.py diff --git a/server/app/component/sqids.py b/server/app/core/sqids.py similarity index 98% rename from server/app/component/sqids.py rename to server/app/core/sqids.py index 65bf1fe0..7ecc9ef2 100644 --- a/server/app/component/sqids.py +++ b/server/app/core/sqids.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from sqids import Sqids diff --git a/server/app/component/trigger_utils.py b/server/app/core/trigger_utils.py similarity index 98% rename from server/app/component/trigger_utils.py rename to server/app/core/trigger_utils.py index b37cf98f..917559e1 100644 --- a/server/app/component/trigger_utils.py +++ b/server/app/core/trigger_utils.py @@ -19,7 +19,7 @@ import logging from sqlmodel import select, and_ from app.model.trigger.trigger_execution import TriggerExecution -from app.component.environment import env +from app.core.environment import env logger = logging.getLogger("server_trigger_utils") diff --git a/server/app/component/validator/McpServer.py b/server/app/core/validator/McpServer.py similarity index 99% rename from server/app/component/validator/McpServer.py rename to server/app/core/validator/McpServer.py index 70d6be17..641e3a87 100644 --- a/server/app/component/validator/McpServer.py +++ b/server/app/core/validator/McpServer.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= diff --git a/server/app/domains/__init__.py b/server/app/domains/__init__.py new file mode 100644 index 00000000..a019e0e1 --- /dev/null +++ b/server/app/domains/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 domain modules.""" diff --git a/server/app/domains/chat/__init__.py b/server/app/domains/chat/__init__.py new file mode 100644 index 00000000..435cdb70 --- /dev/null +++ b/server/app/domains/chat/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Chat domain: history, steps, snapshots, files, share, logs.""" diff --git a/server/app/controller/trigger/__init__.py b/server/app/domains/chat/api/__init__.py similarity index 100% rename from server/app/controller/trigger/__init__.py rename to server/app/domains/chat/api/__init__.py diff --git a/server/app/domains/chat/api/history_controller.py b/server/app/domains/chat/api/history_controller.py new file mode 100644 index 00000000..f2e0c053 --- /dev/null +++ b/server/app/domains/chat/api/history_controller.py @@ -0,0 +1,158 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Chat History controller. Uses ChatService for grouping.""" + +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, Response +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from loguru import logger +from sqlmodel import Session, case, delete, desc, func, select +from fastapi_babel import _ + +from app.core.database import session +from app.model.chat.chat_history import ChatHistory, ChatHistoryIn, ChatHistoryOut, ChatHistoryUpdate, ChatStatus +from app.model.chat.chat_history_grouped import GroupedHistoryResponse, ProjectGroup +from app.model.trigger.trigger import Trigger +from app.model.trigger.trigger_execution import TriggerExecution +from app.model.user.key import Key +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.domains.chat.service.chat_service import ChatService + +router = APIRouter(prefix="/chat", tags=["Chat History"]) + + +@router.post("/history", name="save chat history", response_model=ChatHistoryOut) +def create_chat_history(data: ChatHistoryIn, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + data.user_id = auth.id + chat_history = ChatHistory(**data.model_dump()) + db_session.add(chat_history) + db_session.commit() + db_session.refresh(chat_history) + return chat_history + + +@router.get("/histories", name="get chat history") +def list_chat_history(db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)) -> Page[ChatHistoryOut]: + stmt = ( + select(ChatHistory) + .where(ChatHistory.user_id == auth.id) + .order_by( + desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), + desc(ChatHistory.created_at), + desc(ChatHistory.id), + ) + ) + return paginate(db_session, stmt) + + +@router.get("/histories/grouped", name="get grouped chat history") +def list_grouped_chat_history( + include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in groups"), + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +) -> GroupedHistoryResponse: + return ChatService.get_grouped_histories(auth.id, include_tasks, db_session) + + +@router.get("/histories/grouped/{project_id}", name="get single grouped project") +def get_grouped_project( + project_id: str, + include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in the project"), + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +) -> ProjectGroup: + result = ChatService.get_grouped_project(auth.id, project_id, include_tasks, db_session) + if result is None: + raise HTTPException(status_code=404, detail="Project not found") + return result + + +@router.delete("/history/{history_id}", name="delete chat history") +def delete_chat_history(history_id: int, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + history = db_session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first() + if not history: + raise HTTPException(status_code=404, detail="Chat History not found") + if history.user_id != auth.id: + raise HTTPException(status_code=403, detail="You are not allowed to delete this chat history") + + project_id = history.project_id if history.project_id else history.task_id + + sibling_count = ( + db_session.exec( + select(func.count(ChatHistory.id)).where( + ChatHistory.id != history_id, + ChatHistory.project_id == project_id if history.project_id else ChatHistory.task_id == project_id, + ) + ).first() + or 0 + ) + + db_session.delete(history) + + if sibling_count == 0: + triggers = db_session.exec(select(Trigger).where(Trigger.project_id == project_id)).all() + for trigger in triggers: + db_session.exec(delete(TriggerExecution).where(TriggerExecution.trigger_id == trigger.id)) + db_session.delete(trigger) + logger.info( + "Deleted triggers for removed project", extra={"project_id": project_id, "trigger_count": len(triggers)} + ) + + db_session.commit() + return Response(status_code=204) + + +@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut) +async def update_chat_history( + history_id: int, data: ChatHistoryUpdate, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must) +): + history = db_session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first() + if not history: + raise HTTPException(status_code=404, detail="Chat History not found") + if history.user_id != auth.id: + raise HTTPException(status_code=403, detail="You are not allowed to update this chat history") + + update_data = data.model_dump(exclude_unset=True) + history.update_fields(update_data) + history.save(db_session) + + db_session.refresh(history) + return history + + +@router.put("/project/{project_id}/name", name="update project name") +def update_project_name( + project_id: str, new_name: str, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must) +): + user_id = auth.id + stmt = select(ChatHistory).where(ChatHistory.project_id == project_id).where(ChatHistory.user_id == user_id) + histories = db_session.exec(stmt).all() + + if not histories: + raise HTTPException(status_code=404, detail="Project not found or access denied") + + try: + for history in histories: + history.project_name = new_name + db_session.add(history) + db_session.commit() + return Response(status_code=200) + except Exception as e: + db_session.rollback() + logger.error("Project name update failed", extra={"user_id": user_id, "project_id": project_id, "error": str(e)}) + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/server/app/domains/chat/api/share_controller.py b/server/app/domains/chat/api/share_controller.py new file mode 100644 index 00000000..674c7222 --- /dev/null +++ b/server/app/domains/chat/api/share_controller.py @@ -0,0 +1,91 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Chat Share controller with auth and task ownership on create.""" + +import asyncio +import json + +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session, select +from starlette.responses import StreamingResponse + +from app.core.database import session +from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn +from app.model.chat.chat_step import ChatStep +from app.model.chat.chat_history import ChatHistory + +from app.shared.auth import auth_must +from app.domains.chat.service import ChatService +from app.domains.chat.schema import TaskOwnershipCheckReq +from itsdangerous import BadTimeSignature, SignatureExpired + +router = APIRouter(prefix="/chat", tags=["V1 Chat Share"]) + + +@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut) +def get_share_info(token: str, db_session: Session = Depends(session)): + try: + task_id = ChatShare.verify_token(token, False) + except (SignatureExpired, BadTimeSignature): + raise HTTPException(status_code=400, detail="Share link is invalid or has expired.") + stmt = select(ChatHistory).where(ChatHistory.task_id == task_id) + history = db_session.exec(stmt).one_or_none() + if not history: + raise HTTPException(status_code=404, detail="Chat history not found.") + return history + + +@router.get("/share/playback/{token}", name="Playback shared chat via SSE") +async def share_playback(token: str, db_session: Session = Depends(session), delay_time: float = 0): + if delay_time > 5: + delay_time = 5 + try: + task_id = ChatShare.verify_token(token, False) + except SignatureExpired: + raise HTTPException(status_code=400, detail="Share link has expired.") + except BadTimeSignature: + raise HTTPException(status_code=400, detail="Share link is invalid.") + + async def event_generator(): + stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(ChatStep.id) + steps = db_session.exec(stmt).all() + if not steps: + yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" + return + for s in steps: + step_data = { + "id": s.id, + "task_id": s.task_id, + "step": s.step, + "data": s.data, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + yield f"data: {json.dumps(step_data)}\n\n" + if delay_time > 0 and s.step != "create_agent": + await asyncio.sleep(delay_time) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@router.post("/share", name="Generate sharable link for a task(1 day expiration)") +def create_share_link( + data: ChatShareIn, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + if not ChatService.verify_task_ownership(TaskOwnershipCheckReq(task_id=data.task_id, user_id=auth.user.id)): + raise HTTPException(status_code=403, detail="Task not found or access denied.") + share_token = ChatShare.generate_token(data.task_id) + return {"share_token": share_token} diff --git a/server/app/domains/chat/api/snapshot_controller.py b/server/app/domains/chat/api/snapshot_controller.py new file mode 100644 index 00000000..a4900ab6 --- /dev/null +++ b/server/app/domains/chat/api/snapshot_controller.py @@ -0,0 +1,122 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 Chat Snapshot - H3 auth, H4 ownership, H19 path traversal, P2 Update model. +STATUS: full-rewrite (security: H3, H4, H19, P2 Update model) +""" + +import re +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi_babel import _ +from sqlmodel import Session, select + +from app.core.database import session +from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn, ChatSnapshotUpdate + +from app.shared.auth import auth_must +from app.shared.auth.ownership import require_owner + +router = APIRouter(prefix="/chat", tags=["V1 Chat Snapshot"]) + +# H19: api_task_id must be safe for path - only alphanumeric, dash, underscore +API_TASK_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") + + +def _validate_api_task_id(value: str) -> None: + """Reject path traversal: api_task_id must match safe charset.""" + if not value or not API_TASK_ID_PATTERN.match(value): + raise HTTPException(status_code=400, detail=_("Invalid api_task_id: only letters, numbers, - and _ allowed")) + +@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot]) +async def list_chat_snapshots( + api_task_id: Optional[str] = None, + camel_task_id: Optional[str] = None, + browser_url: Optional[str] = None, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + query = select(ChatSnapshot).where(ChatSnapshot.user_id == auth.user.id) + if api_task_id is not None: + query = query.where(ChatSnapshot.api_task_id == api_task_id) + if camel_task_id is not None: + query = query.where(ChatSnapshot.camel_task_id == camel_task_id) + if browser_url is not None: + query = query.where(ChatSnapshot.browser_url == browser_url) + return list(db_session.exec(query).all()) + + +@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot) +async def get_chat_snapshot( + snapshot_id: int, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + snapshot = db_session.get(ChatSnapshot, snapshot_id) + require_owner(snapshot, auth.user.id) + return snapshot + + +@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot) +async def create_chat_snapshot( + snapshot: ChatSnapshotIn, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + _validate_api_task_id(snapshot.api_task_id) + image_path = ChatSnapshotIn.save_image(auth.user.id, snapshot.api_task_id, snapshot.image_base64) + chat_snapshot = ChatSnapshot( + user_id=auth.user.id, + api_task_id=snapshot.api_task_id, + camel_task_id=snapshot.camel_task_id, + browser_url=snapshot.browser_url, + image_path=image_path, + ) + db_session.add(chat_snapshot) + db_session.commit() + db_session.refresh(chat_snapshot) + return chat_snapshot + + +@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot) +async def update_chat_snapshot( + snapshot_id: int, + snapshot_update: ChatSnapshotUpdate, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + db_snapshot = db_session.get(ChatSnapshot, snapshot_id) + require_owner(db_snapshot, auth.user.id) + for key, value in snapshot_update.model_dump(exclude_unset=True).items(): + if key == "api_task_id" and value is not None: + _validate_api_task_id(str(value)) + setattr(db_snapshot, key, value) + db_session.add(db_snapshot) + db_session.commit() + db_session.refresh(db_snapshot) + return db_snapshot + + +@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot") +async def delete_chat_snapshot( + snapshot_id: int, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + db_snapshot = db_session.get(ChatSnapshot, snapshot_id) + require_owner(db_snapshot, auth.user.id) + db_session.delete(db_snapshot) + db_session.commit() + return Response(status_code=204) diff --git a/server/app/domains/chat/api/step_controller.py b/server/app/domains/chat/api/step_controller.py new file mode 100644 index 00000000..d14c84ad --- /dev/null +++ b/server/app/domains/chat/api/step_controller.py @@ -0,0 +1,162 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 Chat Step - H1 auth, H2 ownership, P2 dedicated Update model. +STATUS: full-rewrite (security: H1, H2, P2 Update model) +""" + +import asyncio +import json +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi.responses import StreamingResponse +from sqlalchemy.sql.expression import case +from sqlmodel import Session, asc, select + +from app.core.database import session +from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn, ChatStepUpdate +from app.model.chat.chat_history import ChatHistory + +from app.shared.auth import auth_must +from app.domains.chat.service import ChatService +from app.domains.chat.schema import TaskOwnershipCheckReq + +router = APIRouter(prefix="/chat", tags=["V1 Chat Step"]) + + +def _task_owned_by_user(db: Session, task_id: str, user_id: int) -> bool: + return ChatService.verify_task_ownership(TaskOwnershipCheckReq(task_id=task_id, user_id=user_id)) + + +@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut]) +async def list_chat_steps( + task_id: str, + step: Optional[str] = None, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + if not _task_owned_by_user(db_session, task_id, auth.user.id): + return [] + query = select(ChatStep).where(ChatStep.task_id == task_id) + if step is not None: + query = query.where(ChatStep.step == step) + return list(db_session.exec(query).all()) + + +@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE") +async def share_playback( + task_id: str, + delay_time: float = 0, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + if delay_time > 5: + delay_time = 5 + if not _task_owned_by_user(db_session, task_id, auth.user.id): + raise HTTPException(status_code=404, detail="Task not found") + + async def event_generator(): + stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by( + asc(case((ChatStep.timestamp.is_(None), 1), else_=0)), + asc(ChatStep.timestamp), + asc(ChatStep.id), + ) + steps = db_session.exec(stmt).all() + if not steps: + yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" + return + for s in steps: + step_data = { + "id": s.id, + "task_id": s.task_id, + "step": s.step, + "data": s.data, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + yield f"data: {json.dumps(step_data)}\n\n" + if delay_time > 0: + await asyncio.sleep(delay_time) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut) +async def get_chat_step( + step_id: int, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + chat_step = db_session.get(ChatStep, step_id) + if not chat_step: + raise HTTPException(status_code=404, detail="Chat step not found") + if not _task_owned_by_user(db_session, chat_step.task_id, auth.user.id): + raise HTTPException(status_code=404, detail="Chat step not found") + return chat_step + + +@router.post("/steps", name="create chat step") +async def create_chat_step( + step: ChatStepIn, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + if not _task_owned_by_user(db_session, step.task_id, auth.user.id): + raise HTTPException(status_code=403, detail="Task not found or access denied") + chat_step = ChatStep( + task_id=step.task_id, + step=step.step, + data=step.data, + timestamp=step.timestamp, + ) + db_session.add(chat_step) + db_session.commit() + db_session.refresh(chat_step) + return {"code": 200, "msg": "success"} + + +@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut) +async def update_chat_step( + step_id: int, + chat_step_update: ChatStepUpdate, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + db_chat_step = db_session.get(ChatStep, step_id) + if not db_chat_step: + raise HTTPException(status_code=404, detail="Chat step not found") + if not _task_owned_by_user(db_session, db_chat_step.task_id, auth.user.id): + raise HTTPException(status_code=404, detail="Chat step not found") + for key, value in chat_step_update.model_dump(exclude_unset=True).items(): + setattr(db_chat_step, key, value) + db_session.add(db_chat_step) + db_session.commit() + db_session.refresh(db_chat_step) + return db_chat_step + + +@router.delete("/steps/{step_id}", name="delete chat step") +async def delete_chat_step( + step_id: int, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + db_chat_step = db_session.get(ChatStep, step_id) + if not db_chat_step: + raise HTTPException(status_code=404, detail="Chat step not found") + if not _task_owned_by_user(db_session, db_chat_step.task_id, auth.user.id): + raise HTTPException(status_code=404, detail="Chat step not found") + db_session.delete(db_chat_step) + db_session.commit() + return Response(status_code=204) diff --git a/server/app/domains/chat/schema/__init__.py b/server/app/domains/chat/schema/__init__.py new file mode 100644 index 00000000..7908ed90 --- /dev/null +++ b/server/app/domains/chat/schema/__init__.py @@ -0,0 +1,27 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Chat domain schemas.""" + +from app.domains.chat.schema.schemas import ( + TaskOwnershipCheckReq, + FileValidationReq, + FileValidationResult, +) + +__all__ = [ + "TaskOwnershipCheckReq", + "FileValidationReq", + "FileValidationResult", +] diff --git a/server/app/controller/health_controller.py b/server/app/domains/chat/schema/schemas.py similarity index 66% rename from server/app/controller/health_controller.py rename to server/app/domains/chat/schema/schemas.py index 9c5fc9ff..351dc458 100644 --- a/server/app/controller/health_controller.py +++ b/server/app/domains/chat/schema/schemas.py @@ -12,18 +12,21 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -from fastapi import APIRouter +"""v1 ChatService request/response schemas.""" + from pydantic import BaseModel -router = APIRouter(tags=["Health"]) + +class TaskOwnershipCheckReq(BaseModel): + task_id: str + user_id: int -class HealthResponse(BaseModel): - status: str - service: str +class FileValidationReq(BaseModel): + filename: str + file_size: int -@router.get("/health", name="health check", response_model=HealthResponse) -async def health_check(): - """Health check endpoint for monitoring and container orchestration.""" - return HealthResponse(status="ok", service="eigent-server") +class FileValidationResult(BaseModel): + valid: bool + error: str | None = None diff --git a/server/app/middleware/__init__.py b/server/app/domains/chat/service/__init__.py similarity index 79% rename from server/app/middleware/__init__.py rename to server/app/domains/chat/service/__init__.py index b6527261..d65c1575 100644 --- a/server/app/middleware/__init__.py +++ b/server/app/domains/chat/service/__init__.py @@ -1,20 +1,19 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -from fastapi_babel import BabelMiddleware +"""Chat domain services.""" -from app import api -from app.component.babel import babel_configs +from app.domains.chat.service.chat_service import ChatService -api.add_middleware(BabelMiddleware, babel_configs=babel_configs) +__all__ = ["ChatService"] diff --git a/server/app/domains/chat/service/chat_service.py b/server/app/domains/chat/service/chat_service.py new file mode 100644 index 00000000..e16c1819 --- /dev/null +++ b/server/app/domains/chat/service/chat_service.py @@ -0,0 +1,242 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""ChatService: task ownership, file validation, history grouping. No billing in eigent.""" + +from collections import defaultdict +from datetime import datetime +from typing import Dict, List + +from loguru import logger +from sqlmodel import Session, case, desc, func, select + +from app.core.database import session_make +from app.model.chat.chat_history import ChatHistory, ChatHistoryOut, ChatStatus +from app.model.chat.chat_history_grouped import GroupedHistoryResponse, ProjectGroup +from app.model.trigger.trigger import Trigger +from app.domains.chat.schema import TaskOwnershipCheckReq, FileValidationReq, FileValidationResult + +ALLOWED_EXTENSIONS = { + "jpg", "jpeg", "png", "gif", "webp", + "pdf", "txt", "md", "csv", + "json", "xml", "yaml", "yml", + "doc", "docx", "xls", "xlsx", + "zip", +} +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB + + +class ChatService: + """Chat domain business logic - static methods, self-managed session.""" + + @staticmethod + def verify_task_ownership(req: TaskOwnershipCheckReq) -> bool: + """Check if task_id belongs to user_id.""" + with session_make() as s: + h = s.exec( + select(ChatHistory) + .where(ChatHistory.task_id == req.task_id, ChatHistory.user_id == req.user_id) + ).first() + return h is not None + + @staticmethod + def validate_file(req: FileValidationReq) -> FileValidationResult: + """Validate filename extension and file size.""" + if not req.filename: + return FileValidationResult(valid=False, error="Filename is required") + ext = req.filename.rsplit(".", 1)[-1].lower() if "." in req.filename else "" + if ext not in ALLOWED_EXTENSIONS: + return FileValidationResult( + valid=False, + error=f"File type '.{ext}' is not allowed. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}", + ) + if req.file_size > MAX_FILE_SIZE: + return FileValidationResult( + valid=False, + error=f"File size ({req.file_size} bytes) exceeds max {MAX_FILE_SIZE // (1024 * 1024)} MB", + ) + return FileValidationResult(valid=True) + + @staticmethod + async def reconcile_if_needed( + history: ChatHistory, status_changed_to_done: bool, trace_id: str | None = None + ) -> None: + """No-op: eigent does not have billing/credits.""" + return None + + @staticmethod + def upload_file( + user_id: int, task_id: str, filename: str, file_content: bytes, file_type: str | None, s: Session + ) -> "ChatFile": + """Validate, upload to S3, and create ChatFile record. Caller must commit.""" + from app.model.chat.chat_file import ChatFile, ChatFileIn + + validation = ChatService.validate_file(FileValidationReq(filename=filename, file_size=len(file_content))) + if not validation.valid: + raise ValueError(validation.error or "Invalid file") + + file_info = ChatFileIn.save_file_to_s3( + user_id=user_id, + task_id=task_id, + filename=filename, + file_content=file_content, + file_type=file_type, + ) + chat_file = ChatFile( + user_id=user_id, + task_id=task_id, + filename=file_info["filename"], + file_size=file_info["file_size"], + file_type=file_info["file_type"], + s3_key=file_info["s3_key"], + s3_bucket=file_info["s3_bucket"], + ) + s.add(chat_file) + s.commit() + s.refresh(chat_file) + return chat_file + + @staticmethod + def is_real_task(history: ChatHistory) -> bool: + """Check if a task is a real task vs a placeholder/trigger-created task.""" + if history.spend and history.spend > 0: + return True + if history.tokens and history.tokens > 0: + return True + if ( + history.model_platform + and history.model_platform != "none" + and history.model_type + and history.model_type != "none" + and history.installed_mcp + and history.installed_mcp != "none" + ): + return True + if history.question and history.question.startswith("Project created via trigger:"): + return False + return True + + @staticmethod + def _build_project_data( + histories: list[ChatHistory], + trigger_count_map: dict[str, int], + include_tasks: bool, + project_id_override: str | None = None, + ) -> list[ProjectGroup]: + """Build ProjectGroup list from histories. Shared by grouped list and single project endpoints.""" + project_map: Dict[str, Dict] = defaultdict( + lambda: { + "project_id": "", + "project_name": None, + "total_tokens": 0, + "task_count": 0, + "latest_task_date": "", + "last_prompt": None, + "tasks": [], + "total_completed_tasks": 0, + "total_ongoing_tasks": 0, + "average_tokens_per_task": 0, + "total_triggers": 0, + } + ) + + for history in histories: + project_id = project_id_override or (history.project_id if history.project_id else history.task_id) + project_data = project_map[project_id] + + if not project_data["project_id"]: + project_data["project_id"] = project_id + project_data["project_name"] = history.project_name or f"Project {project_id}" + project_data["latest_task_date"] = history.created_at.isoformat() if history.created_at else "" + project_data["last_prompt"] = history.question + + if include_tasks and ChatService.is_real_task(history): + project_data["tasks"].append(ChatHistoryOut(**history.model_dump())) + + if ChatService.is_real_task(history): + project_data["task_count"] += 1 + project_data["total_tokens"] += history.tokens or 0 + + if history.status == ChatStatus.done: + project_data["total_completed_tasks"] += 1 + elif history.status == ChatStatus.ongoing: + project_data["total_ongoing_tasks"] += 1 + + if history.created_at: + task_date = history.created_at.isoformat() + if not project_data["latest_task_date"] or task_date > project_data["latest_task_date"]: + project_data["latest_task_date"] = task_date + project_data["last_prompt"] = history.question + + projects = [] + for project_data in project_map.values(): + if include_tasks: + project_data["tasks"].sort(key=lambda x: (x.created_at is None, x.created_at or ""), reverse=False) + pid = project_data["project_id"] + project_data["total_triggers"] = trigger_count_map.get(pid, 0) + projects.append(ProjectGroup(**project_data)) + + projects.sort(key=lambda x: x.latest_task_date, reverse=True) + return projects + + @staticmethod + def get_grouped_histories(user_id: int, include_tasks: bool, s: Session) -> GroupedHistoryResponse: + """Get all chat histories grouped by project for a user.""" + stmt = ( + select(ChatHistory) + .where(ChatHistory.user_id == user_id) + .order_by( + desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), + desc(ChatHistory.created_at), + desc(ChatHistory.id), + ) + ) + histories = s.exec(stmt).all() + + trigger_count_stmt = ( + select(Trigger.project_id, func.count(Trigger.id).label("count")) + .where(Trigger.user_id == str(user_id)) + .group_by(Trigger.project_id) + ) + trigger_counts = s.exec(trigger_count_stmt).all() + trigger_count_map = {project_id: count for project_id, count in trigger_counts} + + projects = ChatService._build_project_data(histories, trigger_count_map, include_tasks) + return GroupedHistoryResponse(projects=projects) + + @staticmethod + def get_grouped_project(user_id: int, project_id: str, include_tasks: bool, s: Session) -> ProjectGroup | None: + """Get a single project group by project_id.""" + stmt = ( + select(ChatHistory) + .where(ChatHistory.user_id == user_id) + .where(ChatHistory.project_id == project_id) + .order_by( + desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), + desc(ChatHistory.created_at), + desc(ChatHistory.id), + ) + ) + histories = s.exec(stmt).all() + if not histories: + return None + + trigger_count_stmt = ( + select(func.count(Trigger.id)).where(Trigger.user_id == str(user_id)).where(Trigger.project_id == project_id) + ) + trigger_count = s.exec(trigger_count_stmt).first() or 0 + trigger_count_map = {project_id: trigger_count} + + projects = ChatService._build_project_data(histories, trigger_count_map, include_tasks, project_id_override=project_id) + return projects[0] if projects else None diff --git a/server/app/domains/config/__init__.py b/server/app/domains/config/__init__.py new file mode 100644 index 00000000..3ebb6a9d --- /dev/null +++ b/server/app/domains/config/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Config domain: plans, config, providers.""" diff --git a/server/app/controller/__init__.py b/server/app/domains/config/api/__init__.py similarity index 99% rename from server/app/controller/__init__.py rename to server/app/domains/config/api/__init__.py index fa7455a0..3a4d90c0 100644 --- a/server/app/controller/__init__.py +++ b/server/app/domains/config/api/__init__.py @@ -11,3 +11,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/config/api/config_controller.py b/server/app/domains/config/api/config_controller.py new file mode 100644 index 00000000..90bdf7b0 --- /dev/null +++ b/server/app/domains/config/api/config_controller.py @@ -0,0 +1,93 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +# STATUS: full-rewrite (uses ConfigService, self-managed session) +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query, Response +from fastapi_babel import _ + +from app.model.config.config import ConfigCreate, ConfigUpdate, ConfigOut +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.domains.config.service.config_service import ConfigService + +router = APIRouter(tags=["Config Management"]) + + +@router.get("/configs", name="list configs", response_model=list[ConfigOut]) +async def list_configs( + config_group: Optional[str] = None, + auth: V1UserAuth = Depends(auth_must), +): + return ConfigService.list_for_user(auth.id, config_group) + + +@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut) +async def get_config(config_id: int, auth: V1UserAuth = Depends(auth_must)): + config = ConfigService.get(config_id, auth.id) + if not config: + raise HTTPException(status_code=404, detail=_("Configuration not found")) + return config + + +@router.post("/configs", name="create config", response_model=ConfigOut) +async def create_config(config: ConfigCreate, auth: V1UserAuth = Depends(auth_must)): + result = ConfigService.create( + user_id=auth.id, + config_name=config.config_name, + config_value=config.config_value, + config_group=config.config_group, + ) + if not result["success"]: + error_map = { + "CONFIG_INVALID_NAME": (400, _("Config Name is invalid")), + "CONFIG_DUPLICATE": (400, _("Configuration already exists for this user")), + } + status, detail = error_map.get(result["error_code"], (500, "Config error")) + raise HTTPException(status_code=status, detail=detail) + return result["config"] + + +@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut) +async def update_config(config_id: int, config_update: ConfigUpdate, auth: V1UserAuth = Depends(auth_must)): + result = ConfigService.update( + config_id=config_id, + user_id=auth.id, + config_name=config_update.config_name, + config_value=config_update.config_value, + config_group=config_update.config_group, + ) + if not result["success"]: + error_map = { + "CONFIG_NOT_FOUND": (404, _("Configuration not found")), + "CONFIG_INVALID_NAME": (400, _("Invalid configuration group")), + "CONFIG_DUPLICATE": (400, _("Configuration already exists for this user")), + } + status, detail = error_map.get(result["error_code"], (500, "Config error")) + raise HTTPException(status_code=status, detail=detail) + return result["config"] + + +@router.delete("/configs/{config_id}", name="delete config") +async def delete_config(config_id: int, auth: V1UserAuth = Depends(auth_must)): + if not ConfigService.delete(config_id, auth.id): + raise HTTPException(status_code=404, detail=_("Configuration not found")) + return Response(status_code=204) + + +@router.get("/config/info", name="get config info") +async def get_config_info( + show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"), +): + return ConfigService.get_config_info(show_all) diff --git a/server/app/domains/config/schema/__init__.py b/server/app/domains/config/schema/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/config/schema/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/config/service/__init__.py b/server/app/domains/config/service/__init__.py new file mode 100644 index 00000000..ac746c6b --- /dev/null +++ b/server/app/domains/config/service/__init__.py @@ -0,0 +1,17 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from .config_service import ConfigService + +__all__ = ["ConfigService"] diff --git a/server/app/domains/config/service/config_service.py b/server/app/domains/config/service/config_service.py new file mode 100644 index 00000000..06fd791d --- /dev/null +++ b/server/app/domains/config/service/config_service.py @@ -0,0 +1,120 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""ConfigService: user config CRUD with validation. Follows CreditsService pattern.""" + +from sqlmodel import select +from loguru import logger + +from app.core.database import session_make +from app.model.config.config import Config, ConfigInfo + + +class ConfigService: + """User configuration management - static methods, self-managed session.""" + + @staticmethod + def list_for_user(user_id: int, config_group: str | None = None) -> list[Config]: + """List user configs, optionally filtered by group.""" + with session_make() as s: + query = select(Config).where(Config.user_id == user_id) + if config_group is not None: + query = query.where(Config.config_group == config_group) + return list(s.exec(query).all()) + + @staticmethod + def get(config_id: int, user_id: int) -> Config | None: + """Get a single config by id, scoped to user.""" + with session_make() as s: + return s.exec( + select(Config).where(Config.id == config_id, Config.user_id == user_id) + ).first() + + @staticmethod + def create(user_id: int, config_name: str, config_value: str, config_group: str | None = None) -> dict: + """Create config: env var name validation + duplicate check. + Returns {"success": True, "config": Config} or {"success": False, "error_code": str}. + """ + if not ConfigInfo.is_valid_env_var(config_group, config_name): + return {"success": False, "error_code": "CONFIG_INVALID_NAME"} + + from sqlalchemy.exc import IntegrityError + + with session_make() as s: + db_config = Config( + user_id=user_id, + config_name=config_name, + config_value=config_value, + config_group=config_group, + ) + s.add(db_config) + try: + s.commit() + except IntegrityError: + s.rollback() + return {"success": False, "error_code": "CONFIG_DUPLICATE"} + s.refresh(db_config) + return {"success": True, "config": db_config} + + @staticmethod + def update(config_id: int, user_id: int, config_name: str, config_value: str, config_group: str | None = None) -> dict: + """Update config: ownership + validation + duplicate check.""" + if not ConfigInfo.is_valid_env_var(config_group, config_name): + return {"success": False, "error_code": "CONFIG_INVALID_NAME"} + + with session_make() as s: + db_config = s.exec( + select(Config).where(Config.id == config_id, Config.user_id == user_id) + ).first() + if not db_config: + return {"success": False, "error_code": "CONFIG_NOT_FOUND"} + + # Check for name conflict with other configs + conflict = s.exec( + select(Config).where( + Config.user_id == user_id, + Config.config_name == config_name, + Config.id != config_id, + ) + ).first() + if conflict: + return {"success": False, "error_code": "CONFIG_DUPLICATE"} + + db_config.config_name = config_name + db_config.config_value = config_value + s.add(db_config) + s.commit() + s.refresh(db_config) + return {"success": True, "config": db_config} + + @staticmethod + def delete(config_id: int, user_id: int) -> bool: + """Delete config: ownership check.""" + with session_make() as s: + db_config = s.exec( + select(Config).where(Config.id == config_id, Config.user_id == user_id) + ).first() + if not db_config: + return False + s.delete(db_config) + s.commit() + return True + + @staticmethod + def get_config_info(show_all: bool = False) -> dict: + """Return available config metadata.""" + configs = ConfigInfo.getinfo() + if show_all: + return configs + return {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0} diff --git a/server/app/domains/mcp/__init__.py b/server/app/domains/mcp/__init__.py new file mode 100644 index 00000000..cbea39ec --- /dev/null +++ b/server/app/domains/mcp/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""MCP domain: MCP components, categories, proxy, user mappings.""" diff --git a/server/app/domains/mcp/api/__init__.py b/server/app/domains/mcp/api/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/mcp/api/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/controller/mcp/category_controller.py b/server/app/domains/mcp/api/category_controller.py similarity index 89% rename from server/app/controller/mcp/category_controller.py rename to server/app/domains/mcp/api/category_controller.py index e120bc4e..8f315d7e 100644 --- a/server/app/controller/mcp/category_controller.py +++ b/server/app/domains/mcp/api/category_controller.py @@ -1,27 +1,27 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from fastapi import APIRouter, Depends from sqlmodel import Session, col, select -from app.component.database import session +from app.core.database import session from app.model.mcp.category import Category, CategoryOut router = APIRouter(prefix="/mcp", tags=["Mcp Category"]) @router.get("/categories", name="category list", response_model=list[CategoryOut]) -def gets(session: Session = Depends(session)): +def gets(db_session: Session = Depends(session)): stmt = select(Category).where(Category.no_delete()).order_by(col(Category.priority).asc()) - return session.exec(stmt) + return db_session.exec(stmt) diff --git a/server/app/domains/mcp/api/mcp_controller.py b/server/app/domains/mcp/api/mcp_controller.py new file mode 100644 index 00000000..03b6d034 --- /dev/null +++ b/server/app/domains/mcp/api/mcp_controller.py @@ -0,0 +1,101 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""MCP controller. Uses McpUserService for install and import.""" + +from fastapi import Depends, HTTPException, APIRouter +from fastapi_babel import _ +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from sqlmodel import Session, col, select +from sqlalchemy.orm import selectinload, with_loader_criteria + +from app.core.database import session +from app.model.mcp.mcp import Mcp, McpOut +from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus +from app.model.mcp.mcp_user import McpImportType, McpUser +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.shared.middleware.rate_limit import install_rate_limiter +from app.domains.mcp.service.mcp_user_service import McpUserService + +router = APIRouter(tags=["Mcp Servers"]) + +_INSTALL_ERROR_MAP = { + "MCP_NOT_FOUND": (404, "Mcp not found"), + "MCP_ALREADY_INSTALLED": (400, "mcp is installed"), + "MCP_INVALID_INSTALL_COMMAND": (400, "Install command is not valid json"), +} + + +@router.get("/mcps", name="mcp list") +async def gets( + keyword: str | None = None, + category_id: int | None = None, + mine: int | None = None, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +) -> Page[McpOut]: + stmt = ( + select(Mcp) + .where(Mcp.no_delete()) + .options( + selectinload(Mcp.category), + selectinload(Mcp.envs), + with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use), + ) + ) + if keyword: + stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%")) + if category_id: + stmt = stmt.where(Mcp.category_id == category_id) + if mine and auth: + stmt = ( + stmt.join(McpUser) + .where(McpUser.user_id == auth.id) + .options( + selectinload(Mcp.mcp_user), + with_loader_criteria(McpUser, col(McpUser.user_id) == auth.id), + ) + ) + return paginate(db_session, stmt) + + +@router.get("/mcp", name="mcp detail", response_model=McpOut) +async def get(id: int, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs)) + model = db_session.exec(stmt).first() + if not model: + raise HTTPException(status_code=404, detail=_("Mcp not found")) + return model + + +@router.post("/mcp/install", name="mcp install", dependencies=[install_rate_limiter]) +async def install(mcp_id: int, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + result = McpUserService.install(mcp_id, auth.id, db_session) + if not result["success"]: + status, msg = _INSTALL_ERROR_MAP.get(result["error_code"], (400, "Install failed")) + raise HTTPException(status_code=status, detail=_(msg)) + return result["mcp_user"] + + +@router.post("/mcp/import/{mcp_type}", name="mcp import", dependencies=[install_rate_limiter]) +async def import_mcp( + mcp_type: McpImportType, mcp_data: dict, auth: V1UserAuth = Depends(auth_must) +): + result = McpUserService.import_mcp(mcp_type, mcp_data, auth.id) + if not result["success"]: + detail = result.get("detail", "Import failed") + raise HTTPException(status_code=400, detail=detail) + return result diff --git a/server/app/controller/mcp/proxy_controller.py b/server/app/domains/mcp/api/proxy_controller.py similarity index 61% rename from server/app/controller/mcp/proxy_controller.py rename to server/app/domains/mcp/api/proxy_controller.py index b32dd01c..764a25bb 100644 --- a/server/app/controller/mcp/proxy_controller.py +++ b/server/app/domains/mcp/api/proxy_controller.py @@ -12,19 +12,14 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -import logging -from typing import Any, cast -from urllib.parse import quote_plus - -import requests +from fastapi import APIRouter, Depends from exa_py import Exa -from fastapi import APIRouter, Depends, HTTPException - -from app.component.auth import key_must -from app.component.environment import env_not_empty +from loguru import logger +from app.shared.auth.user_auth import key_must +from app.core.environment import env_not_empty from app.model.mcp.proxy import ExaSearch - -logger = logging.getLogger("server_proxy_controller") +from typing import Any, cast +import requests from app.model.user.key import Key @@ -33,41 +28,25 @@ router = APIRouter(prefix="/proxy", tags=["Mcp Servers"]) @router.post("/exa") def exa_search(search: ExaSearch, key: Key = Depends(key_must)): - """Search using Exa API.""" EXA_API_KEY = env_not_empty("EXA_API_KEY") try: - # Validate input parameters - if search.num_results is not None and not 0 < search.num_results <= 100: - logger.warning("Invalid exa search parameter", extra={"param": "num_results", "value": search.num_results}) - raise ValueError("num_results must be between 1 and 100") - - if search.include_text is not None and len(search.include_text) > 0: - if len(search.include_text) > 1: - logger.warning( - "Invalid exa search parameter", extra={"param": "include_text", "reason": "more than 1 string"} - ) - raise ValueError("include_text can only contain 1 string") - if len(search.include_text[0].split()) > 5: - logger.warning( - "Invalid exa search parameter", extra={"param": "include_text", "reason": "exceeds 5 words"} - ) - raise ValueError("include_text string cannot be longer than 5 words") - - if search.exclude_text is not None and len(search.exclude_text) > 0: - if len(search.exclude_text) > 1: - logger.warning( - "Invalid exa search parameter", extra={"param": "exclude_text", "reason": "more than 1 string"} - ) - raise ValueError("exclude_text can only contain 1 string") - if len(search.exclude_text[0].split()) > 5: - logger.warning( - "Invalid exa search parameter", extra={"param": "exclude_text", "reason": "exceeds 5 words"} - ) - raise ValueError("exclude_text string cannot be longer than 5 words") - exa = Exa(EXA_API_KEY) - # Call Exa API with direct parameters + if search.num_results is not None and not 0 < search.num_results <= 100: + raise ValueError("num_results must be between 1 and 100") + + if search.include_text is not None: + if len(search.include_text) > 1: + raise ValueError("include_text can only contain 1 string") + if len(search.include_text[0].split()) > 5: + raise ValueError("include_text string cannot be longer than 5 words") + + if search.exclude_text is not None: + if len(search.exclude_text) > 1: + raise ValueError("exclude_text can only contain 1 string") + if len(search.exclude_text[0].split()) > 5: + raise ValueError("exclude_text string cannot be longer than 5 words") + if search.text: results = cast( dict[str, Any], @@ -96,41 +75,23 @@ def exa_search(search: ExaSearch, key: Key = Depends(key_must)): ), ) - result_count = len(results.get("results", [])) if "results" in results else 0 - logger.info( - "Exa search completed", - extra={"query": search.query, "search_type": search.search_type, "result_count": result_count}, - ) return results - except ValueError as e: - logger.warning("Exa search validation error", extra={"error": str(e)}) - raise HTTPException(status_code=500, detail="Internal server error") except Exception as e: - logger.error("Exa search failed", extra={"query": search.query, "error": str(e)}, exc_info=True) - raise HTTPException(status_code=500, detail="Internal server error") + return {"error": f"Exa search failed: {e!s}"} @router.get("/google") def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)): - """Search using Google Custom Search API.""" - # https://developers.google.com/custom-search/v1/overview GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY") - # https://cse.google.com/cse/all SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID") - # Using the first page start_page_idx = 1 - # Different language may get different result search_language = "en" - # How many pages to return num_result_pages = 10 - - # Constructing the URL - # Doc: https://developers.google.com/custom-search/v1/using_rest base_url = ( f"https://www.googleapis.com/customsearch/v1?" - f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={quote_plus(query)}&start=" + f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start=" f"{start_page_idx}&lr={search_language}&num={num_result_pages}" ) @@ -140,32 +101,21 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m url = base_url responses = [] - try: - # Make the GET request result = requests.get(url) data = result.json() - # Get the result items if "items" in data: search_items = data.get("items") - - # Iterate over results found for i, search_item in enumerate(search_items, start=1): if search_type == "image": - # Process image search results title = search_item.get("title") image_url = search_item.get("link") display_link = search_item.get("displayLink") - - # Get context URL (page containing the image) image_info = search_item.get("image", {}) context_url = image_info.get("contextLink", "") - - # Get image dimensions if available width = image_info.get("width") height = image_info.get("height") - response = { "result_id": i, "title": title, @@ -173,17 +123,12 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m "display_link": display_link, "context_url": context_url, } - - # Add dimensions if available if width: response["width"] = int(width) if height: response["height"] = int(height) - responses.append(response) else: - # Process web search results - # Check metatags are present if "pagemap" not in search_item: continue if "metatags" not in search_item["pagemap"]: @@ -192,12 +137,8 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m long_description = search_item["pagemap"]["metatags"][0]["og:description"] else: long_description = "N/A" - # Get the page title title = search_item.get("title") - # Page snippet snippet = search_item.get("snippet") - - # Extract the page url link = search_item.get("link") response = { "result_id": i, @@ -207,20 +148,11 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m "url": link, } responses.append(response) - - logger.info( - "Google search completed", - extra={"query": query, "search_type": search_type, "result_count": len(responses)}, - ) else: error_info = data.get("error", {}) - logger.error("Google search API error", extra={"query": query, "api_error": error_info}) - raise HTTPException(status_code=500, detail="Internal server error") + logger.error(f"Google search failed - API response: {error_info}") + responses.append({"error": f"Google search failed - API response: {error_info}"}) except Exception as e: - logger.error( - "Google search failed", extra={"query": query, "search_type": search_type, "error": str(e)}, exc_info=True - ) - raise HTTPException(status_code=500, detail="Internal server error") - + responses.append({"error": f"google search failed: {e!s}"}) return responses diff --git a/server/app/domains/mcp/api/user_controller.py b/server/app/domains/mcp/api/user_controller.py new file mode 100644 index 00000000..9b324a59 --- /dev/null +++ b/server/app/domains/mcp/api/user_controller.py @@ -0,0 +1,99 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""MCP User controller - H15 ownership on GET/DELETE.""" + +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi_babel import _ +from sqlmodel import Session, select + +from app.core.database import session +from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate + +from app.shared.auth import auth_must +from app.shared.auth.ownership import require_owner + +router = APIRouter(tags=["V1 McpUser Management"]) + + +@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut]) +async def list_mcp_users( + mcp_id: Optional[int] = None, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + query = select(McpUser).where(McpUser.user_id == auth.user.id) + if mcp_id is not None: + query = query.where(McpUser.mcp_id == mcp_id) + return list(db_session.exec(query).all()) + + +@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut) +async def get_mcp_user( + mcp_user_id: int, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + mcp_user = db_session.get(McpUser, mcp_user_id) + require_owner(mcp_user, auth.user.id) + return mcp_user + + +@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut) +async def create_mcp_user( + mcp_user: McpUserIn, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + exists = db_session.exec( + select(McpUser).where(McpUser.mcp_id == mcp_user.mcp_id, McpUser.user_id == auth.user.id) + ).first() + if exists: + raise HTTPException(status_code=400, detail=_("mcp is installed")) + db_mcp_user = McpUser(mcp_id=mcp_user.mcp_id, user_id=auth.user.id, env=mcp_user.env) + db_session.add(db_mcp_user) + db_session.commit() + db_session.refresh(db_mcp_user) + return db_mcp_user + + +@router.put("/mcp/users/{id}", name="update mcp user") +async def update_mcp_user( + id: int, + update_item: McpUserUpdate, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + model = db_session.get(McpUser, id) + require_owner(model, auth.user.id) + update_data = update_item.model_dump(exclude_unset=True) + model.update_fields(update_data) + model.save(db_session) + db_session.refresh(model) + return model + + +@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user") +async def delete_mcp_user( + mcp_user_id: int, + db_session: Session = Depends(session), + auth=Depends(auth_must), +): + db_mcp_user = db_session.get(McpUser, mcp_user_id) + require_owner(db_mcp_user, auth.user.id) + db_session.delete(db_mcp_user) + db_session.commit() + return Response(status_code=204) diff --git a/server/app/domains/mcp/schema/__init__.py b/server/app/domains/mcp/schema/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/mcp/schema/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/mcp/service/__init__.py b/server/app/domains/mcp/service/__init__.py new file mode 100644 index 00000000..1d3470ad --- /dev/null +++ b/server/app/domains/mcp/service/__init__.py @@ -0,0 +1,17 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from .mcp_user_service import McpUserService + +__all__ = ["McpUserService"] diff --git a/server/app/domains/mcp/service/mcp_user_service.py b/server/app/domains/mcp/service/mcp_user_service.py new file mode 100644 index 00000000..e1195b0c --- /dev/null +++ b/server/app/domains/mcp/service/mcp_user_service.py @@ -0,0 +1,175 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""McpUserService: MCP install/uninstall/import with dedup.""" + +import json + +from sqlmodel import select +from loguru import logger + +from app.core.database import session_make +from app.model.mcp.mcp import Mcp, McpType +from app.model.mcp.mcp_user import McpImportType, McpUser, Status +from app.core.validator.McpServer import ( + McpRemoteServer, + validate_mcp_remote_servers, + validate_mcp_servers, +) + + +class McpUserService: + """MCP user installation management - static methods, self-managed session.""" + + @staticmethod + def list_for_user(user_id: int, mcp_id: int | None = None) -> list[McpUser]: + """List user's MCP installations, optionally filtered by mcp_id.""" + with session_make() as s: + query = select(McpUser).where(McpUser.user_id == user_id) + if mcp_id is not None: + query = query.where(McpUser.mcp_id == mcp_id) + return list(s.exec(query).all()) + + @staticmethod + def get(mcp_user_id: int, user_id: int) -> McpUser | None: + """Get a single MCP user installation, scoped to user.""" + with session_make() as s: + model = s.get(McpUser, mcp_user_id) + if model and model.user_id == user_id: + return model + return None + + @staticmethod + def install(mcp_id: int, user_id: int, s) -> dict: + """Install MCP from store: dedup check, parse install_command, create McpUser. + Returns {"success": True, "mcp_user": McpUser} or {"success": False, "error_code": str}. + """ + mcp = s.get(Mcp, mcp_id) + if not mcp: + return {"success": False, "error_code": "MCP_NOT_FOUND"} + + exists = s.exec( + select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == user_id) + ).first() + if exists: + return {"success": False, "error_code": "MCP_ALREADY_INSTALLED"} + + install_command_raw = mcp.install_command or "{}" + try: + install_command = ( + json.loads(install_command_raw) if isinstance(install_command_raw, str) else install_command_raw or {} + ) + except json.JSONDecodeError: + return {"success": False, "error_code": "MCP_INVALID_INSTALL_COMMAND"} + if not isinstance(install_command, dict) or "command" not in install_command: + return {"success": False, "error_code": "MCP_INVALID_INSTALL_COMMAND"} + + mcp_user = McpUser( + mcp_id=mcp.id, + user_id=user_id, + mcp_name=mcp.name, + mcp_key=mcp.key, + mcp_desc=mcp.description, + type=mcp.type, + status=Status.enable, + command=install_command["command"], + args=install_command.get("args"), + env=install_command.get("env"), + server_url=None, + ) + mcp_user.save() + return {"success": True, "mcp_user": mcp_user} + + @staticmethod + def import_mcp(mcp_type: McpImportType, mcp_data: dict, user_id: int) -> dict: + """Import MCP servers from user config. Returns result dict with imported/failed lists.""" + if mcp_type == McpImportType.Local: + is_valid, res = validate_mcp_servers(mcp_data) + if not is_valid: + return {"success": False, "error_code": "MCP_INVALID_DATA", "detail": res} + + mcp_data_parsed = getattr(res, "mcpServers", mcp_data) + imported_names = [] + failed_names = [] + for name, data in mcp_data_parsed.items(): + command = data.command if hasattr(data, "command") else data.get("command") + args = data.args if hasattr(data, "args") else data.get("args") + env = data.env if hasattr(data, "env") else data.get("env") + try: + mcp_user = McpUser( + mcp_id=0, + user_id=user_id, + mcp_name=name, + mcp_key=name, + mcp_desc=name, + type=McpType.Local, + status=Status.enable, + command=command, + args=args, + env=env, + server_url=None, + ) + mcp_user.save() + imported_names.append(name) + except Exception as e: + logger.error("Failed to import local MCP", extra={"mcp_name": name, "error": str(e)}) + failed_names.append(name) + return { + "success": True, + "message": "Local MCP servers imported successfully", + "count": len(imported_names), + "imported": imported_names, + "failed": failed_names, + } + + elif mcp_type == McpImportType.Remote: + is_valid, res = validate_mcp_remote_servers(mcp_data) + if not is_valid: + return {"success": False, "error_code": "MCP_INVALID_DATA", "detail": res} + data: McpRemoteServer = res + mcp_user = McpUser( + mcp_id=0, + user_id=user_id, + type=McpType.Remote, + status=Status.enable, + mcp_name=data.server_name, + server_url=data.server_url, + ) + mcp_user.save() + return {"success": True, "mcp_user": mcp_user} + + return {"success": False, "error_code": "MCP_INVALID_TYPE"} + + @staticmethod + def update(mcp_user_id: int, user_id: int, data: dict) -> dict: + """Update MCP user: ownership check.""" + with session_make() as s: + model = s.get(McpUser, mcp_user_id) + if not model or model.user_id != user_id: + return {"success": False, "error_code": "MCP_USER_NOT_FOUND"} + model.update_fields(data) + model.save(s) + s.refresh(model) + return {"success": True, "mcp_user": model} + + @staticmethod + def uninstall(mcp_user_id: int, user_id: int) -> bool: + """Uninstall MCP: ownership check + delete.""" + with session_make() as s: + model = s.get(McpUser, mcp_user_id) + if not model or model.user_id != user_id: + return False + s.delete(model) + s.commit() + return True diff --git a/server/app/domains/model_provider/__init__.py b/server/app/domains/model_provider/__init__.py new file mode 100644 index 00000000..b46f4783 --- /dev/null +++ b/server/app/domains/model_provider/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Model Provider domain - API key / model supplier management.""" diff --git a/server/app/domains/model_provider/api/__init__.py b/server/app/domains/model_provider/api/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/model_provider/api/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/model_provider/api/provider_controller.py b/server/app/domains/model_provider/api/provider_controller.py new file mode 100644 index 00000000..c28664db --- /dev/null +++ b/server/app/domains/model_provider/api/provider_controller.py @@ -0,0 +1,91 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +# STATUS: full-rewrite (uses ProviderService, self-managed session) +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query, Response +from fastapi_babel import _ +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from sqlmodel import Session, select, col + +from app.core.database import session +from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.domains.model_provider.service.provider_service import ProviderService + +router = APIRouter(tags=["Provider Management"]) + + +@router.get("/providers", name="list providers", response_model=Page[ProviderOut]) +async def gets( + keyword: str | None = None, + prefer: Optional[bool] = Query(None, description="Filter by prefer status"), + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +) -> Page[ProviderOut]: + # Pagination still needs session for paginate() — use session directly for this read-only query + stmt = select(Provider).where(Provider.user_id == auth.id, Provider.no_delete()) + if keyword: + stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%")) + if prefer is not None: + stmt = stmt.where(Provider.prefer == prefer) + stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc()) + return paginate(db_session, stmt) + + +@router.get("/provider", name="get provider detail", response_model=ProviderOut) +async def get(id: int, auth: V1UserAuth = Depends(auth_must)): + model = ProviderService.get(id, auth.id) + if not model: + raise HTTPException(status_code=404, detail=_("Provider not found")) + return model + + +@router.post("/provider", name="create provider", response_model=ProviderOut) +async def post(data: ProviderIn, auth: V1UserAuth = Depends(auth_must)): + result = ProviderService.create(auth.id, data.model_dump()) + return result["provider"] + + +@router.put("/provider/{id}", name="update provider", response_model=ProviderOut) +async def put(id: int, data: ProviderIn, auth: V1UserAuth = Depends(auth_must)): + result = ProviderService.update(id, auth.id, data.model_dump()) + if not result["success"]: + raise HTTPException(status_code=404, detail=_("Provider not found")) + return result["provider"] + + +@router.delete("/provider/{id}", name="delete provider") +async def delete(id: int, auth: V1UserAuth = Depends(auth_must)): + if not ProviderService.delete(id, auth.id): + raise HTTPException(status_code=404, detail=_("Provider not found")) + return Response(status_code=204) + + +@router.patch("/provider/{id}/invalidate", name="invalidate provider") +async def invalidate(id: int, auth: V1UserAuth = Depends(auth_must)): + result = ProviderService.invalidate(id, auth.id) + if not result["success"]: + raise HTTPException(status_code=404, detail=_("Provider not found")) + return {"success": True} + + +@router.post("/provider/prefer", name="set provider prefer") +async def set_prefer(data: ProviderPreferIn, auth: V1UserAuth = Depends(auth_must)): + result = ProviderService.set_prefer(data.provider_id, auth.id) + if not result["success"]: + raise HTTPException(status_code=500, detail="Failed to set prefer") + return {"success": True} diff --git a/server/app/domains/model_provider/schema/__init__.py b/server/app/domains/model_provider/schema/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/model_provider/schema/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/model_provider/service/__init__.py b/server/app/domains/model_provider/service/__init__.py new file mode 100644 index 00000000..2bb1a24c --- /dev/null +++ b/server/app/domains/model_provider/service/__init__.py @@ -0,0 +1,17 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from .provider_service import ProviderService + +__all__ = ["ProviderService"] diff --git a/server/app/domains/model_provider/service/provider_service.py b/server/app/domains/model_provider/service/provider_service.py new file mode 100644 index 00000000..599d3f72 --- /dev/null +++ b/server/app/domains/model_provider/service/provider_service.py @@ -0,0 +1,121 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""ProviderService: model provider CRUD with prefer/invalidate. Follows CreditsService pattern.""" + +from sqlalchemy import update +from sqlmodel import select, col +from loguru import logger + +from app.core.database import session_make +from app.model.provider.provider import Provider, VaildStatus + + +class ProviderService: + """Model provider management - static methods, self-managed session.""" + + @staticmethod + def list_for_user(user_id: int, keyword: str | None = None, prefer: bool | None = None) -> list[Provider]: + """List user providers, supports keyword search and prefer filter.""" + with session_make() as s: + stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete()) + if keyword: + safe_keyword = keyword.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + stmt = stmt.where(col(Provider.provider_name).like(f"%{safe_keyword}%")) + if prefer is not None: + stmt = stmt.where(Provider.prefer == prefer) + stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc()) + return list(s.exec(stmt).all()) + + @staticmethod + def get(provider_id: int, user_id: int) -> Provider | None: + """Get a single provider by id, scoped to user.""" + with session_make() as s: + model = s.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id) + ).first() + return model + + @staticmethod + def create(user_id: int, data: dict) -> dict: + """Create provider. Returns {"success": True, "provider": Provider}.""" + with session_make() as s: + model = Provider(**data, user_id=user_id) + model.save(s) + s.refresh(model) + return {"success": True, "provider": model} + + @staticmethod + def update(provider_id: int, user_id: int, data: dict) -> dict: + """Update provider: ownership check.""" + with session_make() as s: + model = s.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id) + ).first() + if not model: + return {"success": False, "error_code": "PROVIDER_NOT_FOUND"} + # H10: only allow updating safe fields + _UPDATABLE_FIELDS = {"provider_name", "api_key", "api_base", "extra_config", "prefer", "is_vaild"} + for key, value in data.items(): + if key in _UPDATABLE_FIELDS: + setattr(model, key, value) + model.save(s) + s.refresh(model) + return {"success": True, "provider": model} + + @staticmethod + def delete(provider_id: int, user_id: int) -> bool: + """Soft-delete provider.""" + with session_make() as s: + model = s.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id) + ).first() + if not model: + return False + model.delete(s) + return True + + @staticmethod + def invalidate(provider_id: int, user_id: int) -> dict: + """Mark provider as not_valid.""" + with session_make() as s: + model = s.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id) + ).first() + if not model: + return {"success": False, "error_code": "PROVIDER_NOT_FOUND"} + model.is_vaild = VaildStatus.not_valid + model.save(s) + s.refresh(model) + logger.info("Provider invalidated", extra={"user_id": user_id, "provider_id": provider_id}) + return {"success": True, "provider": model} + + @staticmethod + def set_prefer(provider_id: int, user_id: int) -> dict: + """Set preferred provider: lock user rows, clear all prefer flags, then set the specified one.""" + with session_make() as s: + # Lock all provider rows for this user to prevent concurrent prefer changes + s.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete()).with_for_update() + ).all() + # Clear all prefer flags for this user + s.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False)) + # Set the specified provider as preferred + s.exec( + update(Provider) + .where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id) + .values(prefer=True) + ) + s.commit() + return {"success": True} diff --git a/server/app/domains/oauth/__init__.py b/server/app/domains/oauth/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/oauth/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/oauth/api/__init__.py b/server/app/domains/oauth/api/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/oauth/api/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/oauth/api/oauth_controller.py b/server/app/domains/oauth/api/oauth_controller.py new file mode 100644 index 00000000..6803519f --- /dev/null +++ b/server/app/domains/oauth/api/oauth_controller.py @@ -0,0 +1,97 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +# STATUS: full-rewrite (uses OAuthService, H12 code/state XSS validation) +import json +from typing import Optional +from urllib.parse import quote + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse + +from app.core.oauth_adapter import OauthCallbackPayload +from app.domains.oauth.schema import OAuthAuthorizeReq, OAuthTokenReq +from app.domains.oauth.service.oauth_service import OAuthService + +router = APIRouter(prefix="/oauth", tags=["Oauth Servers"]) + +_OAUTH_CODE_MAX_LEN = 2048 +_OAUTH_STATE_MAX_LEN = 512 +_OAUTH_SAFE_CHARS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.") + + +def _validate_oauth_param(value: Optional[str], name: str, max_len: int) -> str: + """Validate and sanitize OAuth callback params to prevent XSS injection (H12).""" + if value is None: + return "" + s = str(value).strip() + if len(s) > max_len: + raise HTTPException(status_code=400, detail=f"Invalid {name}: too long") + if s and not all(c in _OAUTH_SAFE_CHARS or c in " /+=" for c in s): + raise HTTPException(status_code=400, detail=f"Invalid {name}: invalid characters") + return s + + +@router.get("/{app}/login", name="OAuth Login Redirect") +def oauth_login(app: str, request: Request, state: Optional[str] = None): + callback_url = str(request.url_for("OAuth Callback", app=app)) + if callback_url.startswith("http://"): + callback_url = "https://" + callback_url[len("http://"):] + + result = OAuthService.get_authorize_url( + OAuthAuthorizeReq(provider=app, redirect_uri=callback_url, state=state) + ) + if not result.success: + raise HTTPException(status_code=400, detail="Failed to generate authorization URL") + return RedirectResponse(str(result.authorize_url)) + + +@router.get("/{app}/callback", name="OAuth Callback") +def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None): + if not code: + raise HTTPException(status_code=400, detail="Missing code parameter") + safe_code = _validate_oauth_param(code, "code", _OAUTH_CODE_MAX_LEN) + safe_state = _validate_oauth_param(state, "state", _OAUTH_STATE_MAX_LEN) + safe_app = _validate_oauth_param(app, "provider", 64) or "" + query = f"provider={quote(safe_app, safe='')}&code={quote(safe_code, safe='')}&state={quote(safe_state, safe='')}" + redirect_url = f"eigent://callback/oauth?{query}" + html_content = f""" + + + OAuth Callback + + + +

Redirecting, please wait...

+ + + + """ + return HTMLResponse(content=html_content) + + +@router.post("/{app}/token", name="OAuth Fetch Token") +def fetch_token(app: str, request: Request, data: OauthCallbackPayload): + callback_url = str(request.url_for("OAuth Callback", app=app)) + if callback_url.startswith("http://"): + callback_url = "https://" + callback_url[len("http://"):] + + result = OAuthService.exchange_token( + OAuthTokenReq(provider=app, code=data.code, redirect_uri=callback_url) + ) + if not result.success: + raise HTTPException(status_code=500, detail="Token exchange failed") + return JSONResponse(result.token_data) diff --git a/server/app/domains/oauth/schema/__init__.py b/server/app/domains/oauth/schema/__init__.py new file mode 100644 index 00000000..66d7bacf --- /dev/null +++ b/server/app/domains/oauth/schema/__init__.py @@ -0,0 +1,27 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""OAuth domain models - re-exports v1 schemas.""" + +from app.domains.oauth.schema.schemas import ( + OAuthAuthorizeReq, + OAuthTokenReq, + OAuthResult, +) + +__all__ = [ + "OAuthAuthorizeReq", + "OAuthTokenReq", + "OAuthResult", +] diff --git a/server/app/domains/oauth/schema/schemas.py b/server/app/domains/oauth/schema/schemas.py new file mode 100644 index 00000000..2163b237 --- /dev/null +++ b/server/app/domains/oauth/schema/schemas.py @@ -0,0 +1,36 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 OAuth request/response schemas.""" + +from pydantic import BaseModel + + +class OAuthAuthorizeReq(BaseModel): + provider: str + redirect_uri: str | None = None + state: str | None = None + + +class OAuthTokenReq(BaseModel): + provider: str + code: str + redirect_uri: str | None = None + + +class OAuthResult(BaseModel): + success: bool + authorize_url: str | None = None + token_data: dict | None = None + error_code: str | None = None diff --git a/server/app/domains/oauth/service/__init__.py b/server/app/domains/oauth/service/__init__.py new file mode 100644 index 00000000..2dd79fc4 --- /dev/null +++ b/server/app/domains/oauth/service/__init__.py @@ -0,0 +1,17 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from .oauth_service import OAuthService + +__all__ = ["OAuthService"] diff --git a/server/app/domains/oauth/service/oauth_service.py b/server/app/domains/oauth/service/oauth_service.py new file mode 100644 index 00000000..fad35658 --- /dev/null +++ b/server/app/domains/oauth/service/oauth_service.py @@ -0,0 +1,55 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""OAuthService: adapter factory wrapper for unified OAuth flow. Follows CreditsService pattern.""" + +from loguru import logger + +from app.core.oauth_adapter import get_oauth_adapter, OAUTH_ADAPTERS +from app.domains.oauth.schema import OAuthAuthorizeReq, OAuthTokenReq, OAuthResult + + +class OAuthService: + """OAuth operations - static methods, wraps adapter factory.""" + + @staticmethod + def get_authorize_url(req: OAuthAuthorizeReq) -> OAuthResult: + """Get OAuth authorization URL via adapter factory.""" + try: + adapter = get_oauth_adapter(req.provider, req.redirect_uri) + url = adapter.get_authorize_url(req.state) + if not url: + return OAuthResult(success=False, error_code="OAUTH_URL_FAILED") + return OAuthResult(success=True, authorize_url=url) + except Exception as e: + logger.error(f"OAuth authorize failed provider={req.provider}: {e}") + return OAuthResult(success=False, error_code="OAUTH_PROVIDER_ERROR") + + @staticmethod + def exchange_token(req: OAuthTokenReq) -> OAuthResult: + """Exchange authorization code for token via adapter factory.""" + try: + adapter = get_oauth_adapter(req.provider, req.redirect_uri) + token_data = adapter.fetch_token(req.code) + if not token_data: + return OAuthResult(success=False, error_code="OAUTH_TOKEN_FAILED") + return OAuthResult(success=True, token_data=token_data) + except Exception as e: + logger.error(f"OAuth token exchange failed provider={req.provider}: {e}") + return OAuthResult(success=False, error_code="OAUTH_PROVIDER_ERROR") + + @staticmethod + def list_providers() -> list[str]: + """Return list of supported OAuth providers.""" + return list(set(OAUTH_ADAPTERS.keys())) diff --git a/server/app/domains/trigger/__init__.py b/server/app/domains/trigger/__init__.py new file mode 100644 index 00000000..b6366081 --- /dev/null +++ b/server/app/domains/trigger/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Trigger domain: triggers, executions, webhooks, schedules, slack.""" diff --git a/server/app/domains/trigger/api/__init__.py b/server/app/domains/trigger/api/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/trigger/api/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/domains/trigger/api/slack_controller.py b/server/app/domains/trigger/api/slack_controller.py new file mode 100644 index 00000000..20019c4e --- /dev/null +++ b/server/app/domains/trigger/api/slack_controller.py @@ -0,0 +1,57 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Slack integration controller. Uses SlackService.""" + +from fastapi import APIRouter, Depends, HTTPException +from typing import Optional, List +from pydantic import BaseModel + +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.domains.trigger.service.slack_service import SlackService + + +class SlackChannelOut(BaseModel): + """Output model for Slack channels.""" + id: str + name: str + is_private: bool = False + is_member: bool = False + num_members: Optional[int] = None + + +class SlackChannelsResponse(BaseModel): + """Response model for Slack channels list.""" + channels: List[SlackChannelOut] + has_credentials: bool + + +router = APIRouter(prefix="/trigger/slack", tags=["Slack Integration"]) + + +@router.get("/channels", name="get slack channels") +def get_slack_channels(auth: V1UserAuth = Depends(auth_must)) -> SlackChannelsResponse: + result = SlackService.get_channels(auth.id) + if not result["success"]: + error_map = { + "SLACK_SDK_NOT_INSTALLED": (500, "Slack SDK not installed on server"), + "SLACK_API_ERROR": (400, f"Slack API error: {result.get('detail', 'Unknown error')}"), + } + status, msg = error_map.get(result["error_code"], (500, "Failed to fetch Slack channels")) + raise HTTPException(status_code=status, detail=msg) + return SlackChannelsResponse( + channels=[SlackChannelOut(**ch) for ch in result["channels"]], + has_credentials=result["has_credentials"], + ) diff --git a/server/app/domains/trigger/api/trigger_controller.py b/server/app/domains/trigger/api/trigger_controller.py new file mode 100644 index 00000000..5249bb23 --- /dev/null +++ b/server/app/domains/trigger/api/trigger_controller.py @@ -0,0 +1,225 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""Trigger controller. Uses TriggerCrudService for business logic.""" + +from fastapi import APIRouter, Depends, HTTPException, Response, Query +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from sqlmodel import Session, select, desc, and_, delete +from typing import Optional +from loguru import logger + +from app.model.trigger.trigger import Trigger, TriggerIn, TriggerOut, TriggerUpdate, TriggerConfigSchemaOut +from app.model.trigger.trigger_execution import TriggerExecution, TriggerExecutionOut +from app.model.trigger.app_configs import get_config_schema, has_config +from app.shared.types.trigger_types import TriggerType, TriggerStatus +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.core.database import session +from app.domains.trigger.service.trigger_crud_service import TriggerCrudService + +router = APIRouter(prefix="/trigger", tags=["Triggers"]) + + +def _raise_on_error(result: dict) -> None: + """Convert service error dict to HTTPException.""" + if result["success"]: + return + status_code = result.get("status_code", 500) + error = result.get("error", "Internal server error") + raise HTTPException(status_code=status_code, detail=error) + + +@router.post("/", name="create trigger", response_model=TriggerOut) +def create_trigger( + data: TriggerIn, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + """Create a new trigger.""" + try: + result = TriggerCrudService.create(data, auth.id, db_session) + _raise_on_error(result) + return result["trigger_out"] + except HTTPException: + raise + except Exception as e: + db_session.rollback() + logger.error("Trigger creation failed", extra={"user_id": auth.id, "error": str(e)}, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + + +@router.get("/", name="list triggers") +def list_triggers( + trigger_type: Optional[TriggerType] = Query(None, description="Filter by trigger type"), + status: Optional[TriggerStatus] = Query(None, description="Filter by status"), + project_id: Optional[str] = Query(None, description="Filter by project ID"), + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +) -> Page[TriggerOut]: + """List triggers for current user.""" + user_id = auth.id + conditions = [Trigger.user_id == str(user_id)] + if trigger_type: + conditions.append(Trigger.trigger_type == trigger_type) + if status is not None: + conditions.append(Trigger.status == status) + if project_id: + conditions.append(Trigger.project_id == project_id) + + stmt = select(Trigger).where(and_(*conditions)).order_by(desc(Trigger.created_at)) + result = paginate(db_session, stmt) + + # Enrich with execution counts + trigger_ids = [t.id for t in result.items] + counts = TriggerCrudService.get_execution_counts(db_session, trigger_ids) + result.items = [TriggerCrudService.trigger_to_out(t, counts.get(t.id, 0)) for t in result.items] + + return result + + +@router.get("/{trigger_id}", name="get trigger", response_model=TriggerOut) +def get_trigger( + trigger_id: int, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + """Get a specific trigger by ID.""" + trigger = db_session.exec( + select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(auth.id))) + ).first() + if not trigger: + raise HTTPException(status_code=404, detail="Trigger not found") + + counts = TriggerCrudService.get_execution_counts(db_session, [trigger_id]) + return TriggerCrudService.trigger_to_out(trigger, counts.get(trigger_id, 0)) + + +@router.put("/{trigger_id}", name="update trigger", response_model=TriggerOut) +def update_trigger( + trigger_id: int, + data: TriggerUpdate, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + """Update a trigger.""" + try: + result = TriggerCrudService.update(trigger_id, data, auth.id, db_session) + _raise_on_error(result) + return result["trigger_out"] + except HTTPException: + raise + except Exception as e: + db_session.rollback() + logger.error("Trigger update failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + + +@router.delete("/{trigger_id}", name="delete trigger") +def delete_trigger( + trigger_id: int, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + """Delete a trigger and its executions.""" + trigger = db_session.exec( + select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(auth.id))) + ).first() + if not trigger: + raise HTTPException(status_code=404, detail="Trigger not found") + + try: + db_session.exec(delete(TriggerExecution).where(TriggerExecution.trigger_id == trigger_id)) + db_session.delete(trigger) + db_session.commit() + logger.info("Trigger deleted", extra={"user_id": auth.id, "trigger_id": trigger_id}) + return Response(status_code=204) + except Exception as e: + db_session.rollback() + logger.error("Trigger deletion failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + + +@router.post("/{trigger_id}/activate", name="activate trigger", response_model=TriggerOut) +def activate_trigger( + trigger_id: int, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + """Activate a trigger.""" + try: + result = TriggerCrudService.activate(trigger_id, auth.id, db_session) + _raise_on_error(result) + return result["trigger_out"] + except HTTPException: + raise + except Exception as e: + db_session.rollback() + logger.error("Trigger activation failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + + +@router.post("/{trigger_id}/deactivate", name="deactivate trigger", response_model=TriggerOut) +def deactivate_trigger( + trigger_id: int, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + """Deactivate a trigger.""" + try: + result = TriggerCrudService.deactivate(trigger_id, auth.id, db_session) + _raise_on_error(result) + return result["trigger_out"] + except HTTPException: + raise + except Exception as e: + db_session.rollback() + logger.error("Trigger deactivation failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + + +@router.get("/{trigger_id}/executions", name="list trigger executions") +def list_trigger_executions( + trigger_id: int, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +) -> Page[TriggerExecutionOut]: + """List executions for a specific trigger.""" + trigger = db_session.exec( + select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(auth.id))) + ).first() + if not trigger: + raise HTTPException(status_code=404, detail="Trigger not found") + + stmt = ( + select(TriggerExecution) + .where(TriggerExecution.trigger_id == trigger_id) + .order_by(desc(TriggerExecution.created_at)) + ) + return paginate(db_session, stmt) + + +@router.get("/{trigger_type}/config", name="get trigger type config schema") +def get_trigger_type_config( + trigger_type: TriggerType, + auth: V1UserAuth = Depends(auth_must), +) -> TriggerConfigSchemaOut: + """Get the configuration schema for a specific trigger type.""" + schema = get_config_schema(trigger_type) + return TriggerConfigSchemaOut( + trigger_type=trigger_type.value, + has_config=has_config(trigger_type), + schema_=schema, + ) diff --git a/server/app/controller/trigger/trigger_execution_controller.py b/server/app/domains/trigger/api/trigger_execution_controller.py similarity index 55% rename from server/app/controller/trigger/trigger_execution_controller.py rename to server/app/domains/trigger/api/trigger_execution_controller.py index 1c838549..1802aa77 100644 --- a/server/app/controller/trigger/trigger_execution_controller.py +++ b/server/app/domains/trigger/api/trigger_execution_controller.py @@ -12,111 +12,62 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +"""Trigger Execution controller. Uses TriggerCrudService for REST, WebSocket handled locally.""" + from fastapi import APIRouter, Depends, HTTPException, Response, WebSocket, WebSocketDisconnect from fastapi_pagination import Page from fastapi_pagination.ext.sqlmodel import paginate from sqlmodel import Session, select, desc, and_ from typing import Optional, Dict, Any from datetime import datetime, timezone -from uuid import uuid4 -import logging +from loguru import logger import asyncio from app.model.trigger.trigger_execution import ( - TriggerExecution, - TriggerExecutionIn, - TriggerExecutionOut, - TriggerExecutionUpdate + TriggerExecution, + TriggerExecutionIn, + TriggerExecutionOut, + TriggerExecutionUpdate, ) from app.model.trigger.trigger import Trigger from app.model.user.user import User -from app.type.trigger_types import ExecutionStatus, ExecutionType -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.component.redis_utils import get_redis_manager -from app.service.trigger.trigger_service import TriggerService - -logger = logging.getLogger("server_trigger_execution_controller") +from app.shared.types.trigger_types import ExecutionStatus, ExecutionType +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.core.database import session +from app.core.redis_utils import get_redis_manager +from app.domains.trigger.service.trigger_crud_service import TriggerCrudService # Store active WebSocket connections per session (WebSocket objects only, metadata in Redis) -# Format: {session_id: WebSocket} -# This is per-worker, and Redis pub/sub is used to broadcast across workers active_websockets: Dict[str, WebSocket] = {} - -# Background task for Redis pub/sub _pubsub_task = None router = APIRouter(prefix="/execution", tags=["Trigger Executions"]) +def _raise_on_error(result: dict) -> None: + """Convert service error dict to HTTPException.""" + if result["success"]: + return + raise HTTPException(status_code=result.get("status_code", 500), detail=result.get("error", "Internal server error")) + + @router.post("/", name="create trigger execution", response_model=TriggerExecutionOut) async def create_trigger_execution( data: TriggerExecutionIn, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), ): """Create a new trigger execution.""" - user_id = auth.user.id - - # Verify the trigger exists and belongs to the user - trigger = session.exec( - select(Trigger).where( - and_(Trigger.id == data.trigger_id, Trigger.user_id == str(user_id)) - ) - ).first() - - if not trigger: - logger.warning("Trigger not found for execution creation", extra={ - "user_id": user_id, - "trigger_id": data.trigger_id - }) - raise HTTPException(status_code=404, detail="Trigger not found") - try: - execution_data = data.model_dump() - execution = TriggerExecution(**execution_data) - - session.add(execution) - session.commit() - session.refresh(execution) - - # Update trigger last executed timestamp - trigger.last_executed_at = datetime.now(timezone.utc) - session.add(trigger) - session.commit() - - logger.info("Trigger execution created", extra={ - "user_id": user_id, - "trigger_id": data.trigger_id, - "execution_id": execution.execution_id, - "execution_type": data.execution_type.value - }) - - # Publish to Redis pub/sub (broadcasts to all workers) - redis_manager = get_redis_manager() - redis_manager.publish_execution_event({ - "type": "execution_created", - "execution_id": execution.execution_id, - "trigger_id": trigger.id, - "trigger_type": trigger.trigger_type.value if trigger.trigger_type else "unknown", - "task_prompt": trigger.task_prompt, - "status": execution.status.value, - "input_data": execution.input_data, - "execution_type": data.execution_type.value, - "user_id": str(user_id), - "timestamp": datetime.now(timezone.utc).isoformat(), - "project_id": str(trigger.project_id) - }) - - return execution - + result = TriggerCrudService.create_execution(data, auth.id, db_session) + _raise_on_error(result) + return result["execution"] + except HTTPException: + raise except Exception as e: - session.rollback() - logger.error("Trigger execution creation failed", extra={ - "user_id": user_id, - "trigger_id": data.trigger_id, - "error": str(e) - }, exc_info=True) + db_session.rollback() + logger.error("Trigger execution creation failed", extra={"user_id": auth.id, "trigger_id": data.trigger_id, "error": str(e)}, exc_info=True) raise HTTPException(status_code=500, detail="Internal server error") @@ -125,90 +76,43 @@ def list_executions( trigger_id: Optional[int] = None, status: Optional[ExecutionStatus] = None, execution_type: Optional[ExecutionType] = None, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), ) -> Page[TriggerExecutionOut]: """List trigger executions for current user.""" - user_id = auth.user.id - - # Get all trigger IDs that belong to the user - user_trigger_ids = session.exec( - select(Trigger.id).where(Trigger.user_id == str(user_id)) - ).all() - + user_id = auth.id + user_trigger_ids = db_session.exec(select(Trigger.id).where(Trigger.user_id == str(user_id))).all() if not user_trigger_ids: - # User has no triggers, return empty result return Page(items=[], total=0, page=1, size=50, pages=0) - - # Build conditions + conditions = [TriggerExecution.trigger_id.in_(user_trigger_ids)] - if trigger_id: if trigger_id not in user_trigger_ids: raise HTTPException(status_code=404, detail="Trigger not found") conditions.append(TriggerExecution.trigger_id == trigger_id) - if status is not None: conditions.append(TriggerExecution.status == status) - if execution_type: conditions.append(TriggerExecution.execution_type == execution_type) - - stmt = ( - select(TriggerExecution) - .where(and_(*conditions)) - .order_by(desc(TriggerExecution.created_at)) - ) - - result = paginate(session, stmt) - total = result.total if hasattr(result, 'total') else 0 - - logger.debug("Executions listed", extra={ - "user_id": user_id, - "total": total, - "filters": { - "trigger_id": trigger_id, - "status": status.value if status is not None else None, - "execution_type": execution_type.value if execution_type else None - } - }) - - return result + + stmt = select(TriggerExecution).where(and_(*conditions)).order_by(desc(TriggerExecution.created_at)) + return paginate(db_session, stmt) @router.get("/{execution_id}", name="get execution", response_model=TriggerExecutionOut) def get_execution( execution_id: str, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), ): """Get a specific execution by execution ID.""" - user_id = auth.user.id - - # Get the execution and verify ownership through trigger - execution = session.exec( - select(TriggerExecution) - .join(Trigger) - .where( - and_( - TriggerExecution.execution_id == execution_id, - Trigger.user_id == str(user_id) - ) + execution = db_session.exec( + select(TriggerExecution).join(Trigger).where( + and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(auth.id)) ) ).first() - if not execution: - logger.warning("Execution not found", extra={ - "user_id": user_id, - "execution_id": execution_id - }) raise HTTPException(status_code=404, detail="Execution not found") - - logger.debug("Execution retrieved", extra={ - "user_id": user_id, - "execution_id": execution_id - }) - return execution @@ -216,245 +120,67 @@ def get_execution( async def update_execution( execution_id: str, data: TriggerExecutionUpdate, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), ): """Update a trigger execution.""" - user_id = auth.user.id - - # Get the execution and verify ownership through trigger - execution = session.exec( - select(TriggerExecution) - .join(Trigger) - .where( - and_( - TriggerExecution.execution_id == execution_id, - Trigger.user_id == str(user_id) - ) - ) - ).first() - - if not execution: - logger.warning("Execution not found for update", extra={ - "user_id": user_id, - "execution_id": execution_id - }) - raise HTTPException(status_code=404, detail="Execution not found") - try: - update_data = data.model_dump(exclude_unset=True) - - # Check if status is being updated - use TriggerService for proper failure tracking - if "status" in update_data: - trigger_service = TriggerService(session) - # Convert status string back to enum for TriggerService - status_value = ExecutionStatus(update_data["status"]) if isinstance(update_data["status"], str) else update_data["status"] - trigger_service.update_execution_status( - execution=execution, - status=status_value, - output_data=update_data.get("output_data"), - error_message=update_data.get("error_message"), - tokens_used=update_data.get("tokens_used"), - tools_executed=update_data.get("tools_executed") - ) - # Remove status-related fields from update_data since TriggerService handled them - for key in ["status", "output_data", "error_message", "tokens_used", "tools_executed"]: - update_data.pop(key, None) - - # Update remaining fields - if update_data: - # Auto-calculate duration if both started_at and completed_at are set - if ("started_at" in update_data or "completed_at" in update_data) and execution.started_at: - completed_at = update_data.get("completed_at") or execution.completed_at - if completed_at: - # Ensure both datetimes are timezone-aware for subtraction - started_at = execution.started_at - if started_at.tzinfo is None: - started_at = started_at.replace(tzinfo=timezone.utc) - if completed_at.tzinfo is None: - completed_at = completed_at.replace(tzinfo=timezone.utc) - duration = (completed_at - started_at).total_seconds() - update_data["duration_seconds"] = duration - - for key, value in update_data.items(): - setattr(execution, key, value) - - session.add(execution) - session.commit() - - session.refresh(execution) - - # Get trigger for event publishing - trigger = session.get(Trigger, execution.trigger_id) - - logger.info("Execution updated", extra={ - "user_id": user_id, - "execution_id": execution_id, - "fields_updated": list(data.model_dump(exclude_unset=True).keys()) - }) - - # Publish to Redis pub/sub (broadcasts to all workers) - redis_manager = get_redis_manager() - redis_manager.publish_execution_event({ - "type": "execution_updated", - "execution_id": execution_id, - "trigger_id": execution.trigger_id, - "status": execution.status.value, - "updated_fields": list(update_data.keys()), - "user_id": str(user_id), - "timestamp": datetime.now(timezone.utc).isoformat(), - "project_id": str(trigger.project_id) if trigger else None - }) - - return execution - + result = TriggerCrudService.update_execution(execution_id, data, auth.id, db_session) + _raise_on_error(result) + return result["execution"] + except HTTPException: + raise except Exception as e: - session.rollback() - logger.error("Execution update failed", extra={ - "user_id": user_id, - "execution_id": execution_id, - "error": str(e) - }, exc_info=True) + db_session.rollback() + logger.error("Execution update failed", extra={"user_id": auth.id, "execution_id": execution_id, "error": str(e)}, exc_info=True) raise HTTPException(status_code=500, detail="Internal server error") @router.delete("/{execution_id}", name="delete execution") def delete_execution( execution_id: str, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), ): """Delete a trigger execution.""" - user_id = auth.user.id - - # Get the execution and verify ownership through trigger - execution = session.exec( - select(TriggerExecution) - .join(Trigger) - .where( - and_( - TriggerExecution.execution_id == execution_id, - Trigger.user_id == str(user_id) - ) + execution = db_session.exec( + select(TriggerExecution).join(Trigger).where( + and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(auth.id)) ) ).first() - if not execution: - logger.warning("Execution not found for deletion", extra={ - "user_id": user_id, - "execution_id": execution_id - }) raise HTTPException(status_code=404, detail="Execution not found") - try: - session.delete(execution) - session.commit() - - logger.info("Execution deleted", extra={ - "user_id": user_id, - "execution_id": execution_id - }) - + db_session.delete(execution) + db_session.commit() return Response(status_code=204) - except Exception as e: - session.rollback() - logger.error("Execution deletion failed", extra={ - "user_id": user_id, - "execution_id": execution_id, - "error": str(e) - }, exc_info=True) + db_session.rollback() + logger.error("Execution deletion failed", extra={"user_id": auth.id, "execution_id": execution_id, "error": str(e)}, exc_info=True) raise HTTPException(status_code=500, detail="Internal server error") @router.post("/{execution_id}/retry", name="retry execution", response_model=TriggerExecutionOut) def retry_execution( execution_id: str, - session: Session = Depends(session), - auth: Auth = Depends(auth_must) + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), ): """Retry a failed execution.""" - user_id = auth.user.id - - # Get the execution and verify ownership through trigger - execution = session.exec( - select(TriggerExecution) - .join(Trigger) - .where( - and_( - TriggerExecution.execution_id == execution_id, - Trigger.user_id == str(user_id) - ) - ) - ).first() - - if not execution: - logger.warning("Execution not found for retry", extra={ - "user_id": user_id, - "execution_id": execution_id - }) - raise HTTPException(status_code=404, detail="Execution not found") - - if execution.status != ExecutionStatus.failed: - raise HTTPException(status_code=400, detail="Only failed executions can be retried") - - if execution.attempts >= execution.max_retries: - raise HTTPException(status_code=400, detail="Maximum retry attempts exceeded") - try: - # Create a new execution for the retry - new_execution_id = str(uuid4()) - new_execution = TriggerExecution( - trigger_id=execution.trigger_id, - execution_id=new_execution_id, - execution_type=execution.execution_type, - input_data=execution.input_data, - attempts=execution.attempts + 1, - max_retries=execution.max_retries - ) - - session.add(new_execution) - session.commit() - session.refresh(new_execution) - - # Get trigger for event publishing - trigger = session.get(Trigger, execution.trigger_id) - - logger.info("Execution retry created", extra={ - "user_id": user_id, - "original_execution_id": execution_id, - "new_execution_id": new_execution_id, - "attempts": new_execution.attempts - }) - - # Publish to Redis pub/sub (broadcasts to all workers) - redis_manager = get_redis_manager() - redis_manager.publish_execution_event({ - "type": "execution_created", - "execution_id": new_execution.execution_id, - "trigger_id": trigger.id if trigger else execution.trigger_id, - "trigger_type": trigger.trigger_type.value if trigger and trigger.trigger_type else "unknown", - "task_prompt": trigger.task_prompt if trigger else None, - "status": new_execution.status.value, - "input_data": new_execution.input_data, - "execution_type": new_execution.execution_type.value, - "user_id": str(user_id), - "timestamp": datetime.now(timezone.utc).isoformat(), - "project_id": str(trigger.project_id) if trigger else None - }) - - return new_execution - + result = TriggerCrudService.retry_execution(execution_id, auth.id, db_session) + _raise_on_error(result) + return result["execution"] + except HTTPException: + raise except Exception as e: - session.rollback() - logger.error("Execution retry failed", extra={ - "user_id": user_id, - "execution_id": execution_id, - "error": str(e) - }, exc_info=True) + db_session.rollback() + logger.error("Execution retry failed", extra={"user_id": auth.id, "execution_id": execution_id, "error": str(e)}, exc_info=True) raise HTTPException(status_code=500, detail="Internal server error") +# ---- WebSocket (kept in controller due to process-level state) ---- + @router.websocket("/subscribe") async def subscribe_executions(websocket: WebSocket): """Subscribe to trigger execution events via WebSocket. @@ -471,12 +197,12 @@ async def subscribe_executions(websocket: WebSocket): await websocket.accept() session_id = None user_id = None - db_session = None - + ws_db_session = None + try: # Create database session manually for WebSocket - from app.component.database import session_make - db_session = session_make() + from app.core.database import session_make + ws_db_session = session_make() # Wait for subscription message data = await websocket.receive_json() @@ -501,14 +227,18 @@ async def subscribe_executions(websocket: WebSocket): return try: - from app.component.auth import Auth - # Decode token and fetch user - auth = Auth.decode_token(auth_token) - user = db_session.get(User, auth.id) + from app.shared.auth.user_auth import V1UserAuth, _get_jti + from app.shared.auth.token_blacklist import is_blacklisted as _is_blacklisted + # Decode token and check blacklist + auth = V1UserAuth.decode_token(auth_token) + jti = _get_jti(auth_token) + if jti and await _is_blacklisted(jti): + raise Exception("Token has been revoked") + user = ws_db_session.get(User, auth.id) if not user: raise Exception("User not found") auth._user = user - user_id = auth.user.id + user_id = auth.id logger.info(f"User authenticated for WebSocket {user_id} and {session_id}", extra={ "user_id": user_id, "session_id": session_id @@ -562,7 +292,7 @@ async def subscribe_executions(websocket: WebSocket): redis_manager.remove_pending_execution(session_id, execution_id) # Update execution status to running - execution = db_session.exec( + execution = ws_db_session.exec( select(TriggerExecution).where( TriggerExecution.execution_id == execution_id ) @@ -571,8 +301,8 @@ async def subscribe_executions(websocket: WebSocket): if execution and execution.status == ExecutionStatus.pending: execution.status = ExecutionStatus.running execution.started_at = datetime.now(timezone.utc) - db_session.add(execution) - db_session.commit() + ws_db_session.add(execution) + ws_db_session.commit() logger.info("Execution acknowledged and started", extra={ "session_id": session_id, @@ -636,8 +366,8 @@ async def subscribe_executions(websocket: WebSocket): logger.info("Session cleaned up", extra={"session_id": session_id}) # Close database session - if db_session: - db_session.close() + if ws_db_session: + ws_db_session.close() async def handle_pubsub_message(event_data: Dict[str, Any]): diff --git a/server/app/controller/trigger/webhook_controller.py b/server/app/domains/trigger/api/webhook_controller.py similarity index 92% rename from server/app/controller/trigger/webhook_controller.py rename to server/app/domains/trigger/api/webhook_controller.py index b786cc85..15bb3bee 100644 --- a/server/app/controller/trigger/webhook_controller.py +++ b/server/app/domains/trigger/api/webhook_controller.py @@ -23,17 +23,15 @@ from sqlmodel import Session, select, and_, or_ from uuid import uuid4 from datetime import datetime, timezone import json -import logging +from loguru import logger from fastapi_limiter.depends import RateLimiter from app.model.trigger.trigger import Trigger from app.model.trigger.trigger_execution import TriggerExecution -from app.type.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus -from app.component.database import session -from app.component.trigger_utils import check_rate_limits -from app.service.trigger.app_handler_service import get_app_handler - -logger = logging.getLogger("server_webhook_controller") +from app.shared.types.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus +from app.core.database import session +from app.core.trigger_utils import check_rate_limits +from app.domains.trigger.service.app_handler_service import get_app_handler router = APIRouter(prefix="/webhook", tags=["Webhook"]) @@ -58,13 +56,14 @@ async def webhook_trigger( input_data = {"raw_body": body.decode()} headers = dict(request.headers) - webhook_url = f"/webhook/trigger/{webhook_uuid}" - - # Find the trigger (allow active and pending_verification for verification flows) + webhook_url_new = f"/v1/webhook/trigger/{webhook_uuid}" + webhook_url_old = f"/webhook/trigger/{webhook_uuid}" + + # Find the trigger (match both old and new URL formats) trigger = db_session.exec( select(Trigger).where( and_( - Trigger.webhook_url == webhook_url, + or_(Trigger.webhook_url == webhook_url_new, Trigger.webhook_url == webhook_url_old), Trigger.trigger_type.in_(WEBHOOK_TRIGGER_TYPES), Trigger.status.in_([TriggerStatus.active, TriggerStatus.pending_verification]) ) @@ -111,7 +110,7 @@ async def webhook_trigger( # Notify Redis subscribers of successful activation try: - from app.component.redis_utils import get_redis_manager + from app.core.redis_utils import get_redis_manager redis_manager = get_redis_manager() redis_manager.publish_execution_event({ "type": "trigger_activated", @@ -233,7 +232,7 @@ async def webhook_trigger( # Notify WebSocket subscribers and wait for delivery confirmation try: - from app.component.redis_utils import get_redis_manager + from app.core.redis_utils import get_redis_manager redis_manager = get_redis_manager() # Check if user has any active WebSocket sessions @@ -302,7 +301,7 @@ async def webhook_trigger( "execution_id": execution_id, "message": "Webhook trigger processed but WebSocket notification failed", "delivered": False, - "reason": "websocket_notification_error" + "reason": str(e) } except HTTPException: @@ -321,12 +320,13 @@ def get_webhook_info( db_session: Session = Depends(session) ): """Get information about a webhook trigger (public endpoint).""" - webhook_url = f"/webhook/trigger/{webhook_uuid}" - + webhook_url_new = f"/v1/webhook/trigger/{webhook_uuid}" + webhook_url_old = f"/webhook/trigger/{webhook_uuid}" + trigger = db_session.exec( select(Trigger).where( and_( - Trigger.webhook_url == webhook_url, + or_(Trigger.webhook_url == webhook_url_new, Trigger.webhook_url == webhook_url_old), Trigger.trigger_type.in_(WEBHOOK_TRIGGER_TYPES) ) ) @@ -335,14 +335,11 @@ def get_webhook_info( if not trigger: raise HTTPException(status_code=404, detail="Webhook not found") + # Only expose minimal info on public endpoint to avoid information disclosure return { - "name": trigger.name, - "description": trigger.description, - "status": trigger.status.value, - "trigger_type": trigger.trigger_type.value, "is_active": trigger.status == TriggerStatus.active, + "trigger_type": trigger.trigger_type.value, "webhook_method": trigger.webhook_method.value if trigger.webhook_method else None, - "last_executed_at": trigger.last_executed_at.isoformat() if trigger.last_executed_at else None, } diff --git a/server/app/domains/trigger/schema/__init__.py b/server/app/domains/trigger/schema/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/domains/trigger/schema/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/service/trigger/__init__.py b/server/app/domains/trigger/service/__init__.py similarity index 69% rename from server/app/service/trigger/__init__.py rename to server/app/domains/trigger/service/__init__.py index f8bbaca2..1178087e 100644 --- a/server/app/service/trigger/__init__.py +++ b/server/app/domains/trigger/service/__init__.py @@ -12,18 +12,11 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -""" -Trigger Service Package +"""Trigger domain services.""" -Contains services for managing triggers including: -- TriggerService: Main service for trigger operations -- TriggerScheduleService: Service for scheduled trigger operations -- App Handlers: Handlers for different trigger types (Slack, Webhook, Schedule) -""" - -from app.service.trigger.trigger_service import TriggerService, get_trigger_service -from app.service.trigger.trigger_schedule_service import TriggerScheduleService -from app.service.trigger.app_handler_service import ( +from app.domains.trigger.service.trigger_service import TriggerService, get_trigger_service +from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService +from app.domains.trigger.service.app_handler_service import ( BaseAppHandler, SlackAppHandler, DefaultWebhookHandler, @@ -36,19 +29,16 @@ from app.service.trigger.app_handler_service import ( ) __all__ = [ - # Services "TriggerService", "get_trigger_service", "TriggerScheduleService", - # Handlers "BaseAppHandler", "SlackAppHandler", "DefaultWebhookHandler", "ScheduleAppHandler", "AppHandlerResult", - # Handler functions "get_app_handler", "get_schedule_handler", "register_app_handler", "get_supported_trigger_types", -] \ No newline at end of file +] diff --git a/server/app/component/service/trigger/app_handler_service.py b/server/app/domains/trigger/service/app_handler_service.py similarity index 98% rename from server/app/component/service/trigger/app_handler_service.py rename to server/app/domains/trigger/service/app_handler_service.py index 0fd1c7a2..e31a2a01 100644 --- a/server/app/component/service/trigger/app_handler_service.py +++ b/server/app/domains/trigger/service/app_handler_service.py @@ -24,13 +24,13 @@ from typing import Optional from dataclasses import dataclass from fastapi import Request from sqlmodel import Session, select, and_ -import logging +from loguru import logger from app.model.trigger.trigger import Trigger from app.model.config.config import Config from app.model.trigger.app_configs import SlackTriggerConfig, WebhookTriggerConfig, ScheduleTriggerConfig -from app.type.trigger_types import TriggerType, ExecutionType, TriggerStatus -from app.type.config_group import ConfigGroup +from app.shared.types.trigger_types import TriggerType, ExecutionType, TriggerStatus +from app.shared.types.config_group import ConfigGroup @dataclass diff --git a/server/app/domains/trigger/service/slack_service.py b/server/app/domains/trigger/service/slack_service.py new file mode 100644 index 00000000..76b0de7b --- /dev/null +++ b/server/app/domains/trigger/service/slack_service.py @@ -0,0 +1,83 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""SlackService: Slack integration helpers (channel listing).""" + +from loguru import logger +from sqlmodel import select, and_ + +from app.core.database import session_make +from app.model.config.config import Config +from app.shared.types.config_group import ConfigGroup + + +class SlackService: + """Slack integration - static methods.""" + + @staticmethod + def get_channels(user_id: int) -> dict: + """Fetch Slack channels for user. Returns {"success": True, "channels": [...], "has_credentials": True} or error.""" + with session_make() as s: + configs = s.exec( + select(Config).where( + and_( + Config.user_id == int(user_id), + Config.config_group == ConfigGroup.SLACK.value, + ) + ) + ).all() + + credentials = {config.config_name: config.config_value for config in configs} + bot_token = credentials.get("SLACK_BOT_TOKEN") + + if not bot_token: + logger.warning("Slack credentials not found", extra={"user_id": user_id}) + return {"success": True, "channels": [], "has_credentials": False} + + try: + from slack_sdk import WebClient + from slack_sdk.errors import SlackApiError + + client = WebClient(token=bot_token) + channels = [] + cursor = None + + while True: + response = client.conversations_list( + types="public_channel,private_channel", + cursor=cursor, + limit=200, + ) + for channel in response.get("channels", []): + channels.append({ + "id": channel.get("id"), + "name": channel.get("name"), + "is_private": channel.get("is_private", False), + "is_member": channel.get("is_member", False), + "num_members": channel.get("num_members"), + }) + cursor = response.get("response_metadata", {}).get("next_cursor") + if not cursor: + break + + logger.info("Slack channels fetched", extra={"user_id": user_id, "channel_count": len(channels)}) + return {"success": True, "channels": channels, "has_credentials": True} + + except ImportError: + logger.error("slack_sdk not installed") + return {"success": False, "error_code": "SLACK_SDK_NOT_INSTALLED"} + except Exception as e: + error_detail = getattr(e, "response", {}).get("error", str(e)) if hasattr(e, "response") else str(e) + logger.error("Slack API error", extra={"user_id": user_id, "error": error_detail}) + return {"success": False, "error_code": "SLACK_API_ERROR", "detail": error_detail} diff --git a/server/app/domains/trigger/service/trigger_crud_service.py b/server/app/domains/trigger/service/trigger_crud_service.py new file mode 100644 index 00000000..20a67251 --- /dev/null +++ b/server/app/domains/trigger/service/trigger_crud_service.py @@ -0,0 +1,594 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""TriggerCrudService: trigger CRUD + execution business logic.""" + +from datetime import datetime, timezone +from uuid import uuid4 +from typing import Optional + +from loguru import logger +from pydantic import ValidationError +from sqlmodel import Session, select, and_ +from sqlalchemy import func + +from app.model.trigger.trigger import Trigger, TriggerIn, TriggerOut, TriggerUpdate +from app.model.trigger.trigger_execution import TriggerExecution, TriggerExecutionIn, TriggerExecutionUpdate +from app.shared.types.trigger_types import ExecutionStatus, ExecutionType +from app.model.trigger.app_configs import ( + get_config_schema, + validate_config, + has_config, + validate_activation, + ActivationError, +) +from app.model.trigger.app_configs.config_registry import requires_authentication +from app.model.chat.chat_history import ChatHistory +from app.shared.types.trigger_types import TriggerType, TriggerStatus +from app.core.redis_utils import get_redis_manager +from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService +from app.domains.trigger.service.trigger_service import TriggerService + + +ACTIVE_STATUSES = (TriggerStatus.active, TriggerStatus.pending_verification) +MAX_ACTIVE_PER_USER = 25 +MAX_ACTIVE_PER_PROJECT = 5 + + +class TriggerCrudService: + """Trigger CRUD business logic - static methods, caller-managed session.""" + + @staticmethod + def get_active_trigger_counts(s: Session, user_id: str, project_id: str | None = None) -> tuple[int, int]: + """Return (user_active_count, project_active_count) for active/pending triggers.""" + user_count = s.exec( + select(func.count(Trigger.id)).where( + and_( + Trigger.user_id == user_id, + Trigger.status.in_(ACTIVE_STATUSES), + ) + ) + ).one() + + project_count = 0 + if project_id: + project_count = s.exec( + select(func.count(Trigger.id)).where( + and_( + Trigger.user_id == user_id, + Trigger.project_id == project_id, + Trigger.status.in_(ACTIVE_STATUSES), + ) + ) + ).one() + + return user_count, project_count + + @staticmethod + def get_execution_counts(s: Session, trigger_ids: list[int]) -> dict[int, int]: + """Get execution counts for multiple triggers in a single query.""" + if not trigger_ids: + return {} + result = s.exec( + select(TriggerExecution.trigger_id, func.count(TriggerExecution.id)) + .where(TriggerExecution.trigger_id.in_(trigger_ids)) + .group_by(TriggerExecution.trigger_id) + ).all() + return {trigger_id: count for trigger_id, count in result} + + @staticmethod + def trigger_to_out(trigger: Trigger, execution_count: int = 0) -> TriggerOut: + """Convert Trigger model to TriggerOut with execution count.""" + return TriggerOut( + id=trigger.id, + user_id=trigger.user_id, + project_id=trigger.project_id, + name=trigger.name, + description=trigger.description, + trigger_type=trigger.trigger_type, + status=trigger.status, + execution_count=execution_count, + webhook_url=trigger.webhook_url, + webhook_method=trigger.webhook_method, + custom_cron_expression=trigger.custom_cron_expression, + listener_type=trigger.listener_type, + agent_model=trigger.agent_model, + task_prompt=trigger.task_prompt, + config=trigger.config, + max_executions_per_hour=trigger.max_executions_per_hour, + max_executions_per_day=trigger.max_executions_per_day, + is_single_execution=trigger.is_single_execution, + last_executed_at=trigger.last_executed_at, + next_run_at=trigger.next_run_at, + last_execution_status=trigger.last_execution_status, + created_at=trigger.created_at, + updated_at=trigger.updated_at, + ) + + @staticmethod + def _ensure_project_chat_history(data: TriggerIn, user_id: int, s: Session) -> None: + """Create placeholder ChatHistory for a new project if it doesn't exist, and notify via WebSocket.""" + if not data.project_id: + return + + existing_chat = s.exec( + select(ChatHistory).where(ChatHistory.project_id == data.project_id) + ).first() + if existing_chat: + return + + chat_history = ChatHistory( + user_id=user_id, + task_id=data.project_id, + project_id=data.project_id, + question=f"Project created via trigger: {data.name}", + language="en", + model_platform=data.agent_model or "none", + model_type=data.agent_model or "none", + installed_mcp="none", + api_key="", + api_url="", + max_retries=3, + project_name=data.name, + summary=data.description or "", + tokens=0, + spend=0, + status=2, + ) + s.add(chat_history) + s.commit() + s.refresh(chat_history) + + logger.info("Chat history created for new project", extra={ + "user_id": user_id, + "project_id": data.project_id, + "chat_history_id": chat_history.id, + }) + + # WebSocket notification (best effort) + try: + redis_manager = get_redis_manager() + redis_manager.publish_execution_event({ + "type": "project_created", + "project_id": data.project_id, + "project_name": data.name, + "chat_history_id": chat_history.id, + "trigger_name": data.name, + "user_id": str(user_id), + "created_at": chat_history.created_at.isoformat() if chat_history.created_at else None, + }) + except Exception as e: + logger.warning("Failed to send WebSocket notification for new project", extra={ + "user_id": user_id, + "project_id": data.project_id, + "error": str(e), + }) + + @staticmethod + def _validate_trigger_config(trigger_type: TriggerType, config: dict | None) -> None: + """Validate trigger-type specific config. Raises HTTPException-compatible ValueError on failure.""" + if config and has_config(trigger_type): + try: + validate_config(trigger_type, config) + except ValidationError as e: + raise ValueError(f"Invalid config for {trigger_type.value}: {e.errors()}") + + @staticmethod + def _determine_initial_status( + data: TriggerIn, user_id: int, s: Session + ) -> TriggerStatus: + """Determine initial trigger status based on auth requirements and concurrency limits.""" + # Desired status from auth requirements + if has_config(data.trigger_type) and data.config and requires_authentication(data.trigger_type, data.config): + desired_status = TriggerStatus.pending_verification + else: + desired_status = TriggerStatus.active + + # Check concurrency limits + user_active, project_active = TriggerCrudService.get_active_trigger_counts( + s, str(user_id), data.project_id + ) + if user_active >= MAX_ACTIVE_PER_USER or ( + data.project_id and project_active >= MAX_ACTIVE_PER_PROJECT + ): + logger.info( + "Active trigger limit reached — new trigger created as inactive", + extra={ + "user_id": user_id, + "project_id": data.project_id, + "user_active": user_active, + "project_active": project_active, + }, + ) + return TriggerStatus.inactive + + return desired_status + + @staticmethod + def create(data: TriggerIn, user_id: int, s: Session) -> dict: + """ + Create a new trigger with all business rules. + Returns {"success": True, "trigger_out": TriggerOut} or {"success": False, "error": str, "status_code": int}. + """ + # 1. Ensure project ChatHistory exists + TriggerCrudService._ensure_project_chat_history(data, user_id, s) + + # 2. Generate webhook URL + webhook_url = None + if data.trigger_type in (TriggerType.webhook, TriggerType.slack_trigger): + webhook_url = f"/v1/webhook/trigger/{uuid4()}" + + # 3. Validate config + try: + TriggerCrudService._validate_trigger_config(data.trigger_type, data.config) + except ValueError as e: + return {"success": False, "error": str(e), "status_code": 400} + + # 4. Determine initial status + initial_status = TriggerCrudService._determine_initial_status(data, user_id, s) + + # 5. Create trigger + trigger_data = data.model_dump() + trigger_data["user_id"] = str(user_id) + trigger_data["webhook_url"] = webhook_url + trigger_data["status"] = initial_status + + trigger = Trigger(**trigger_data) + s.add(trigger) + s.commit() + s.refresh(trigger) + + # 6. Calculate next_run_at for scheduled triggers + if trigger.trigger_type == TriggerType.schedule and trigger.custom_cron_expression: + schedule_service = TriggerScheduleService(s) + trigger.next_run_at = schedule_service.calculate_next_run_at(trigger) + s.add(trigger) + s.commit() + s.refresh(trigger) + + logger.info("Trigger created", extra={ + "user_id": user_id, + "trigger_id": trigger.id, + "trigger_type": data.trigger_type.value, + "next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None, + }) + + return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, 0)} + + @staticmethod + def update(trigger_id: int, data: TriggerUpdate, user_id: int, s: Session) -> dict: + """ + Update a trigger with config validation and schedule recalculation. + Returns {"success": True, "trigger_out": TriggerOut} or {"success": False, "error": str, "status_code": int}. + """ + trigger = s.exec( + select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))) + ).first() + if not trigger: + return {"success": False, "error": "Trigger not found", "status_code": 404} + + update_data = data.model_dump(exclude_unset=True) + + # Validate config if being updated + if "config" in update_data and update_data["config"] is not None: + try: + TriggerCrudService._validate_trigger_config(trigger.trigger_type, update_data["config"]) + except ValueError as e: + return {"success": False, "error": str(e), "status_code": 400} + + for key, value in update_data.items(): + setattr(trigger, key, value) + + # Recalculate next_run_at if cron expression or status changed for scheduled triggers + if trigger.trigger_type == TriggerType.schedule: + if "custom_cron_expression" in update_data or "status" in update_data: + if trigger.status == TriggerStatus.active and trigger.custom_cron_expression: + schedule_service = TriggerScheduleService(s) + trigger.next_run_at = schedule_service.calculate_next_run_at(trigger) + + s.add(trigger) + s.commit() + s.refresh(trigger) + + counts = TriggerCrudService.get_execution_counts(s, [trigger_id]) + execution_count = counts.get(trigger_id, 0) + + logger.info("Trigger updated", extra={ + "user_id": user_id, + "trigger_id": trigger_id, + "fields_updated": list(update_data.keys()), + "next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None, + }) + + return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, execution_count)} + + @staticmethod + def activate(trigger_id: int, user_id: int, s: Session) -> dict: + """ + Activate a trigger with limit checks, auth requirements, and activation validation. + Returns {"success": True, "trigger_out": TriggerOut} + or {"success": False, "error": str/dict, "status_code": int}. + """ + trigger = s.exec( + select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))) + ).first() + if not trigger: + return {"success": False, "error": "Trigger not found", "status_code": 404} + + # 1. Check concurrency limits + user_active, project_active = TriggerCrudService.get_active_trigger_counts( + s, str(user_id), trigger.project_id + ) + if user_active >= MAX_ACTIVE_PER_USER: + return { + "success": False, + "error": f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_USER}) reached for this user", + "status_code": 400, + } + if trigger.project_id and project_active >= MAX_ACTIVE_PER_PROJECT: + return { + "success": False, + "error": f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_PROJECT}) reached for this project", + "status_code": 400, + } + + # 2. Check if authentication is required — go straight to pending_verification + if has_config(trigger.trigger_type) and requires_authentication(trigger.trigger_type, trigger.config): + trigger.status = TriggerStatus.pending_verification + s.add(trigger) + s.commit() + s.refresh(trigger) + return { + "success": False, + "error": { + "message": "Authentication required for this trigger type", + "missing_requirements": ["authentication"], + "trigger_type": trigger.trigger_type.value, + }, + "status_code": 401, + } + + # 3. Validate activation requirements (non-auth triggers) + if has_config(trigger.trigger_type): + try: + validate_activation( + trigger_type=trigger.trigger_type, + config_data=trigger.config, + user_id=int(user_id), + session=s, + ) + except ActivationError as e: + return { + "success": False, + "error": { + "message": e.message, + "missing_requirements": e.missing_requirements, + "trigger_type": trigger.trigger_type.value, + }, + "status_code": 400, + } + + # 4. Activate + trigger.status = TriggerStatus.active + s.add(trigger) + s.commit() + s.refresh(trigger) + + counts = TriggerCrudService.get_execution_counts(s, [trigger_id]) + execution_count = counts.get(trigger_id, 0) + + logger.info("Trigger activated", extra={ + "user_id": user_id, + "trigger_id": trigger_id, + "status": trigger.status.value, + }) + + return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, execution_count)} + + @staticmethod + def deactivate(trigger_id: int, user_id: int, s: Session) -> dict: + """ + Deactivate a trigger. + Returns {"success": True, "trigger_out": TriggerOut} or {"success": False, ...}. + """ + trigger = s.exec( + select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))) + ).first() + if not trigger: + return {"success": False, "error": "Trigger not found", "status_code": 404} + + trigger.status = TriggerStatus.inactive + s.add(trigger) + s.commit() + s.refresh(trigger) + + counts = TriggerCrudService.get_execution_counts(s, [trigger_id]) + execution_count = counts.get(trigger_id, 0) + + logger.info("Trigger deactivated", extra={ + "user_id": user_id, + "trigger_id": trigger_id, + }) + + return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, execution_count)} + + # ---- Execution CRUD ---- + + @staticmethod + def _publish_execution_event(event_type: str, execution: TriggerExecution, trigger: Trigger, user_id: int, extra: dict | None = None) -> None: + """Publish execution event to Redis pub/sub (best effort).""" + try: + payload = { + "type": event_type, + "execution_id": execution.execution_id, + "trigger_id": trigger.id, + "trigger_type": trigger.trigger_type.value if trigger.trigger_type else "unknown", + "task_prompt": trigger.task_prompt, + "status": execution.status.value, + "input_data": execution.input_data, + "execution_type": execution.execution_type.value, + "user_id": str(user_id), + "timestamp": datetime.now(timezone.utc).isoformat(), + "project_id": str(trigger.project_id), + } + if extra: + payload.update(extra) + get_redis_manager().publish_execution_event(payload) + except Exception as e: + logger.warning("Failed to publish execution event", extra={"execution_id": execution.execution_id, "error": str(e)}) + + @staticmethod + def create_execution(data: TriggerExecutionIn, user_id: int, s: Session) -> dict: + """ + Create a trigger execution: verify ownership, create record, update trigger timestamp, publish event. + Returns {"success": True, "execution": TriggerExecution} or {"success": False, ...}. + """ + trigger = s.exec( + select(Trigger).where(and_(Trigger.id == data.trigger_id, Trigger.user_id == str(user_id))) + ).first() + if not trigger: + return {"success": False, "error": "Trigger not found", "status_code": 404} + + execution = TriggerExecution(**data.model_dump()) + s.add(execution) + s.commit() + s.refresh(execution) + + # Update trigger timestamp + trigger.last_executed_at = datetime.now(timezone.utc) + s.add(trigger) + s.commit() + + logger.info("Trigger execution created", extra={ + "user_id": user_id, + "trigger_id": data.trigger_id, + "execution_id": execution.execution_id, + "execution_type": data.execution_type.value, + }) + + TriggerCrudService._publish_execution_event("execution_created", execution, trigger, user_id) + + return {"success": True, "execution": execution} + + @staticmethod + def update_execution(execution_id: str, data: TriggerExecutionUpdate, user_id: int, s: Session) -> dict: + """ + Update a trigger execution: ownership check, status via TriggerService, duration calc, publish event. + Returns {"success": True, "execution": TriggerExecution} or {"success": False, ...}. + """ + execution = s.exec( + select(TriggerExecution) + .join(Trigger) + .where(and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(user_id))) + ).first() + if not execution: + return {"success": False, "error": "Execution not found", "status_code": 404} + + update_data = data.model_dump(exclude_unset=True) + + # Delegate status update to TriggerService for proper failure tracking + if "status" in update_data: + trigger_service = TriggerService(s) + status_value = ExecutionStatus(update_data["status"]) if isinstance(update_data["status"], str) else update_data["status"] + trigger_service.update_execution_status( + execution=execution, + status=status_value, + output_data=update_data.get("output_data"), + error_message=update_data.get("error_message"), + tokens_used=update_data.get("tokens_used"), + tools_executed=update_data.get("tools_executed"), + ) + for key in ["status", "output_data", "error_message", "tokens_used", "tools_executed"]: + update_data.pop(key, None) + + # Update remaining fields + auto-calculate duration + if update_data: + if ("started_at" in update_data or "completed_at" in update_data) and execution.started_at: + completed_at = update_data.get("completed_at") or execution.completed_at + if completed_at: + started_at = execution.started_at + if started_at.tzinfo is None: + started_at = started_at.replace(tzinfo=timezone.utc) + if completed_at.tzinfo is None: + completed_at = completed_at.replace(tzinfo=timezone.utc) + update_data["duration_seconds"] = (completed_at - started_at).total_seconds() + + for key, value in update_data.items(): + setattr(execution, key, value) + s.add(execution) + s.commit() + + s.refresh(execution) + + # Publish event + trigger = s.get(Trigger, execution.trigger_id) + logger.info("Execution updated", extra={ + "user_id": user_id, + "execution_id": execution_id, + "fields_updated": list(data.model_dump(exclude_unset=True).keys()), + }) + + if trigger: + TriggerCrudService._publish_execution_event( + "execution_updated", execution, trigger, user_id, + extra={"updated_fields": list(update_data.keys())}, + ) + + return {"success": True, "execution": execution} + + @staticmethod + def retry_execution(execution_id: str, user_id: int, s: Session) -> dict: + """ + Retry a failed execution: validate status, create new execution, publish event. + Returns {"success": True, "execution": TriggerExecution} or {"success": False, ...}. + """ + execution = s.exec( + select(TriggerExecution) + .join(Trigger) + .where(and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(user_id))) + ).first() + if not execution: + return {"success": False, "error": "Execution not found", "status_code": 404} + + if execution.status != ExecutionStatus.failed: + return {"success": False, "error": "Only failed executions can be retried", "status_code": 400} + + if execution.attempts >= execution.max_retries: + return {"success": False, "error": "Maximum retry attempts exceeded", "status_code": 400} + + new_execution = TriggerExecution( + trigger_id=execution.trigger_id, + execution_id=str(uuid4()), + execution_type=execution.execution_type, + input_data=execution.input_data, + attempts=execution.attempts + 1, + max_retries=execution.max_retries, + ) + s.add(new_execution) + s.commit() + s.refresh(new_execution) + + trigger = s.get(Trigger, execution.trigger_id) + + logger.info("Execution retry created", extra={ + "user_id": user_id, + "original_execution_id": execution_id, + "new_execution_id": new_execution.execution_id, + "attempts": new_execution.attempts, + }) + + if trigger: + TriggerCrudService._publish_execution_event("execution_created", new_execution, trigger, user_id) + + return {"success": True, "execution": new_execution} diff --git a/server/app/component/service/trigger/trigger_schedule_service.py b/server/app/domains/trigger/service/trigger_schedule_service.py similarity index 98% rename from server/app/component/service/trigger/trigger_schedule_service.py rename to server/app/domains/trigger/service/trigger_schedule_service.py index c18bdff0..bfae64e3 100644 --- a/server/app/component/service/trigger/trigger_schedule_service.py +++ b/server/app/domains/trigger/service/trigger_schedule_service.py @@ -14,7 +14,7 @@ from datetime import datetime, timedelta, timezone from typing import List, Tuple, Optional -import logging +from loguru import logger from croniter import croniter from uuid import uuid4 import asyncio @@ -22,8 +22,8 @@ from sqlmodel import select from app.model.trigger.trigger import Trigger from app.model.trigger.trigger_execution import TriggerExecution -from app.type.trigger_types import TriggerStatus, ExecutionType, ExecutionStatus, TriggerType -from app.component.trigger_utils import check_rate_limits, MAX_DISPATCH_PER_TICK +from app.shared.types.trigger_types import TriggerStatus, ExecutionType, ExecutionStatus, TriggerType +from app.core.trigger_utils import check_rate_limits, MAX_DISPATCH_PER_TICK from app.model.trigger.app_configs import ScheduleTriggerConfig @@ -206,7 +206,7 @@ class TriggerScheduleService: # Using asyncio.run() to run async code from sync Celery worker context try: # Notify WebSocket subscribers via Redis pub/sub - from app.component.redis_utils import get_redis_manager + from app.core.redis_utils import get_redis_manager redis_manager = get_redis_manager() redis_manager.publish_execution_event({ "type": "execution_created", diff --git a/server/app/schedule/trigger_schedule_task.py b/server/app/domains/trigger/service/trigger_schedule_task.py similarity index 70% rename from server/app/schedule/trigger_schedule_task.py rename to server/app/domains/trigger/service/trigger_schedule_task.py index 88f1a17a..9079b677 100644 --- a/server/app/schedule/trigger_schedule_task.py +++ b/server/app/domains/trigger/service/trigger_schedule_task.py @@ -12,44 +12,40 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -from celery import shared_task +""" +Celery tasks for trigger scheduling: poll due triggers and check execution timeouts. +""" + import logging from datetime import datetime, timezone + +from celery import shared_task from sqlmodel import select, or_ -from app.component.database import session_make -from app.component.environment import env -from app.service.trigger.trigger_schedule_service import TriggerScheduleService -from app.service.trigger.trigger_service import TriggerService -from app.component.trigger_utils import MAX_DISPATCH_PER_TICK -from app.component.redis_utils import get_redis_manager +from app.core.database import session_make +from app.core.environment import env +from app.core.trigger_utils import MAX_DISPATCH_PER_TICK +from app.core.redis_utils import get_redis_manager from app.model.trigger.trigger_execution import TriggerExecution from app.model.trigger.trigger import Trigger -from app.type.trigger_types import ExecutionStatus +from app.shared.types.trigger_types import ExecutionStatus +from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService +from app.domains.trigger.service.trigger_service import TriggerService -# Timeout configuration from environment variables EXECUTION_PENDING_TIMEOUT_SECONDS = int(env("EXECUTION_PENDING_TIMEOUT_SECONDS", "60")) -EXECUTION_RUNNING_TIMEOUT_SECONDS = int(env("EXECUTION_RUNNING_TIMEOUT_SECONDS", "600")) # 10 minutes +EXECUTION_RUNNING_TIMEOUT_SECONDS = int(env("EXECUTION_RUNNING_TIMEOUT_SECONDS", "600")) logger = logging.getLogger("server_trigger_schedule_task") + @shared_task(queue="poll_trigger_schedules") def poll_trigger_schedules() -> None: - """ - Celery task to poll and execute scheduled triggers. - This runs periodically to check for triggers that are due for execution. - - This is a lightweight wrapper around TriggerScheduleService that handles - session management and delegates all business logic to the service layer. - """ + """Poll and execute scheduled triggers.""" logger.info("Starting poll_trigger_schedules task") - + session = session_make() try: - # Create service instance with session schedule_service = TriggerScheduleService(session) - - # Delegate all logic to the service schedule_service.poll_and_execute_due_triggers( max_dispatch_per_tick=MAX_DISPATCH_PER_TICK ) @@ -59,28 +55,19 @@ def poll_trigger_schedules() -> None: @shared_task(queue="check_execution_timeouts") def check_execution_timeouts() -> None: - """ - Celery task to check for timed-out pending and running executions. - - This runs periodically to find: - - Pending executions that haven't been acknowledged within EXECUTION_PENDING_TIMEOUT_SECONDS - - Running executions that have been stuck for more than EXECUTION_RUNNING_TIMEOUT_SECONDS - - These are marked as missed/failed respectively. - """ + """Check for timed-out pending and running executions.""" logger.info("Starting check_execution_timeouts task", extra={ "pending_timeout": EXECUTION_PENDING_TIMEOUT_SECONDS, "running_timeout": EXECUTION_RUNNING_TIMEOUT_SECONDS }) - + session = session_make() redis_manager = get_redis_manager() trigger_service = TriggerService(session) - + try: now = datetime.now(timezone.utc) - - # Find all pending and running executions + executions = session.exec( select(TriggerExecution).where( or_( @@ -89,28 +76,26 @@ def check_execution_timeouts() -> None: ) ) ).all() - + timed_out_pending_count = 0 timed_out_running_count = 0 - + for execution in executions: is_pending = execution.status == ExecutionStatus.pending is_running = execution.status == ExecutionStatus.running - - # Determine the reference time and timeout based on status + if is_pending: reference_time = execution.created_at timeout_seconds = EXECUTION_PENDING_TIMEOUT_SECONDS - else: # running + else: reference_time = execution.started_at or execution.created_at timeout_seconds = EXECUTION_RUNNING_TIMEOUT_SECONDS - + if reference_time.tzinfo is None: reference_time = reference_time.replace(tzinfo=timezone.utc) time_elapsed = (now - reference_time).total_seconds() - + if time_elapsed > timeout_seconds: - # Determine the new status and error message if is_pending: new_status = ExecutionStatus.missed error_message = f"Execution acknowledgment timeout ({timeout_seconds} seconds)" @@ -119,34 +104,26 @@ def check_execution_timeouts() -> None: new_status = ExecutionStatus.failed error_message = f"Execution running timeout ({timeout_seconds} seconds) - no completion received" timed_out_running_count += 1 - - # Use TriggerService.update_execution_status for proper failure tracking + trigger_service.update_execution_status( execution=execution, status=new_status, error_message=error_message ) - - # Remove from Redis pending list (best effort, may not exist) + try: - # Get all sessions for this execution's user trigger = session.get(Trigger, execution.trigger_id) if trigger and trigger.user_id: user_session_ids = redis_manager.get_user_sessions(trigger.user_id) for session_id in user_session_ids: redis_manager.remove_pending_execution(session_id, execution.execution_id) - elif not trigger: - logger.warning("Trigger not found for execution", extra={ - "execution_id": execution.execution_id, - "trigger_id": execution.trigger_id - }) except Exception as e: logger.warning("Failed to remove execution from Redis", extra={ "execution_id": execution.execution_id, "trigger_id": execution.trigger_id, "error": str(e) }) - + logger.info("Execution timed out", extra={ "execution_id": execution.execution_id, "trigger_id": execution.trigger_id, @@ -154,7 +131,7 @@ def check_execution_timeouts() -> None: "new_status": new_status.value, "time_elapsed": time_elapsed }) - + total_timed_out = timed_out_pending_count + timed_out_running_count if total_timed_out > 0: logger.info("Marked executions as timed out", extra={ @@ -162,15 +139,13 @@ def check_execution_timeouts() -> None: "timed_out_running_count": timed_out_running_count, "total_timed_out": total_timed_out }) - else: - logger.debug("No timed-out executions found") - + except Exception as e: logger.error("Error checking execution timeouts", extra={ "error": str(e), "error_type": type(e).__name__ }, exc_info=True) session.rollback() - + finally: - session.close() \ No newline at end of file + session.close() diff --git a/server/app/service/trigger/trigger_service.py b/server/app/domains/trigger/service/trigger_service.py similarity index 97% rename from server/app/service/trigger/trigger_service.py rename to server/app/domains/trigger/service/trigger_service.py index 0f49de79..e32380e3 100644 --- a/server/app/service/trigger/trigger_service.py +++ b/server/app/domains/trigger/service/trigger_service.py @@ -16,18 +16,17 @@ from datetime import datetime, timedelta, timezone from typing import Optional, List, Dict, Any from sqlmodel import select, and_, or_ from uuid import uuid4 -import logging +from loguru import logger from app.model.trigger.trigger import Trigger from app.model.trigger.trigger_execution import TriggerExecution -from app.type.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus -from app.component.database import session_make -from app.service.trigger.trigger_schedule_service import TriggerScheduleService -from app.component.trigger_utils import SCHEDULED_FETCH_BATCH_SIZE, check_rate_limits +from app.shared.types.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus +from app.core.database import session_make +from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService +from app.core.trigger_utils import SCHEDULED_FETCH_BATCH_SIZE, check_rate_limits from app.model.trigger.app_configs import ScheduleTriggerConfig, WebhookTriggerConfig from app.model.trigger.app_configs.base_config import BaseTriggerConfig -logger = logging.getLogger("server_trigger_service") class TriggerService: diff --git a/server/app/domains/user/__init__.py b/server/app/domains/user/__init__.py new file mode 100644 index 00000000..7e380f4d --- /dev/null +++ b/server/app/domains/user/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""User domain: users, admins, auth, invites, keys.""" diff --git a/server/app/domains/user/api/__init__.py b/server/app/domains/user/api/__init__.py new file mode 100644 index 00000000..1942bf3e --- /dev/null +++ b/server/app/domains/user/api/__init__.py @@ -0,0 +1,15 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""User domain API.""" diff --git a/server/app/domains/user/api/login_controller.py b/server/app/domains/user/api/login_controller.py new file mode 100644 index 00000000..a733a39e --- /dev/null +++ b/server/app/domains/user/api/login_controller.py @@ -0,0 +1,129 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 Login - 1h access token, refresh token, rate limit.""" + +import time + +from fastapi import APIRouter, Depends, Form, HTTPException +from fastapi_babel import _ +from loguru import logger +from pydantic import BaseModel +from sqlmodel import Session + +from app.core import code +from app.core.database import session +from app.core.encrypt import password_verify +from app.core.environment import env +from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User +from app.shared.auth import create_access_token, create_refresh_token +from app.shared.auth.token_blacklist import blacklist_token +from app.shared.auth.user_auth import decode_refresh_token +from app.shared.exception import TokenException, UserException +from app.shared.middleware.rate_limit import login_rate_limiter + +router = APIRouter(prefix="/user", tags=["V1 Login"]) + + +@router.post("/dev_login", name="dev login (Swagger only)", include_in_schema=True) +async def dev_login(username: str | None = Form(default=None), password: str | None = Form(default=None)): + """Debug-only login for Swagger Authorize. Accepts OAuth2 password form.""" + if env("debug", "") != "on": + raise HTTPException(status_code=404) + return {"access_token": create_access_token(1), "token_type": "bearer"} + + +@router.post("/auto-login", name="auto login for local mode") +async def auto_login(db_session: Session = Depends(session)) -> LoginResponse: + """Auto login for fully local mode. Returns most recently active user or creates default.""" + user = User.by( + User.status == Status.Normal, + order_by=User.updated_at.desc(), + limit=1, + s=db_session, + ).one_or_none() + + if not user: + with db_session as s: + try: + user = User( + email="admin@local.eigent.ai", + username="admin", + nickname="Admin", + avatar="", + fullname="", + work_desc="", + ) + s.add(user) + s.commit() + s.refresh(user) + logger.info("Default admin user created", extra={"user_id": user.id}) + except Exception as e: + s.rollback() + logger.error("Failed to create default admin user", extra={"error": str(e)}, exc_info=True) + raise UserException(code.error, _("Failed to create default user")) + + logger.info("Auto login successful", extra={"user_id": user.id, "email": user.email}) + return LoginResponse(token=create_access_token(user.id), email=user.email) + + +class RefreshTokenIn(BaseModel): + refresh_token: str + + +@router.post("/login", name="login by email or password", dependencies=[login_rate_limiter]) +async def by_password(data: LoginByPasswordIn, db_session: Session = Depends(session)) -> dict: + """User login with email and password. Returns access_token (1h) and refresh_token (30d).""" + user = User.by(User.email == data.email, s=db_session).one_or_none() + if not user or not password_verify(data.password, user.password): + raise UserException(code.password, _("Account or password error")) + if user.status == Status.Block: + raise UserException(code.error, _("Your account has been blocked.")) + if not user.is_active: + raise UserException(code.error, _("Please activate your account via the email link.")) + + access_token = create_access_token(user.id) + refresh_token = create_refresh_token(user.id) + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "email": user.email, + } + + +@router.post("/refresh", name="refresh tokens", dependencies=[login_rate_limiter]) +async def refresh(data: RefreshTokenIn, db_session: Session = Depends(session)) -> dict: + """Exchange valid refresh_token for new access_token and refresh_token.""" + if not data.refresh_token: + raise TokenException(code.token_need, _("Refresh token required")) + user_id, jti, exp_ts = await decode_refresh_token(data.refresh_token) + user = db_session.get(User, user_id) + if not user: + raise TokenException(code.token_invalid, _("User not found")) + if user.status == Status.Block: + raise UserException(code.error, _("Your account has been blocked.")) + if not user.is_active: + raise UserException(code.error, _("Please activate your account via the email link.")) + if jti: + ttl = max(0, exp_ts - int(time.time())) + await blacklist_token(jti, ttl) + access_token = create_access_token(user.id) + refresh_token = create_refresh_token(user.id) + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "email": user.email, + } diff --git a/server/app/domains/user/api/logout_controller.py b/server/app/domains/user/api/logout_controller.py new file mode 100644 index 00000000..b602dde3 --- /dev/null +++ b/server/app/domains/user/api/logout_controller.py @@ -0,0 +1,51 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 Logout - token blacklist, audit log.""" + +from datetime import datetime + +import jwt +from fastapi import APIRouter, Depends +from loguru import logger + +from app.core.environment import env_not_empty +from app.shared.auth.token_blacklist import blacklist_token +from app.shared.auth.user_auth import _get_jti, oauth2_scheme + +router = APIRouter(prefix="/user", tags=["V1 Auth"]) + + +@router.post("/logout", name="logout") +async def logout(token: str = Depends(oauth2_scheme)): + """Revoke current token. Requires Bearer token.""" + if not token: + logger.info("logout: no token provided") + return {"success": True, "message": "No token to revoke"} + jti = _get_jti(token) + user_id = None + if jti: + try: + payload = jwt.decode( + token, env_not_empty("secret_key"), algorithms=["HS256"], options={"verify_exp": False} + ) + user_id = payload.get("id") + exp = payload.get("exp") + ttl = max(0, int(exp) - int(datetime.utcnow().timestamp())) if exp else 3600 + await blacklist_token(jti, ttl) + except Exception as e: + logger.warning("logout: token decode/blacklist failed", extra={"error": str(e)}) + jti_safe = (jti[:8] + "...") if jti and len(jti) >= 8 else (jti or None) + logger.info("logout", extra={"user_id": user_id, "jti_preview": jti_safe}) + return {"success": True, "message": "Logged out"} diff --git a/server/app/controller/user/user_password_controller.py b/server/app/domains/user/api/password_controller.py similarity index 61% rename from server/app/controller/user/user_password_controller.py rename to server/app/domains/user/api/password_controller.py index 55e6b1d5..f71be47c 100644 --- a/server/app/controller/user/user_password_controller.py +++ b/server/app/domains/user/api/password_controller.py @@ -12,39 +12,30 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -import logging - from fastapi import APIRouter, Depends from fastapi_babel import _ from sqlmodel import Session -from app.component import code -from app.component.auth import Auth, auth_must -from app.component.database import session -from app.component.encrypt import password_hash, password_verify -from app.exception.exception import UserException +from app.core import code +from app.core.database import session +from app.core.encrypt import password_hash, password_verify from app.model.user.user import UpdatePassword, UserOut - -logger = logging.getLogger("server_password_controller") +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth +from app.shared.exception import UserException router = APIRouter(tags=["User"]) @router.put("/user/update-password", name="update password", response_model=UserOut) -def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)): - """Update user password after verifying current password.""" - user_id = auth.user.id +def update_password( + data: UpdatePassword, auth: V1UserAuth = Depends(auth_must), db_session: Session = Depends(session) +): model = auth.user - if not password_verify(data.password, model.password): - logger.warning("Password update failed: incorrect current password", extra={"user_id": user_id}) raise UserException(code.error, _("Password is incorrect")) - if data.new_password != data.re_new_password: - logger.warning("Password update failed: new passwords do not match", extra={"user_id": user_id}) raise UserException(code.error, _("The two passwords do not match")) - model.password = password_hash(data.new_password) - model.save(session) - logger.info("Password updated successfully", extra={"user_id": user_id}) + model.save(db_session) return model diff --git a/server/app/domains/user/api/user_controller.py b/server/app/domains/user/api/user_controller.py new file mode 100644 index 00000000..c6f884ef --- /dev/null +++ b/server/app/domains/user/api/user_controller.py @@ -0,0 +1,116 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from fastapi import APIRouter, Depends +from sqlalchemy import func +from sqlmodel import Session, select + +from app.core.database import session +from app.model.chat.chat_history import ChatHistory +from app.model.chat.chat_snpshot import ChatSnapshot +from app.model.config.config import Config +from app.model.mcp.mcp_user import McpUser +from app.model.user.privacy import UserPrivacy, UserPrivacySettings +from app.model.user.user import User, UserIn, UserOut, UserProfile +from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut +from app.shared.auth import auth_must +from app.shared.auth.user_auth import V1UserAuth + +router = APIRouter(tags=["User"]) + + +@router.get("/user", name="user info", response_model=UserOut) +def get(db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + user: User = auth.user + db_session.refresh(user) + return user + + +@router.put("/user", name="update user info", response_model=UserOut) +def put(data: UserIn, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + model = auth.user + model.username = data.username + model.save(db_session) + return model + + +@router.put("/user/profile", name="update user profile", response_model=UserProfile) +def put_profile(data: UserProfile, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + model = auth.user + model.nickname = data.nickname + model.fullname = data.fullname + model.work_desc = data.work_desc + model.save(db_session) + return model + + +@router.get("/user/privacy", name="get user privacy") +def get_privacy(db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + user_id = auth.id + stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id) + model = db_session.exec(stmt).one_or_none() + if not model: + return UserPrivacySettings.default_settings() + return UserPrivacySettings(**model.pricacy_setting).to_response() + + +@router.put("/user/privacy", name="update user privacy") +def put_privacy( + data: UserPrivacySettings, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must) +): + user_id = auth.id + stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id) + model = db_session.exec(stmt).one_or_none() + default_settings = UserPrivacySettings.default_settings() + + if model: + model.pricacy_setting = {**model.pricacy_setting, **data.model_dump(exclude_unset=True)} + model.save(db_session) + else: + model = UserPrivacy( + user_id=user_id, pricacy_setting={**default_settings, **data.model_dump(exclude_unset=True)} + ) + model.save(db_session) + + return UserPrivacySettings(**model.pricacy_setting).to_response() + + +@router.get("/user/stat", name="get user stat", response_model=UserStatOut) +def get_user_stat(db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)): + stat = db_session.exec(select(UserStat).where(UserStat.user_id == auth.id)).first() + data = UserStatOut() + if stat: + data = UserStatOut(**stat.model_dump()) + else: + data = UserStatOut(user_id=auth.id) + data.task_queries = ChatHistory.count(ChatHistory.user_id == auth.id, s=db_session) + mcp = McpUser.count(McpUser.user_id == auth.id, s=db_session) + tool: list = db_session.exec( + select(func.count("*")).where(Config.user_id == auth.id).group_by(Config.config_group) + ).all() + tool = tool.__len__() + data.mcp_install_count = mcp + tool + data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(auth.id)) + return data + + +@router.post("/user/stat", name="record user stat") +def record_user_stat( + data: UserStatActionIn, + db_session: Session = Depends(session), + auth: V1UserAuth = Depends(auth_must), +): + data.user_id = auth.id + stat = UserStat.record_action(db_session, data) + return stat diff --git a/server/app/domains/user/schema/__init__.py b/server/app/domains/user/schema/__init__.py new file mode 100644 index 00000000..3ef1999d --- /dev/null +++ b/server/app/domains/user/schema/__init__.py @@ -0,0 +1,33 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""User domain schemas.""" + +from app.domains.user.schema.schemas import ( + LoginReq, + RefreshTokenReq, + LogoutReq, + AuthResult, + GetKeyReq, + KeyResult, +) + +__all__ = [ + "LoginReq", + "RefreshTokenReq", + "LogoutReq", + "AuthResult", + "GetKeyReq", + "KeyResult", +] diff --git a/server/app/domains/user/schema/schemas.py b/server/app/domains/user/schema/schemas.py new file mode 100644 index 00000000..3f3dd050 --- /dev/null +++ b/server/app/domains/user/schema/schemas.py @@ -0,0 +1,52 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 UserAuthService and KeyService request/response schemas.""" + +from pydantic import BaseModel + + +# UserAuthService DTOs +class LoginReq(BaseModel): + email: str + password: str + + +class RefreshTokenReq(BaseModel): + refresh_token: str + + +class LogoutReq(BaseModel): + token: str + + +class AuthResult(BaseModel): + success: bool + access_token: str | None = None + refresh_token: str | None = None + email: str | None = None + error_code: str | None = None + + +# KeyService DTOs +class GetKeyReq(BaseModel): + user_id: int + + +class KeyResult(BaseModel): + success: bool + key_value: str | None = None + warning_code: str | None = None + warning_text: str | None = None + error_code: str | None = None diff --git a/server/app/domains/user/service/__init__.py b/server/app/domains/user/service/__init__.py new file mode 100644 index 00000000..14e117a0 --- /dev/null +++ b/server/app/domains/user/service/__init__.py @@ -0,0 +1,18 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from .key_service import KeyService +from .user_auth_service import UserAuthService + +__all__ = ["KeyService", "UserAuthService"] diff --git a/server/app/domains/user/service/user_auth_service.py b/server/app/domains/user/service/user_auth_service.py new file mode 100644 index 00000000..f82be548 --- /dev/null +++ b/server/app/domains/user/service/user_auth_service.py @@ -0,0 +1,118 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""UserAuthService: login, refresh, logout. Follows CreditsService pattern.""" + +import time +from datetime import datetime + +import jwt +from loguru import logger + +from app.core.database import session_make +from app.core.encrypt import password_verify +from app.core.environment import env_not_empty +from app.model.user.user import Status, User + +from app.shared.auth import create_access_token, create_refresh_token +from app.shared.auth.user_auth import decode_refresh_token, _get_jti +from app.shared.auth.token_blacklist import blacklist_token, is_blacklisted +from app.domains.user.schema import AuthResult + + +class UserAuthService: + """User authentication operations - static methods.""" + + @staticmethod + def login(email: str, password: str) -> AuthResult: + """Password login + credits refresh. Returns tokens or error.""" + with session_make() as s: + user = s.exec( + User.by(User.email == email, s=s) + ).one_or_none() + + if not user or not password_verify(password, user.password): + return AuthResult(success=False, error_code="AUTH_INVALID_CREDENTIALS") + if user.status == Status.Block: + return AuthResult(success=False, error_code="AUTH_ACCOUNT_BLOCKED") + if not user.is_active: + return AuthResult(success=False, error_code="AUTH_ACCOUNT_INACTIVE") + + # No credits refresh in eigent (no billing domain) + + access_token = create_access_token(user.id) + refresh_token = create_refresh_token(user.id) + return AuthResult( + success=True, + access_token=access_token, + refresh_token=refresh_token, + email=user.email, + ) + + @staticmethod + async def refresh(refresh_token_str: str) -> AuthResult: + """Exchange valid refresh_token for new token pair.""" + if not refresh_token_str: + return AuthResult(success=False, error_code="AUTH_REFRESH_TOKEN_REQUIRED") + + user_id, jti, exp_ts = decode_refresh_token(refresh_token_str) + if jti and await is_blacklisted(jti): + return AuthResult(success=False, error_code="AUTH_TOKEN_REVOKED") + + with session_make() as s: + user = s.get(User, user_id) + + if not user: + return AuthResult(success=False, error_code="AUTH_USER_NOT_FOUND") + if user.status == Status.Block: + return AuthResult(success=False, error_code="AUTH_ACCOUNT_BLOCKED") + if not user.is_active: + return AuthResult(success=False, error_code="AUTH_ACCOUNT_INACTIVE") + + # Blacklist old refresh token + if jti: + ttl = max(0, exp_ts - int(time.time())) + await blacklist_token(jti, ttl) + + access_token = create_access_token(user.id) + new_refresh_token = create_refresh_token(user.id) + return AuthResult( + success=True, + access_token=access_token, + refresh_token=new_refresh_token, + email=user.email, + ) + + @staticmethod + async def logout(token: str) -> bool: + """Revoke token by adding to blacklist.""" + if not token: + return True + jti = _get_jti(token) + if jti: + try: + payload = jwt.decode( + token, + env_not_empty("secret_key"), + algorithms=["HS256"], + options={"verify_exp": False}, + ) + user_id = payload.get("id") + exp = payload.get("exp") + ttl = max(0, int(exp) - int(datetime.utcnow().timestamp())) if exp else 3600 + await blacklist_token(jti, ttl) + logger.info("logout", extra={"user_id": user_id}) + except Exception as e: + logger.warning("logout: token decode/blacklist failed", extra={"error": str(e)}) + return True diff --git a/server/app/exception/handler.py b/server/app/exception/handler.py deleted file mode 100644 index b9354724..00000000 --- a/server/app/exception/handler.py +++ /dev/null @@ -1,65 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from fastapi import Request -from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse -from sqlalchemy.exc import NoResultFound - -from app import api -from app.component import code -from app.component.pydantic.i18n import get_language, trans -from app.exception.exception import ( - NoPermissionException, - TokenException, - UserException, -) - - -@api.exception_handler(RequestValidationError) -async def request_exception(request: Request, e: RequestValidationError): - if (lang := get_language(request.headers.get("Accept-Language"))) is None: - lang = "en_US" - return JSONResponse( - content={ - "code": code.form_error, - "error": jsonable_encoder(trans.translate(list(e.errors()), locale=lang)), - } - ) - - -@api.exception_handler(TokenException) -async def token_exception(request: Request, e: TokenException): - return JSONResponse(content={"code": e.code, "text": e.text}) - - -@api.exception_handler(UserException) -async def user_exception(request: Request, e: UserException): - return JSONResponse(content={"code": e.code, "text": e.description}) - - -@api.exception_handler(NoPermissionException) -async def no_permission(request: Request, exception: NoPermissionException): - return JSONResponse( - status_code=200, - content={"code": code.no_permission_error, "text": exception.text}, - ) - - -async def no_results(request: Request, exception: NoResultFound): - return JSONResponse( - status_code=200, - content={"code": code.not_found, "text": exception._message()}, - ) diff --git a/server/app/model/abstract/model.py b/server/app/model/abstract/model.py index 0508441e..da80f341 100644 --- a/server/app/model/abstract/model.py +++ b/server/app/model/abstract/model.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import logging @@ -33,9 +33,9 @@ from sqlmodel import ( text, ) -from app.component import code -from app.component.database import engine -from app.exception.exception import UserException +from app.core import code +from app.core.database import engine +from app.shared.exception import UserException logger = logging.getLogger("abstract_model") diff --git a/server/app/model/chat/chat_snpshot.py b/server/app/model/chat/chat_snpshot.py index b6588760..de8031b5 100644 --- a/server/app/model/chat/chat_snpshot.py +++ b/server/app/model/chat/chat_snpshot.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import base64 @@ -20,7 +20,7 @@ from pydantic import BaseModel from sqlalchemy import Column, Integer, text from sqlmodel import Field -from app.component.sqids import encode_user_id +from app.core.sqids import encode_user_id from app.model.abstract.model import AbstractModel, DefaultTimes @@ -68,3 +68,11 @@ class ChatSnapshotIn(BaseModel): with open(file_path, "wb") as f: f.write(base64.b64decode(image_base64)) return f"/public/upload/{user_dir}/{api_task_id}/{filename}" + + +class ChatSnapshotUpdate(BaseModel): + """Update model - only updatable fields.""" + api_task_id: str | None = None + camel_task_id: str | None = None + browser_url: str | None = None + image_path: str | None = None diff --git a/server/app/model/chat/chat_step.py b/server/app/model/chat/chat_step.py index 244d61c1..5755b611 100644 --- a/server/app/model/chat/chat_step.py +++ b/server/app/model/chat/chat_step.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import json @@ -59,3 +59,9 @@ class ChatStepOut(BaseModel): step: str data: Any timestamp: float | None = None + + +class ChatStepUpdate(BaseModel): + step: str | None = None + data: Any | None = None + timestamp: float | None = None diff --git a/server/app/model/config/config.py b/server/app/model/config/config.py index 1c9cefe9..18141b62 100644 --- a/server/app/model/config/config.py +++ b/server/app/model/config/config.py @@ -15,7 +15,7 @@ from sqlmodel import Field, SQLModel, UniqueConstraint from app.model.abstract.model import AbstractModel, DefaultTimes -from app.type.config_group import ConfigGroup +from app.shared.types.config_group import ConfigGroup class Config(AbstractModel, DefaultTimes, table=True): diff --git a/server/app/model/mcp/mcp.py b/server/app/model/mcp/mcp.py index 5389c200..31e3d4ba 100644 --- a/server/app/model/mcp/mcp.py +++ b/server/app/model/mcp/mcp.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from enum import IntEnum @@ -24,7 +24,7 @@ from sqlmodel import JSON, Field, Relationship from app.model.abstract.model import AbstractModel, DefaultTimes from app.model.mcp.category import Category, CategoryOut from app.model.mcp.mcp_env import McpEnv -from app.type.pydantic import HttpUrlStr +from app.shared.types.pydantic import HttpUrlStr if TYPE_CHECKING: from app.model.mcp.mcp_user import McpUser diff --git a/server/app/model/trigger/app_configs/base_config.py b/server/app/model/trigger/app_configs/base_config.py index 9e10d306..b2e2d153 100644 --- a/server/app/model/trigger/app_configs/base_config.py +++ b/server/app/model/trigger/app_configs/base_config.py @@ -23,7 +23,7 @@ import re from typing import Optional, List, Dict, Any, TYPE_CHECKING from pydantic import BaseModel, Field, field_validator -from app.type.config_group import ConfigGroup +from app.shared.types.config_group import ConfigGroup if TYPE_CHECKING: from sqlmodel import Session diff --git a/server/app/model/trigger/app_configs/config_registry.py b/server/app/model/trigger/app_configs/config_registry.py index 9f3c21ba..ea88f311 100644 --- a/server/app/model/trigger/app_configs/config_registry.py +++ b/server/app/model/trigger/app_configs/config_registry.py @@ -21,7 +21,7 @@ Used for validation and JSON schema generation. from typing import Type, Optional, Dict, Any, TYPE_CHECKING -from app.type.trigger_types import TriggerType +from app.shared.types.trigger_types import TriggerType from app.model.trigger.app_configs.base_config import BaseTriggerConfig from app.model.trigger.app_configs.slack_config import SlackTriggerConfig from app.model.trigger.app_configs.webhook_config import WebhookTriggerConfig diff --git a/server/app/model/trigger/app_configs/slack_config.py b/server/app/model/trigger/app_configs/slack_config.py index 328c0ed9..2b555d38 100644 --- a/server/app/model/trigger/app_configs/slack_config.py +++ b/server/app/model/trigger/app_configs/slack_config.py @@ -25,7 +25,7 @@ from typing import Optional, List, TYPE_CHECKING from pydantic import Field from app.model.trigger.app_configs.base_config import BaseTriggerConfig -from app.type.config_group import ConfigGroup +from app.shared.types.config_group import ConfigGroup if TYPE_CHECKING: from sqlmodel import Session diff --git a/server/app/model/trigger/trigger.py b/server/app/model/trigger/trigger.py index 8a59f186..98790b32 100644 --- a/server/app/model/trigger/trigger.py +++ b/server/app/model/trigger/trigger.py @@ -18,7 +18,7 @@ from sqlmodel import Field, Column, SmallInteger, JSON, String from sqlalchemy_utils import ChoiceType from pydantic import BaseModel from app.model.abstract.model import AbstractModel, DefaultTimes -from app.type.trigger_types import TriggerType, TriggerStatus, ListenerType, RequestType +from app.shared.types.trigger_types import TriggerType, TriggerStatus, ListenerType, RequestType class Trigger(AbstractModel, DefaultTimes, table=True): diff --git a/server/app/model/trigger/trigger_execution.py b/server/app/model/trigger/trigger_execution.py index 75c48d89..c337a98f 100644 --- a/server/app/model/trigger/trigger_execution.py +++ b/server/app/model/trigger/trigger_execution.py @@ -18,7 +18,7 @@ from sqlmodel import Field, Column, SmallInteger, JSON, String, Float from sqlalchemy_utils import ChoiceType from pydantic import BaseModel from app.model.abstract.model import AbstractModel, DefaultTimes -from app.type.trigger_types import ExecutionType, ExecutionStatus +from app.shared.types.trigger_types import ExecutionType, ExecutionStatus class TriggerExecution(AbstractModel, DefaultTimes, table=True): diff --git a/server/app/model/user/key.py b/server/app/model/user/key.py index 6d5ffc55..1ee51c3a 100644 --- a/server/app/model/user/key.py +++ b/server/app/model/user/key.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, computed_field from sqlalchemy_utils import ChoiceType from sqlmodel import Column, Field, SmallInteger -from app.component.environment import env_not_empty +from app.core.environment import env_not_empty from app.model.abstract.model import AbstractModel, DefaultTimes diff --git a/server/app/model/user/user.py b/server/app/model/user/user.py index 6e2d63c0..ed49fb48 100644 --- a/server/app/model/user/user.py +++ b/server/app/model/user/user.py @@ -20,7 +20,7 @@ from sqlalchemy import Integer, SmallInteger, text from sqlalchemy_utils import ChoiceType from sqlmodel import Column, Field -from app.component.encrypt import password_hash +from app.core.encrypt import password_hash from app.model.abstract.model import AbstractModel, DefaultTimes diff --git a/server/app/model/user/user_credits_record.py b/server/app/model/user/user_credits_record.py index 146bc85b..244702b0 100644 --- a/server/app/model/user/user_credits_record.py +++ b/server/app/model/user/user_credits_record.py @@ -21,7 +21,7 @@ from sqlalchemy import Boolean, SmallInteger, text from sqlalchemy_utils import ChoiceType from sqlmodel import Column, Field, Session, col, select -from app.component.database import session_make +from app.core.database import session_make from app.model.abstract.model import AbstractModel, DefaultTimes logger = logging.getLogger("user_credits_record") diff --git a/server/app/service/trigger/app_handler_service.py b/server/app/service/trigger/app_handler_service.py deleted file mode 100644 index be99502b..00000000 --- a/server/app/service/trigger/app_handler_service.py +++ /dev/null @@ -1,448 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -""" -Trigger App Handler Service - -Modular service for handling app-specific webhook authentication, -filtering, and payload normalization based on trigger_type. -""" - -import re -from typing import Optional -from dataclasses import dataclass -from fastapi import Request -from sqlmodel import Session, select, and_ -import logging - -from app.model.trigger.trigger import Trigger -from app.model.config.config import Config -from app.model.trigger.app_configs import SlackTriggerConfig, WebhookTriggerConfig, ScheduleTriggerConfig -from app.type.trigger_types import TriggerType, ExecutionType, TriggerStatus -from app.type.config_group import ConfigGroup - -logger = logging.getLogger("server_app_handler_service") - - -@dataclass -class AppHandlerResult: - """Result from app handler operations.""" - success: bool - data: Optional[dict] = None - reason: Optional[str] = None - - -class BaseAppHandler: - """Base class for app-specific handlers.""" - - trigger_type: TriggerType - execution_type: ExecutionType = ExecutionType.webhook - config_group: Optional[str] = None - - async def get_credentials(self, session: Session, user_id: str) -> dict: - """Get user credentials from config table.""" - if not self.config_group: - return {} - - configs = session.exec( - select(Config).where( - and_( - Config.user_id == int(user_id), - Config.config_group == self.config_group - ) - ) - ).all() - return {config.config_name: config.config_value for config in configs} - - async def authenticate( - self, - request: Request, - body: bytes, - trigger: Trigger, - session: Session - ) -> AppHandlerResult: - """ - Authenticate the incoming webhook request. - Returns (success, challenge_response or None) - """ - return AppHandlerResult(success=True) - - async def filter_event( - self, - payload: dict, - trigger: Trigger - ) -> AppHandlerResult: - """ - Filter events based on trigger configuration. - Returns (should_process, reason) - """ - return AppHandlerResult(success=True, reason="ok") - - def normalize_payload( - self, - payload: dict, - trigger: Trigger, - request_meta: dict = None - ) -> dict: - """Normalize the payload for execution input.""" - return payload - - -class SlackAppHandler(BaseAppHandler): - """Handler for Slack triggers.""" - - trigger_type = TriggerType.slack_trigger - execution_type = ExecutionType.slack - config_group = ConfigGroup.SLACK.value - - async def authenticate( - self, - request: Request, - body: bytes, - trigger: Trigger, - session: Session - ) -> AppHandlerResult: - """Handle Slack authentication and URL verification.""" - from camel.auth.slack_auth import SlackAuth - - credentials = await self.get_credentials(session, trigger.user_id) - - slack_auth = SlackAuth( - signing_secret=credentials.get("SLACK_SIGNING_SECRET"), - bot_token=credentials.get("SLACK_BOT_TOKEN"), - api_token=credentials.get("SLACK_API_TOKEN"), - ) - - # Check for URL verification challenge - challenge_response = slack_auth.get_verification_response(request, body) - if challenge_response: - # Return the challenge response (already in correct format: {"challenge": "..."}) - logger.info(f"Slack URL verification - challenge_response: {challenge_response}") - return AppHandlerResult(success=True, data=challenge_response) - - # Verify webhook signature - if not slack_auth.verify_webhook_request(request, body): - logger.warning("Invalid Slack webhook signature", extra={ - "trigger_id": trigger.id - }) - return AppHandlerResult(success=False, reason="invalid_signature") - - return AppHandlerResult(success=True) - - async def filter_event( - self, - payload: dict, - trigger: Trigger - ) -> AppHandlerResult: - """Filter Slack events based on trigger config.""" - # Prefer 'config' field - config_data = trigger.config or {} - config = SlackTriggerConfig(**config_data) - event = payload.get("event", {}) - event_type = event.get("type", "") - - # Check event type - if not config.should_trigger(event_type): - return AppHandlerResult(success=False, reason="event_type_not_configured") - - # Check channel filter (if channel_id is set, only trigger for that channel) - if config.channel_id: - if event.get("channel") != config.channel_id: - return AppHandlerResult(success=False, reason="channel_not_matched") - - # Check bot message filter - if config.ignore_bot_messages: - if event.get("bot_id") or event.get("subtype") == "bot_message": - return AppHandlerResult(success=False, reason="bot_message_ignored") - - # Check user filter - if config.ignore_users and event.get("user") in config.ignore_users: - return AppHandlerResult(success=False, reason="user_filtered") - - # Check message filter regex - if config.message_filter and event.get("text"): - if not re.search(config.message_filter, event.get("text", ""), re.IGNORECASE): - return AppHandlerResult(success=False, reason="message_filter_not_matched") - - return AppHandlerResult(success=True, reason="ok") - - def normalize_payload( - self, - payload: dict, - trigger: Trigger, - request_meta: dict = None - ) -> dict: - """Normalize Slack event payload.""" - logger.info("Normalizing payload", extra={"payload": payload}) - # Prefer 'config' field - config_data = trigger.config or {} - config = SlackTriggerConfig(**config_data) - event = payload.get("event", {}) - - normalized = { - "event_type": event.get("type"), - "event_ts": event.get("event_ts"), - "team_id": payload.get("team_id"), - "user_id": event.get("user"), - "channel_id": event.get("channel"), - "text": event.get("text"), - "message_ts": event.get("ts"), - "thread_ts": event.get("thread_ts"), - "reaction": event.get("reaction"), - "files": event.get("files"), - "event_id": payload.get("event_id") or payload.get("id") - } - - # if config.include_raw_payload: - # normalized["raw_payload"] = payload - - return normalized - - -class DefaultWebhookHandler(BaseAppHandler): - """Default handler for generic webhooks with config-based filtering.""" - - trigger_type = TriggerType.webhook - execution_type = ExecutionType.webhook - - async def filter_event( - self, - payload: dict, - trigger: Trigger, - headers: dict = None, - body_raw: str = None - ) -> AppHandlerResult: - """Filter webhook events based on trigger config.""" - config_data = trigger.config or {} - config = WebhookTriggerConfig(**config_data) - - # Get text content for message_filter (check body for text field or stringify) - text = None - if isinstance(payload, dict): - text = payload.get("text") or payload.get("message") or payload.get("content") - if text is None and body_raw: - text = body_raw - - # Use the config's should_trigger method - should_trigger, reason = config.should_trigger( - body=body_raw or "", - headers=headers or {}, - text=text - ) - - if not should_trigger: - return AppHandlerResult(success=False, reason=reason) - - return AppHandlerResult(success=True, reason="ok") - - def normalize_payload( - self, - payload: dict, - trigger: Trigger, - request_meta: dict = None - ) -> dict: - """Normalize generic webhook payload with full request metadata.""" - config_data = trigger.config or {} - config = WebhookTriggerConfig(**config_data) - - result = {"body": payload} - - if request_meta: - # Include headers if configured - if config.include_headers and "headers" in request_meta: - result["headers"] = request_meta["headers"] - - # Include query params if configured - if config.include_query_params and "query_params" in request_meta: - result["query_params"] = request_meta["query_params"] - - # Include request metadata if configured - if config.include_request_metadata: - if "method" in request_meta: - result["method"] = request_meta["method"] - if "url" in request_meta: - result["url"] = request_meta["url"] - if "client_ip" in request_meta: - result["client_ip"] = request_meta["client_ip"] - - return result - - -class ScheduleAppHandler(BaseAppHandler): - """ - Handler for scheduled triggers. - - Manages schedule-specific logic including: - - Expiration checking (expirationDate for recurring schedules) - - Date validation for one-time executions (date field) - """ - - trigger_type = TriggerType.schedule - execution_type = ExecutionType.scheduled - - async def filter_event( - self, - payload: dict, - trigger: Trigger - ) -> AppHandlerResult: - """ - Filter scheduled events based on trigger config. - - Checks: - - If one-time (date set) and date has passed - - If recurring with expirationDate and it has passed - """ - config_data = trigger.config or {} - - try: - config = ScheduleTriggerConfig(**config_data) - except Exception as e: - logger.warning( - "Invalid schedule config", - extra={"trigger_id": trigger.id, "error": str(e)} - ) - # Allow execution if config is missing/invalid (backwards compatibility) - return AppHandlerResult(success=True, reason="ok") - - # Check if schedule should execute - should_execute, reason = config.should_execute() - - if not should_execute: - return AppHandlerResult(success=False, reason=reason) - - return AppHandlerResult(success=True, reason="ok") - - def normalize_payload( - self, - payload: dict, - trigger: Trigger, - request_meta: dict = None - ) -> dict: - """Normalize scheduled trigger payload.""" - config_data = trigger.config or {} - - normalized = { - "scheduled_at": payload.get("scheduled_at"), - "trigger_id": trigger.id, - "trigger_name": trigger.name, - "is_single_execution": trigger.is_single_execution, - } - - # Include config details if present - if config_data: - if config_data.get("date"): - normalized["date"] = config_data.get("date") - if config_data.get("expirationDate"): - normalized["expirationDate"] = config_data.get("expirationDate") - - return normalized - - def check_and_handle_expiration( - self, - trigger: Trigger, - session: Session - ) -> bool: - """ - Check if a schedule has expired and handle accordingly. - - Args: - trigger: The trigger to check - session: Database session for updates - - Returns: - True if trigger is expired and was deactivated, False otherwise - """ - config_data = trigger.config or {} - - try: - config = ScheduleTriggerConfig(**config_data) - except Exception as e: - logger.warning( - "Invalid schedule config during expiration check", - extra={"trigger_id": trigger.id, "error": str(e)} - ) - return False - - if config.is_expired(): - # Deactivate the trigger - trigger.status = TriggerStatus.completed - session.add(trigger) - session.commit() - - logger.info( - "Schedule trigger expired and deactivated", - extra={ - "trigger_id": trigger.id, - "trigger_name": trigger.name, - "expiration_date": config.expirationDate or config.date - } - ) - - return True - - return False - - def validate_schedule_for_execution( - self, - trigger: Trigger - ) -> tuple[bool, str]: - """ - Validate that a scheduled trigger is valid for execution. - - Args: - trigger: The trigger to validate - - Returns: - Tuple of (is_valid, reason) - """ - config_data = trigger.config or {} - - try: - config = ScheduleTriggerConfig(**config_data) - except Exception as e: - return False, f"invalid_config: {str(e)}" - - # Check expiration - if config.is_expired(): - return False, "schedule_expired" - - return True, "ok" - - -# Registry of handlers by trigger_type -_HANDLERS: dict[TriggerType, BaseAppHandler] = { - TriggerType.slack_trigger: SlackAppHandler(), - TriggerType.webhook: DefaultWebhookHandler(), - TriggerType.schedule: ScheduleAppHandler(), -} - - -def get_app_handler(trigger_type: TriggerType) -> Optional[BaseAppHandler]: - """Get the handler for a trigger type.""" - return _HANDLERS.get(trigger_type) - - -def register_app_handler(trigger_type: TriggerType, handler: BaseAppHandler): - """Register a new app handler.""" - _HANDLERS[trigger_type] = handler - - -def get_supported_trigger_types() -> list[TriggerType]: - """Get list of trigger types with webhook support.""" - return list(_HANDLERS.keys()) - - -def get_schedule_handler() -> ScheduleAppHandler: - """Get the schedule handler instance.""" - return _HANDLERS.get(TriggerType.schedule) diff --git a/server/app/service/trigger/trigger_schedule_service.py b/server/app/service/trigger/trigger_schedule_service.py deleted file mode 100644 index 4db83b46..00000000 --- a/server/app/service/trigger/trigger_schedule_service.py +++ /dev/null @@ -1,430 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -from datetime import datetime, timedelta, timezone -from typing import List, Tuple, Optional -import logging -from croniter import croniter -from uuid import uuid4 -import asyncio -from sqlmodel import select - -from app.model.trigger.trigger import Trigger -from app.model.trigger.trigger_execution import TriggerExecution -from app.type.trigger_types import TriggerStatus, ExecutionType, ExecutionStatus, TriggerType -from app.component.trigger_utils import check_rate_limits, MAX_DISPATCH_PER_TICK -from app.model.trigger.app_configs import ScheduleTriggerConfig - -logger = logging.getLogger("server_trigger_schedule_service") - - -class TriggerScheduleService: - """Service for managing scheduled trigger operations. - This service mainly delegates schedule business logic - from the main trigger_service.py. - - Handles tasks from the Celery beat scheduler. - - Mainly handles: - - Polling for due schedules - - Dispatching scheduled triggers - - Calculating next run times based on cron expressions - """ - - def __init__(self, session): - """ - Initialize the schedule service with a database session. - - Args: - session: SQLModel session for database operations - """ - self.session = session - - def fetch_due_schedules(self, limit: Optional[int] = 100) -> List[Trigger]: - """ - Fetch triggers that are due for execution. - - Args: - limit: Maximum number of triggers to fetch - - Returns: - List of triggers that need to be executed - """ - now = datetime.now(timezone.utc) - - try: - statement = ( - select(Trigger) - .where(Trigger.trigger_type == TriggerType.schedule) - .where(Trigger.status == TriggerStatus.active) - .where(Trigger.next_run_at <= now) - .order_by(Trigger.next_run_at) - .limit(limit) - ) - - results = self.session.exec(statement).all() - - logger.debug( - "Fetched due schedules", - extra={ - "count": len(results), - "current_time": now.isoformat() - } - ) - - return list(results) - - except Exception as e: - logger.error( - "Failed to fetch due schedules", - extra={"error": str(e)}, - exc_info=True - ) - return [] - - def calculate_next_run_at( - self, - trigger: Trigger, - base_time: Optional[datetime] = None - ) -> datetime: - """ - Calculate the next run time for a trigger based on its cron expression. - - Args: - trigger: The trigger to calculate next run time for - base_time: Base time to calculate from (defaults to now) - - Returns: - The next scheduled run time - - Raises: - ValueError: If trigger has no cron expression or invalid expression - """ - if not trigger.custom_cron_expression: - raise ValueError(f"Trigger {trigger.id} has no cron expression") - - if base_time is None: - base_time = datetime.now(timezone.utc) - - try: - cron = croniter(trigger.custom_cron_expression, base_time) - next_run = cron.get_next(datetime) - return next_run - except Exception as e: - logger.error( - "Failed to calculate next run time", - extra={ - "trigger_id": trigger.id, - "cron_expression": trigger.custom_cron_expression, - "error": str(e) - } - ) - raise - - def dispatch_trigger(self, trigger: Trigger) -> bool: - """ - Dispatch a trigger for execution. - - Args: - trigger: The trigger to dispatch - - Returns: - True if dispatched successfully, False otherwise - """ - try: - # Check schedule expiration before dispatching - if not self._check_schedule_valid(trigger): - logger.info( - "Schedule trigger expired, skipping dispatch", - extra={"trigger_id": trigger.id, "trigger_name": trigger.name} - ) - return False - - # Create execution record - execution_id = str(uuid4()) - execution = TriggerExecution( - trigger_id=trigger.id, - execution_id=execution_id, - execution_type=ExecutionType.scheduled, - status=ExecutionStatus.pending, - input_data={"scheduled_at": datetime.now(timezone.utc).isoformat()}, - started_at=datetime.now(timezone.utc) - ) - - self.session.add(execution) - - # Update trigger statistics - trigger.last_executed_at = datetime.now(timezone.utc) - trigger.last_execution_status = "pending" - - # Calculate and set next run time - try: - trigger.next_run_at = self.calculate_next_run_at(trigger, datetime.now(timezone.utc)) - except Exception as e: - logger.error( - "Failed to calculate next run time, trigger will be skipped", - extra={"trigger_id": trigger.id, "error": str(e)} - ) - # Set next_run_at far in the future to prevent immediate re-execution - trigger.next_run_at = datetime.now(timezone.utc) + timedelta(days=365) - - # If single execution, deactivate the trigger - if trigger.is_single_execution: - trigger.status = TriggerStatus.inactive - logger.info( - "Trigger deactivated after single execution", - extra={"trigger_id": trigger.id} - ) - - self.session.add(trigger) - self.session.commit() - - # TODO: Queue the actual task execution - # This would integrate with a task queue (e.g., Celery) to execute the trigger's action - # For now event is sent to client for execution - - logger.info( - "Trigger dispatched successfully", - extra={ - "trigger_id": trigger.id, - "trigger_name": trigger.name, - "execution_id": execution_id, - "next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None - } - ) - - # Notify WebSocket subscribers - # Using asyncio.run() to run async code from sync Celery worker context - try: - # Notify WebSocket subscribers via Redis pub/sub - from app.component.redis_utils import get_redis_manager - redis_manager = get_redis_manager() - redis_manager.publish_execution_event({ - "type": "execution_created", - "execution_id": execution_id, - "trigger_id": trigger.id, - "trigger_type": "schedule", - "status": "pending", - "input_data": execution.input_data, - "task_prompt": trigger.task_prompt, - "execution_type": "schedule", - "user_id": str(trigger.user_id), - "project_id": str(trigger.project_id) - }) - - logger.debug("WebSocket notification sent", extra={ - "execution_id": execution_id, - "trigger_id": trigger.id - }) - except Exception as e: - # Don't fail the trigger dispatch if notification fails - logger.warning("Failed to send WebSocket notification", extra={ - "trigger_id": trigger.id, - "execution_id": execution_id, - "error": str(e) - }) - - return True - - except Exception as e: - logger.error( - "Failed to dispatch trigger", - extra={ - "trigger_id": trigger.id, - "error": str(e) - }, - exc_info=True - ) - self.session.rollback() - return False - - def process_schedules(self, due_schedules: List[Trigger]) -> Tuple[int, int]: - """ - Process due schedules, checking rate limits and dispatching. - - Args: - due_schedules: List of triggers that are due for execution - - Returns: - Tuple of (dispatched_count, rate_limited_count) - """ - dispatched_count = 0 - rate_limited_count = 0 - - for trigger in due_schedules: - # Check rate limits - if not check_rate_limits(self.session, trigger): - rate_limited_count += 1 - - # Still update next_run_at even if rate limited, so we don't keep checking - try: - trigger.next_run_at = self.calculate_next_run_at(trigger, datetime.now(timezone.utc)) - self.session.add(trigger) - self.session.commit() - except Exception as e: - logger.error( - "Failed to update next_run_at for rate limited trigger", - extra={"trigger_id": trigger.id, "error": str(e)} - ) - - continue - - # Dispatch the trigger - if self.dispatch_trigger(trigger): - dispatched_count += 1 - - return dispatched_count, rate_limited_count - - def poll_and_execute_due_triggers( - self, - max_dispatch_per_tick: Optional[int] = None - ) -> Tuple[int, int]: - """ - Poll for due triggers and execute them in batches. - - Args: - max_dispatch_per_tick: Maximum number of triggers to dispatch in this tick - (defaults to MAX_DISPATCH_PER_TICK) - - Returns: - Tuple of (total_dispatched, total_rate_limited) - """ - max_dispatch = max_dispatch_per_tick or MAX_DISPATCH_PER_TICK - total_dispatched = 0 - total_rate_limited = 0 - - # Process in batches until we've handled all due schedules or hit the limit - while True: - due_schedules = self.fetch_due_schedules() - - if not due_schedules: - break - - dispatched_count, rate_limited_count = self.process_schedules(due_schedules) - total_dispatched += dispatched_count - total_rate_limited += rate_limited_count - - logger.debug( - "Batch processed", - extra={ - "dispatched": dispatched_count, - "rate_limited": rate_limited_count - } - ) - - # Check if we've hit the per-tick limit (if enabled) - if max_dispatch > 0 and total_dispatched >= max_dispatch: - logger.warning( - "Circuit breaker activated: reached dispatch limit, will continue next tick", - extra={"limit": max_dispatch} - ) - break - - if total_dispatched > 0 or total_rate_limited > 0: - logger.info( - "Trigger schedule poll completed", - extra={ - "total_dispatched": total_dispatched, - "total_rate_limited": total_rate_limited - } - ) - - return total_dispatched, total_rate_limited - - def _check_schedule_valid(self, trigger: Trigger) -> bool: - """ - Check if a scheduled trigger is valid for execution. - - Validates: - - For one-time (date set): Checks if the scheduled date has passed - - For recurring (expirationDate set): Checks if expirationDate has passed - - If expired, the trigger will be marked as completed. - - Args: - trigger: The trigger to check - - Returns: - True if trigger is valid for execution, False if expired - """ - config_data = trigger.config or {} - - # If no config or empty config, allow execution (no expiration) - if not config_data: - return True - - try: - config = ScheduleTriggerConfig(**config_data) - except Exception as e: - logger.warning( - "Invalid schedule config", - extra={"trigger_id": trigger.id, "error": str(e)} - ) - return False - - # Check if schedule has expired - if config.is_expired(): - # Mark trigger as completed - trigger.status = TriggerStatus.completed - self.session.add(trigger) - self.session.commit() - - logger.info( - "Schedule trigger expired and marked as completed", - extra={ - "trigger_id": trigger.id, - "trigger_name": trigger.name, - "expiration_info": config.expirationDate or config.date - } - ) - return False - - return True - - def update_trigger_next_run(self, trigger: Trigger) -> None: - """ - Update a trigger's next_run_at based on its cron expression. - - Args: - trigger: The trigger to update - """ - try: - # Check if schedule is expired before updating next run - if not self._check_schedule_valid(trigger): - logger.info( - "Trigger expired, not updating next_run_at", - extra={"trigger_id": trigger.id} - ) - return - - trigger.next_run_at = self.calculate_next_run_at(trigger) - self.session.add(trigger) - self.session.commit() - - logger.info( - "Trigger next_run_at updated", - extra={ - "trigger_id": trigger.id, - "next_run_at": trigger.next_run_at.isoformat() - } - ) - except Exception as e: - logger.error( - "Failed to update trigger next_run_at", - extra={ - "trigger_id": trigger.id, - "error": str(e) - } - ) - self.session.rollback() diff --git a/server/app/shared/__init__.py b/server/app/shared/__init__.py new file mode 100644 index 00000000..0dddcc57 --- /dev/null +++ b/server/app/shared/__init__.py @@ -0,0 +1,20 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +v1 shared infrastructure layer. + +Provides trace context, middleware, auth, http client, logging, and exception handling +for the v1 API. Does not depend on any domains. +""" diff --git a/server/app/shared/auth/__init__.py b/server/app/shared/auth/__init__.py new file mode 100644 index 00000000..65c61bbd --- /dev/null +++ b/server/app/shared/auth/__init__.py @@ -0,0 +1,34 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Unified auth layer for v1. + +- auth_must: user token, 1 week expiry, blacklist check, type: user claim +- admin_must: admin token, separate claim, blacklist check +- require_owner: IDOR protection utility +""" + +from app.shared.auth.user_auth import auth_must, auth_optional, create_access_token, create_refresh_token +from app.shared.auth.admin_auth import admin_must +from app.shared.auth.ownership import require_owner + +__all__ = [ + "auth_must", + "auth_optional", + "create_access_token", + "create_refresh_token", + "admin_must", + "require_owner", +] diff --git a/server/app/shared/auth/admin_auth.py b/server/app/shared/auth/admin_auth.py new file mode 100644 index 00000000..7ac3995f --- /dev/null +++ b/server/app/shared/auth/admin_auth.py @@ -0,0 +1,124 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 admin auth with type: admin claim and blacklist check.""" + +import uuid +from datetime import datetime, timedelta + +import jwt +from fastapi import Depends, Request +from fastapi.security import OAuth2PasswordBearer, SecurityScopes +from fastapi_babel import _ +from jwt.exceptions import InvalidTokenError +from sqlmodel import Session + +from app.core import code +from app.core.database import session +from app.core.environment import env, env_not_empty +from app.shared.exception import NoPermissionException, TokenException +from app.model.user.admin import Admin + +from app.shared.auth.token_blacklist import is_blacklisted + +SECRET_KEY = env_not_empty("secret_key") +TOKEN_EXPIRY = timedelta(hours=1) +TOKEN_TYPE_ADMIN = "admin" + +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl=f"{env('url_prefix', '')}/v1/user/login", + auto_error=False, +) + + +class V1AdminAuth: + """v1 admin auth context.""" + + def __init__(self, admin_id: int, expired_at: datetime): + self.admin_id = admin_id + self.expired_at = expired_at + self._admin: Admin | None = None + + @property + def user(self) -> Admin: + if self._admin is None: + raise NoPermissionException(_("Admin user not found")) + return self._admin + + def can(self, *args: str) -> bool: + if len(args) == 0: + return True + for item in self.user.roles: + if set(item.permissions) & set(args): + return True + return False + + def check_permission(self, security_scopes: SecurityScopes, request: Request, db_session: Session) -> None: + if not self.can(*security_scopes.scopes): + raise NoPermissionException(_("Your are not authorized to access this function")) + + @classmethod + def decode_token(cls, token: str) -> "V1AdminAuth": + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + token_type = payload.get("type", "admin") + if token_type != TOKEN_TYPE_ADMIN: + raise TokenException(code.token_invalid, _("Invalid token type - admin required")) + admin_id = payload["admin_id"] + if payload["exp"] < int(datetime.utcnow().timestamp()): + raise TokenException(code.token_expired, _("Validate credentials expired")) + return V1AdminAuth(admin_id, datetime.fromtimestamp(payload["exp"])) + except InvalidTokenError: + raise TokenException(code.token_invalid, _("Could not validate credentials")) + + @classmethod + def create_access_token(cls, admin_id: int, expires_delta: timedelta | None = None) -> str: + """Create admin token with type: admin claim.""" + expire = datetime.utcnow() + (expires_delta or TOKEN_EXPIRY) + to_encode = { + "admin_id": admin_id, + "type": TOKEN_TYPE_ADMIN, + "jti": str(uuid.uuid4()), + "exp": expire, + } + return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256") + + +def _get_jti(token: str) -> str | None: + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"], options={"verify_exp": False}) + return payload.get("jti") + except Exception: + return None + + +async def admin_must( + security_scopes: SecurityScopes, + request: Request, + token: str = Depends(oauth2_scheme), + db_session: Session = Depends(session), +) -> V1AdminAuth: + """Require valid admin token. Raises TokenException if invalid or blacklisted.""" + if not token: + raise TokenException(code.token_need, _("Token required")) + model = V1AdminAuth.decode_token(token) + jti = _get_jti(token) + if jti and await is_blacklisted(jti): + raise TokenException(code.token_blocked, _("Token has been revoked")) + admin = db_session.get(Admin, model.admin_id) + if not admin: + raise TokenException(code.token_invalid, _("Admin not found")) + model._admin = admin + model.check_permission(security_scopes, request, db_session) + return model diff --git a/server/app/shared/auth/ownership.py b/server/app/shared/auth/ownership.py new file mode 100644 index 00000000..d0bb16dd --- /dev/null +++ b/server/app/shared/auth/ownership.py @@ -0,0 +1,37 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Ownership check utility for IDOR protection. + +All resource-access endpoints should verify the authenticated user owns the resource. +""" + +from app.shared.exception import NoPermissionException + + +def require_owner(resource, user_id: int, field: str = "user_id") -> None: + """ + Raise NoPermissionException if resource does not belong to user. + + :param resource: Model instance with ownership field (e.g. user_id) + :param user_id: Authenticated user's ID + :param field: Name of the ownership field (default: user_id) + :raises NoPermissionException: If resource is None or ownership mismatch + """ + if resource is None: + raise NoPermissionException("Resource not found") + owner_id = getattr(resource, field, None) + if owner_id is None or owner_id != user_id: + raise NoPermissionException("Resource not found") diff --git a/server/app/shared/auth/token_blacklist.py b/server/app/shared/auth/token_blacklist.py new file mode 100644 index 00000000..b1084ce7 --- /dev/null +++ b/server/app/shared/auth/token_blacklist.py @@ -0,0 +1,63 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Token blacklist for logout support (H16). + +Uses Redis to store revoked token JTIs. Keys: token:blacklist:{jti} +TTL matches remaining token lifetime. +""" + +from app.core.environment import env_or_fail +from redis import asyncio as aioredis + +_redis: aioredis.Redis | None = None +BLACKLIST_PREFIX = "token:blacklist:" + + +def _get_redis() -> aioredis.Redis: + global _redis + if _redis is None: + _redis = aioredis.from_url(env_or_fail("redis_url"), encoding="utf-8", decode_responses=True) + return _redis + + +async def is_blacklisted(jti: str) -> bool: + """Check if token JTI is in blacklist. Fail-closed: reject token if Redis is unavailable.""" + try: + r = _get_redis() + key = f"{BLACKLIST_PREFIX}{jti}" + return await r.exists(key) > 0 + except Exception as e: + from loguru import logger + logger.warning(f"Redis blacklist check failed (fail-closed): {e}") + return True + + +async def blacklist_token(jti: str, ttl_seconds: int) -> None: + """ + Add token JTI to blacklist. + + :param jti: JWT ID claim + :param ttl_seconds: Seconds until token would have expired (blacklist entry TTL) + """ + if ttl_seconds <= 0: + return + try: + r = _get_redis() + key = f"{BLACKLIST_PREFIX}{jti}" + await r.set(key, "1", ex=ttl_seconds) + except Exception as e: + from loguru import logger + logger.error(f"Redis blacklist_token failed: {e}") diff --git a/server/app/shared/auth/user_auth.py b/server/app/shared/auth/user_auth.py new file mode 100644 index 00000000..20315867 --- /dev/null +++ b/server/app/shared/auth/user_auth.py @@ -0,0 +1,177 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +User auth with 1 week access token, refresh token, blacklist check, type claim. +""" + +import uuid +from datetime import datetime, timedelta + +import jwt +from fastapi import Depends, Header +from fastapi.security import OAuth2PasswordBearer +from fastapi_babel import _ +from jwt.exceptions import InvalidTokenError +from sqlmodel import Session, select + +from app.core import code +from app.core.database import session +from app.core.environment import env, env_not_empty +from app.model.mcp.proxy import ApiKey +from app.model.user.key import Key +from app.model.user.user import User +from app.shared.auth.token_blacklist import is_blacklisted +from app.shared.exception import NoPermissionException, TokenException + +SECRET_KEY = env_not_empty("secret_key") +TOKEN_EXPIRY = timedelta(weeks=1) # 1 week +REFRESH_EXPIRY = timedelta(days=30) +TOKEN_TYPE_USER = "user" +TOKEN_TYPE_REFRESH = "refresh" + +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl=f"{env('url_prefix', '')}/v1/user/dev_login", + auto_error=False, +) + + +class V1UserAuth: + """v1 user auth context.""" + + def __init__(self, id: int, expired_at: datetime): + self.id = id + self.expired_at = expired_at + self._user: User | None = None + + @property + def user(self) -> User: + if self._user is None: + raise NoPermissionException("未查询到登录用户") + return self._user + + @classmethod + def decode_token(cls, token: str) -> "V1UserAuth": + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + token_type = payload.get("type", "user") + if token_type != TOKEN_TYPE_USER: + raise TokenException(code.token_invalid, _("Invalid token type")) + user_id = payload["id"] + if payload["exp"] < int(datetime.utcnow().timestamp()): + raise TokenException(code.token_expired, _("Validate credentials expired")) + return V1UserAuth(user_id, datetime.fromtimestamp(payload["exp"])) + except InvalidTokenError: + raise TokenException(code.token_invalid, _("Could not validate credentials")) + + @classmethod + def create_access_token(cls, user_id: int, expires_delta: timedelta | None = None) -> str: + """Create access token with 1 week expiry and type: user claim (M3, M4).""" + expire = datetime.utcnow() + (expires_delta or TOKEN_EXPIRY) + to_encode = { + "id": user_id, + "type": TOKEN_TYPE_USER, + "jti": str(uuid.uuid4()), + "exp": expire, + } + return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256") + + @classmethod + def create_refresh_token(cls, user_id: int) -> str: + """Create refresh token with 30d expiry (M3).""" + expire = datetime.utcnow() + REFRESH_EXPIRY + to_encode = { + "id": user_id, + "type": TOKEN_TYPE_REFRESH, + "jti": str(uuid.uuid4()), + "exp": expire, + } + return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256") + + +def _get_jti(token: str) -> str | None: + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"], options={"verify_exp": False}) + return payload.get("jti") + except Exception: + return None + + +async def decode_refresh_token(token: str) -> tuple[int, str | None, int]: + """ + Validate refresh token, check blacklist, and return (user_id, jti, exp_timestamp). + + :raises TokenException: if invalid, wrong type, or blacklisted. + """ + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + if payload.get("type") != TOKEN_TYPE_REFRESH: + raise TokenException(code.token_invalid, _("Invalid token type - refresh required")) + user_id = payload["id"] + jti = payload.get("jti") + exp = payload["exp"] + if jti and await is_blacklisted(jti): + raise TokenException(code.token_blocked, _("Token has been revoked")) + return user_id, jti, exp + except InvalidTokenError: + raise TokenException(code.token_invalid, _("Could not validate credentials")) + + +async def auth_must( + token: str = Depends(oauth2_scheme), + db_session: Session = Depends(session), +) -> V1UserAuth: + """Require valid user token. Raises TokenException if invalid or blacklisted.""" + if not token: + raise TokenException(code.token_need, _("Token required")) + model = V1UserAuth.decode_token(token) + jti = _get_jti(token) + if jti and await is_blacklisted(jti): + raise TokenException(code.token_blocked, _("Token has been revoked")) + user = db_session.get(User, model.id) + if not user: + raise TokenException(code.token_invalid, _("User not found")) + model._user = user + return model + + +def create_access_token(user_id: int) -> str: + """Convenience: create access token with default 1 week expiry.""" + return V1UserAuth.create_access_token(user_id) + + +def create_refresh_token(user_id: int) -> str: + """Create refresh token for token renewal.""" + return V1UserAuth.create_refresh_token(user_id) + + +async def auth_optional( + token: str | None = Depends(oauth2_scheme), + db_session: Session = Depends(session), +) -> V1UserAuth | None: + """Optional auth. Returns None if no token or invalid. Catches TokenException only (L5).""" + if token is None: + return None + try: + return await auth_must(token, db_session) + except TokenException: + return None + + +async def key_must(headers: ApiKey = Header(), db_session: Session = Depends(session)) -> Key: + """Validate API key from request headers.""" + model = db_session.exec(select(Key).where(Key.value == headers.api_key)).one_or_none() + if model is None: + raise TokenException(code.token_invalid, _(f"Could not validate key credentials: {headers.api_key}")) + return model diff --git a/server/app/shared/context.py b/server/app/shared/context.py new file mode 100644 index 00000000..05a4a927 --- /dev/null +++ b/server/app/shared/context.py @@ -0,0 +1,59 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Trace ID context for v1 request lifecycle. + +Uses contextvars to propagate trace_id through the request and into downstream +calls (httpx, Celery, Redis pub/sub). +""" + +import uuid +from contextvars import ContextVar + +_trace_id_ctx: ContextVar[str | None] = ContextVar("trace_id", default=None) + + +def get_trace_id() -> str | None: + """Get the current trace ID from context, or None if not set.""" + return _trace_id_ctx.get() + + +def set_trace_id(trace_id: str) -> None: + """Set the trace ID in the current context.""" + _trace_id_ctx.set(trace_id) + + +def generate_trace_id() -> str: + """Generate a new UUID v4 trace ID.""" + return str(uuid.uuid4()) + + +def ensure_trace_id(trace_id: str | None) -> str: + """ + Use provided trace_id if valid (UUID format), otherwise generate new one. + Returns the trace_id to use. Does NOT set context - caller must call set_trace_id. + """ + if trace_id and _is_valid_uuid(trace_id): + return trace_id + return generate_trace_id() + + +def _is_valid_uuid(value: str) -> bool: + """Check if string is a valid UUID v4.""" + try: + uuid.UUID(value, version=4) + return True + except (ValueError, TypeError): + return False diff --git a/server/app/exception/exception.py b/server/app/shared/exception/__init__.py similarity index 81% rename from server/app/exception/exception.py rename to server/app/shared/exception/__init__.py index 706fd617..80d7a72d 100644 --- a/server/app/exception/exception.py +++ b/server/app/shared/exception/__init__.py @@ -1,16 +1,18 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 exception classes and handler registration.""" class UserException(Exception): @@ -33,3 +35,14 @@ class NoPermissionException(Exception): class ProgramException(Exception): def __init__(self, text: str): self.text = text + + +from app.shared.exception.handlers import register_exception_handlers + +__all__ = [ + "UserException", + "TokenException", + "NoPermissionException", + "ProgramException", + "register_exception_handlers", +] diff --git a/server/app/shared/exception/handlers.py b/server/app/shared/exception/handlers.py new file mode 100644 index 00000000..fd11f2df --- /dev/null +++ b/server/app/shared/exception/handlers.py @@ -0,0 +1,65 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 exception handler registration.""" + +from fastapi import Request, FastAPI +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from sqlalchemy.exc import NoResultFound + +from app.core import code +from app.core.pydantic.i18n import trans, get_language +from app.shared.exception import ( + NoPermissionException, + TokenException, + UserException, +) + + +def register_exception_handlers(app: FastAPI) -> None: + """Register all v1 exception handlers on the given FastAPI app.""" + + @app.exception_handler(RequestValidationError) + async def request_exception(request: Request, e: RequestValidationError): + lang = get_language(request.headers.get("Accept-Language")) or "en_US" + return JSONResponse( + content={ + "code": code.form_error, + "error": jsonable_encoder(trans.translate(list(e.errors()), locale=lang)), + } + ) + + @app.exception_handler(TokenException) + async def token_exception(request: Request, e: TokenException): + return JSONResponse(content={"code": e.code, "text": e.text}) + + @app.exception_handler(UserException) + async def user_exception(request: Request, e: UserException): + return JSONResponse(content={"code": e.code, "text": e.description}) + + @app.exception_handler(NoPermissionException) + async def no_permission(request: Request, exception: NoPermissionException): + return JSONResponse( + status_code=200, + content={"code": code.no_permission_error, "text": exception.text}, + ) + + @app.exception_handler(NoResultFound) + async def no_results(request: Request, exception: NoResultFound): + return JSONResponse( + status_code=200, + content={"code": code.not_found, "text": exception._message()}, + ) diff --git a/server/app/shared/http/__init__.py b/server/app/shared/http/__init__.py new file mode 100644 index 00000000..4e4459d7 --- /dev/null +++ b/server/app/shared/http/__init__.py @@ -0,0 +1,25 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +v1 HTTP client with automatic X-Trace-ID injection. +""" + +from app.shared.http.client import trace_httpx_client, trace_httpx_post, trace_httpx_get + +__all__ = [ + "trace_httpx_client", + "trace_httpx_post", + "trace_httpx_get", +] diff --git a/server/app/shared/http/client.py b/server/app/shared/http/client.py new file mode 100644 index 00000000..9957d9d9 --- /dev/null +++ b/server/app/shared/http/client.py @@ -0,0 +1,92 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +httpx wrapper that auto-injects X-Trace-ID header from context. + +All external calls (litellm, OAuth adapters, Stack Auth, Google Search) +should route through this client for trace propagation. +""" + +from typing import Any + +import httpx + +from app.shared.context import get_trace_id + +TRACE_HEADER = "X-Trace-ID" + + +def _trace_headers(extra: dict | None = None) -> dict: + """Build headers dict with X-Trace-ID from context.""" + headers = dict(extra or {}) + trace_id = get_trace_id() + if trace_id: + headers[TRACE_HEADER] = trace_id + return headers + + +async def trace_httpx_post( + url: str, + *, + json: dict | None = None, + data: Any = None, + headers: dict | None = None, + timeout: float = 30.0, + **kwargs, +) -> httpx.Response: + """ + POST request with X-Trace-ID injected. + """ + merged_headers = _trace_headers(headers) + async with httpx.AsyncClient(timeout=timeout) as client: + return await client.post(url, json=json, data=data, headers=merged_headers, **kwargs) + + +async def trace_httpx_get( + url: str, + *, + params: dict | None = None, + headers: dict | None = None, + timeout: float = 30.0, + **kwargs, +) -> httpx.Response: + """ + GET request with X-Trace-ID injected. + """ + merged_headers = _trace_headers(headers) + async with httpx.AsyncClient(timeout=timeout) as client: + return await client.get(url, params=params, headers=merged_headers, **kwargs) + + +def trace_httpx_client( + base_url: str | None = None, + headers: dict | None = None, + timeout: float = 30.0, + **kwargs, +) -> httpx.AsyncClient: + """ + Create httpx.AsyncClient with default headers including X-Trace-ID. + + Use as context manager: + async with trace_httpx_client(base_url="https://...") as client: + r = await client.post("/path", json={...}) + """ + default_headers = _trace_headers(headers) + return httpx.AsyncClient( + base_url=base_url, + headers=default_headers, + timeout=timeout, + **kwargs, + ) diff --git a/server/app/shared/logging/__init__.py b/server/app/shared/logging/__init__.py new file mode 100644 index 00000000..827f83c8 --- /dev/null +++ b/server/app/shared/logging/__init__.py @@ -0,0 +1,29 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +"""v1 trace-aware logging with sensitive data masking.""" + +from app.shared.logging.logging_utils import ( + MASK_PATTERNS, + mask_sensitive, + trace_filter, + configure_v1_logging, +) + +__all__ = [ + "MASK_PATTERNS", + "mask_sensitive", + "trace_filter", + "configure_v1_logging", +] diff --git a/server/app/shared/logging/logging_utils.py b/server/app/shared/logging/logging_utils.py new file mode 100644 index 00000000..dba6e4a8 --- /dev/null +++ b/server/app/shared/logging/logging_utils.py @@ -0,0 +1,56 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +v1 trace-aware logging with sensitive data masking. + +- Binds trace_id from contextvars to all log records +- Masks API keys, emails, Redis URLs, payment info in log output +""" + +import re +from loguru import logger + +from app.shared.context import get_trace_id + + +# Patterns for sensitive data - replace with masked placeholder +MASK_PATTERNS = [ + (re.compile(r'(api[_-]?key|apikey|secret[_-]?key|authorization|bearer)\s*[:=]\s*["\']?([a-zA-Z0-9_\-\.]{8,})["\']?', re.I), r'\1=***MASKED***'), + (re.compile(r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'), r'***@***'), + (re.compile(r'redis://[^\s\'"&\]]+'), 'redis://***MASKED***'), + (re.compile(r'(sk_|pk_)[a-zA-Z0-9]+'), r'\1***MASKED***'), # Stripe keys & primary keys +] + + +def mask_sensitive(text: str) -> str: + """Mask sensitive data in log message.""" + if not isinstance(text, str): + return str(text) + result = text + for pattern, repl in MASK_PATTERNS: + result = pattern.sub(repl, result) + return result + + +def trace_filter(record: dict) -> bool: + """Loguru filter: bind trace_id and mask sensitive data.""" + record.setdefault("extra", {})["trace_id"] = get_trace_id() or "-" + record["message"] = mask_sensitive(record["message"]) + return True + + +def configure_v1_logging(): + """Add trace_id to default log format. Call during v1 app setup.""" + logger.configure(extra={"trace_id": "-"}) diff --git a/server/app/shared/middleware/__init__.py b/server/app/shared/middleware/__init__.py new file mode 100644 index 00000000..16616ebb --- /dev/null +++ b/server/app/shared/middleware/__init__.py @@ -0,0 +1,31 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +v1 middleware layer. + +- TraceID: propagates X-Trace-ID through request lifecycle +- CORS: configurable allow_origins from env +- Rate limit: factory for per-route limiters +""" + +from app.shared.middleware.cors import get_cors_middleware +from app.shared.middleware.rate_limit import rate_limiter_factory +from app.shared.middleware.trace import TraceIDMiddleware + +__all__ = [ + "TraceIDMiddleware", + "get_cors_middleware", + "rate_limiter_factory", +] diff --git a/server/app/shared/middleware/cors.py b/server/app/shared/middleware/cors.py new file mode 100644 index 00000000..63a3d84f --- /dev/null +++ b/server/app/shared/middleware/cors.py @@ -0,0 +1,52 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +CORS middleware with configurable origins from env. + +Reads CORS_ALLOW_ORIGINS from environment. Comma-separated list. +Defaults to ["*"] for development when not set. +When origins is ["*"], allow_credentials is False (CORS spec forbids * + credentials). +P2: Production warning when CORS_ALLOW_ORIGINS not set. +""" + +import os + +from loguru import logger + + +def get_cors_middleware(): + """ + Return kwargs for CORSMiddleware. Use: add_middleware(CORSMiddleware, **get_cors_middleware()) + + Env: CORS_ALLOW_ORIGINS (comma-separated, e.g. "https://app.example.com,https://admin.example.com") + Default: ["*"] when not set (development). Production should set explicit origins. + """ + raw = os.getenv("CORS_ALLOW_ORIGINS", "").strip() + if raw: + origins = [o.strip() for o in raw.split(",") if o.strip()] + else: + origins = ["*"] + from app.core.environment import env + + if env("debug", "") != "on": + logger.warning('CORS_ALLOW_ORIGINS not set, using ["*"]. Production should set explicit origins.') + # CORS spec: Access-Control-Allow-Origin: * cannot be used with credentials + allow_credentials = origins != ["*"] + return { + "allow_origins": origins, + "allow_credentials": allow_credentials, + "allow_methods": ["*"], + "allow_headers": ["*"], + } diff --git a/server/app/shared/middleware/rate_limit.py b/server/app/shared/middleware/rate_limit.py new file mode 100644 index 00000000..671448c4 --- /dev/null +++ b/server/app/shared/middleware/rate_limit.py @@ -0,0 +1,43 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Rate limiter factory for per-route limiting. + +Uses fastapi-limiter (initialized by main app lifespan). +Provides convenient factory for common limits: login, register, webhook. +""" + +from collections.abc import Callable + +from fastapi import Depends +from fastapi_limiter.depends import RateLimiter + + +def rate_limiter_factory(times: int = 10, seconds: int = 60) -> Callable: + """ + Create a RateLimiter dependency with given limits. + + :param times: Max requests allowed in the window + :param seconds: Window size in seconds + :return: FastAPI Depends-compatible callable + """ + return Depends(RateLimiter(times=times, seconds=seconds)) + + +# Predefined limiters for common use cases +login_rate_limiter = rate_limiter_factory(times=5, seconds=60) +register_rate_limiter = rate_limiter_factory(times=3, seconds=60) +webhook_rate_limiter = rate_limiter_factory(times=10, seconds=60) +install_rate_limiter = rate_limiter_factory(times=10, seconds=60) diff --git a/server/app/shared/middleware/trace.py b/server/app/shared/middleware/trace.py new file mode 100644 index 00000000..bb5cc72e --- /dev/null +++ b/server/app/shared/middleware/trace.py @@ -0,0 +1,49 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Trace ID middleware. + +Reads X-Trace-ID from request header (or generates UUID v4), injects into +request.state and contextvars, and adds to response headers. +""" + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +from app.shared.context import ( + ensure_trace_id, + set_trace_id, + get_trace_id, +) + +TRACE_HEADER = "X-Trace-ID" + + +class TraceIDMiddleware(BaseHTTPMiddleware): + """ASGI middleware that propagates X-Trace-ID through the request lifecycle.""" + + async def dispatch(self, request: Request, call_next): + incoming = request.headers.get(TRACE_HEADER) + trace_id = ensure_trace_id(incoming) + set_trace_id(trace_id) + + request.state.trace_id = trace_id + + response = await call_next(request) + + if TRACE_HEADER not in response.headers: + response.headers[TRACE_HEADER] = trace_id + + return response diff --git a/server/app/shared/redis_publish.py b/server/app/shared/redis_publish.py new file mode 100644 index 00000000..f002b9f3 --- /dev/null +++ b/server/app/shared/redis_publish.py @@ -0,0 +1,30 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Trace-aware Redis pub/sub publish helper. + +When publishing events from v1 code, use publish_with_trace to inject trace_id +into the event payload. Old code paths continue to use app.core.redis_utils directly. +""" + +from app.shared.context import get_trace_id + + +def inject_trace_into_event(event_data: dict) -> dict: + """Add trace_id to event_data if not present. Use before publishing.""" + data = dict(event_data) + if "trace_id" not in data: + data["trace_id"] = get_trace_id() or "" + return data diff --git a/server/app/shared/redis_sync.py b/server/app/shared/redis_sync.py new file mode 100644 index 00000000..e4d90993 --- /dev/null +++ b/server/app/shared/redis_sync.py @@ -0,0 +1,37 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +""" +Shared sync Redis client for v1 (CreditsService idempotency, etc). + +Reuse connection instead of creating new one per call. +Uses same redis_url as app.core.database.redis (async). +""" + +from redis import Redis + +from app.core.environment import env_or_fail + +_redis_sync: Redis | None = None + + +def get_redis_sync() -> Redis: + """Get shared sync Redis client. Lazy init, reused across calls.""" + global _redis_sync + if _redis_sync is None: + _redis_sync = Redis.from_url( + env_or_fail("redis_url"), + decode_responses=True, + ) + return _redis_sync diff --git a/server/app/shared/types/__init__.py b/server/app/shared/types/__init__.py new file mode 100644 index 00000000..3a4d90c0 --- /dev/null +++ b/server/app/shared/types/__init__.py @@ -0,0 +1,14 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + diff --git a/server/app/type/config_group.py b/server/app/shared/types/config_group.py similarity index 100% rename from server/app/type/config_group.py rename to server/app/shared/types/config_group.py diff --git a/server/app/type/pydantic.py b/server/app/shared/types/pydantic.py similarity index 98% rename from server/app/type/pydantic.py rename to server/app/shared/types/pydantic.py index 9a40a4f0..0b79dfcc 100644 --- a/server/app/type/pydantic.py +++ b/server/app/shared/types/pydantic.py @@ -1,15 +1,15 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from typing import Annotated, Literal diff --git a/server/app/type/trigger_types.py b/server/app/shared/types/trigger_types.py similarity index 100% rename from server/app/type/trigger_types.py rename to server/app/shared/types/trigger_types.py diff --git a/server/celery/beat/start b/server/celery/beat/start index 0f16c685..39220b9e 100644 --- a/server/celery/beat/start +++ b/server/celery/beat/start @@ -4,4 +4,4 @@ set -o errexit set -o nounset rm -f './celerybeat.pid' -celery -A app.component.celery beat -l info \ No newline at end of file +celery -A app.core.celery beat -l info \ No newline at end of file diff --git a/server/celery/worker/start b/server/celery/worker/start index 71d24f19..25f28e8a 100644 --- a/server/celery/worker/start +++ b/server/celery/worker/start @@ -3,4 +3,4 @@ set -o errexit set -o nounset -celery -A app.component.celery worker --loglevel=info --queues=celery,poll_trigger_schedules,check_execution_timeouts \ No newline at end of file +celery -A app.core.celery worker --loglevel=info --queues=celery,poll_trigger_schedules,check_execution_timeouts \ No newline at end of file diff --git a/server/cli.py b/server/cli.py index c647ec2c..ee338d19 100644 --- a/server/cli.py +++ b/server/cli.py @@ -1,20 +1,20 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= from app.command import cli -from app.component.environment import auto_import +from app.core.environment import auto_import auto_import("app.command") diff --git a/server/main.py b/server/main.py index f3a305af..ee64b68a 100644 --- a/server/main.py +++ b/server/main.py @@ -22,21 +22,79 @@ if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) import logging +import sys from fastapi.staticfiles import StaticFiles +from fastapi_pagination import add_pagination +from loguru import logger as loguru_logger -# Import exception handlers to register them -import app.exception.handler # noqa: F401 +from fastapi_babel import BabelMiddleware -# Import middleware to register BabelMiddleware -import app.middleware # noqa: F401 -from app import api -from app.component.environment import auto_include_routers, env +from app import api, router +from app.core.babel import babel_configs +from app.core.environment import auto_include_routers, env +from app.shared.exception.handlers import register_exception_handlers +from app.shared.middleware import TraceIDMiddleware +from app.shared.logging import trace_filter -logger = logging.getLogger("server_main") +# Register exception handlers and i18n middleware +register_exception_handlers(api) +api.add_middleware(BabelMiddleware, babel_configs=babel_configs) + +std_logger = logging.getLogger("server_main") prefix = env("url_prefix", "") -auto_include_routers(api, prefix, "app/controller") + +# Routes: domain-driven architecture +auto_include_routers(router, "", "app/domains") +auto_include_routers(router, "", "app/api") +api.include_router(router, prefix=f"{prefix}/v1") + +# Health check at root level for Docker healthcheck (GET /health) +@api.get("/health", tags=["Health"]) +async def health_check(): + return {"status": "ok", "service": "eigent-server"} + +# Backward-compatible webhook route (/api/webhook/...) +from app.domains.trigger.api.webhook_controller import router as webhook_router +api.include_router(webhook_router, prefix=prefix) + +# Pagination +add_pagination(api) + +# TraceID middleware +api.add_middleware(TraceIDMiddleware) + +# Loguru: trace_id injection + sensitive data filtering +LOG_FORMAT = "{time:YYYY-MM-DD HH:mm:ss} | {level} | {extra[trace_id]} | {message}" +loguru_logger.configure(extra={"trace_id": "-"}) +loguru_logger.remove() +loguru_logger.add(sys.stderr, level="DEBUG", filter=trace_filter, format=LOG_FORMAT) +loguru_logger.add( + "runtime/log/app.log", + rotation="10 MB", + retention="10 days", + level="DEBUG", + enqueue=True, + filter=trace_filter, + format=LOG_FORMAT, +) + +# Intercept stdlib logging and redirect to loguru +class _InterceptHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + try: + level = loguru_logger.level(record.levelname).name + except ValueError: + level = record.levelno + loguru_logger.opt(depth=6, exception=record.exc_info).log(level, record.getMessage()) + +logging.basicConfig(handlers=[_InterceptHandler()], level=logging.INFO, force=True) +for _name in ("uvicorn", "uvicorn.access", "uvicorn.error", "sqlalchemy.engine", "sqlalchemy.engine.Engine"): + _lg = logging.getLogger(_name) + _lg.handlers = [_InterceptHandler()] + _lg.propagate = False + public_dir = os.environ.get("PUBLIC_DIR") or os.path.join(os.path.dirname(__file__), "app", "public") if not os.path.isdir(public_dir): try: diff --git a/server/pyproject.toml b/server/pyproject.toml index 28042c3e..9428a9a7 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "slack-sdk>=3.39.0", "celery>=5.6.2", "redis>=7.2.1", + "loguru>=0.7.3", ] [project.optional-dependencies] diff --git a/server/start_server.sh b/server/start_server.sh index 4422b950..7b364c40 100644 --- a/server/start_server.sh +++ b/server/start_server.sh @@ -96,14 +96,39 @@ else exit 1 fi -# Start service -echo -e "${YELLOW}[5/5] Starting FastAPI service...${NC}" +# Cleanup function to stop background processes on exit +cleanup() { + echo -e "\n${YELLOW}Shutting down services...${NC}" + if [ -n "$CELERY_WORKER_PID" ] && kill -0 "$CELERY_WORKER_PID" 2>/dev/null; then + kill "$CELERY_WORKER_PID" 2>/dev/null + echo -e "${GREEN}Celery worker stopped${NC}" + fi + if [ -n "$CELERY_BEAT_PID" ] && kill -0 "$CELERY_BEAT_PID" 2>/dev/null; then + kill "$CELERY_BEAT_PID" 2>/dev/null + echo -e "${GREEN}Celery beat stopped${NC}" + fi + exit 0 +} +trap cleanup SIGINT SIGTERM + +# Start services +echo -e "${YELLOW}[5/7] Starting Celery worker...${NC}" +uv run celery -A app.core.celery worker --loglevel=info --queues=celery,poll_trigger_schedules,check_execution_timeouts & +CELERY_WORKER_PID=$! +echo -e "${GREEN}Celery worker started (PID: $CELERY_WORKER_PID)${NC}" + +echo -e "${YELLOW}[6/7] Starting Celery beat...${NC}" +uv run celery -A app.core.celery beat --loglevel=info & +CELERY_BEAT_PID=$! +echo -e "${GREEN}Celery beat started (PID: $CELERY_BEAT_PID)${NC}" + +echo -e "${YELLOW}[7/7] Starting FastAPI service...${NC}" echo -e "${CYAN}Service will start at http://localhost:3001${NC}" -echo -e "${CYAN}Press Ctrl+C to stop the service${NC}" +echo -e "${CYAN}Press Ctrl+C to stop all services${NC}" echo -e "${GREEN}========================================${NC}" if ! uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0; then echo -e "${RED}Service startup failed${NC}" - read -p "Press Enter to exit" + cleanup exit 1 fi diff --git a/server/tests/app/component/test_encrypt.py b/server/tests/app/core/test_encrypt.py similarity index 97% rename from server/tests/app/component/test_encrypt.py rename to server/tests/app/core/test_encrypt.py index 9e8341d5..574cde72 100644 --- a/server/tests/app/component/test_encrypt.py +++ b/server/tests/app/core/test_encrypt.py @@ -12,7 +12,7 @@ # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -from app.component.encrypt import password_hash, password_verify +from app.core.encrypt import password_hash, password_verify def test_hash_returns_string(): diff --git a/server/tests/test_auth.py b/server/tests/test_auth.py index ffcb2794..0c79a39d 100644 --- a/server/tests/test_auth.py +++ b/server/tests/test_auth.py @@ -16,7 +16,7 @@ import inspect import pytest -from app.controller.chat.share_controller import ( +from app.domains.chat.api.share_controller import ( create_share_link, get_share_info, share_playback, @@ -35,7 +35,7 @@ class TestAuthMustNoneTokenHandling: def test_auth_must_has_none_type_annotation(self): """auth_must should accept Optional[str] since oauth2_scheme may return None with auto_error=False.""" - from app.component.auth import auth_must + from app.shared.auth.user_auth import auth_must sig = inspect.signature(auth_must) token_param = sig.parameters["token"] @@ -47,43 +47,39 @@ class TestAuthMustNoneTokenHandling: """auth_must should raise TokenException immediately when token is None, not pass it to jwt.decode().""" import asyncio - from unittest.mock import MagicMock, patch + from unittest.mock import MagicMock - from app.component.auth import auth_must - from app.exception.exception import TokenException + from app.shared.auth.user_auth import auth_must + from app.shared.exception import TokenException mock_session = MagicMock() with pytest.raises(TokenException): - asyncio.run(auth_must(token=None, session=mock_session)) + asyncio.run(auth_must(token=None, db_session=mock_session)) def test_auth_must_does_not_call_decode_on_none(self): """Verify jwt.decode is never called with None token.""" import asyncio from unittest.mock import MagicMock, patch - from app.component.auth import auth_must + from app.shared.auth.user_auth import auth_must mock_session = MagicMock() - with patch("app.component.auth.Auth.decode_token") as mock_decode: + with patch("app.shared.auth.user_auth.V1UserAuth.decode_token") as mock_decode: try: - asyncio.run(auth_must(token=None, session=mock_session)) + asyncio.run(auth_must(token=None, db_session=mock_session)) except Exception: pass mock_decode.assert_not_called() class TestSnapshotEndpointAuthRequirements: - """Tests verifying that all snapshot CRUD endpoints require authentication. - - The list endpoint was previously missing the auth dependency, allowing - unauthenticated users to enumerate all snapshots across all users. - """ + """Tests verifying that all snapshot CRUD endpoints require authentication.""" def test_list_snapshots_requires_auth_dependency(self): """GET /snapshots must include auth_must as a dependency.""" - from app.controller.chat.snapshot_controller import list_chat_snapshots + from app.domains.chat.api.snapshot_controller import list_chat_snapshots sig = inspect.signature(list_chat_snapshots) param_names = list(sig.parameters.keys()) @@ -94,7 +90,7 @@ class TestSnapshotEndpointAuthRequirements: def test_get_snapshot_requires_auth_dependency(self): """GET /snapshots/{id} must include auth_must as a dependency.""" - from app.controller.chat.snapshot_controller import get_chat_snapshot + from app.domains.chat.api.snapshot_controller import get_chat_snapshot sig = inspect.signature(get_chat_snapshot) param_names = list(sig.parameters.keys()) @@ -102,7 +98,7 @@ class TestSnapshotEndpointAuthRequirements: def test_create_snapshot_requires_auth_dependency(self): """POST /snapshots must include auth_must as a dependency.""" - from app.controller.chat.snapshot_controller import create_chat_snapshot + from app.domains.chat.api.snapshot_controller import create_chat_snapshot sig = inspect.signature(create_chat_snapshot) param_names = list(sig.parameters.keys()) @@ -110,7 +106,7 @@ class TestSnapshotEndpointAuthRequirements: def test_update_snapshot_requires_auth_dependency(self): """PUT /snapshots/{id} must include auth_must as a dependency.""" - from app.controller.chat.snapshot_controller import update_chat_snapshot + from app.domains.chat.api.snapshot_controller import update_chat_snapshot sig = inspect.signature(update_chat_snapshot) param_names = list(sig.parameters.keys()) @@ -118,7 +114,7 @@ class TestSnapshotEndpointAuthRequirements: def test_delete_snapshot_requires_auth_dependency(self): """DELETE /snapshots/{id} must include auth_must as a dependency.""" - from app.controller.chat.snapshot_controller import delete_chat_snapshot + from app.domains.chat.api.snapshot_controller import delete_chat_snapshot sig = inspect.signature(delete_chat_snapshot) param_names = list(sig.parameters.keys()) diff --git a/src/components/AddWorker/ToolSelect.tsx b/src/components/AddWorker/ToolSelect.tsx index f09338d3..5bc86348 100644 --- a/src/components/AddWorker/ToolSelect.tsx +++ b/src/components/AddWorker/ToolSelect.tsx @@ -103,7 +103,7 @@ const ToolSelect = forwardRef< const fetchIntegrationsData = useCallback( (keyword?: string) => { - proxyFetchGet('/api/config/info') + proxyFetchGet('/api/v1/config/info') .then((res) => { if (res && typeof res === 'object' && !res.error) { const baseURL = getProxyBaseURL(); @@ -131,7 +131,7 @@ const ToolSelect = forwardRef< // Still proceed but log the warning } // Save to config to mark as installed - await proxyFetchPost('/api/configs', { + await proxyFetchPost('/api/v1/configs', { config_group: 'Notion', config_name: 'MCP_REMOTE_CONFIG_DIR', config_value: @@ -181,7 +181,7 @@ const ToolSelect = forwardRef< } try { const existingConfigs = - await proxyFetchGet('/api/configs'); + await proxyFetchGet('/api/v1/configs'); const existing = Array.isArray(existingConfigs) ? existingConfigs.find( (c: any) => @@ -199,11 +199,14 @@ const ToolSelect = forwardRef< if (existing) { await proxyFetchPut( - `/api/configs/${existing.id}`, + `/api/v1/configs/${existing.id}`, configPayload ); } else { - await proxyFetchPost('/api/configs', configPayload); + await proxyFetchPost( + '/api/v1/configs', + configPayload + ); } } catch (configError) { console.warn( @@ -253,7 +256,7 @@ const ToolSelect = forwardRef< } else { onInstall = () => window.open( - `${baseURL}/api/oauth/${key.toLowerCase()}/login`, + `${baseURL}/api/v1/oauth/${key.toLowerCase()}/login`, '_blank', 'width=600,height=700' ); @@ -313,7 +316,7 @@ const ToolSelect = forwardRef< // data fetching const fetchData = useCallback((keyword?: string) => { - proxyFetchGet('/api/mcps', { + proxyFetchGet('/api/v1/mcps', { keyword: keyword || '', page: 1, size: 100, @@ -334,7 +337,7 @@ const ToolSelect = forwardRef< }, []); const fetchInstalledMcps = useCallback(() => { - proxyFetchGet('/api/mcp/users') + proxyFetchGet('/api/v1/mcp/users') .then((res) => { let dataList = []; let ids: number[] = []; @@ -378,7 +381,7 @@ const ToolSelect = forwardRef< value: string ) => { // First fetch current configs to check for existing ones - const configsRes = await proxyFetchGet('/api/configs'); + const configsRes = await proxyFetchGet('/api/v1/configs'); const configs = Array.isArray(configsRes) ? configsRes : []; const configPayload = { @@ -396,10 +399,13 @@ const ToolSelect = forwardRef< if (existingConfig) { // Update existing config - await proxyFetchPut(`/api/configs/${existingConfig.id}`, configPayload); + await proxyFetchPut( + `/api/v1/configs/${existingConfig.id}`, + configPayload + ); } else { // Create new config - await proxyFetchPost('/api/configs', configPayload); + await proxyFetchPost('/api/v1/configs', configPayload); } if (window.electronAPI?.envWrite) { @@ -476,7 +482,7 @@ const ToolSelect = forwardRef< if (response.success) { console.log('[ToolSelect installMcp] Immediate success'); // Mark as successfully installed by writing refresh token marker - const existingConfigs = await proxyFetchGet('/api/configs'); + const existingConfigs = await proxyFetchGet('/api/v1/configs'); const existing = Array.isArray(existingConfigs) ? existingConfigs.find( (c: any) => @@ -492,9 +498,12 @@ const ToolSelect = forwardRef< }; if (existing) { - await proxyFetchPut(`/api/configs/${existing.id}`, configPayload); + await proxyFetchPut( + `/api/v1/configs/${existing.id}`, + configPayload + ); } else { - await proxyFetchPost('/api/configs', configPayload); + await proxyFetchPost('/api/v1/configs', configPayload); } // Refresh integrations to update install status @@ -541,7 +550,8 @@ const ToolSelect = forwardRef< ); if (retryResponse.success) { // Mark as successfully installed - const existingConfigs = await proxyFetchGet('/api/configs'); + const existingConfigs = + await proxyFetchGet('/api/v1/configs'); const existing = Array.isArray(existingConfigs) ? existingConfigs.find( (c: any) => @@ -559,11 +569,11 @@ const ToolSelect = forwardRef< if (existing) { await proxyFetchPut( - `/api/configs/${existing.id}`, + `/api/v1/configs/${existing.id}`, configPayload ); } else { - await proxyFetchPost('/api/configs', configPayload); + await proxyFetchPost('/api/v1/configs', configPayload); } fetchIntegrationsData(); @@ -622,7 +632,7 @@ const ToolSelect = forwardRef< } setInstalling((prev) => ({ ...prev, [id]: true })); try { - await proxyFetchPost('/api/mcp/install?mcp_id=' + id); + await proxyFetchPost('/api/v1/mcp/install?mcp_id=' + id); setInstalled((prev) => ({ ...prev, [id]: true })); const installedMcp = mcpList.find((mcp) => mcp.id === id); if (window.ipcRenderer && installedMcp) { diff --git a/src/components/ChatBox/index.tsx b/src/components/ChatBox/index.tsx index 3e646d04..dedbede2 100644 --- a/src/components/ChatBox/index.tsx +++ b/src/components/ChatBox/index.tsx @@ -87,11 +87,11 @@ export default function ChatBox(): JSX.Element { try { if (modelType === 'cloud') { // For cloud model, check if API key exists - const res = await proxyFetchGet('/api/user/key'); + const res = await proxyFetchGet('/api/v1/user/key'); setHasModel(!!res.value); } else if (modelType === 'local' || modelType === 'custom') { // For local/custom model, check if provider exists - const res = await proxyFetchGet('/api/providers', { prefer: true }); + const res = await proxyFetchGet('/api/v1/providers', { prefer: true }); const providerList = res.items || []; setHasModel(providerList.length > 0); } else { @@ -107,7 +107,7 @@ export default function ChatBox(): JSX.Element { // Check model config on mount and when modelType changes useEffect(() => { - proxyFetchGet('/api/configs') + proxyFetchGet('/api/v1/configs') .then((configsRes) => { const configs = Array.isArray(configsRes) ? configsRes : []; const _hasApiKey = configs.find( @@ -337,7 +337,7 @@ export default function ChatBox(): JSX.Element { let taskId: string = token.split('__')[1]; chatStore.create(taskId, 'share'); chatStore.setHasMessages(taskId, true); - const res = await proxyFetchGet(`/api/chat/share/info/${_token}`); + const res = await proxyFetchGet(`/api/v1/chat/share/info/${_token}`); if (res?.question) { chatStore.addMessages(taskId, { id: generateUniqueId(), @@ -906,7 +906,7 @@ export default function ChatBox(): JSX.Element { const history_id = projectStore.getHistoryId(projectId); if (history_id) { try { - await proxyFetchDelete(`/api/chat/history/${history_id}`); + await proxyFetchDelete(`/api/v1/chat/history/${history_id}`); } catch (error) { console.error( `Failed to delete chat history (ID: ${history_id}) for project ${projectId}:`, diff --git a/src/components/Folder/index.tsx b/src/components/Folder/index.tsx index 3bbcb097..c2d32771 100644 --- a/src/components/Folder/index.tsx +++ b/src/components/Folder/index.tsx @@ -431,7 +431,7 @@ export default function Folder({ data: _data }: { data?: Agent }) { } else { if (!hasFetchedRemote.current) { //TODO(file): rename endpoint to use project_id - res = await proxyFetchGet('/api/chat/files', { + res = await proxyFetchGet('/api/v1/chat/files', { task_id: projectStore.activeProjectId as string, }); hasFetchedRemote.current = true; diff --git a/src/components/GroupedHistoryView/index.tsx b/src/components/GroupedHistoryView/index.tsx index 25528d82..a6cb232e 100644 --- a/src/components/GroupedHistoryView/index.tsx +++ b/src/components/GroupedHistoryView/index.tsx @@ -140,7 +140,7 @@ export default function GroupedHistoryView({ // Delete each task one by one for (const history of targetProject.tasks) { try { - await proxyFetchDelete(`/api/chat/history/${history.id}`); + await proxyFetchDelete(`/api/v1/chat/history/${history.id}`); console.log(`Successfully deleted task ${history.task_id}`); // Also delete local files for this task if available (via Electron IPC) @@ -220,7 +220,7 @@ export default function GroupedHistoryView({ // Call API to update project name try { const response = await proxyFetchPut( - `/api/chat/project/${projectId}/name?new_name=${encodeURIComponent(newName)}` + `/api/v1/chat/project/${projectId}/name?new_name=${encodeURIComponent(newName)}` ); if (response && response.code !== undefined && response.code !== 0) { diff --git a/src/components/HistorySidebar/index.tsx b/src/components/HistorySidebar/index.tsx index 14772504..501df675 100644 --- a/src/components/HistorySidebar/index.tsx +++ b/src/components/HistorySidebar/index.tsx @@ -67,7 +67,7 @@ export default function HistorySidebar() { useEffect(() => { if (!chatStore) return; fetchGroupedHistoryTasks(setHistoryTasks); - }, [chatStore?.updateCount, chatStore]); + }, [chatStore?.updateCount]); // Group ongoing tasks by project const ongoingProjects = useMemo(() => { @@ -190,7 +190,7 @@ export default function HistorySidebar() { historyId: string ) => { try { - const res = await proxyFetchDelete(`/api/chat/history/${historyId}`); + const res = await proxyFetchDelete(`/api/v1/chat/history/${historyId}`); console.log(res); // also delete local files for this task if available (via Electron IPC) const { email } = getAuthStore(); @@ -236,7 +236,7 @@ export default function HistorySidebar() { ); try { const deleteRes = await proxyFetchDelete( - `/api/chat/history/${history.id}` + `/api/v1/chat/history/${history.id}` ); console.log( `Successfully deleted task ${history.task_id}:`, diff --git a/src/components/IntegrationList/index.tsx b/src/components/IntegrationList/index.tsx index aadfe3f5..6247f9a8 100644 --- a/src/components/IntegrationList/index.tsx +++ b/src/components/IntegrationList/index.tsx @@ -134,7 +134,7 @@ export default function IntegrationList({ if (item.key === 'LinkedIn') { // Open LinkedIn OAuth login via the remote server (same pattern as other OAuth providers) const baseUrl = getProxyBaseURL(); - const oauthUrl = `${baseUrl}/api/oauth/linkedin/login`; + const oauthUrl = `${baseUrl}/api/v1/oauth/linkedin/login`; window.open(oauthUrl, '_blank', 'width=600,height=700'); return; } @@ -340,7 +340,7 @@ export default function IntegrationList({ } > {isSelectMode ? ( -
+
{(isSelectMode || showStatusDot) && (
) : ( -
-
+
+
{showStatusDot && (
-
+
{showConfigButton && (