mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-04-28 03:30:06 +00:00
Feat: Server refactor v1 (#1509)
Some checks are pending
Pre-commit / pre-commit (push) Waiting to run
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (javascript-typescript) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Test / Run Python Tests (push) Waiting to run
Some checks are pending
Pre-commit / pre-commit (push) Waiting to run
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (javascript-typescript) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Test / Run Python Tests (push) Waiting to run
This commit is contained in:
parent
1e542f9d27
commit
712f20a8fa
179 changed files with 5593 additions and 6063 deletions
51
.github/workflows/build-view.yml
vendored
51
.github/workflows/build-view.yml
vendored
|
|
@ -12,6 +12,7 @@ jobs:
|
|||
timeout-minutes: 120
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: macos-latest
|
||||
|
|
@ -39,6 +40,22 @@ jobs:
|
|||
run: |
|
||||
rm -rf node_modules
|
||||
rm -rf release
|
||||
rm -rf dist out build .vite
|
||||
rm -rf node_modules/.cache || true
|
||||
|
||||
# Clean build outputs on GitHub-hosted runners to avoid stale artifacts in current job
|
||||
- name: Clean build outputs (non-Windows)
|
||||
if: "!contains(matrix.os, 'self-hosted') && runner.os != 'Windows'"
|
||||
run: |
|
||||
rm -rf release dist out .vite
|
||||
rm -rf node_modules/.cache || true
|
||||
|
||||
- name: Clean build outputs (Windows)
|
||||
if: "!contains(matrix.os, 'self-hosted') && runner.os == 'Windows'"
|
||||
shell: pwsh
|
||||
run: |
|
||||
Remove-Item -Recurse -Force release, dist, out, .vite -ErrorAction SilentlyContinue
|
||||
Remove-Item -Recurse -Force node_modules/.cache -ErrorAction SilentlyContinue
|
||||
|
||||
- name: Setup Node.js
|
||||
if: "!contains(matrix.os, 'self-hosted')"
|
||||
|
|
@ -92,6 +109,20 @@ jobs:
|
|||
echo "LLVM_DIR=$(brew --prefix llvm@20)/lib/cmake/llvm" >> $GITHUB_ENV
|
||||
echo "CMAKE_PREFIX_PATH=$(brew --prefix llvm@20)/lib/cmake/llvm" >> $GITHUB_ENV
|
||||
|
||||
# Prebuild separately on macOS so signing/package issues are isolated
|
||||
- name: Build Release Files (macOS prebuild)
|
||||
if: runner.os == 'macOS'
|
||||
timeout-minutes: 45
|
||||
run: |
|
||||
npm run prebuild
|
||||
env:
|
||||
VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }}
|
||||
VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }}
|
||||
VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }}
|
||||
VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }}
|
||||
VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }}
|
||||
USE_NPM_INSTALL_BUN: 'true'
|
||||
|
||||
# Step for macOS builds with signing
|
||||
- name: Build Release Files (macOS with signing)
|
||||
if: runner.os == 'macOS'
|
||||
|
|
@ -104,8 +135,20 @@ jobs:
|
|||
fi
|
||||
ulimit -n 65536 2>/dev/null || ulimit -n 10240 2>/dev/null || true
|
||||
echo "File descriptor limit: $(ulimit -n) (hard: $(ulimit -Hn 2>/dev/null || echo 'N/A'))"
|
||||
npm run prebuild
|
||||
npx electron-builder --mac --${{ matrix.arch }} --publish never
|
||||
|
||||
set +e
|
||||
npx electron-builder --mac dmg --${{ matrix.arch }} --publish never
|
||||
BUILD_EXIT=$?
|
||||
|
||||
if [ $BUILD_EXIT -ne 0 ]; then
|
||||
echo "First attempt failed with exit code $BUILD_EXIT"
|
||||
echo "Retrying once in 5 seconds..."
|
||||
sleep 5
|
||||
npx electron-builder --mac dmg --${{ matrix.arch }} --publish never
|
||||
BUILD_EXIT=$?
|
||||
fi
|
||||
|
||||
exit $BUILD_EXIT
|
||||
env:
|
||||
CSC_LINK: ${{ secrets.CERT_P12 }}
|
||||
CSC_KEY_PASSWORD: ${{ secrets.CERT_PASSWORD }}
|
||||
|
|
@ -113,6 +156,7 @@ jobs:
|
|||
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||
VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }}
|
||||
VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }}
|
||||
VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }}
|
||||
VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }}
|
||||
VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }}
|
||||
|
|
@ -127,6 +171,7 @@ jobs:
|
|||
npx electron-builder --win --${{ matrix.arch }} --publish never
|
||||
env:
|
||||
VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }}
|
||||
VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }}
|
||||
VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }}
|
||||
VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }}
|
||||
VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }}
|
||||
|
|
@ -141,6 +186,7 @@ jobs:
|
|||
npx electron-builder --linux --${{ matrix.arch }} --publish never
|
||||
env:
|
||||
VITE_BASE_URL: ${{ secrets.VITE_BASE_URL }}
|
||||
VITE_PROXY_URL: ${{ secrets.VITE_PROXY_URL }}
|
||||
VITE_STACK_PROJECT_ID: ${{ secrets.VITE_STACK_PROJECT_ID }}
|
||||
VITE_STACK_PUBLISHABLE_CLIENT_KEY: ${{ secrets.VITE_STACK_PUBLISHABLE_CLIENT_KEY }}
|
||||
VITE_STACK_SECRET_SERVER_KEY: ${{ secrets.VITE_STACK_SECRET_SERVER_KEY }}
|
||||
|
|
@ -195,6 +241,7 @@ jobs:
|
|||
path: |
|
||||
release/*.AppImage
|
||||
retention-days: 5
|
||||
|
||||
merge-release:
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
|
@ -779,15 +780,13 @@ async def open_browser_login():
|
|||
bufsize=1, # Line buffered
|
||||
)
|
||||
|
||||
# Create async task to log Electron output
|
||||
async def log_electron_output():
|
||||
def log_electron_output():
|
||||
for line in iter(process.stdout.readline, ""):
|
||||
if line:
|
||||
logger.info(f"[ELECTRON OUTPUT] {line.strip()}")
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(log_electron_output())
|
||||
log_thread = threading.Thread(target=log_electron_output, daemon=True)
|
||||
log_thread.start()
|
||||
|
||||
# Wait a bit for Electron to start
|
||||
import asyncio
|
||||
|
|
|
|||
|
|
@ -73,6 +73,11 @@ def register_routers(app: FastAPI, prefix: str = "") -> None:
|
|||
},
|
||||
]
|
||||
|
||||
app.include_router(health_controller.router, tags=["Health"])
|
||||
logger.info(
|
||||
"Registered Health router at root level for Docker health checks"
|
||||
)
|
||||
|
||||
for config in routers_config:
|
||||
app.include_router(
|
||||
config["router"], prefix=prefix, tags=config["tags"]
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
|
|
@ -25,7 +25,7 @@ from sqlalchemy import engine_from_config, pool
|
|||
from sqlmodel import SQLModel
|
||||
|
||||
from alembic import context
|
||||
from app.component.environment import auto_import, env_not_empty
|
||||
from app.core.environment import auto_import, env_not_empty
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
|
|||
|
|
@ -24,12 +24,12 @@ from typing import Sequence, Union
|
|||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel.sql.sqltypes
|
||||
from app.type.trigger_types import ExecutionStatus
|
||||
from app.type.trigger_types import ExecutionType
|
||||
from app.type.trigger_types import ListenerType
|
||||
from app.type.trigger_types import RequestType
|
||||
from app.type.trigger_types import TriggerStatus
|
||||
from app.type.trigger_types import TriggerType
|
||||
from app.shared.types.trigger_types import ExecutionStatus
|
||||
from app.shared.types.trigger_types import ExecutionType
|
||||
from app.shared.types.trigger_types import ListenerType
|
||||
from app.shared.types.trigger_types import RequestType
|
||||
from app.shared.types.trigger_types import TriggerStatus
|
||||
from app.shared.types.trigger_types import TriggerType
|
||||
from sqlalchemy_utils.types import ChoiceType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi_pagination import add_pagination
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from app.component.environment import env_or_fail
|
||||
from app.core.environment import env_or_fail
|
||||
from redis import asyncio as aioredis
|
||||
import logging
|
||||
|
||||
|
|
@ -43,3 +43,5 @@ api = FastAPI(
|
|||
lifespan=lifespan
|
||||
)
|
||||
add_pagination(api)
|
||||
|
||||
router = APIRouter()
|
||||
|
|
|
|||
15
server/app/api/__init__.py
Normal file
15
server/app/api/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Non-domain v1 API endpoints."""
|
||||
27
server/app/api/demo_controller.py
Normal file
27
server/app/api/demo_controller.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Demo endpoint - uses v1 auth."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi_babel import _
|
||||
|
||||
from app.shared.auth import auth_optional
|
||||
|
||||
router = APIRouter(tags=["demo"])
|
||||
|
||||
|
||||
@router.get("/demo")
|
||||
def get(user=Depends(auth_optional)):
|
||||
return {"message": user.id if user else _("no auth"), "content": _("hello")}
|
||||
|
|
@ -1,16 +1,18 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Redirect controller - H11 XSS fix with json.dumps + encodeURIComponent."""
|
||||
|
||||
import json
|
||||
|
||||
|
|
@ -24,6 +26,7 @@ router = APIRouter(tags=["Redirect"])
|
|||
def redirect_callback(code: str, request: Request):
|
||||
cookies = request.cookies
|
||||
cookies_json = json.dumps(cookies)
|
||||
safe_code = json.dumps(code)
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
|
|
@ -72,10 +75,10 @@ def redirect_callback(code: str, request: Request):
|
|||
<script>
|
||||
(function() {{
|
||||
const allCookies = {cookies_json};
|
||||
const baseUrl = "eigent://callback?code={code}";
|
||||
const code = {safe_code};
|
||||
const baseUrl = "eigent://callback?code=" + encodeURIComponent(code);
|
||||
let finalUrl = baseUrl;
|
||||
|
||||
// 自动跳转到应用
|
||||
window.location.href = finalUrl;
|
||||
}})();
|
||||
</script>
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, Header
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi_babel import _
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.component import code
|
||||
from app.component.database import session
|
||||
from app.component.environment import env, env_not_empty
|
||||
from app.exception.exception import (
|
||||
NoPermissionException,
|
||||
TokenException,
|
||||
)
|
||||
from app.model.mcp.proxy import ApiKey
|
||||
from app.model.user.key import Key
|
||||
from app.model.user.user import User
|
||||
|
||||
|
||||
class Auth:
|
||||
SECRET_KEY = env_not_empty("secret_key")
|
||||
|
||||
def __init__(self, id: int, expired_at: datetime):
|
||||
self.id = id
|
||||
self.expired_at = expired_at
|
||||
self._user: User | None = None
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
if self._user is None:
|
||||
raise NoPermissionException("未查询到登录用户")
|
||||
return self._user
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, Auth.SECRET_KEY, algorithms=["HS256"])
|
||||
id = payload["id"]
|
||||
if payload["exp"] < int(datetime.now().timestamp()):
|
||||
raise TokenException(code.token_expired, _("Validate credentials expired"))
|
||||
except InvalidTokenError:
|
||||
raise TokenException(code.token_invalid, _("Could not validate credentials"))
|
||||
return Auth(id, payload["exp"])
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, user_id: int, expires_delta: timedelta | None = None):
|
||||
to_encode: dict = {"id": user_id}
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + timedelta(days=30)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, Auth.SECRET_KEY, algorithm="HS256")
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{env('url_prefix', '')}/dev_login", auto_error=False)
|
||||
|
||||
|
||||
async def auth(
|
||||
token: str | None = Depends(oauth2_scheme),
|
||||
session: Session = Depends(session),
|
||||
) -> Auth | None:
|
||||
if token is None:
|
||||
return None
|
||||
try:
|
||||
model = Auth.decode_token(token)
|
||||
user = session.get(User, model.id)
|
||||
model._user = user
|
||||
return model
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def auth_must(
|
||||
token: str | None = Depends(oauth2_scheme),
|
||||
session: Session = Depends(session),
|
||||
) -> Auth:
|
||||
if token is None:
|
||||
raise TokenException(code.token_invalid, _("Authentication required"))
|
||||
model = Auth.decode_token(token)
|
||||
user = session.get(User, model.id)
|
||||
model._user = user
|
||||
return model
|
||||
|
||||
|
||||
async def key_must(headers: ApiKey = Header(), session: Session = Depends(session)):
|
||||
model = session.exec(select(Key).where(Key.value == headers.api_key)).one_or_none()
|
||||
if model is None:
|
||||
raise TokenException(code.token_invalid, _(f"Could not validate key credentials: {headers.api_key}"))
|
||||
return model
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi_babel import _
|
||||
|
||||
"""
|
||||
权限定义:
|
||||
当存在子权限的时候,父权限则不生效,应该全部放至子权限中定义处理
|
||||
"""
|
||||
|
||||
|
||||
def permissions():
|
||||
return [
|
||||
{
|
||||
"name": _("User"),
|
||||
"description": _("User manager"),
|
||||
"children": [
|
||||
{
|
||||
"identity": "user:view",
|
||||
"name": _("User Manage"),
|
||||
"description": _("View users"),
|
||||
},
|
||||
{
|
||||
"identity": "user:edit",
|
||||
"name": _("User Edit"),
|
||||
"description": _("Manage users"), # 修改用户信息,邀请用户(限本组织下)
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": _("Admin"),
|
||||
"description": _("Admin manager"),
|
||||
"children": [
|
||||
{
|
||||
"identity": "admin:view",
|
||||
"name": _("Admin View"),
|
||||
"description": _("View admins"), # 修改项目,工作区,角色,用户
|
||||
},
|
||||
{
|
||||
"identity": "admin:edit",
|
||||
"name": _("Admin Edit"),
|
||||
"description": _("Edit admins"),
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": _("Role"),
|
||||
"description": _("Role manager"),
|
||||
"children": [
|
||||
{
|
||||
"identity": "role:view",
|
||||
"name": _("Role View"),
|
||||
"description": _("View roles"), # 修改项目和工作区中的角色,创建新的角色
|
||||
},
|
||||
{
|
||||
"identity": "role:edit",
|
||||
"name": _("Role Edit"),
|
||||
"description": _("Edit roles"), # 修改角色
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": _("Mcp"),
|
||||
"description": _("Mcp manager"),
|
||||
"children": [
|
||||
{
|
||||
"identity": "mcp:edit",
|
||||
"name": _("Mcp Edit"),
|
||||
"description": _("Edit mcp service"),
|
||||
},
|
||||
{
|
||||
"identity": "mcp-category:edit",
|
||||
"name": _("Mcp Category Edit"),
|
||||
"description": _("Edit mcp category"),
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""
|
||||
Trigger Service Package
|
||||
|
||||
Contains services for managing triggers including:
|
||||
- TriggerService: Main service for trigger operations
|
||||
- TriggerScheduleService: Service for scheduled trigger operations
|
||||
- App Handlers: Handlers for different trigger types (Slack, Webhook, Schedule)
|
||||
"""
|
||||
|
||||
from app.service.trigger.trigger_service import TriggerService, get_trigger_service
|
||||
from app.service.trigger.trigger_schedule_service import TriggerScheduleService
|
||||
from app.service.trigger.app_handler_service import (
|
||||
BaseAppHandler,
|
||||
SlackAppHandler,
|
||||
DefaultWebhookHandler,
|
||||
ScheduleAppHandler,
|
||||
AppHandlerResult,
|
||||
get_app_handler,
|
||||
get_schedule_handler,
|
||||
register_app_handler,
|
||||
get_supported_trigger_types,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Services
|
||||
"TriggerService",
|
||||
"get_trigger_service",
|
||||
"TriggerScheduleService",
|
||||
# Handlers
|
||||
"BaseAppHandler",
|
||||
"SlackAppHandler",
|
||||
"DefaultWebhookHandler",
|
||||
"ScheduleAppHandler",
|
||||
"AppHandlerResult",
|
||||
# Handler functions
|
||||
"get_app_handler",
|
||||
"get_schedule_handler",
|
||||
"register_app_handler",
|
||||
"get_supported_trigger_types",
|
||||
]
|
||||
|
|
@ -1,391 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlmodel import select, and_, or_
|
||||
from uuid import uuid4
|
||||
import logging
|
||||
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.type.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus
|
||||
from app.component.database import session_make
|
||||
from app.service.trigger.trigger_schedule_service import TriggerScheduleService
|
||||
from app.component.trigger_utils import SCHEDULED_FETCH_BATCH_SIZE, check_rate_limits
|
||||
from app.model.trigger.app_configs import ScheduleTriggerConfig, WebhookTriggerConfig
|
||||
from app.model.trigger.app_configs.base_config import BaseTriggerConfig
|
||||
|
||||
|
||||
|
||||
class TriggerService:
|
||||
"""Service for managing trigger operations and scheduling."""
|
||||
|
||||
def __init__(self, session=None):
|
||||
self.session = session or session_make()
|
||||
self.schedule_service = TriggerScheduleService(self.session)
|
||||
|
||||
def create_execution(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
execution_type: ExecutionType,
|
||||
input_data: Optional[Dict[str, Any]] = None
|
||||
) -> TriggerExecution:
|
||||
"""Create a new trigger execution."""
|
||||
execution_id = str(uuid4())
|
||||
|
||||
execution = TriggerExecution(
|
||||
trigger_id=trigger.id,
|
||||
execution_id=execution_id,
|
||||
execution_type=execution_type,
|
||||
status=ExecutionStatus.pending,
|
||||
input_data=input_data or {},
|
||||
started_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
self.session.add(execution)
|
||||
self.session.commit()
|
||||
self.session.refresh(execution)
|
||||
|
||||
# Update trigger statistics
|
||||
trigger.last_executed_at = datetime.now(timezone.utc)
|
||||
trigger.last_execution_status = "pending"
|
||||
self.session.add(trigger)
|
||||
self.session.commit()
|
||||
|
||||
logger.info("Execution created", extra={
|
||||
"trigger_id": trigger.id,
|
||||
"execution_id": execution_id,
|
||||
"execution_type": execution_type.value
|
||||
})
|
||||
|
||||
return execution
|
||||
|
||||
def update_execution_status(
|
||||
self,
|
||||
execution: TriggerExecution,
|
||||
status: ExecutionStatus,
|
||||
output_data: Optional[Dict[str, Any]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
tokens_used: Optional[int] = None,
|
||||
tools_executed: Optional[Dict[str, Any]] = None
|
||||
) -> TriggerExecution:
|
||||
"""Update execution status and metadata."""
|
||||
execution.status = status
|
||||
|
||||
# Set completed_at and duration for terminal statuses
|
||||
if status in [ExecutionStatus.completed, ExecutionStatus.failed, ExecutionStatus.cancelled, ExecutionStatus.missed]:
|
||||
execution.completed_at = datetime.now(timezone.utc)
|
||||
if execution.started_at:
|
||||
# Ensure started_at is timezone-aware for subtraction
|
||||
started_at = execution.started_at
|
||||
if started_at.tzinfo is None:
|
||||
started_at = started_at.replace(tzinfo=timezone.utc)
|
||||
execution.duration_seconds = (execution.completed_at - started_at).total_seconds()
|
||||
|
||||
if output_data:
|
||||
execution.output_data = output_data
|
||||
|
||||
if error_message:
|
||||
execution.error_message = error_message
|
||||
|
||||
if tokens_used:
|
||||
execution.tokens_used = tokens_used
|
||||
|
||||
if tools_executed:
|
||||
execution.tools_executed = tools_executed
|
||||
|
||||
self.session.add(execution)
|
||||
self.session.commit()
|
||||
|
||||
# Update trigger status and handle auto-disable logic
|
||||
trigger = self.session.get(Trigger, execution.trigger_id)
|
||||
if trigger:
|
||||
if status == ExecutionStatus.failed:
|
||||
trigger.last_execution_status = "failed"
|
||||
trigger.consecutive_failures += 1
|
||||
|
||||
# Check for auto-disable based on max_failure_count in config
|
||||
self._check_auto_disable(trigger)
|
||||
|
||||
elif status == ExecutionStatus.completed:
|
||||
trigger.last_execution_status = "completed"
|
||||
# Reset consecutive failures on success
|
||||
trigger.consecutive_failures = 0
|
||||
elif status == ExecutionStatus.cancelled:
|
||||
trigger.last_execution_status = "cancelled"
|
||||
elif status == ExecutionStatus.missed:
|
||||
trigger.last_execution_status = "missed"
|
||||
|
||||
self.session.add(trigger)
|
||||
self.session.commit()
|
||||
|
||||
logger.info("Execution status updated", extra={
|
||||
"execution_id": execution.execution_id,
|
||||
"status": status.name,
|
||||
"duration": execution.duration_seconds
|
||||
})
|
||||
|
||||
return execution
|
||||
|
||||
def _check_auto_disable(self, trigger: Trigger) -> bool:
|
||||
"""
|
||||
Check if trigger should be auto-disabled based on consecutive failures.
|
||||
|
||||
Args:
|
||||
trigger: The trigger to check
|
||||
|
||||
Returns:
|
||||
True if trigger was auto-disabled, False otherwise
|
||||
"""
|
||||
if not trigger.config:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get the appropriate config class based on trigger type
|
||||
config: BaseTriggerConfig
|
||||
if trigger.trigger_type == TriggerType.schedule:
|
||||
config = ScheduleTriggerConfig(**trigger.config)
|
||||
elif trigger.trigger_type == TriggerType.webhook:
|
||||
config = WebhookTriggerConfig(**trigger.config)
|
||||
else:
|
||||
# For other trigger types, use base config
|
||||
config = BaseTriggerConfig(**trigger.config)
|
||||
|
||||
# Check if auto-disable should happen
|
||||
if config.should_auto_disable(trigger.consecutive_failures):
|
||||
trigger.status = TriggerStatus.inactive
|
||||
trigger.auto_disabled_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.warning(
|
||||
"Trigger auto-disabled due to max failures",
|
||||
extra={
|
||||
"trigger_id": trigger.id,
|
||||
"trigger_name": trigger.name,
|
||||
"consecutive_failures": trigger.consecutive_failures,
|
||||
"max_failure_count": config.max_failure_count
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check auto-disable for trigger",
|
||||
extra={
|
||||
"trigger_id": trigger.id,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def get_pending_executions(self) -> List[TriggerExecution]:
|
||||
"""Get all pending executions that need to be processed."""
|
||||
executions = self.session.exec(
|
||||
select(TriggerExecution).where(
|
||||
TriggerExecution.status == ExecutionStatus.pending
|
||||
).order_by(TriggerExecution.created_at)
|
||||
).all()
|
||||
|
||||
return list(executions)
|
||||
|
||||
def get_failed_executions_for_retry(self) -> List[TriggerExecution]:
|
||||
"""Get failed executions that can be retried."""
|
||||
executions = self.session.exec(
|
||||
select(TriggerExecution).where(
|
||||
and_(
|
||||
TriggerExecution.status == ExecutionStatus.failed,
|
||||
TriggerExecution.attempts < TriggerExecution.max_retries
|
||||
)
|
||||
).order_by(TriggerExecution.created_at)
|
||||
).all()
|
||||
|
||||
return list(executions)
|
||||
|
||||
def get_due_scheduled_triggers(self, limit: Optional[int] = None) -> List[Trigger]:
|
||||
"""
|
||||
Fetch scheduled triggers that are due for execution.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of triggers to fetch (defaults to SCHEDULED_FETCH_BATCH_SIZE)
|
||||
|
||||
Returns:
|
||||
List of triggers that are due for execution
|
||||
"""
|
||||
current_time = datetime.now(timezone.utc)
|
||||
limit = limit or SCHEDULED_FETCH_BATCH_SIZE
|
||||
|
||||
# Query triggers that:
|
||||
# 1. Are scheduled type
|
||||
# 2. Are active
|
||||
# 3. Have a cron expression
|
||||
# 4. next_run_at is null (never run) or next_run_at <= now
|
||||
triggers = self.session.exec(
|
||||
select(Trigger)
|
||||
.where(
|
||||
and_(
|
||||
Trigger.trigger_type == TriggerType.schedule,
|
||||
Trigger.status == TriggerStatus.active,
|
||||
Trigger.custom_cron_expression.is_not(None),
|
||||
or_(
|
||||
Trigger.next_run_at.is_(None),
|
||||
Trigger.next_run_at <= current_time
|
||||
)
|
||||
)
|
||||
)
|
||||
.limit(limit)
|
||||
).all()
|
||||
|
||||
return list(triggers)
|
||||
|
||||
def execute_scheduled_triggers(self) -> int:
|
||||
"""
|
||||
Execute all due scheduled triggers.
|
||||
Uses TriggerScheduleService for the actual execution logic.
|
||||
"""
|
||||
due_triggers = self.get_due_scheduled_triggers()
|
||||
|
||||
if not due_triggers:
|
||||
return 0
|
||||
|
||||
dispatched_count, rate_limited_count = self.schedule_service.process_schedules(due_triggers)
|
||||
|
||||
logger.info(
|
||||
"Scheduled triggers execution completed",
|
||||
extra={
|
||||
"dispatched": dispatched_count,
|
||||
"rate_limited": rate_limited_count
|
||||
}
|
||||
)
|
||||
|
||||
return dispatched_count
|
||||
|
||||
def process_slack_trigger(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
slack_data: Dict[str, Any]
|
||||
) -> Optional[TriggerExecution]:
|
||||
"""Process a Slack trigger event."""
|
||||
if trigger.trigger_type != TriggerType.slack_trigger:
|
||||
raise ValueError("Trigger is not a Slack trigger")
|
||||
|
||||
if trigger.status != TriggerStatus.active:
|
||||
logger.warning("Slack trigger is not active", extra={
|
||||
"trigger_id": trigger.id
|
||||
})
|
||||
return None
|
||||
|
||||
if not check_rate_limits(self.session, trigger):
|
||||
logger.warning("Slack trigger execution skipped due to rate limits", extra={
|
||||
"trigger_id": trigger.id
|
||||
})
|
||||
return None
|
||||
|
||||
try:
|
||||
execution = self.create_execution(
|
||||
trigger=trigger,
|
||||
execution_type=ExecutionType.slack,
|
||||
input_data=slack_data
|
||||
)
|
||||
|
||||
# TODO: Queue the actual task execution
|
||||
|
||||
logger.info("Slack trigger executed", extra={
|
||||
"trigger_id": trigger.id,
|
||||
"execution_id": execution.execution_id
|
||||
})
|
||||
|
||||
return execution
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Slack trigger execution failed", extra={
|
||||
"trigger_id": trigger.id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
return None
|
||||
|
||||
def cleanup_old_executions(self, days_to_keep: int = 30) -> int:
|
||||
"""Clean up old execution records."""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
old_executions = self.session.exec(
|
||||
select(TriggerExecution).where(
|
||||
and_(
|
||||
TriggerExecution.created_at < cutoff_date,
|
||||
TriggerExecution.status.in_([
|
||||
ExecutionStatus.completed,
|
||||
ExecutionStatus.failed,
|
||||
ExecutionStatus.cancelled
|
||||
])
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
count = len(old_executions)
|
||||
|
||||
for execution in old_executions:
|
||||
self.session.delete(execution)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
logger.info("Old executions cleaned up", extra={
|
||||
"count": count,
|
||||
"days_to_keep": days_to_keep
|
||||
})
|
||||
|
||||
return count
|
||||
|
||||
def get_trigger_statistics(self, trigger_id: int) -> Dict[str, Any]:
|
||||
"""Get statistics for a specific trigger."""
|
||||
trigger = self.session.get(Trigger, trigger_id)
|
||||
if not trigger:
|
||||
raise ValueError("Trigger not found")
|
||||
|
||||
# Get execution counts by status
|
||||
executions = self.session.exec(
|
||||
select(TriggerExecution).where(
|
||||
TriggerExecution.trigger_id == trigger_id
|
||||
)
|
||||
).all()
|
||||
|
||||
stats = {
|
||||
"trigger_id": trigger_id,
|
||||
"name": trigger.name,
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
"status": trigger.status.name,
|
||||
"total_executions": len(executions),
|
||||
"successful_executions": len([e for e in executions if e.status == ExecutionStatus.completed]),
|
||||
"failed_executions": len([e for e in executions if e.status == ExecutionStatus.failed]),
|
||||
"pending_executions": len([e for e in executions if e.status == ExecutionStatus.pending]),
|
||||
"cancelled_executions": len([e for e in executions if e.status == ExecutionStatus.cancelled]),
|
||||
"last_executed_at": trigger.last_executed_at.isoformat() if trigger.last_executed_at else None,
|
||||
"created_at": trigger.created_at.isoformat() if trigger.created_at else None
|
||||
}
|
||||
|
||||
# Calculate average execution time for completed executions
|
||||
completed_executions = [e for e in executions if e.status == ExecutionStatus.completed and e.duration_seconds]
|
||||
if completed_executions:
|
||||
avg_duration = sum(e.duration_seconds for e in completed_executions) / len(completed_executions)
|
||||
stats["average_execution_time_seconds"] = round(avg_duration, 2)
|
||||
|
||||
# Calculate total tokens used
|
||||
total_tokens = sum(e.tokens_used for e in executions if e.tokens_used)
|
||||
if total_tokens:
|
||||
stats["total_tokens_used"] = total_tokens
|
||||
|
||||
return stats
|
||||
|
||||
def get_trigger_service(session=None) -> TriggerService:
|
||||
"""Factory function to create a TriggerService instance with a fresh session."""
|
||||
return TriggerService(session)
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
|
||||
from app.component import code
|
||||
from app.component.environment import env_not_empty
|
||||
from app.exception.exception import UserException
|
||||
|
||||
|
||||
class StackAuth:
|
||||
_signing_key_cache = {}
|
||||
|
||||
@staticmethod
|
||||
async def user_id(token: str):
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header.get("kid")
|
||||
if not kid:
|
||||
raise jwt.InvalidTokenError("Token is missing 'kid' in header")
|
||||
|
||||
signed = await StackAuth.stack_signing_key(kid)
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
signed.key,
|
||||
algorithms=["ES256"],
|
||||
audience=env_not_empty("stack_project_id"),
|
||||
# issuer="https://access-token.jwt-signature.stack-auth.com",
|
||||
)
|
||||
return payload["sub"]
|
||||
|
||||
@staticmethod
|
||||
async def user_info(token: str):
|
||||
headers = {
|
||||
"X-Stack-Access-Type": "server",
|
||||
"X-Stack-Project-Id": env_not_empty("stack_project_id"),
|
||||
"X-Stack-Secret-Server-Key": env_not_empty("stack_secret_server_key"),
|
||||
"X-Stack-Access-Token": token,
|
||||
}
|
||||
url = "https://api.stack-auth.com/api/v1/users/me"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def stack_signing_key(kid: str):
|
||||
if kid in StackAuth._signing_key_cache:
|
||||
return StackAuth._signing_key_cache[kid]
|
||||
|
||||
jwks_endpoint = (
|
||||
f"https://api.stack-auth.com/api/v1/projects/{env_not_empty('stack_project_id')}/.well-known/jwks.json"
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
jwks_client = jwt.PyJWKClient(jwks_endpoint)
|
||||
|
||||
try:
|
||||
signing_key = await loop.run_in_executor(None, jwks_client.get_signing_key, kid)
|
||||
StackAuth._signing_key_cache[kid] = signing_key
|
||||
return signing_key
|
||||
except jwt.exceptions.PyJWKClientError as e:
|
||||
raise UserException(code.token_invalid, str(e))
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import arrow
|
||||
|
||||
|
||||
def to_date(time: str, format: str | None = None):
|
||||
try:
|
||||
if format:
|
||||
return arrow.get(time, format).date()
|
||||
else:
|
||||
return arrow.get(time).date()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def monday_start_time() -> datetime:
|
||||
# 获取当前时间
|
||||
now = datetime.now()
|
||||
# 计算今天是本周的第几天(星期一是0,星期天是6)
|
||||
weekday = now.weekday()
|
||||
# 计算本周一的日期
|
||||
monday = now - timedelta(days=weekday)
|
||||
# 设置时间为 0 点
|
||||
return monday.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
|
@ -1,454 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from sqlmodel import Session, case, desc, select, func, delete
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.chat.chat_history import (
|
||||
ChatHistory,
|
||||
ChatHistoryIn,
|
||||
ChatHistoryOut,
|
||||
ChatHistoryUpdate,
|
||||
ChatStatus,
|
||||
)
|
||||
from app.model.chat.chat_history_grouped import (
|
||||
GroupedHistoryResponse,
|
||||
ProjectGroup,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("server_chat_history")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat History"])
|
||||
|
||||
def is_real_task(history: ChatHistory) -> bool:
|
||||
"""
|
||||
Check if a task is a real task vs a placeholder/trigger-created task.
|
||||
Excludes placeholder tasks created during trigger creation.
|
||||
"""
|
||||
# Has actual token usage
|
||||
if history.tokens and history.tokens > 0:
|
||||
return True
|
||||
|
||||
# Has real model configuration (not placeholder "none" values)
|
||||
if (history.model_platform and history.model_platform != "none" and
|
||||
history.model_type and history.model_type != "none" and
|
||||
history.installed_mcp and history.installed_mcp != "none"):
|
||||
return True
|
||||
|
||||
# Check if question starts with trigger placeholder prefix
|
||||
if history.question and history.question.startswith("Project created via trigger:"):
|
||||
return False
|
||||
|
||||
# Default to real task if no placeholder indicators
|
||||
return True
|
||||
|
||||
@router.post("/history", name="save chat history", response_model=ChatHistoryOut)
|
||||
def create_chat_history(data: ChatHistoryIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Save new chat history."""
|
||||
user_id = auth.user.id
|
||||
|
||||
try:
|
||||
data.user_id = user_id
|
||||
chat_history = ChatHistory(**data.model_dump())
|
||||
session.add(chat_history)
|
||||
session.commit()
|
||||
session.refresh(chat_history)
|
||||
logger.info(
|
||||
"Chat history created", extra={"user_id": user_id, "history_id": chat_history.id, "task_id": data.task_id}
|
||||
)
|
||||
return chat_history
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Chat history creation failed",
|
||||
extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/histories", name="get chat history")
|
||||
def list_chat_history(session: Session = Depends(session), auth: Auth = Depends(auth_must)) -> Page[ChatHistoryOut]:
|
||||
"""List chat histories for current user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Order by created_at descending, but fallback to id descending for old records without timestamps
|
||||
# This ensures newer records with timestamps come first, followed by old records ordered by id
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == user_id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), # Non-null created_at first
|
||||
desc(ChatHistory.created_at), # Then by created_at descending
|
||||
desc(ChatHistory.id), # Finally by id descending for records with same/null created_at
|
||||
)
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, "total") else 0
|
||||
logger.debug("Chat histories listed", extra={"user_id": user_id, "total": total})
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/histories/grouped", name="get grouped chat history")
|
||||
def list_grouped_chat_history(
|
||||
include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in groups"),
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
) -> GroupedHistoryResponse:
|
||||
"""List chat histories grouped by project_id for current user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get all histories for the user, ordered by creation time
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == user_id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)), # Non-null created_at first
|
||||
desc(ChatHistory.created_at), # Then by created_at descending
|
||||
desc(ChatHistory.id) # Finally by id descending for records with same/null created_at
|
||||
)
|
||||
)
|
||||
|
||||
histories = session.exec(stmt).all()
|
||||
|
||||
# Get trigger counts per project
|
||||
trigger_count_stmt = (
|
||||
select(Trigger.project_id, func.count(Trigger.id).label('count'))
|
||||
.where(Trigger.user_id == str(user_id))
|
||||
.group_by(Trigger.project_id)
|
||||
)
|
||||
trigger_counts = session.exec(trigger_count_stmt).all()
|
||||
trigger_count_map = {project_id: count for project_id, count in trigger_counts}
|
||||
|
||||
# Group histories by project_id
|
||||
project_map = defaultdict(lambda: {
|
||||
'project_id': '',
|
||||
'project_name': None,
|
||||
'total_tokens': 0,
|
||||
'task_count': 0,
|
||||
'latest_task_date': '',
|
||||
'last_prompt': None,
|
||||
'tasks': [],
|
||||
'total_completed_tasks': 0,
|
||||
'total_ongoing_tasks': 0,
|
||||
'average_tokens_per_task': 0,
|
||||
'total_triggers': 0
|
||||
})
|
||||
|
||||
for history in histories:
|
||||
# Use project_id if available, fallback to task_id
|
||||
project_id = history.project_id if history.project_id else history.task_id
|
||||
project_data = project_map[project_id]
|
||||
|
||||
# Initialize project data
|
||||
if not project_data['project_id']:
|
||||
project_data['project_id'] = project_id
|
||||
project_data['project_name'] = history.project_name or f"Project {project_id}"
|
||||
project_data['latest_task_date'] = history.created_at.isoformat() if history.created_at else ''
|
||||
project_data['last_prompt'] = history.question # Set the most recent question
|
||||
|
||||
# Convert to ChatHistoryOut format
|
||||
history_out = ChatHistoryOut(**history.model_dump())
|
||||
|
||||
# Add task to project if requested (only real tasks)
|
||||
if include_tasks and is_real_task(history):
|
||||
project_data['tasks'].append(history_out)
|
||||
|
||||
# Update project statistics (only for real tasks)
|
||||
if is_real_task(history):
|
||||
project_data['task_count'] += 1
|
||||
project_data['total_tokens'] += history.tokens or 0
|
||||
|
||||
if history.status == ChatStatus.done:
|
||||
project_data['total_completed_tasks'] += 1
|
||||
elif history.status == ChatStatus.ongoing:
|
||||
project_data['total_ongoing_tasks'] += 1
|
||||
|
||||
# Update latest task date and last prompt
|
||||
if history.created_at:
|
||||
task_date = history.created_at.isoformat()
|
||||
if not project_data['latest_task_date'] or task_date > project_data['latest_task_date']:
|
||||
project_data['latest_task_date'] = task_date
|
||||
project_data['last_prompt'] = history.question
|
||||
|
||||
# Convert to ProjectGroup objects and sort
|
||||
projects = []
|
||||
for project_data in project_map.values():
|
||||
# Sort tasks within each project by creation date (oldest first)
|
||||
if include_tasks:
|
||||
project_data['tasks'].sort(key=lambda x: (x.created_at is None, x.created_at or ''), reverse=False)
|
||||
|
||||
# Set trigger count from trigger_count_map
|
||||
project_id = project_data['project_id']
|
||||
project_data['total_triggers'] = trigger_count_map.get(project_id, 0)
|
||||
|
||||
project_group = ProjectGroup(**project_data)
|
||||
projects.append(project_group)
|
||||
|
||||
# Sort projects by latest task date (newest first)
|
||||
projects.sort(key=lambda x: x.latest_task_date, reverse=True)
|
||||
|
||||
response = GroupedHistoryResponse(projects=projects)
|
||||
|
||||
logger.debug("Grouped chat histories listed", extra={
|
||||
"user_id": user_id,
|
||||
"total_projects": response.total_projects,
|
||||
"total_tasks": response.total_tasks,
|
||||
"include_tasks": include_tasks
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/histories/grouped/{project_id}", name="get single grouped project")
|
||||
def get_grouped_project(
|
||||
project_id: str,
|
||||
include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in the project"),
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
) -> ProjectGroup:
|
||||
"""Get a single project group by project_id for current user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get all histories for the specific project
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == user_id)
|
||||
.where(ChatHistory.project_id == project_id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)),
|
||||
desc(ChatHistory.created_at),
|
||||
desc(ChatHistory.id)
|
||||
)
|
||||
)
|
||||
|
||||
histories = session.exec(stmt).all()
|
||||
|
||||
if not histories:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Get trigger count for this project
|
||||
trigger_count_stmt = (
|
||||
select(func.count(Trigger.id))
|
||||
.where(Trigger.user_id == str(user_id))
|
||||
.where(Trigger.project_id == project_id)
|
||||
)
|
||||
trigger_count = session.exec(trigger_count_stmt).first() or 0
|
||||
|
||||
# Build project data
|
||||
project_data = {
|
||||
'project_id': project_id,
|
||||
'project_name': None,
|
||||
'total_tokens': 0,
|
||||
'task_count': 0,
|
||||
'latest_task_date': '',
|
||||
'last_prompt': None,
|
||||
'tasks': [],
|
||||
'total_completed_tasks': 0,
|
||||
'total_ongoing_tasks': 0,
|
||||
'average_tokens_per_task': 0,
|
||||
'total_triggers': trigger_count
|
||||
}
|
||||
|
||||
for history in histories:
|
||||
# Initialize project name from first history
|
||||
if not project_data['project_name']:
|
||||
project_data['project_name'] = history.project_name or f"Project {project_id}"
|
||||
project_data['latest_task_date'] = history.created_at.isoformat() if history.created_at else ''
|
||||
project_data['last_prompt'] = history.question
|
||||
|
||||
# Convert to ChatHistoryOut format
|
||||
history_out = ChatHistoryOut(**history.model_dump())
|
||||
|
||||
# Add task to project if requested (only real tasks)
|
||||
if include_tasks and is_real_task(history):
|
||||
project_data['tasks'].append(history_out)
|
||||
|
||||
# Update project statistics (only for real tasks)
|
||||
if is_real_task(history):
|
||||
project_data['task_count'] += 1
|
||||
project_data['total_tokens'] += history.tokens or 0
|
||||
|
||||
if history.status == ChatStatus.done:
|
||||
project_data['total_completed_tasks'] += 1
|
||||
elif history.status == ChatStatus.ongoing:
|
||||
project_data['total_ongoing_tasks'] += 1
|
||||
|
||||
# Update latest task date and last prompt
|
||||
if history.created_at:
|
||||
task_date = history.created_at.isoformat()
|
||||
if not project_data['latest_task_date'] or task_date > project_data['latest_task_date']:
|
||||
project_data['latest_task_date'] = task_date
|
||||
project_data['last_prompt'] = history.question
|
||||
|
||||
# Sort tasks within the project by creation date (oldest first)
|
||||
if include_tasks:
|
||||
project_data['tasks'].sort(key=lambda x: (x.created_at is None, x.created_at or ''), reverse=False)
|
||||
|
||||
project_group = ProjectGroup(**project_data)
|
||||
|
||||
logger.debug("Single grouped project retrieved", extra={
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"task_count": project_group.task_count,
|
||||
"include_tasks": include_tasks
|
||||
})
|
||||
|
||||
return project_group
|
||||
|
||||
|
||||
@router.delete("/history/{history_id}", name="delete chat history")
|
||||
def delete_chat_history(history_id: str, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete chat history."""
|
||||
user_id = auth.user.id
|
||||
history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
|
||||
|
||||
if not history:
|
||||
logger.warning("Chat history not found for deletion", extra={"user_id": user_id, "history_id": history_id})
|
||||
raise HTTPException(status_code=404, detail="Chat History not found")
|
||||
|
||||
if history.user_id != user_id:
|
||||
logger.warning(
|
||||
"Unauthorized deletion attempt",
|
||||
extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id},
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="You are not allowed to delete this chat history")
|
||||
|
||||
try:
|
||||
# Determine the project this history belongs to
|
||||
project_id = history.project_id if history.project_id else history.task_id
|
||||
|
||||
# Check if this is the last history in the project
|
||||
sibling_count = (
|
||||
session.exec(
|
||||
select(func.count(ChatHistory.id)).where(
|
||||
ChatHistory.id != history_id,
|
||||
ChatHistory.project_id == project_id if history.project_id else ChatHistory.task_id == project_id,
|
||||
)
|
||||
).first()
|
||||
or 0
|
||||
)
|
||||
|
||||
session.delete(history)
|
||||
|
||||
if sibling_count == 0:
|
||||
# Last history in the project — delete all related triggers
|
||||
triggers = session.exec(select(Trigger).where(Trigger.project_id == project_id)).all()
|
||||
for trigger in triggers:
|
||||
session.exec(delete(TriggerExecution).where(TriggerExecution.trigger_id == trigger.id))
|
||||
session.delete(trigger)
|
||||
logger.info(
|
||||
"Deleted triggers for removed project", extra={"project_id": project_id, "trigger_count": len(triggers)}
|
||||
)
|
||||
|
||||
session.commit()
|
||||
logger.info("Chat history deleted", extra={"user_id": user_id, "history_id": history_id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Chat history deletion failed",
|
||||
extra={"user_id": user_id, "history_id": history_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut)
|
||||
def update_chat_history(
|
||||
history_id: int, data: ChatHistoryUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update chat history."""
|
||||
user_id = auth.user.id
|
||||
history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
|
||||
|
||||
if not history:
|
||||
logger.warning("Chat history not found for update", extra={"user_id": user_id, "history_id": history_id})
|
||||
raise HTTPException(status_code=404, detail="Chat History not found")
|
||||
|
||||
if history.user_id != user_id:
|
||||
logger.warning(
|
||||
"Unauthorized update attempt",
|
||||
extra={"user_id": user_id, "history_id": history_id, "owner_id": history.user_id},
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="You are not allowed to update this chat history")
|
||||
|
||||
try:
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
history.update_fields(update_data)
|
||||
history.save(session)
|
||||
session.refresh(history)
|
||||
logger.info(
|
||||
"Chat history updated",
|
||||
extra={"user_id": user_id, "history_id": history_id, "fields_updated": list(update_data.keys())},
|
||||
)
|
||||
return history
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Chat history update failed",
|
||||
extra={"user_id": user_id, "history_id": history_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/project/{project_id}/name", name="update project name")
|
||||
def update_project_name(
|
||||
project_id: str, new_name: str, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update project name for all tasks in a project."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get all histories for this project
|
||||
stmt = select(ChatHistory).where(ChatHistory.project_id == project_id).where(ChatHistory.user_id == user_id)
|
||||
|
||||
histories = session.exec(stmt).all()
|
||||
|
||||
if not histories:
|
||||
logger.warning("No histories found for project", extra={"user_id": user_id, "project_id": project_id})
|
||||
raise HTTPException(status_code=404, detail="Project not found or access denied")
|
||||
|
||||
try:
|
||||
# Update all histories for this project
|
||||
for history in histories:
|
||||
history.project_name = new_name
|
||||
session.add(history)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
"Project name updated",
|
||||
extra={"user_id": user_id, "project_id": project_id, "new_name": new_name, "updated_count": len(histories)},
|
||||
)
|
||||
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Project name update failed",
|
||||
extra={"user_id": user_id, "project_id": project_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import BadTimeSignature, SignatureExpired
|
||||
from sqlmodel import Session, asc, select
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from app.model.chat.chat_share import (
|
||||
ChatHistoryShareOut,
|
||||
ChatShare,
|
||||
ChatShareIn,
|
||||
)
|
||||
from app.model.chat.chat_step import ChatStep
|
||||
|
||||
logger = logging.getLogger("server_chat_share")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Share"])
|
||||
|
||||
|
||||
@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut)
|
||||
def get_share_info(token: str, session: Session = Depends(session)):
|
||||
"""
|
||||
Get shared chat history info by token, excluding sensitive data.
|
||||
"""
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except SignatureExpired:
|
||||
logger.warning("Shared chat access failed: token expired", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
|
||||
except BadTimeSignature:
|
||||
logger.warning("Shared chat access failed: invalid token", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
|
||||
|
||||
stmt = select(ChatHistory).where(ChatHistory.task_id == task_id)
|
||||
history = session.exec(stmt).one_or_none()
|
||||
|
||||
if not history:
|
||||
logger.warning("Shared chat not found", extra={"task_id": task_id})
|
||||
raise HTTPException(status_code=404, detail="Chat history not found.")
|
||||
|
||||
logger.info("Shared chat info accessed", extra={"task_id": task_id})
|
||||
return history
|
||||
|
||||
|
||||
@router.get("/share/playback/{token}", name="Playback shared chat via SSE")
|
||||
async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0):
|
||||
"""
|
||||
Playbacks the chat history via a sharing token (SSE).
|
||||
delay_time: control sse interval, max 5 seconds
|
||||
"""
|
||||
if delay_time > 5:
|
||||
logger.debug("Delay time capped", extra={"requested": delay_time, "capped": 5})
|
||||
delay_time = 5
|
||||
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except SignatureExpired:
|
||||
logger.warning("Shared chat playback failed: token expired", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link has expired.")
|
||||
except BadTimeSignature:
|
||||
logger.warning("Shared chat playback failed: invalid token", extra={"token_prefix": token[:10]})
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid.")
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id))
|
||||
steps = session.exec(stmt).all()
|
||||
|
||||
if not steps:
|
||||
logger.warning("No steps found for playback", extra={"task_id": task_id})
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Shared chat playback started",
|
||||
extra={"task_id": task_id, "step_count": len(steps), "delay_time": delay_time},
|
||||
)
|
||||
|
||||
for idx, step in enumerate(steps, start=1):
|
||||
step_data = {
|
||||
"id": step.id,
|
||||
"task_id": step.task_id,
|
||||
"step": step.step,
|
||||
"data": step.data,
|
||||
"created_at": step.created_at.isoformat() if step.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
|
||||
if delay_time > 0 and step.step != "create_agent":
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
logger.info("Shared chat playback completed", extra={"task_id": task_id, "step_count": len(steps)})
|
||||
except Exception as e:
|
||||
logger.error("Shared chat playback error", extra={"task_id": task_id, "error": str(e)}, exc_info=True)
|
||||
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
|
||||
def create_share_link(data: ChatShareIn, auth: Auth = Depends(auth_must)):
|
||||
"""Generate sharing token with 1-day expiration for task."""
|
||||
user_id = auth.user.id
|
||||
try:
|
||||
share_token = ChatShare.generate_token(data.task_id)
|
||||
logger.info(
|
||||
"Share link created",
|
||||
extra={"user_id": user_id, "task_id": data.task_id, "token_prefix": share_token[:10]},
|
||||
)
|
||||
return {"share_token": share_token}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Share link creation failed",
|
||||
extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn
|
||||
|
||||
logger = logging.getLogger("server_chat_snapshot")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"])
|
||||
|
||||
|
||||
@router.get("/snapshots", name="list chat snapshots", response_model=list[ChatSnapshot])
|
||||
async def list_chat_snapshots(
|
||||
api_task_id: str | None = None,
|
||||
camel_task_id: str | None = None,
|
||||
browser_url: str | None = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""List chat snapshots with optional filtering."""
|
||||
user_id = auth.user.id
|
||||
query = select(ChatSnapshot).where(ChatSnapshot.user_id == user_id)
|
||||
if api_task_id is not None:
|
||||
query = query.where(ChatSnapshot.api_task_id == api_task_id)
|
||||
if camel_task_id is not None:
|
||||
query = query.where(ChatSnapshot.camel_task_id == camel_task_id)
|
||||
if browser_url is not None:
|
||||
query = query.where(ChatSnapshot.browser_url == browser_url)
|
||||
|
||||
snapshots = session.exec(query).all()
|
||||
logger.debug(
|
||||
"Snapshots listed",
|
||||
extra={"user_id": user_id, "api_task_id": api_task_id, "camel_task_id": camel_task_id, "count": len(snapshots)},
|
||||
)
|
||||
return snapshots
|
||||
|
||||
|
||||
@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot)
|
||||
async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get specific chat snapshot."""
|
||||
user_id = auth.user.id
|
||||
snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
|
||||
if not snapshot:
|
||||
logger.warning("Snapshot not found", extra={"user_id": user_id, "snapshot_id": snapshot_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
|
||||
if snapshot.user_id != user_id:
|
||||
logger.warning(
|
||||
"Unauthorized snapshot access",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": snapshot.user_id},
|
||||
)
|
||||
raise HTTPException(status_code=403, detail=_("You are not allowed to view this snapshot"))
|
||||
|
||||
logger.debug(
|
||||
"Snapshot retrieved",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "api_task_id": snapshot.api_task_id},
|
||||
)
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot)
|
||||
async def create_chat_snapshot(
|
||||
snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session)
|
||||
):
|
||||
"""Create new chat snapshot from image."""
|
||||
user_id = auth.user.id
|
||||
|
||||
try:
|
||||
image_path = ChatSnapshotIn.save_image(user_id, snapshot.api_task_id, snapshot.image_base64)
|
||||
chat_snapshot = ChatSnapshot(
|
||||
user_id=user_id,
|
||||
api_task_id=snapshot.api_task_id,
|
||||
camel_task_id=snapshot.camel_task_id,
|
||||
browser_url=snapshot.browser_url,
|
||||
image_path=image_path,
|
||||
)
|
||||
session.add(chat_snapshot)
|
||||
session.commit()
|
||||
session.refresh(chat_snapshot)
|
||||
logger.info(
|
||||
"Snapshot created",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"snapshot_id": chat_snapshot.id,
|
||||
"api_task_id": snapshot.api_task_id,
|
||||
"image_path": image_path,
|
||||
},
|
||||
)
|
||||
return chat_snapshot
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Snapshot creation failed",
|
||||
extra={"user_id": user_id, "api_task_id": snapshot.api_task_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot)
|
||||
async def update_chat_snapshot(
|
||||
snapshot_id: int,
|
||||
snapshot_update: ChatSnapshot,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""Update chat snapshot."""
|
||||
user_id = auth.user.id
|
||||
db_snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
|
||||
if not db_snapshot:
|
||||
logger.warning("Snapshot not found for update", extra={"user_id": user_id, "snapshot_id": snapshot_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
|
||||
if db_snapshot.user_id != user_id:
|
||||
logger.warning(
|
||||
"Unauthorized snapshot update",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id},
|
||||
)
|
||||
raise HTTPException(status_code=403, detail=_("You are not allowed to update this snapshot"))
|
||||
|
||||
try:
|
||||
update_data = snapshot_update.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_snapshot, key, value)
|
||||
session.add(db_snapshot)
|
||||
session.commit()
|
||||
session.refresh(db_snapshot)
|
||||
logger.info(
|
||||
"Snapshot updated",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "fields_updated": list(update_data.keys())},
|
||||
)
|
||||
return db_snapshot
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Snapshot update failed",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot")
|
||||
async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete chat snapshot."""
|
||||
user_id = auth.user.id
|
||||
db_snapshot = session.get(ChatSnapshot, snapshot_id)
|
||||
|
||||
if not db_snapshot:
|
||||
logger.warning("Snapshot not found for deletion", extra={"user_id": user_id, "snapshot_id": snapshot_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))
|
||||
|
||||
if db_snapshot.user_id != user_id:
|
||||
logger.warning(
|
||||
"Unauthorized snapshot deletion",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": db_snapshot.user_id},
|
||||
)
|
||||
raise HTTPException(status_code=403, detail=_("You are not allowed to delete this snapshot"))
|
||||
|
||||
try:
|
||||
session.delete(db_snapshot)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Snapshot deleted",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "image_path": db_snapshot.image_path},
|
||||
)
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Snapshot deletion failed",
|
||||
extra={"user_id": user_id, "snapshot_id": snapshot_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi_babel import _
|
||||
from sqlalchemy.sql.expression import case
|
||||
from sqlmodel import Session, asc, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.chat.chat_step import ChatStep, ChatStepIn, ChatStepOut
|
||||
|
||||
logger = logging.getLogger("server_chat_step")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat Step Management"])
|
||||
|
||||
|
||||
@router.get("/steps", name="list chat steps", response_model=list[ChatStepOut])
|
||||
async def list_chat_steps(
|
||||
task_id: str, step: str | None = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""List chat steps for a task with optional step type filtering."""
|
||||
user_id = auth.user.id
|
||||
query = select(ChatStep)
|
||||
if task_id is not None:
|
||||
query = query.where(ChatStep.task_id == task_id)
|
||||
if step is not None:
|
||||
query = query.where(ChatStep.step == step)
|
||||
|
||||
chat_steps = session.exec(query).all()
|
||||
logger.debug(
|
||||
"Chat steps listed", extra={"user_id": user_id, "task_id": task_id, "step_type": step, "count": len(chat_steps)}
|
||||
)
|
||||
return chat_steps
|
||||
|
||||
|
||||
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
|
||||
async def share_playback(
|
||||
task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Playback chat steps via SSE stream."""
|
||||
user_id = auth.user.id
|
||||
if delay_time > 5:
|
||||
logger.debug(
|
||||
"Delay time capped", extra={"user_id": user_id, "task_id": task_id, "requested": delay_time, "capped": 5}
|
||||
)
|
||||
delay_time = 5
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
stmt = (
|
||||
select(ChatStep)
|
||||
.where(ChatStep.task_id == task_id)
|
||||
.order_by(
|
||||
asc(case((ChatStep.timestamp.is_(None), 1), else_=0)), asc(ChatStep.timestamp), asc(ChatStep.id)
|
||||
)
|
||||
)
|
||||
steps = session.exec(stmt).all()
|
||||
|
||||
if not steps:
|
||||
logger.warning("No steps found for playback", extra={"user_id": user_id, "task_id": task_id})
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Chat step playback started",
|
||||
extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps), "delay_time": delay_time},
|
||||
)
|
||||
|
||||
for step in steps:
|
||||
step_data = {
|
||||
"id": step.id,
|
||||
"task_id": step.task_id,
|
||||
"step": step.step,
|
||||
"data": step.data,
|
||||
"created_at": step.created_at.isoformat() if step.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
if delay_time > 0:
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
logger.info(
|
||||
"Chat step playback completed", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Chat step playback error",
|
||||
extra={"user_id": user_id, "task_id": task_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
|
||||
async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get specific chat step."""
|
||||
user_id = auth.user.id
|
||||
chat_step = session.get(ChatStep, step_id)
|
||||
|
||||
if not chat_step:
|
||||
logger.warning("Chat step not found", extra={"user_id": user_id, "step_id": step_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
|
||||
logger.debug("Chat step retrieved", extra={"user_id": user_id, "step_id": step_id, "task_id": chat_step.task_id})
|
||||
return chat_step
|
||||
|
||||
|
||||
@router.post("/steps", name="create chat step")
|
||||
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
|
||||
"""Create new chat step. TODO: Implement request source validation."""
|
||||
try:
|
||||
chat_step = ChatStep(task_id=step.task_id, step=step.step, data=step.data, timestamp=step.timestamp)
|
||||
session.add(chat_step)
|
||||
session.commit()
|
||||
session.refresh(chat_step)
|
||||
logger.info(
|
||||
"Chat step created", extra={"step_id": chat_step.id, "task_id": step.task_id, "step_type": step.step}
|
||||
)
|
||||
return {"code": 200, "msg": "success"}
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Chat step creation failed",
|
||||
extra={"task_id": step.task_id, "step_type": step.step, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
|
||||
async def update_chat_step(
|
||||
step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update chat step."""
|
||||
user_id = auth.user.id
|
||||
db_chat_step = session.get(ChatStep, step_id)
|
||||
|
||||
if not db_chat_step:
|
||||
logger.warning("Chat step not found for update", extra={"user_id": user_id, "step_id": step_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
|
||||
try:
|
||||
update_data = chat_step_update.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_chat_step, key, value)
|
||||
session.add(db_chat_step)
|
||||
session.commit()
|
||||
session.refresh(db_chat_step)
|
||||
logger.info(
|
||||
"Chat step updated",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"step_id": step_id,
|
||||
"task_id": db_chat_step.task_id,
|
||||
"fields_updated": list(update_data.keys()),
|
||||
},
|
||||
)
|
||||
return db_chat_step
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Chat step update failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/steps/{step_id}", name="delete chat step")
|
||||
async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete chat step."""
|
||||
user_id = auth.user.id
|
||||
db_chat_step = session.get(ChatStep, step_id)
|
||||
|
||||
if not db_chat_step:
|
||||
logger.warning("Chat step not found for deletion", extra={"user_id": user_id, "step_id": step_id})
|
||||
raise HTTPException(status_code=404, detail=_("Chat step not found"))
|
||||
|
||||
try:
|
||||
session.delete(db_chat_step)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Chat step deleted", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id}
|
||||
)
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Chat step deletion failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.config.config import (
|
||||
Config,
|
||||
ConfigCreate,
|
||||
ConfigInfo,
|
||||
ConfigOut,
|
||||
ConfigUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("server_config_controller")
|
||||
|
||||
router = APIRouter(tags=["Config Management"])
|
||||
|
||||
|
||||
@router.get("/configs", name="list configs", response_model=list[ConfigOut])
|
||||
async def list_configs(
|
||||
config_group: str | None = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""List user's configurations with optional group filtering."""
|
||||
user_id = auth.user.id
|
||||
query = select(Config).where(Config.user_id == user_id)
|
||||
|
||||
if config_group is not None:
|
||||
query = query.where(Config.config_group == config_group)
|
||||
|
||||
configs = session.exec(query).all()
|
||||
logger.debug("Configs listed", extra={"user_id": user_id, "config_group": config_group, "count": len(configs)})
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut)
|
||||
async def get_config(
|
||||
config_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
query = select(Config).where(Config.user_id == auth.user.id)
|
||||
|
||||
if config_id is not None:
|
||||
query = query.where(Config.id == config_id)
|
||||
|
||||
config = session.exec(query).first()
|
||||
|
||||
if not config:
|
||||
logger.warning("Config not found")
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
logger.debug("Config retrieved")
|
||||
return config
|
||||
|
||||
|
||||
@router.post("/configs", name="create config", response_model=ConfigOut)
|
||||
async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Create new configuration."""
|
||||
user_id = auth.user.id
|
||||
|
||||
if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name):
|
||||
logger.warning(
|
||||
"Config validation failed",
|
||||
extra={"user_id": user_id, "config_group": config.config_group, "config_name": config.config_name},
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=_("Invalid config name or group"))
|
||||
|
||||
# Check if configuration already exists
|
||||
existing_config = session.exec(
|
||||
select(Config).where(Config.user_id == user_id, Config.config_name == config.config_name)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(
|
||||
"Config creation failed: already exists", extra={"user_id": user_id, "config_name": config.config_name}
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
|
||||
|
||||
try:
|
||||
db_config = Config(
|
||||
user_id=user_id,
|
||||
config_name=config.config_name,
|
||||
config_value=config.config_value,
|
||||
config_group=config.config_group,
|
||||
)
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
session.refresh(db_config)
|
||||
logger.info(
|
||||
"Config created",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"config_id": db_config.id,
|
||||
"config_group": config.config_group,
|
||||
"config_name": config.config_name,
|
||||
},
|
||||
)
|
||||
return db_config
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Config creation failed",
|
||||
extra={"user_id": user_id, "config_name": config.config_name, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut)
|
||||
async def update_config(
|
||||
config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update configuration."""
|
||||
user_id = auth.user.id
|
||||
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first()
|
||||
|
||||
if not db_config:
|
||||
logger.warning("Config not found for update", extra={"user_id": user_id, "config_id": config_id})
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
# Check if configuration group is valid
|
||||
if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name):
|
||||
logger.warning(
|
||||
"Config update validation failed",
|
||||
extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group},
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=_("Invalid configuration group"))
|
||||
|
||||
# Check for conflicts with other configurations
|
||||
existing_config = session.exec(
|
||||
select(Config).where(
|
||||
Config.user_id == user_id,
|
||||
Config.config_name == config_update.config_name,
|
||||
Config.id != config_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(
|
||||
"Config update failed: duplicate name",
|
||||
extra={"user_id": user_id, "config_id": config_id, "config_name": config_update.config_name},
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=_("Configuration already exists for this user"))
|
||||
|
||||
try:
|
||||
db_config.config_name = config_update.config_name
|
||||
db_config.config_value = config_update.config_value
|
||||
db_config.config_group = config_update.config_group
|
||||
session.add(db_config)
|
||||
session.commit()
|
||||
session.refresh(db_config)
|
||||
logger.info(
|
||||
"Config updated",
|
||||
extra={"user_id": user_id, "config_id": config_id, "config_group": config_update.config_group},
|
||||
)
|
||||
return db_config
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Config update failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/configs/{config_id}", name="delete config")
|
||||
async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete configuration."""
|
||||
user_id = auth.user.id
|
||||
db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == user_id)).first()
|
||||
|
||||
if not db_config:
|
||||
logger.warning("Config not found for deletion", extra={"user_id": user_id, "config_id": config_id})
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
|
||||
try:
|
||||
session.delete(db_config)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Config deleted", extra={"user_id": user_id, "config_id": config_id, "config_name": db_config.config_name}
|
||||
)
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Config deletion failed", extra={"user_id": user_id, "config_id": config_id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/config/info", name="get config info")
|
||||
async def get_config_info(
|
||||
show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"),
|
||||
):
|
||||
"""Get available configuration templates and info."""
|
||||
configs = ConfigInfo.getinfo()
|
||||
if show_all:
|
||||
logger.debug("Config info retrieved", extra={"show_all": True, "count": len(configs)})
|
||||
return configs
|
||||
|
||||
filtered = {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0}
|
||||
logger.debug(
|
||||
"Config info retrieved", extra={"show_all": False, "total_count": len(configs), "filtered_count": len(filtered)}
|
||||
)
|
||||
return filtered
|
||||
|
|
@ -1,294 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlalchemy.orm import selectinload, with_loader_criteria
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.environment import env
|
||||
from app.model.mcp.mcp import Mcp, McpOut, McpType
|
||||
from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus
|
||||
from app.model.mcp.mcp_user import McpImportType, McpUser, Status
|
||||
|
||||
logger = logging.getLogger("server_mcp_controller")
|
||||
|
||||
from app.component.validator.McpServer import (
|
||||
McpRemoteServer,
|
||||
McpServerItem,
|
||||
validate_mcp_remote_servers,
|
||||
validate_mcp_servers,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Mcp Servers"])
|
||||
|
||||
|
||||
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
|
||||
"""
|
||||
Pre-instantiate MCP toolkit to complete authentication process
|
||||
|
||||
Args:
|
||||
config_dict: MCP server configuration dictionary
|
||||
|
||||
Returns:
|
||||
bool: Whether successfully instantiated and connected
|
||||
"""
|
||||
try:
|
||||
# Ensure unified auth directory for all mcp servers
|
||||
for server_config in config_dict.get("mcpServers", {}).values():
|
||||
if "env" not in server_config:
|
||||
server_config["env"] = {}
|
||||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
# Create MCP toolkit and attempt to connect
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
# Get tools list to ensure connection is successful
|
||||
tools = mcp_toolkit.get_tools()
|
||||
logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)})
|
||||
|
||||
# Disconnect, authentication info is already saved
|
||||
await mcp_toolkit.disconnect()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/mcps", name="mcp list")
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
category_id: int | None = None,
|
||||
mine: int | None = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
) -> Page[McpOut]:
|
||||
"""List MCP servers with optional filtering."""
|
||||
user_id = auth.user.id
|
||||
stmt = (
|
||||
select(Mcp)
|
||||
.where(Mcp.no_delete())
|
||||
.options(
|
||||
selectinload(Mcp.category),
|
||||
selectinload(Mcp.envs),
|
||||
with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use),
|
||||
)
|
||||
)
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%"))
|
||||
if category_id:
|
||||
stmt = stmt.where(Mcp.category_id == category_id)
|
||||
if mine and auth:
|
||||
stmt = (
|
||||
stmt.join(McpUser)
|
||||
.where(McpUser.user_id == user_id)
|
||||
.options(
|
||||
selectinload(Mcp.mcp_user),
|
||||
with_loader_criteria(McpUser, col(McpUser.user_id) == user_id),
|
||||
)
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, "total") else 0
|
||||
logger.debug(
|
||||
"MCP list retrieved",
|
||||
extra={"user_id": user_id, "keyword": keyword, "category_id": category_id, "mine": mine, "total": total},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/mcp", name="mcp detail", response_model=McpOut)
|
||||
async def get(id: int, session: Session = Depends(session)):
|
||||
"""Get MCP server details."""
|
||||
try:
|
||||
stmt = (
|
||||
select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs))
|
||||
)
|
||||
model = session.exec(stmt).one()
|
||||
logger.debug("MCP detail retrieved", extra={"mcp_id": id, "mcp_key": model.key})
|
||||
return model
|
||||
except Exception:
|
||||
logger.warning("MCP not found", extra={"mcp_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp not found"))
|
||||
|
||||
|
||||
@router.post("/mcp/install", name="mcp install")
|
||||
async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Install MCP server for user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
mcp = session.get_one(Mcp, mcp_id)
|
||||
if not mcp:
|
||||
logger.warning("MCP install failed: MCP not found", extra={"user_id": user_id, "mcp_id": mcp_id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp not found"))
|
||||
|
||||
exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == user_id)).first()
|
||||
if exists:
|
||||
logger.warning(
|
||||
"MCP install failed: already installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key}
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
|
||||
install_command: dict = mcp.install_command
|
||||
|
||||
# Pre-instantiate MCP toolkit for authentication
|
||||
config_dict = {"mcpServers": {mcp.key: install_command}}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning(
|
||||
"MCP pre-instantiation failed, continuing with installation",
|
||||
extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key},
|
||||
)
|
||||
else:
|
||||
logger.debug("MCP toolkit pre-instantiated", extra={"mcp_key": mcp.key})
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"MCP pre-instantiation exception",
|
||||
extra={"user_id": user_id, "mcp_key": mcp.key, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
mcp_user = McpUser(
|
||||
mcp_id=mcp.id,
|
||||
user_id=user_id,
|
||||
mcp_name=mcp.name,
|
||||
mcp_key=mcp.key,
|
||||
mcp_desc=mcp.description,
|
||||
type=mcp.type,
|
||||
status=Status.enable,
|
||||
command=install_command["command"],
|
||||
args=install_command["args"],
|
||||
env=install_command["env"],
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
logger.info("MCP installed", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key})
|
||||
return mcp_user
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"MCP installation failed",
|
||||
extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/mcp/import/{mcp_type}", name="mcp import")
|
||||
async def import_mcp(
|
||||
mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Import MCP servers (local or remote)."""
|
||||
user_id = auth.user.id
|
||||
|
||||
if mcp_type == McpImportType.Local:
|
||||
logger.info("Importing local MCP servers", extra={"user_id": user_id})
|
||||
is_valid, res = validate_mcp_servers(mcp_data)
|
||||
if not is_valid:
|
||||
logger.warning("Local MCP import validation failed", extra={"user_id": user_id, "error": res})
|
||||
raise HTTPException(status_code=400, detail=res)
|
||||
|
||||
mcp_data: dict[str, McpServerItem] = res.mcpServers
|
||||
imported_count = 0
|
||||
|
||||
for name, data in mcp_data.items():
|
||||
config_dict = {"mcpServers": {name: {"command": data.command, "args": data.args, "env": data.env or {}}}}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Local MCP pre-instantiation failed, continuing", extra={"user_id": user_id, "mcp_name": name}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Local MCP pre-instantiation exception",
|
||||
extra={"user_id": user_id, "mcp_name": name, "error": str(e)},
|
||||
)
|
||||
|
||||
try:
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=user_id,
|
||||
mcp_name=name,
|
||||
mcp_key=name,
|
||||
mcp_desc=name,
|
||||
type=McpType.Local,
|
||||
status=Status.enable,
|
||||
command=data.command,
|
||||
args=data.args,
|
||||
env=data.env,
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
imported_count += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to import local MCP",
|
||||
extra={"user_id": user_id, "mcp_name": name, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info("Local MCPs imported", extra={"user_id": user_id, "count": imported_count})
|
||||
return {"message": "Local MCP servers imported successfully", "count": imported_count}
|
||||
|
||||
elif mcp_type == McpImportType.Remote:
|
||||
logger.info("Importing remote MCP server", extra={"user_id": user_id})
|
||||
is_valid, res = validate_mcp_remote_servers(mcp_data)
|
||||
if not is_valid:
|
||||
logger.warning("Remote MCP import validation failed", extra={"user_id": user_id, "error": res})
|
||||
raise HTTPException(status_code=400, detail=res)
|
||||
|
||||
data: McpRemoteServer = res
|
||||
|
||||
try:
|
||||
# For remote servers, we don't need to pre-instantiate as they typically don't require authentication
|
||||
# but we can still try to validate the connection if needed
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=user_id,
|
||||
type=McpType.Remote,
|
||||
status=Status.enable,
|
||||
mcp_name=data.server_name,
|
||||
server_url=data.server_url,
|
||||
)
|
||||
mcp_user.save()
|
||||
logger.info(
|
||||
"Remote MCP imported",
|
||||
extra={"user_id": user_id, "server_name": data.server_name, "server_url": data.server_url},
|
||||
)
|
||||
return mcp_user
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Remote MCP import failed",
|
||||
extra={"user_id": user_id, "server_name": data.server_name, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,210 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from camel.toolkits.mcp_toolkit import MCPToolkit
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.environment import env
|
||||
from app.model.mcp.mcp import Mcp
|
||||
from app.model.mcp.mcp_user import (
|
||||
McpUser,
|
||||
McpUserIn,
|
||||
McpUserOut,
|
||||
McpUserUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("server_mcp_user_controller")
|
||||
|
||||
router = APIRouter(tags=["McpUser Management"])
|
||||
|
||||
|
||||
async def pre_instantiate_mcp_toolkit(config_dict: dict) -> bool:
|
||||
"""
|
||||
Pre-instantiate MCP toolkit to complete authentication process
|
||||
|
||||
Args:
|
||||
config_dict: MCP server configuration dictionary
|
||||
|
||||
Returns:
|
||||
bool: Whether successfully instantiated and connected
|
||||
"""
|
||||
try:
|
||||
# Ensure unified auth directory for all mcp servers
|
||||
for server_config in config_dict.get("mcpServers", {}).values():
|
||||
if "env" not in server_config:
|
||||
server_config["env"] = {}
|
||||
# Set global auth directory to persist authentication across tasks
|
||||
if "MCP_REMOTE_CONFIG_DIR" not in server_config["env"]:
|
||||
server_config["env"]["MCP_REMOTE_CONFIG_DIR"] = env(
|
||||
"MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth")
|
||||
)
|
||||
|
||||
# Create MCP toolkit and attempt to connect
|
||||
mcp_toolkit = MCPToolkit(config_dict=config_dict, timeout=30)
|
||||
await mcp_toolkit.connect()
|
||||
|
||||
# Get tools list to ensure connection is successful
|
||||
tools = mcp_toolkit.get_tools()
|
||||
logger.info("MCP toolkit pre-instantiated", extra={"tools_count": len(tools)})
|
||||
|
||||
# Disconnect, authentication info is already saved
|
||||
await mcp_toolkit.disconnect()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP toolkit pre-instantiation failed", extra={"error": str(e)}, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/mcp/users", name="list mcp users", response_model=list[McpUserOut])
|
||||
async def list_mcp_users(
|
||||
mcp_id: int | None = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""List MCP users for current user."""
|
||||
user_id = auth.user.id
|
||||
query = select(McpUser)
|
||||
if mcp_id is not None:
|
||||
query = query.where(McpUser.mcp_id == mcp_id)
|
||||
if user_id is not None:
|
||||
query = query.where(McpUser.user_id == user_id)
|
||||
mcp_users = session.exec(query).all()
|
||||
logger.debug("MCP users listed", extra={"user_id": user_id, "mcp_id": mcp_id, "count": len(mcp_users)})
|
||||
return mcp_users
|
||||
|
||||
|
||||
@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut)
|
||||
async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get MCP user details."""
|
||||
query = select(McpUser).where(McpUser.id == mcp_user_id)
|
||||
mcp_user = session.exec(query).first()
|
||||
if not mcp_user:
|
||||
logger.warning("MCP user not found", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id})
|
||||
raise HTTPException(status_code=404, detail=_("McpUser not found"))
|
||||
logger.debug(
|
||||
"MCP user retrieved", extra={"user_id": auth.user.id, "mcp_user_id": mcp_user_id, "mcp_id": mcp_user.mcp_id}
|
||||
)
|
||||
return mcp_user
|
||||
|
||||
|
||||
@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut)
|
||||
async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Create MCP user installation."""
|
||||
user_id = auth.user.id
|
||||
mcp_id = mcp_user.mcp_id
|
||||
|
||||
exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp_id, McpUser.user_id == user_id)).first()
|
||||
if exists:
|
||||
logger.warning("MCP already installed", extra={"user_id": user_id, "mcp_id": mcp_id})
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
|
||||
# Get MCP configuration from the main Mcp table
|
||||
mcp = session.get(Mcp, mcp_id)
|
||||
if mcp and mcp.install_command:
|
||||
config_dict = {"mcpServers": {mcp.key: mcp.install_command}}
|
||||
|
||||
try:
|
||||
success = await pre_instantiate_mcp_toolkit(config_dict)
|
||||
if not success:
|
||||
logger.warning(
|
||||
"MCP pre-instantiation failed, continuing",
|
||||
extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_key": mcp.key},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"MCP pre-instantiation exception",
|
||||
extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
db_mcp_user = McpUser(mcp_id=mcp_id, user_id=user_id, env=mcp_user.env)
|
||||
session.add(db_mcp_user)
|
||||
session.commit()
|
||||
session.refresh(db_mcp_user)
|
||||
logger.info("MCP user created", extra={"user_id": user_id, "mcp_id": mcp_id, "mcp_user_id": db_mcp_user.id})
|
||||
return db_mcp_user
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"MCP user creation failed", extra={"user_id": user_id, "mcp_id": mcp_id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/mcp/users/{id}", name="update mcp user")
|
||||
async def update_mcp_user(
|
||||
id: int,
|
||||
update_item: McpUserUpdate,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
):
|
||||
"""Update MCP user settings."""
|
||||
user_id = auth.user.id
|
||||
model = session.get(McpUser, id)
|
||||
if not model:
|
||||
logger.warning("MCP user not found for update", extra={"user_id": user_id, "mcp_user_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
|
||||
if model.user_id != user_id:
|
||||
logger.warning(
|
||||
"Unauthorized MCP user update", extra={"user_id": user_id, "mcp_user_id": id, "owner_id": model.user_id}
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=_("current user have no permission to modify"))
|
||||
|
||||
try:
|
||||
update_data = update_item.model_dump(exclude_unset=True)
|
||||
model.update_fields(update_data)
|
||||
model.save(session)
|
||||
session.refresh(model)
|
||||
logger.info("MCP user updated", extra={"user_id": user_id, "mcp_user_id": id})
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"MCP user update failed", extra={"user_id": user_id, "mcp_user_id": id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user")
|
||||
async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete MCP user installation."""
|
||||
user_id = auth.user.id
|
||||
db_mcp_user = session.get(McpUser, mcp_user_id)
|
||||
if not db_mcp_user:
|
||||
logger.warning("MCP user not found for deletion", extra={"user_id": user_id, "mcp_user_id": mcp_user_id})
|
||||
raise HTTPException(status_code=404, detail=_("Mcp Info not found"))
|
||||
|
||||
try:
|
||||
session.delete(db_mcp_user)
|
||||
session.commit()
|
||||
logger.info(
|
||||
"MCP user deleted", extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "mcp_id": db_mcp_user.mcp_id}
|
||||
)
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"MCP user deletion failed",
|
||||
extra={"user_id": user_id, "mcp_user_id": mcp_user_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter
|
||||
|
||||
# Allowed OAuth provider names (alphanumeric and hyphens only)
|
||||
ALLOWED_PROVIDER_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
logger = logging.getLogger("server_oauth_controller")
|
||||
|
||||
router = APIRouter(prefix="/oauth", tags=["Oauth Servers"])
|
||||
|
||||
|
||||
@router.get("/{app}/login", name="OAuth Login Redirect")
|
||||
def oauth_login(app: str, request: Request, state: str | None = None):
|
||||
"""Redirect user to OAuth provider's authorization endpoint."""
|
||||
try:
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://") :]
|
||||
|
||||
adapter = get_oauth_adapter(app, callback_url)
|
||||
url = adapter.get_authorize_url(state)
|
||||
|
||||
if not url:
|
||||
logger.error("Failed to generate authorization URL", extra={"provider": app, "callback_url": callback_url})
|
||||
raise HTTPException(status_code=400, detail="Failed to generate authorization URL")
|
||||
|
||||
logger.info("OAuth login initiated", extra={"provider": app})
|
||||
return RedirectResponse(str(url))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OAuth login failed", extra={"provider": app, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="OAuth login failed")
|
||||
|
||||
|
||||
@router.get("/{app}/callback", name="OAuth Callback")
|
||||
def oauth_callback(
|
||||
app: str,
|
||||
request: Request,
|
||||
code: str | None = None,
|
||||
state: str | None = None,
|
||||
):
|
||||
"""Handle OAuth provider callback and redirect to client app."""
|
||||
# Security: Validate provider name to prevent injection
|
||||
if not ALLOWED_PROVIDER_PATTERN.match(app):
|
||||
logger.warning(
|
||||
"OAuth callback invalid provider name",
|
||||
extra={"provider": app[:50]}, # Truncate for logging
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid provider")
|
||||
|
||||
if not code:
|
||||
logger.warning("OAuth callback missing code", extra={"provider": app})
|
||||
raise HTTPException(status_code=400, detail="Missing code parameter")
|
||||
|
||||
logger.info(
|
||||
"OAuth callback received",
|
||||
extra={"provider": app, "has_state": state is not None},
|
||||
)
|
||||
|
||||
# Security: URL-encode all parameters to prevent XSS
|
||||
params = {"provider": app, "code": code}
|
||||
if state is not None:
|
||||
params["state"] = state
|
||||
redirect_url = f"eigent://callback/oauth?{urlencode(params)}"
|
||||
|
||||
# Security: Use a safe redirect approach without embedding in JavaScript
|
||||
# The redirect URL uses a custom protocol, so we encode it safely
|
||||
safe_redirect_url = quote(redirect_url, safe=":/&?=")
|
||||
|
||||
html_content = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>OAuth Callback</title>
|
||||
<meta http-equiv="refresh" content="0;url={safe_redirect_url}">
|
||||
</head>
|
||||
<body>
|
||||
<p>Redirecting, please wait...</p>
|
||||
<p>If you are not redirected, <a href="{safe_redirect_url}">click here</a>.</p>
|
||||
<button onclick="window.close()">Close this window</button>
|
||||
</body>
|
||||
</html>"""
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
@router.post("/{app}/token", name="OAuth Fetch Token")
|
||||
def fetch_token(app: str, request: Request, data: OauthCallbackPayload):
|
||||
"""Exchange authorization code for access token."""
|
||||
try:
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://") :]
|
||||
|
||||
adapter = get_oauth_adapter(app, callback_url)
|
||||
token_data = adapter.fetch_token(data.code)
|
||||
logger.info("OAuth token fetched", extra={"provider": app})
|
||||
return JSONResponse(token_data)
|
||||
except Exception as e:
|
||||
logger.error("OAuth token fetch failed", extra={"provider": app, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,165 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.provider.provider import (
|
||||
Provider,
|
||||
ProviderIn,
|
||||
ProviderOut,
|
||||
ProviderPreferIn,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("server_provider_controller")
|
||||
|
||||
router = APIRouter(tags=["Provider Management"])
|
||||
|
||||
|
||||
@router.get("/providers", name="list providers", response_model=Page[ProviderOut])
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
prefer: bool | None = Query(None, description="Filter by prefer status"),
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must),
|
||||
) -> Page[ProviderOut]:
|
||||
"""List user's providers with optional filtering."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete())
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%"))
|
||||
if prefer is not None:
|
||||
stmt = stmt.where(Provider.prefer == prefer)
|
||||
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc())
|
||||
logger.debug("Providers listed", extra={"user_id": user_id, "keyword": keyword, "prefer_filter": prefer})
|
||||
return paginate(session, stmt)
|
||||
|
||||
|
||||
@router.get("/provider", name="get provider detail", response_model=ProviderOut)
|
||||
async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get provider details."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
if not model:
|
||||
logger.warning("Provider not found", extra={"user_id": user_id, "provider_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
logger.debug("Provider retrieved", extra={"user_id": user_id, "provider_id": id})
|
||||
return model
|
||||
|
||||
|
||||
@router.post("/provider", name="create provider", response_model=ProviderOut)
|
||||
async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Create a new provider."""
|
||||
user_id = auth.user.id
|
||||
try:
|
||||
model = Provider(**data.model_dump(), user_id=user_id)
|
||||
model.save(session)
|
||||
logger.info(
|
||||
"Provider created", extra={"user_id": user_id, "provider_id": model.id, "provider_name": data.provider_name}
|
||||
)
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error("Provider creation failed", extra={"user_id": user_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.put("/provider/{id}", name="update provider", response_model=ProviderOut)
|
||||
async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update provider details."""
|
||||
user_id = auth.user.id
|
||||
model = session.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
).one_or_none()
|
||||
if not model:
|
||||
logger.warning("Provider not found for update", extra={"user_id": user_id, "provider_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
|
||||
try:
|
||||
model.model_type = data.model_type
|
||||
model.provider_name = data.provider_name
|
||||
model.api_key = data.api_key
|
||||
model.endpoint_url = data.endpoint_url
|
||||
model.encrypted_config = data.encrypted_config
|
||||
model.is_vaild = data.is_vaild
|
||||
model.save(session)
|
||||
session.refresh(model)
|
||||
logger.info(
|
||||
"Provider updated", extra={"user_id": user_id, "provider_id": id, "provider_name": data.provider_name}
|
||||
)
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Provider update failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/provider/{id}", name="delete provider")
|
||||
async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Delete a provider."""
|
||||
user_id = auth.user.id
|
||||
model = session.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id)
|
||||
).one_or_none()
|
||||
if not model:
|
||||
logger.warning("Provider not found for deletion", extra={"user_id": user_id, "provider_id": id})
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
|
||||
try:
|
||||
model.delete(session)
|
||||
logger.info("Provider deleted", extra={"user_id": user_id, "provider_id": id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Provider deletion failed", extra={"user_id": user_id, "provider_id": id, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/provider/prefer", name="set provider prefer")
|
||||
async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Set preferred provider for user."""
|
||||
user_id = auth.user.id
|
||||
provider_id = data.provider_id
|
||||
|
||||
try:
|
||||
# 1. Set all current user's providers prefer to false
|
||||
session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False))
|
||||
# 2. Set the prefer of the specified provider_id to true
|
||||
session.exec(
|
||||
update(Provider)
|
||||
.where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
.values(prefer=True)
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Preferred provider set", extra={"user_id": user_id, "provider_id": provider_id})
|
||||
return {"success": True}
|
||||
except SQLAlchemyError as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
"Failed to set preferred provider",
|
||||
extra={"user_id": user_id, "provider_id": provider_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,135 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select, and_
|
||||
from typing import Optional, List
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.model.config.config import Config
|
||||
from app.type.config_group import ConfigGroup
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
|
||||
logger = logging.getLogger("server_slack_controller")
|
||||
|
||||
|
||||
class SlackChannelOut(BaseModel):
|
||||
"""Output model for Slack channels."""
|
||||
id: str
|
||||
name: str
|
||||
is_private: bool = False
|
||||
is_member: bool = False
|
||||
num_members: Optional[int] = None
|
||||
|
||||
|
||||
class SlackChannelsResponse(BaseModel):
|
||||
"""Response model for Slack channels list."""
|
||||
channels: List[SlackChannelOut]
|
||||
has_credentials: bool
|
||||
|
||||
|
||||
router = APIRouter(prefix="/trigger/slack", tags=["Slack Integration"])
|
||||
|
||||
|
||||
@router.get("/channels", name="get slack channels")
|
||||
def get_slack_channels(
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
) -> SlackChannelsResponse:
|
||||
"""
|
||||
Get list of Slack channels for the authenticated user.
|
||||
|
||||
This endpoint fetches channels from the user's Slack workspace using their
|
||||
stored credentials. Requires SLACK_BOT_TOKEN to be configured in user configs.
|
||||
"""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get Slack credentials from config
|
||||
configs = session.exec(
|
||||
select(Config).where(
|
||||
and_(
|
||||
Config.user_id == int(user_id),
|
||||
Config.config_group == ConfigGroup.SLACK.value
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
credentials = {config.config_name: config.config_value for config in configs}
|
||||
bot_token = credentials.get("SLACK_BOT_TOKEN")
|
||||
|
||||
if not bot_token:
|
||||
logger.warning("Slack credentials not found", extra={"user_id": user_id})
|
||||
return SlackChannelsResponse(channels=[], has_credentials=False)
|
||||
|
||||
try:
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
client = WebClient(token=bot_token)
|
||||
|
||||
# Fetch all channels (public and private the bot has access to)
|
||||
channels = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
cursor=cursor,
|
||||
limit=200
|
||||
)
|
||||
|
||||
for channel in response.get("channels", []):
|
||||
channels.append(SlackChannelOut(
|
||||
id=channel.get("id"),
|
||||
name=channel.get("name"),
|
||||
is_private=channel.get("is_private", False),
|
||||
is_member=channel.get("is_member", False),
|
||||
num_members=channel.get("num_members")
|
||||
))
|
||||
|
||||
# Check for pagination
|
||||
cursor = response.get("response_metadata", {}).get("next_cursor")
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
logger.info("Slack channels fetched", extra={
|
||||
"user_id": user_id,
|
||||
"channel_count": len(channels)
|
||||
})
|
||||
|
||||
return SlackChannelsResponse(channels=channels, has_credentials=True)
|
||||
|
||||
except ImportError:
|
||||
logger.error("slack_sdk not installed")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack SDK not installed on server"
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.error("Slack API error", extra={
|
||||
"user_id": user_id,
|
||||
"error": str(e)
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack API error: {e.response.get('error', 'Unknown error')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error fetching Slack channels", extra={
|
||||
"user_id": user_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to fetch Slack channels")
|
||||
|
|
@ -1,731 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, select, desc, and_, delete
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
import logging
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.model.trigger.trigger import Trigger, TriggerIn, TriggerOut, TriggerUpdate, TriggerConfigSchemaOut
|
||||
from app.model.trigger.trigger_execution import TriggerExecution, TriggerExecutionOut
|
||||
from app.model.trigger.app_configs import (
|
||||
get_config_schema,
|
||||
validate_config,
|
||||
has_config,
|
||||
validate_activation,
|
||||
ActivationError,
|
||||
)
|
||||
from app.model.trigger.app_configs.config_registry import requires_authentication
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from app.type.trigger_types import TriggerType, TriggerStatus
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.redis_utils import get_redis_manager
|
||||
from app.service.trigger.trigger_schedule_service import TriggerScheduleService
|
||||
from fastapi_babel import _
|
||||
from sqlalchemy import func
|
||||
|
||||
logger = logging.getLogger("server_trigger_controller")
|
||||
|
||||
|
||||
ACTIVE_STATUSES = (TriggerStatus.active, TriggerStatus.pending_verification)
|
||||
MAX_ACTIVE_PER_USER = 25
|
||||
MAX_ACTIVE_PER_PROJECT = 5
|
||||
|
||||
|
||||
def get_active_trigger_counts(session: Session, user_id: str, project_id: str | None = None) -> tuple[int, int]:
|
||||
"""Return (user_active_count, project_active_count) for active/pending triggers."""
|
||||
user_count = session.exec(
|
||||
select(func.count(Trigger.id)).where(
|
||||
and_(
|
||||
Trigger.user_id == user_id,
|
||||
Trigger.status.in_(ACTIVE_STATUSES), # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
project_count = 0
|
||||
if project_id:
|
||||
project_count = session.exec(
|
||||
select(func.count(Trigger.id)).where(
|
||||
and_(
|
||||
Trigger.user_id == user_id,
|
||||
Trigger.project_id == project_id,
|
||||
Trigger.status.in_(ACTIVE_STATUSES), # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
return user_count, project_count
|
||||
|
||||
|
||||
def get_execution_counts(session: Session, trigger_ids: list[int]) -> dict[int, int]:
|
||||
"""Get execution counts for multiple triggers in a single query."""
|
||||
if not trigger_ids:
|
||||
return {}
|
||||
|
||||
result = session.exec(
|
||||
select(TriggerExecution.trigger_id, func.count(TriggerExecution.id))
|
||||
.where(TriggerExecution.trigger_id.in_(trigger_ids))
|
||||
.group_by(TriggerExecution.trigger_id)
|
||||
).all()
|
||||
|
||||
return {trigger_id: count for trigger_id, count in result}
|
||||
|
||||
|
||||
def trigger_to_out(trigger: Trigger, execution_count: int = 0) -> TriggerOut:
|
||||
"""Convert Trigger model to TriggerOut with execution count."""
|
||||
return TriggerOut(
|
||||
id=trigger.id,
|
||||
user_id=trigger.user_id,
|
||||
project_id=trigger.project_id,
|
||||
name=trigger.name,
|
||||
description=trigger.description,
|
||||
trigger_type=trigger.trigger_type,
|
||||
status=trigger.status,
|
||||
execution_count=execution_count,
|
||||
webhook_url=trigger.webhook_url,
|
||||
webhook_method=trigger.webhook_method,
|
||||
custom_cron_expression=trigger.custom_cron_expression,
|
||||
listener_type=trigger.listener_type,
|
||||
agent_model=trigger.agent_model,
|
||||
task_prompt=trigger.task_prompt,
|
||||
config=trigger.config,
|
||||
max_executions_per_hour=trigger.max_executions_per_hour,
|
||||
max_executions_per_day=trigger.max_executions_per_day,
|
||||
is_single_execution=trigger.is_single_execution,
|
||||
last_executed_at=trigger.last_executed_at,
|
||||
next_run_at=trigger.next_run_at,
|
||||
last_execution_status=trigger.last_execution_status,
|
||||
created_at=trigger.created_at,
|
||||
updated_at=trigger.updated_at,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/trigger", tags=["Triggers"])
|
||||
|
||||
@router.post("/", name="create trigger", response_model=TriggerOut)
|
||||
def create_trigger(
|
||||
data: TriggerIn,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Create a new trigger."""
|
||||
user_id = auth.user.id
|
||||
|
||||
try:
|
||||
# Check if project_id exists in chat_history, if not create one
|
||||
if data.project_id:
|
||||
existing_chat = session.exec(
|
||||
select(ChatHistory).where(ChatHistory.project_id == data.project_id)
|
||||
).first()
|
||||
|
||||
if not existing_chat:
|
||||
# Create a new chat_history for this project
|
||||
chat_history = ChatHistory(
|
||||
user_id=user_id,
|
||||
task_id=data.project_id, # Using project_id as task_id
|
||||
project_id=data.project_id,
|
||||
question=f"Project created via trigger: {data.name}",
|
||||
language="en",
|
||||
model_platform=data.agent_model or "none",
|
||||
model_type=data.agent_model or "none",
|
||||
installed_mcp="none", #Expects String
|
||||
api_key="",
|
||||
api_url="",
|
||||
max_retries=3,
|
||||
project_name=data.name,
|
||||
summary=data.description or "",
|
||||
tokens=0,
|
||||
spend=0,
|
||||
status=2 # completed status
|
||||
)
|
||||
session.add(chat_history)
|
||||
session.commit()
|
||||
session.refresh(chat_history)
|
||||
|
||||
logger.info("Chat history created for new project", extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id,
|
||||
"chat_history_id": chat_history.id
|
||||
})
|
||||
|
||||
# Send WebSocket notification about new project
|
||||
try:
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "project_created",
|
||||
"project_id": data.project_id,
|
||||
"project_name": data.name,
|
||||
"chat_history_id": chat_history.id,
|
||||
"trigger_name": data.name,
|
||||
"user_id": str(user_id),
|
||||
"created_at": chat_history.created_at.isoformat() if chat_history.created_at else None
|
||||
})
|
||||
logger.debug("WebSocket notification sent for new project", extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send WebSocket notification for new project", extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Generate webhook URL for webhook-based triggers
|
||||
webhook_url = None
|
||||
if data.trigger_type in (TriggerType.webhook, TriggerType.slack_trigger):
|
||||
webhook_url = f"/webhook/trigger/{uuid4()}"
|
||||
|
||||
# Validate trigger-type specific config
|
||||
if data.config and has_config(data.trigger_type):
|
||||
try:
|
||||
validate_config(data.trigger_type, data.config)
|
||||
except ValidationError as e:
|
||||
logger.warning("Invalid trigger config", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_type": data.trigger_type.value,
|
||||
"errors": e.errors()
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid config for {data.trigger_type.value}: {e.errors()}"
|
||||
)
|
||||
|
||||
# Create trigger instance
|
||||
trigger_data = data.model_dump()
|
||||
trigger_data["user_id"] = str(user_id)
|
||||
trigger_data["webhook_url"] = webhook_url
|
||||
|
||||
# Determine desired initial status based on auth requirements
|
||||
if has_config(data.trigger_type) and data.config and requires_authentication(data.trigger_type, data.config):
|
||||
desired_status = TriggerStatus.pending_verification
|
||||
else:
|
||||
desired_status = TriggerStatus.active
|
||||
|
||||
# Check concurrent active-trigger limits before auto-activating
|
||||
user_active, project_active = get_active_trigger_counts(
|
||||
session, str(user_id), data.project_id
|
||||
)
|
||||
if user_active >= MAX_ACTIVE_PER_USER or (
|
||||
data.project_id and project_active >= MAX_ACTIVE_PER_PROJECT
|
||||
):
|
||||
logger.info(
|
||||
"Active trigger limit reached — new trigger created as inactive",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id,
|
||||
"user_active": user_active,
|
||||
"project_active": project_active,
|
||||
},
|
||||
)
|
||||
trigger_data["status"] = TriggerStatus.inactive
|
||||
else:
|
||||
trigger_data["status"] = desired_status
|
||||
|
||||
trigger = Trigger(**trigger_data)
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Calculate next_run_at for scheduled triggers
|
||||
if trigger.trigger_type == TriggerType.schedule and trigger.custom_cron_expression:
|
||||
schedule_service = TriggerScheduleService(session)
|
||||
trigger.next_run_at = schedule_service.calculate_next_run_at(trigger)
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
logger.info("Trigger created", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger.id,
|
||||
"trigger_type": data.trigger_type.value,
|
||||
"next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None
|
||||
})
|
||||
|
||||
return trigger_to_out(trigger, 0) # New trigger has 0 executions
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Trigger creation failed", extra={
|
||||
"user_id": user_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/", name="list triggers")
|
||||
def list_triggers(
|
||||
trigger_type: Optional[TriggerType] = Query(None, description="Filter by trigger type"),
|
||||
status: Optional[TriggerStatus] = Query(None, description="Filter by status"),
|
||||
project_id: Optional[str] = Query(None, description="Filter by project ID"),
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
) -> Page[TriggerOut]:
|
||||
"""List triggers for current user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Build query with filters
|
||||
conditions = [Trigger.user_id == str(user_id)]
|
||||
|
||||
if trigger_type:
|
||||
conditions.append(Trigger.trigger_type == trigger_type)
|
||||
|
||||
if status is not None:
|
||||
conditions.append(Trigger.status == status)
|
||||
|
||||
if project_id:
|
||||
conditions.append(Trigger.project_id == project_id)
|
||||
|
||||
stmt = (
|
||||
select(Trigger)
|
||||
.where(and_(*conditions))
|
||||
.order_by(desc(Trigger.created_at))
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, 'total') else 0
|
||||
|
||||
# Get execution counts for all triggers in the result
|
||||
trigger_ids = [t.id for t in result.items]
|
||||
counts = get_execution_counts(session, trigger_ids)
|
||||
|
||||
# Convert triggers to TriggerOut with execution counts
|
||||
result.items = [trigger_to_out(t, counts.get(t.id, 0)) for t in result.items]
|
||||
|
||||
logger.debug("Triggers listed", extra={
|
||||
"user_id": user_id,
|
||||
"total": total,
|
||||
"filters": {
|
||||
"trigger_type": trigger_type.value if trigger_type else None,
|
||||
"status": status.value if status is not None else None,
|
||||
"project_id": project_id
|
||||
}
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@router.get("/{trigger_id}", name="get trigger", response_model=TriggerOut)
|
||||
def get_trigger(
|
||||
trigger_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Get a specific trigger by ID."""
|
||||
user_id = auth.user.id
|
||||
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
# Get execution count
|
||||
counts = get_execution_counts(session, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.debug("Trigger retrieved", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
|
||||
return trigger_to_out(trigger, execution_count)
|
||||
|
||||
|
||||
@router.put("/{trigger_id}", name="update trigger", response_model=TriggerOut)
|
||||
def update_trigger(
|
||||
trigger_id: int,
|
||||
data: TriggerUpdate,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Update a trigger."""
|
||||
user_id = auth.user.id
|
||||
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found for update", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
try:
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate config if being updated
|
||||
if "config" in update_data and update_data["config"] is not None:
|
||||
if has_config(trigger.trigger_type):
|
||||
try:
|
||||
validate_config(trigger.trigger_type, update_data["config"])
|
||||
except ValidationError as e:
|
||||
logger.warning("Invalid trigger config on update", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
"errors": e.errors()
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid config for {trigger.trigger_type.value}: {e.errors()}"
|
||||
)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(trigger, key, value)
|
||||
|
||||
# Recalculate next_run_at if cron expression or status changed for scheduled triggers
|
||||
if trigger.trigger_type == TriggerType.schedule:
|
||||
if "custom_cron_expression" in update_data or "status" in update_data:
|
||||
if trigger.status == TriggerStatus.active and trigger.custom_cron_expression:
|
||||
schedule_service = TriggerScheduleService(session)
|
||||
trigger.next_run_at = schedule_service.calculate_next_run_at(trigger)
|
||||
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Get execution count
|
||||
counts = get_execution_counts(session, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.info("Trigger updated", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"fields_updated": list(update_data.keys()),
|
||||
"next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None
|
||||
})
|
||||
|
||||
return trigger_to_out(trigger, execution_count)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Trigger update failed", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{trigger_id}", name="delete trigger")
|
||||
def delete_trigger(
|
||||
trigger_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Delete a trigger."""
|
||||
user_id = auth.user.id
|
||||
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found for deletion", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
try:
|
||||
# Delete execution logs first (bulk delete)
|
||||
session.exec(
|
||||
delete(TriggerExecution).where(
|
||||
TriggerExecution.trigger_id == trigger_id
|
||||
)
|
||||
)
|
||||
|
||||
# Then delete the trigger
|
||||
session.delete(trigger)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info("Trigger deleted", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Trigger deletion failed", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{trigger_id}/activate", name="activate trigger", response_model=TriggerOut)
|
||||
def activate_trigger(
|
||||
trigger_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Activate a trigger."""
|
||||
user_id = auth.user.id
|
||||
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found for activation", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
try:
|
||||
# --- Concurrent active-trigger limits ---
|
||||
user_active, project_active = get_active_trigger_counts(
|
||||
session, str(user_id), trigger.project_id
|
||||
)
|
||||
if user_active >= MAX_ACTIVE_PER_USER:
|
||||
logger.warning("User active trigger limit reached", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"current_active": user_active,
|
||||
"limit": MAX_ACTIVE_PER_USER,
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_USER}) reached for this user"
|
||||
)
|
||||
|
||||
if trigger.project_id and project_active >= MAX_ACTIVE_PER_PROJECT:
|
||||
logger.warning("Project active trigger limit reached", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"project_id": trigger.project_id,
|
||||
"current_active": project_active,
|
||||
"limit": MAX_ACTIVE_PER_PROJECT,
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_PROJECT}) reached for this project"
|
||||
)
|
||||
|
||||
# Check if authentication is required first — auth-required triggers
|
||||
# go straight to pending_verification (credentials are provided via
|
||||
# the auth flow, so missing-credential errors are expected).
|
||||
if has_config(trigger.trigger_type) and requires_authentication(trigger.trigger_type, trigger.config):
|
||||
trigger.status = TriggerStatus.pending_verification
|
||||
logger.info("Trigger set to pending verification (authentication required)", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": trigger.trigger_type.value
|
||||
})
|
||||
# Save the status change before raising the exception
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"message": "Authentication required for this trigger type",
|
||||
"missing_requirements": ["authentication"],
|
||||
"trigger_type": trigger.trigger_type.value
|
||||
}
|
||||
)
|
||||
|
||||
# For non-auth triggers, validate activation requirements
|
||||
if has_config(trigger.trigger_type):
|
||||
try:
|
||||
validate_activation(
|
||||
trigger_type=trigger.trigger_type,
|
||||
config_data=trigger.config,
|
||||
user_id=int(user_id),
|
||||
session=session
|
||||
)
|
||||
except ActivationError as e:
|
||||
logger.warning("Trigger activation requirements not met", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
"missing_requirements": e.missing_requirements
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": e.message,
|
||||
"missing_requirements": e.missing_requirements,
|
||||
"trigger_type": trigger.trigger_type.value
|
||||
}
|
||||
)
|
||||
|
||||
trigger.status = TriggerStatus.active
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Get execution count
|
||||
counts = get_execution_counts(session, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.info("Trigger status updated", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"status": trigger.status.value
|
||||
})
|
||||
|
||||
return trigger_to_out(trigger, execution_count)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Trigger activation failed", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{trigger_id}/deactivate", name="deactivate trigger", response_model=TriggerOut)
|
||||
def deactivate_trigger(
|
||||
trigger_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
):
|
||||
"""Deactivate a trigger."""
|
||||
user_id = auth.user.id
|
||||
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found for deactivation", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
try:
|
||||
trigger.status = TriggerStatus.inactive
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Get execution count
|
||||
counts = get_execution_counts(session, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.info("Trigger deactivated", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
|
||||
return trigger_to_out(trigger, execution_count)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Trigger deactivation failed", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{trigger_id}/executions", name="list trigger executions")
|
||||
def list_trigger_executions(
|
||||
trigger_id: int,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
) -> Page[TriggerExecutionOut]:
|
||||
"""List executions for a specific trigger."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# First verify the trigger belongs to the user
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found for executions list", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
# Get executions for this trigger
|
||||
stmt = (
|
||||
select(TriggerExecution)
|
||||
.where(TriggerExecution.trigger_id == trigger_id)
|
||||
.order_by(desc(TriggerExecution.created_at))
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, 'total') else 0
|
||||
|
||||
logger.debug("Trigger executions listed", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"total": total
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Trigger Config Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/{trigger_type}/config", name="get trigger type config schema")
|
||||
def get_trigger_type_config(
|
||||
trigger_type: TriggerType,
|
||||
auth: Auth = Depends(auth_must)
|
||||
) -> TriggerConfigSchemaOut:
|
||||
"""
|
||||
Get the configuration schema for a specific trigger type.
|
||||
|
||||
This endpoint returns the JSON schema for the trigger type's config field,
|
||||
which can be used by the frontend to dynamically render configuration forms.
|
||||
"""
|
||||
schema = get_config_schema(trigger_type)
|
||||
|
||||
return TriggerConfigSchemaOut(
|
||||
trigger_type=trigger_type.value,
|
||||
has_config=has_config(trigger_type),
|
||||
schema_=schema
|
||||
)
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.component import code
|
||||
from app.component.auth import Auth
|
||||
from app.component.database import session
|
||||
from app.component.encrypt import password_verify
|
||||
from app.component.stack_auth import StackAuth
|
||||
from app.exception.exception import UserException
|
||||
from app.model.user.user import (
|
||||
LoginByPasswordIn,
|
||||
LoginResponse,
|
||||
RegisterIn,
|
||||
Status,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("server_login_controller")
|
||||
|
||||
router = APIRouter(tags=["Login/Registration"])
|
||||
|
||||
|
||||
@router.post("/login", name="login by email or password")
|
||||
async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse:
|
||||
"""
|
||||
User login with email and password
|
||||
"""
|
||||
email = data.email
|
||||
user = User.by(User.email == email, s=session).one_or_none()
|
||||
|
||||
if not user:
|
||||
logger.warning("Login failed: user not found", extra={"email": email})
|
||||
raise UserException(code.password, _("Account or password error"))
|
||||
|
||||
if not password_verify(data.password, user.password):
|
||||
logger.warning("Login failed: invalid password", extra={"user_id": user.id, "email": email})
|
||||
raise UserException(code.password, _("Account or password error"))
|
||||
|
||||
logger.info("User login successful", extra={"user_id": user.id, "email": email})
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/dev_login", name="OAuth2 password flow login (for Swagger UI)")
|
||||
async def dev_login(
|
||||
username: str = Form(...), # OAuth2 uses 'username' but we accept email
|
||||
password: str = Form(...),
|
||||
session: Session = Depends(session),
|
||||
) -> dict:
|
||||
"""
|
||||
OAuth2 password flow compatible login endpoint for Swagger UI.
|
||||
This endpoint accepts form data (username/password) and returns an access token.
|
||||
"""
|
||||
# Use username as email (OAuth2 standard uses 'username' field)
|
||||
email = username
|
||||
user = User.by(User.email == email, s=session).one_or_none()
|
||||
|
||||
if not user:
|
||||
logger.warning("OAuth2 login failed: user not found", extra={"email": email})
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password")
|
||||
|
||||
if not password_verify(password, user.password):
|
||||
logger.warning(
|
||||
"OAuth2 login failed: invalid password",
|
||||
extra={"user_id": user.id, "email": email},
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password")
|
||||
|
||||
token = Auth.create_access_token(user.id)
|
||||
logger.info("OAuth2 login successful", extra={"user_id": user.id, "email": email})
|
||||
|
||||
# Return OAuth2 compatible response
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/login-by_stack", name="login by stack")
|
||||
async def by_stack_auth(
|
||||
token: str,
|
||||
type: str = "signup",
|
||||
invite_code: str | None = None,
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
try:
|
||||
stack_id = await StackAuth.user_id(token)
|
||||
info = await StackAuth.user_info(token)
|
||||
except Exception as e:
|
||||
logger.error("Stack auth failed", extra={"type": type, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(500, detail=_("Authentication failed"))
|
||||
|
||||
user = User.by(User.stack_id == stack_id, s=session).one_or_none()
|
||||
|
||||
if not user:
|
||||
if type != "signup":
|
||||
logger.warning(
|
||||
"Stack auth signup blocked: user not found",
|
||||
extra={"stack_id": stack_id, "type": type},
|
||||
)
|
||||
raise UserException(code.error, _("User not found"))
|
||||
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
username=info["username"] if "username" in info else None,
|
||||
nickname=info["display_name"],
|
||||
email=info["primary_email"],
|
||||
avatar=info["profile_image_url"],
|
||||
stack_id=stack_id,
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
logger.info(
|
||||
"New user registered via stack",
|
||||
extra={
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"stack_id": stack_id,
|
||||
},
|
||||
)
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error(
|
||||
"Stack auth registration failed",
|
||||
extra={"stack_id": stack_id, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise UserException(code.error, _("Failed to register"))
|
||||
else:
|
||||
if user.status == Status.Block:
|
||||
logger.warning(
|
||||
"Blocked user login attempt",
|
||||
extra={"user_id": user.id, "stack_id": stack_id},
|
||||
)
|
||||
raise UserException(code.error, _("Your account has been blocked."))
|
||||
|
||||
logger.info(
|
||||
"User login via stack successful",
|
||||
extra={"user_id": user.id, "email": user.email, "stack_id": stack_id},
|
||||
)
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/auto-login", name="auto login for local mode")
|
||||
async def auto_login(session: Session = Depends(session)) -> LoginResponse:
|
||||
"""
|
||||
Auto login for fully local mode (VITE_USE_LOCAL_PROXY=true).
|
||||
Returns the most recently active user, or creates a default admin user if none exists.
|
||||
"""
|
||||
# Find the most recently active user
|
||||
user = User.by(
|
||||
User.status == Status.Normal,
|
||||
order_by=User.updated_at.desc(),
|
||||
limit=1,
|
||||
s=session,
|
||||
).one_or_none()
|
||||
|
||||
if not user:
|
||||
# Create default admin user
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
email="admin@eigent.local",
|
||||
username="admin",
|
||||
nickname="Admin",
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
logger.info("Default admin user created", extra={"user_id": user.id})
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error("Failed to create default admin user", extra={"error": str(e)}, exc_info=True)
|
||||
raise UserException(code.error, _("Failed to create default user"))
|
||||
|
||||
logger.info("Auto login successful", extra={"user_id": user.id, "email": user.email})
|
||||
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
@router.post("/register", name="register by email/password")
|
||||
async def register(data: RegisterIn, session: Session = Depends(session)):
|
||||
email = data.email
|
||||
|
||||
if User.by(User.email == email, s=session).one_or_none():
|
||||
logger.warning("Registration failed: email already exists", extra={"email": email})
|
||||
raise UserException(code.error, _("Email already registered"))
|
||||
|
||||
with session as s:
|
||||
try:
|
||||
user = User(
|
||||
email=email,
|
||||
password=data.password,
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
logger.info(
|
||||
"User registered successfully",
|
||||
extra={"user_id": user.id, "email": email},
|
||||
)
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error(
|
||||
"User registration failed",
|
||||
extra={"email": email, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise UserException(code.error, _("Failed to register"))
|
||||
|
||||
return {"status": "success"}
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from app.model.chat.chat_snpshot import ChatSnapshot
|
||||
from app.model.config.config import Config
|
||||
from app.model.mcp.mcp_user import McpUser
|
||||
from app.model.user.privacy import UserPrivacy, UserPrivacySettings
|
||||
from app.model.user.user import User, UserIn, UserOut, UserProfile
|
||||
from app.model.user.user_credits_record import UserCreditsRecord
|
||||
from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut
|
||||
|
||||
logger = logging.getLogger("server_user_controller")
|
||||
|
||||
router = APIRouter(tags=["User"])
|
||||
|
||||
|
||||
@router.get("/user", name="user info", response_model=UserOut)
|
||||
def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get current user information and refresh credits."""
|
||||
user: User = auth.user
|
||||
user.refresh_credits_on_active(session)
|
||||
logger.debug("User info retrieved", extra={"user_id": user.id})
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/user", name="update user info", response_model=UserOut)
|
||||
def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update user basic information."""
|
||||
model = auth.user
|
||||
model.username = data.username
|
||||
model.save(session)
|
||||
logger.info("User info updated", extra={"user_id": model.id, "username": data.username})
|
||||
return model
|
||||
|
||||
|
||||
@router.put("/user/profile", name="update user profile", response_model=UserProfile)
|
||||
def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update user profile details."""
|
||||
model = auth.user
|
||||
model.nickname = data.nickname
|
||||
model.fullname = data.fullname
|
||||
model.work_desc = data.work_desc
|
||||
model.save(session)
|
||||
logger.info("User profile updated", extra={"user_id": model.id, "nickname": data.nickname})
|
||||
return model
|
||||
|
||||
|
||||
@router.get("/user/privacy", name="get user privacy")
|
||||
def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Get user privacy settings."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
|
||||
if not model:
|
||||
logger.debug("Privacy settings not found, returning defaults", extra={"user_id": user_id})
|
||||
return UserPrivacySettings.default_settings()
|
||||
|
||||
logger.debug("Privacy settings retrieved", extra={"user_id": user_id})
|
||||
return UserPrivacySettings(**model.pricacy_setting).to_response()
|
||||
|
||||
|
||||
@router.put("/user/privacy", name="update user privacy")
|
||||
def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
|
||||
"""Update user privacy settings."""
|
||||
user_id = auth.user.id
|
||||
stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id)
|
||||
model = session.exec(stmt).one_or_none()
|
||||
default_settings = UserPrivacySettings.default_settings()
|
||||
|
||||
if model:
|
||||
model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()}
|
||||
model.save(session)
|
||||
logger.info("Privacy settings updated", extra={"user_id": user_id})
|
||||
else:
|
||||
model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()})
|
||||
model.save(session)
|
||||
logger.info("Privacy settings created", extra={"user_id": user_id})
|
||||
|
||||
return UserPrivacySettings(**model.pricacy_setting).to_response()
|
||||
|
||||
|
||||
@router.get("/user/current_credits", name="get user current credits")
|
||||
def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get user's current credit balance."""
|
||||
user = auth.user
|
||||
user.refresh_credits_on_active(session)
|
||||
credits = user.credits
|
||||
daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id)
|
||||
current_daily_credits = 0
|
||||
if daily_credits:
|
||||
current_daily_credits = daily_credits.amount - daily_credits.balance
|
||||
credits += current_daily_credits if current_daily_credits > 0 else 0
|
||||
|
||||
logger.debug(
|
||||
"Credits retrieved",
|
||||
extra={"user_id": user.id, "total_credits": credits, "daily_credits": current_daily_credits},
|
||||
)
|
||||
return {"credits": credits, "daily_credits": current_daily_credits}
|
||||
|
||||
|
||||
@router.get("/user/stat", name="get user stat", response_model=UserStatOut)
|
||||
def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Get current user's operation statistics."""
|
||||
user_id = auth.user.id
|
||||
stat = session.exec(select(UserStat).where(UserStat.user_id == user_id)).first()
|
||||
data = UserStatOut()
|
||||
|
||||
if stat:
|
||||
data = UserStatOut(**stat.model_dump())
|
||||
else:
|
||||
data = UserStatOut(user_id=user_id)
|
||||
|
||||
data.task_queries = ChatHistory.count(ChatHistory.user_id == user_id, s=session)
|
||||
mcp = McpUser.count(McpUser.user_id == user_id, s=session)
|
||||
tool: list = session.exec(
|
||||
select(func.count("*")).where(Config.user_id == user_id).group_by(Config.config_group)
|
||||
).all()
|
||||
tool = tool.__len__()
|
||||
data.mcp_install_count = mcp + tool
|
||||
data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(user_id))
|
||||
|
||||
logger.debug(
|
||||
"User stats retrieved",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"task_queries": data.task_queries,
|
||||
"mcp_install_count": data.mcp_install_count,
|
||||
"storage_used": data.storage_used,
|
||||
},
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/user/stat", name="record user stat")
|
||||
def record_user_stat(
|
||||
data: UserStatActionIn,
|
||||
auth: Auth = Depends(auth_must),
|
||||
session: Session = Depends(session),
|
||||
):
|
||||
"""Record or update current user's operation statistics."""
|
||||
data.user_id = auth.user.id
|
||||
stat = UserStat.record_action(session, data)
|
||||
logger.info(
|
||||
"User stat recorded",
|
||||
extra={"user_id": data.user_id, "action": data.action if hasattr(data, "action") else "unknown"},
|
||||
)
|
||||
return stat
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from pathlib import Path
|
||||
|
|
@ -13,7 +13,7 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from celery import Celery
|
||||
from app.component.environment import env_or_fail, env
|
||||
from app.core.environment import env_or_fail, env
|
||||
|
||||
celery = Celery(
|
||||
__name__,
|
||||
|
|
@ -23,7 +23,7 @@ celery = Celery(
|
|||
|
||||
# Configure Celery to autodiscover tasks
|
||||
celery.conf.imports = [
|
||||
"app.schedule.trigger_schedule_task",
|
||||
"app.domains.trigger.service.trigger_schedule_task",
|
||||
]
|
||||
|
||||
# Configure Celery Beat schedule
|
||||
|
|
@ -37,14 +37,14 @@ celery.conf.beat_schedule = {}
|
|||
|
||||
if ENABLE_TRIGGER_SCHEDULE_POLLER:
|
||||
celery.conf.beat_schedule["poll-trigger-schedules"] = {
|
||||
"task": "app.schedule.trigger_schedule_task.poll_trigger_schedules",
|
||||
"task": "app.domains.trigger.service.trigger_schedule_task.poll_trigger_schedules",
|
||||
"schedule": TRIGGER_SCHEDULE_POLLER_INTERVAL * 60.0, # Convert minutes to seconds
|
||||
"options": {"queue": "poll_trigger_schedules"},
|
||||
}
|
||||
|
||||
if ENABLE_EXECUTION_TIMEOUT_CHECKER:
|
||||
celery.conf.beat_schedule["check-execution-timeouts"] = {
|
||||
"task": "app.schedule.trigger_schedule_task.check_execution_timeouts",
|
||||
"task": "app.domains.trigger.service.trigger_schedule_task.check_execution_timeouts",
|
||||
"schedule": EXECUTION_TIMEOUT_CHECKER_INTERVAL * 60.0, # Convert minutes to seconds
|
||||
"options": {"queue": "check_execution_timeouts"},
|
||||
}
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
success = 0 # success response code
|
||||
|
|
@ -1,22 +1,22 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, create_engine
|
||||
|
||||
from app.component.environment import env, env_or_fail
|
||||
from app.core.environment import env, env_or_fail
|
||||
|
||||
logger = logging.getLogger("database")
|
||||
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from passlib.context import CryptContext
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import importlib
|
||||
|
|
@ -20,7 +20,7 @@ from urllib.parse import urlencode
|
|||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.component.environment import env
|
||||
from app.core.environment import env
|
||||
|
||||
|
||||
class OAuthAdapter(ABC):
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import os
|
||||
|
|
@ -19,7 +19,7 @@ from pathlib import Path
|
|||
from fastapi_babel.middleware import LANGUAGES_PATTERN
|
||||
from pydantic_i18n import JsonLoader, PydanticI18n
|
||||
|
||||
from app.component.babel import babel, babel_configs
|
||||
from app.core.babel import babel, babel_configs
|
||||
|
||||
|
||||
def get_language(lang_code: str | None = None):
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from sqids import Sqids
|
||||
|
|
@ -19,7 +19,7 @@ import logging
|
|||
from sqlmodel import select, and_
|
||||
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.component.environment import env
|
||||
from app.core.environment import env
|
||||
|
||||
logger = logging.getLogger("server_trigger_utils")
|
||||
|
||||
|
|
@ -1,15 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
|
||||
15
server/app/domains/__init__.py
Normal file
15
server/app/domains/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""v1 domain modules."""
|
||||
15
server/app/domains/chat/__init__.py
Normal file
15
server/app/domains/chat/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Chat domain: history, steps, snapshots, files, share, logs."""
|
||||
158
server/app/domains/chat/api/history_controller.py
Normal file
158
server/app/domains/chat/api/history_controller.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Chat History controller. Uses ChatService for grouping."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, case, delete, desc, func, select
|
||||
from fastapi_babel import _
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.chat.chat_history import ChatHistory, ChatHistoryIn, ChatHistoryOut, ChatHistoryUpdate, ChatStatus
|
||||
from app.model.chat.chat_history_grouped import GroupedHistoryResponse, ProjectGroup
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.model.user.key import Key
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.domains.chat.service.chat_service import ChatService
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Chat History"])
|
||||
|
||||
|
||||
@router.post("/history", name="save chat history", response_model=ChatHistoryOut)
|
||||
def create_chat_history(data: ChatHistoryIn, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)):
|
||||
data.user_id = auth.id
|
||||
chat_history = ChatHistory(**data.model_dump())
|
||||
db_session.add(chat_history)
|
||||
db_session.commit()
|
||||
db_session.refresh(chat_history)
|
||||
return chat_history
|
||||
|
||||
|
||||
@router.get("/histories", name="get chat history")
|
||||
def list_chat_history(db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)) -> Page[ChatHistoryOut]:
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == auth.id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)),
|
||||
desc(ChatHistory.created_at),
|
||||
desc(ChatHistory.id),
|
||||
)
|
||||
)
|
||||
return paginate(db_session, stmt)
|
||||
|
||||
|
||||
@router.get("/histories/grouped", name="get grouped chat history")
|
||||
def list_grouped_chat_history(
|
||||
include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in groups"),
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> GroupedHistoryResponse:
|
||||
return ChatService.get_grouped_histories(auth.id, include_tasks, db_session)
|
||||
|
||||
|
||||
@router.get("/histories/grouped/{project_id}", name="get single grouped project")
|
||||
def get_grouped_project(
|
||||
project_id: str,
|
||||
include_tasks: Optional[bool] = Query(True, description="Whether to include individual tasks in the project"),
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> ProjectGroup:
|
||||
result = ChatService.get_grouped_project(auth.id, project_id, include_tasks, db_session)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/history/{history_id}", name="delete chat history")
|
||||
def delete_chat_history(history_id: int, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)):
|
||||
history = db_session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
|
||||
if not history:
|
||||
raise HTTPException(status_code=404, detail="Chat History not found")
|
||||
if history.user_id != auth.id:
|
||||
raise HTTPException(status_code=403, detail="You are not allowed to delete this chat history")
|
||||
|
||||
project_id = history.project_id if history.project_id else history.task_id
|
||||
|
||||
sibling_count = (
|
||||
db_session.exec(
|
||||
select(func.count(ChatHistory.id)).where(
|
||||
ChatHistory.id != history_id,
|
||||
ChatHistory.project_id == project_id if history.project_id else ChatHistory.task_id == project_id,
|
||||
)
|
||||
).first()
|
||||
or 0
|
||||
)
|
||||
|
||||
db_session.delete(history)
|
||||
|
||||
if sibling_count == 0:
|
||||
triggers = db_session.exec(select(Trigger).where(Trigger.project_id == project_id)).all()
|
||||
for trigger in triggers:
|
||||
db_session.exec(delete(TriggerExecution).where(TriggerExecution.trigger_id == trigger.id))
|
||||
db_session.delete(trigger)
|
||||
logger.info(
|
||||
"Deleted triggers for removed project", extra={"project_id": project_id, "trigger_count": len(triggers)}
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut)
|
||||
async def update_chat_history(
|
||||
history_id: int, data: ChatHistoryUpdate, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)
|
||||
):
|
||||
history = db_session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first()
|
||||
if not history:
|
||||
raise HTTPException(status_code=404, detail="Chat History not found")
|
||||
if history.user_id != auth.id:
|
||||
raise HTTPException(status_code=403, detail="You are not allowed to update this chat history")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
history.update_fields(update_data)
|
||||
history.save(db_session)
|
||||
|
||||
db_session.refresh(history)
|
||||
return history
|
||||
|
||||
|
||||
@router.put("/project/{project_id}/name", name="update project name")
|
||||
def update_project_name(
|
||||
project_id: str, new_name: str, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)
|
||||
):
|
||||
user_id = auth.id
|
||||
stmt = select(ChatHistory).where(ChatHistory.project_id == project_id).where(ChatHistory.user_id == user_id)
|
||||
histories = db_session.exec(stmt).all()
|
||||
|
||||
if not histories:
|
||||
raise HTTPException(status_code=404, detail="Project not found or access denied")
|
||||
|
||||
try:
|
||||
for history in histories:
|
||||
history.project_name = new_name
|
||||
db_session.add(history)
|
||||
db_session.commit()
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error("Project name update failed", extra={"user_id": user_id, "project_id": project_id, "error": str(e)})
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
91
server/app/domains/chat/api/share_controller.py
Normal file
91
server/app/domains/chat/api/share_controller.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Chat Share controller with auth and task ownership on create."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn
|
||||
from app.model.chat.chat_step import ChatStep
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
|
||||
from app.shared.auth import auth_must
|
||||
from app.domains.chat.service import ChatService
|
||||
from app.domains.chat.schema import TaskOwnershipCheckReq
|
||||
from itsdangerous import BadTimeSignature, SignatureExpired
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["V1 Chat Share"])
|
||||
|
||||
|
||||
@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut)
|
||||
def get_share_info(token: str, db_session: Session = Depends(session)):
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except (SignatureExpired, BadTimeSignature):
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid or has expired.")
|
||||
stmt = select(ChatHistory).where(ChatHistory.task_id == task_id)
|
||||
history = db_session.exec(stmt).one_or_none()
|
||||
if not history:
|
||||
raise HTTPException(status_code=404, detail="Chat history not found.")
|
||||
return history
|
||||
|
||||
|
||||
@router.get("/share/playback/{token}", name="Playback shared chat via SSE")
|
||||
async def share_playback(token: str, db_session: Session = Depends(session), delay_time: float = 0):
|
||||
if delay_time > 5:
|
||||
delay_time = 5
|
||||
try:
|
||||
task_id = ChatShare.verify_token(token, False)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Share link has expired.")
|
||||
except BadTimeSignature:
|
||||
raise HTTPException(status_code=400, detail="Share link is invalid.")
|
||||
|
||||
async def event_generator():
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(ChatStep.id)
|
||||
steps = db_session.exec(stmt).all()
|
||||
if not steps:
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
for s in steps:
|
||||
step_data = {
|
||||
"id": s.id,
|
||||
"task_id": s.task_id,
|
||||
"step": s.step,
|
||||
"data": s.data,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
if delay_time > 0 and s.step != "create_agent":
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
|
||||
def create_share_link(
|
||||
data: ChatShareIn,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
if not ChatService.verify_task_ownership(TaskOwnershipCheckReq(task_id=data.task_id, user_id=auth.user.id)):
|
||||
raise HTTPException(status_code=403, detail="Task not found or access denied.")
|
||||
share_token = ChatShare.generate_token(data.task_id)
|
||||
return {"share_token": share_token}
|
||||
122
server/app/domains/chat/api/snapshot_controller.py
Normal file
122
server/app/domains/chat/api/snapshot_controller.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""v1 Chat Snapshot - H3 auth, H4 ownership, H19 path traversal, P2 Update model.
|
||||
STATUS: full-rewrite (security: H3, H4, H19, P2 Update model)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn, ChatSnapshotUpdate
|
||||
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.ownership import require_owner
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["V1 Chat Snapshot"])
|
||||
|
||||
# H19: api_task_id must be safe for path - only alphanumeric, dash, underscore
|
||||
API_TASK_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
|
||||
|
||||
|
||||
def _validate_api_task_id(value: str) -> None:
|
||||
"""Reject path traversal: api_task_id must match safe charset."""
|
||||
if not value or not API_TASK_ID_PATTERN.match(value):
|
||||
raise HTTPException(status_code=400, detail=_("Invalid api_task_id: only letters, numbers, - and _ allowed"))
|
||||
|
||||
@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot])
|
||||
async def list_chat_snapshots(
|
||||
api_task_id: Optional[str] = None,
|
||||
camel_task_id: Optional[str] = None,
|
||||
browser_url: Optional[str] = None,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
query = select(ChatSnapshot).where(ChatSnapshot.user_id == auth.user.id)
|
||||
if api_task_id is not None:
|
||||
query = query.where(ChatSnapshot.api_task_id == api_task_id)
|
||||
if camel_task_id is not None:
|
||||
query = query.where(ChatSnapshot.camel_task_id == camel_task_id)
|
||||
if browser_url is not None:
|
||||
query = query.where(ChatSnapshot.browser_url == browser_url)
|
||||
return list(db_session.exec(query).all())
|
||||
|
||||
|
||||
@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot)
|
||||
async def get_chat_snapshot(
|
||||
snapshot_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
snapshot = db_session.get(ChatSnapshot, snapshot_id)
|
||||
require_owner(snapshot, auth.user.id)
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot)
|
||||
async def create_chat_snapshot(
|
||||
snapshot: ChatSnapshotIn,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
_validate_api_task_id(snapshot.api_task_id)
|
||||
image_path = ChatSnapshotIn.save_image(auth.user.id, snapshot.api_task_id, snapshot.image_base64)
|
||||
chat_snapshot = ChatSnapshot(
|
||||
user_id=auth.user.id,
|
||||
api_task_id=snapshot.api_task_id,
|
||||
camel_task_id=snapshot.camel_task_id,
|
||||
browser_url=snapshot.browser_url,
|
||||
image_path=image_path,
|
||||
)
|
||||
db_session.add(chat_snapshot)
|
||||
db_session.commit()
|
||||
db_session.refresh(chat_snapshot)
|
||||
return chat_snapshot
|
||||
|
||||
|
||||
@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot)
|
||||
async def update_chat_snapshot(
|
||||
snapshot_id: int,
|
||||
snapshot_update: ChatSnapshotUpdate,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
db_snapshot = db_session.get(ChatSnapshot, snapshot_id)
|
||||
require_owner(db_snapshot, auth.user.id)
|
||||
for key, value in snapshot_update.model_dump(exclude_unset=True).items():
|
||||
if key == "api_task_id" and value is not None:
|
||||
_validate_api_task_id(str(value))
|
||||
setattr(db_snapshot, key, value)
|
||||
db_session.add(db_snapshot)
|
||||
db_session.commit()
|
||||
db_session.refresh(db_snapshot)
|
||||
return db_snapshot
|
||||
|
||||
|
||||
@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot")
|
||||
async def delete_chat_snapshot(
|
||||
snapshot_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
db_snapshot = db_session.get(ChatSnapshot, snapshot_id)
|
||||
require_owner(db_snapshot, auth.user.id)
|
||||
db_session.delete(db_snapshot)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
162
server/app/domains/chat/api/step_controller.py
Normal file
162
server/app/domains/chat/api/step_controller.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""v1 Chat Step - H1 auth, H2 ownership, P2 dedicated Update model.
|
||||
STATUS: full-rewrite (security: H1, H2, P2 Update model)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.sql.expression import case
|
||||
from sqlmodel import Session, asc, select
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn, ChatStepUpdate
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
|
||||
from app.shared.auth import auth_must
|
||||
from app.domains.chat.service import ChatService
|
||||
from app.domains.chat.schema import TaskOwnershipCheckReq
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["V1 Chat Step"])
|
||||
|
||||
|
||||
def _task_owned_by_user(db: Session, task_id: str, user_id: int) -> bool:
|
||||
return ChatService.verify_task_ownership(TaskOwnershipCheckReq(task_id=task_id, user_id=user_id))
|
||||
|
||||
|
||||
@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut])
|
||||
async def list_chat_steps(
|
||||
task_id: str,
|
||||
step: Optional[str] = None,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
if not _task_owned_by_user(db_session, task_id, auth.user.id):
|
||||
return []
|
||||
query = select(ChatStep).where(ChatStep.task_id == task_id)
|
||||
if step is not None:
|
||||
query = query.where(ChatStep.step == step)
|
||||
return list(db_session.exec(query).all())
|
||||
|
||||
|
||||
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
|
||||
async def share_playback(
|
||||
task_id: str,
|
||||
delay_time: float = 0,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
if delay_time > 5:
|
||||
delay_time = 5
|
||||
if not _task_owned_by_user(db_session, task_id, auth.user.id):
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
async def event_generator():
|
||||
stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(
|
||||
asc(case((ChatStep.timestamp.is_(None), 1), else_=0)),
|
||||
asc(ChatStep.timestamp),
|
||||
asc(ChatStep.id),
|
||||
)
|
||||
steps = db_session.exec(stmt).all()
|
||||
if not steps:
|
||||
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
|
||||
return
|
||||
for s in steps:
|
||||
step_data = {
|
||||
"id": s.id,
|
||||
"task_id": s.task_id,
|
||||
"step": s.step,
|
||||
"data": s.data,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
yield f"data: {json.dumps(step_data)}\n\n"
|
||||
if delay_time > 0:
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
|
||||
async def get_chat_step(
|
||||
step_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
chat_step = db_session.get(ChatStep, step_id)
|
||||
if not chat_step:
|
||||
raise HTTPException(status_code=404, detail="Chat step not found")
|
||||
if not _task_owned_by_user(db_session, chat_step.task_id, auth.user.id):
|
||||
raise HTTPException(status_code=404, detail="Chat step not found")
|
||||
return chat_step
|
||||
|
||||
|
||||
@router.post("/steps", name="create chat step")
|
||||
async def create_chat_step(
|
||||
step: ChatStepIn,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
if not _task_owned_by_user(db_session, step.task_id, auth.user.id):
|
||||
raise HTTPException(status_code=403, detail="Task not found or access denied")
|
||||
chat_step = ChatStep(
|
||||
task_id=step.task_id,
|
||||
step=step.step,
|
||||
data=step.data,
|
||||
timestamp=step.timestamp,
|
||||
)
|
||||
db_session.add(chat_step)
|
||||
db_session.commit()
|
||||
db_session.refresh(chat_step)
|
||||
return {"code": 200, "msg": "success"}
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
|
||||
async def update_chat_step(
|
||||
step_id: int,
|
||||
chat_step_update: ChatStepUpdate,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
db_chat_step = db_session.get(ChatStep, step_id)
|
||||
if not db_chat_step:
|
||||
raise HTTPException(status_code=404, detail="Chat step not found")
|
||||
if not _task_owned_by_user(db_session, db_chat_step.task_id, auth.user.id):
|
||||
raise HTTPException(status_code=404, detail="Chat step not found")
|
||||
for key, value in chat_step_update.model_dump(exclude_unset=True).items():
|
||||
setattr(db_chat_step, key, value)
|
||||
db_session.add(db_chat_step)
|
||||
db_session.commit()
|
||||
db_session.refresh(db_chat_step)
|
||||
return db_chat_step
|
||||
|
||||
|
||||
@router.delete("/steps/{step_id}", name="delete chat step")
|
||||
async def delete_chat_step(
|
||||
step_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
db_chat_step = db_session.get(ChatStep, step_id)
|
||||
if not db_chat_step:
|
||||
raise HTTPException(status_code=404, detail="Chat step not found")
|
||||
if not _task_owned_by_user(db_session, db_chat_step.task_id, auth.user.id):
|
||||
raise HTTPException(status_code=404, detail="Chat step not found")
|
||||
db_session.delete(db_chat_step)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
27
server/app/domains/chat/schema/__init__.py
Normal file
27
server/app/domains/chat/schema/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Chat domain schemas."""
|
||||
|
||||
from app.domains.chat.schema.schemas import (
|
||||
TaskOwnershipCheckReq,
|
||||
FileValidationReq,
|
||||
FileValidationResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TaskOwnershipCheckReq",
|
||||
"FileValidationReq",
|
||||
"FileValidationResult",
|
||||
]
|
||||
|
|
@ -12,18 +12,21 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi import APIRouter
|
||||
"""v1 ChatService request/response schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
class TaskOwnershipCheckReq(BaseModel):
|
||||
task_id: str
|
||||
user_id: int
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
service: str
|
||||
class FileValidationReq(BaseModel):
|
||||
filename: str
|
||||
file_size: int
|
||||
|
||||
|
||||
@router.get("/health", name="health check", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint for monitoring and container orchestration."""
|
||||
return HealthResponse(status="ok", service="eigent-server")
|
||||
class FileValidationResult(BaseModel):
|
||||
valid: bool
|
||||
error: str | None = None
|
||||
|
|
@ -1,20 +1,19 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi_babel import BabelMiddleware
|
||||
"""Chat domain services."""
|
||||
|
||||
from app import api
|
||||
from app.component.babel import babel_configs
|
||||
from app.domains.chat.service.chat_service import ChatService
|
||||
|
||||
api.add_middleware(BabelMiddleware, babel_configs=babel_configs)
|
||||
__all__ = ["ChatService"]
|
||||
242
server/app/domains/chat/service/chat_service.py
Normal file
242
server/app/domains/chat/service/chat_service.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""ChatService: task ownership, file validation, history grouping. No billing in eigent."""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, case, desc, func, select
|
||||
|
||||
from app.core.database import session_make
|
||||
from app.model.chat.chat_history import ChatHistory, ChatHistoryOut, ChatStatus
|
||||
from app.model.chat.chat_history_grouped import GroupedHistoryResponse, ProjectGroup
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.domains.chat.schema import TaskOwnershipCheckReq, FileValidationReq, FileValidationResult
|
||||
|
||||
ALLOWED_EXTENSIONS = {
|
||||
"jpg", "jpeg", "png", "gif", "webp",
|
||||
"pdf", "txt", "md", "csv",
|
||||
"json", "xml", "yaml", "yml",
|
||||
"doc", "docx", "xls", "xlsx",
|
||||
"zip",
|
||||
}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""Chat domain business logic - static methods, self-managed session."""
|
||||
|
||||
@staticmethod
|
||||
def verify_task_ownership(req: TaskOwnershipCheckReq) -> bool:
|
||||
"""Check if task_id belongs to user_id."""
|
||||
with session_make() as s:
|
||||
h = s.exec(
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.task_id == req.task_id, ChatHistory.user_id == req.user_id)
|
||||
).first()
|
||||
return h is not None
|
||||
|
||||
@staticmethod
|
||||
def validate_file(req: FileValidationReq) -> FileValidationResult:
|
||||
"""Validate filename extension and file size."""
|
||||
if not req.filename:
|
||||
return FileValidationResult(valid=False, error="Filename is required")
|
||||
ext = req.filename.rsplit(".", 1)[-1].lower() if "." in req.filename else ""
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
return FileValidationResult(
|
||||
valid=False,
|
||||
error=f"File type '.{ext}' is not allowed. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}",
|
||||
)
|
||||
if req.file_size > MAX_FILE_SIZE:
|
||||
return FileValidationResult(
|
||||
valid=False,
|
||||
error=f"File size ({req.file_size} bytes) exceeds max {MAX_FILE_SIZE // (1024 * 1024)} MB",
|
||||
)
|
||||
return FileValidationResult(valid=True)
|
||||
|
||||
@staticmethod
|
||||
async def reconcile_if_needed(
|
||||
history: ChatHistory, status_changed_to_done: bool, trace_id: str | None = None
|
||||
) -> None:
|
||||
"""No-op: eigent does not have billing/credits."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def upload_file(
|
||||
user_id: int, task_id: str, filename: str, file_content: bytes, file_type: str | None, s: Session
|
||||
) -> "ChatFile":
|
||||
"""Validate, upload to S3, and create ChatFile record. Caller must commit."""
|
||||
from app.model.chat.chat_file import ChatFile, ChatFileIn
|
||||
|
||||
validation = ChatService.validate_file(FileValidationReq(filename=filename, file_size=len(file_content)))
|
||||
if not validation.valid:
|
||||
raise ValueError(validation.error or "Invalid file")
|
||||
|
||||
file_info = ChatFileIn.save_file_to_s3(
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
filename=filename,
|
||||
file_content=file_content,
|
||||
file_type=file_type,
|
||||
)
|
||||
chat_file = ChatFile(
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
filename=file_info["filename"],
|
||||
file_size=file_info["file_size"],
|
||||
file_type=file_info["file_type"],
|
||||
s3_key=file_info["s3_key"],
|
||||
s3_bucket=file_info["s3_bucket"],
|
||||
)
|
||||
s.add(chat_file)
|
||||
s.commit()
|
||||
s.refresh(chat_file)
|
||||
return chat_file
|
||||
|
||||
@staticmethod
|
||||
def is_real_task(history: ChatHistory) -> bool:
|
||||
"""Check if a task is a real task vs a placeholder/trigger-created task."""
|
||||
if history.spend and history.spend > 0:
|
||||
return True
|
||||
if history.tokens and history.tokens > 0:
|
||||
return True
|
||||
if (
|
||||
history.model_platform
|
||||
and history.model_platform != "none"
|
||||
and history.model_type
|
||||
and history.model_type != "none"
|
||||
and history.installed_mcp
|
||||
and history.installed_mcp != "none"
|
||||
):
|
||||
return True
|
||||
if history.question and history.question.startswith("Project created via trigger:"):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _build_project_data(
|
||||
histories: list[ChatHistory],
|
||||
trigger_count_map: dict[str, int],
|
||||
include_tasks: bool,
|
||||
project_id_override: str | None = None,
|
||||
) -> list[ProjectGroup]:
|
||||
"""Build ProjectGroup list from histories. Shared by grouped list and single project endpoints."""
|
||||
project_map: Dict[str, Dict] = defaultdict(
|
||||
lambda: {
|
||||
"project_id": "",
|
||||
"project_name": None,
|
||||
"total_tokens": 0,
|
||||
"task_count": 0,
|
||||
"latest_task_date": "",
|
||||
"last_prompt": None,
|
||||
"tasks": [],
|
||||
"total_completed_tasks": 0,
|
||||
"total_ongoing_tasks": 0,
|
||||
"average_tokens_per_task": 0,
|
||||
"total_triggers": 0,
|
||||
}
|
||||
)
|
||||
|
||||
for history in histories:
|
||||
project_id = project_id_override or (history.project_id if history.project_id else history.task_id)
|
||||
project_data = project_map[project_id]
|
||||
|
||||
if not project_data["project_id"]:
|
||||
project_data["project_id"] = project_id
|
||||
project_data["project_name"] = history.project_name or f"Project {project_id}"
|
||||
project_data["latest_task_date"] = history.created_at.isoformat() if history.created_at else ""
|
||||
project_data["last_prompt"] = history.question
|
||||
|
||||
if include_tasks and ChatService.is_real_task(history):
|
||||
project_data["tasks"].append(ChatHistoryOut(**history.model_dump()))
|
||||
|
||||
if ChatService.is_real_task(history):
|
||||
project_data["task_count"] += 1
|
||||
project_data["total_tokens"] += history.tokens or 0
|
||||
|
||||
if history.status == ChatStatus.done:
|
||||
project_data["total_completed_tasks"] += 1
|
||||
elif history.status == ChatStatus.ongoing:
|
||||
project_data["total_ongoing_tasks"] += 1
|
||||
|
||||
if history.created_at:
|
||||
task_date = history.created_at.isoformat()
|
||||
if not project_data["latest_task_date"] or task_date > project_data["latest_task_date"]:
|
||||
project_data["latest_task_date"] = task_date
|
||||
project_data["last_prompt"] = history.question
|
||||
|
||||
projects = []
|
||||
for project_data in project_map.values():
|
||||
if include_tasks:
|
||||
project_data["tasks"].sort(key=lambda x: (x.created_at is None, x.created_at or ""), reverse=False)
|
||||
pid = project_data["project_id"]
|
||||
project_data["total_triggers"] = trigger_count_map.get(pid, 0)
|
||||
projects.append(ProjectGroup(**project_data))
|
||||
|
||||
projects.sort(key=lambda x: x.latest_task_date, reverse=True)
|
||||
return projects
|
||||
|
||||
@staticmethod
|
||||
def get_grouped_histories(user_id: int, include_tasks: bool, s: Session) -> GroupedHistoryResponse:
|
||||
"""Get all chat histories grouped by project for a user."""
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == user_id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)),
|
||||
desc(ChatHistory.created_at),
|
||||
desc(ChatHistory.id),
|
||||
)
|
||||
)
|
||||
histories = s.exec(stmt).all()
|
||||
|
||||
trigger_count_stmt = (
|
||||
select(Trigger.project_id, func.count(Trigger.id).label("count"))
|
||||
.where(Trigger.user_id == str(user_id))
|
||||
.group_by(Trigger.project_id)
|
||||
)
|
||||
trigger_counts = s.exec(trigger_count_stmt).all()
|
||||
trigger_count_map = {project_id: count for project_id, count in trigger_counts}
|
||||
|
||||
projects = ChatService._build_project_data(histories, trigger_count_map, include_tasks)
|
||||
return GroupedHistoryResponse(projects=projects)
|
||||
|
||||
@staticmethod
|
||||
def get_grouped_project(user_id: int, project_id: str, include_tasks: bool, s: Session) -> ProjectGroup | None:
|
||||
"""Get a single project group by project_id."""
|
||||
stmt = (
|
||||
select(ChatHistory)
|
||||
.where(ChatHistory.user_id == user_id)
|
||||
.where(ChatHistory.project_id == project_id)
|
||||
.order_by(
|
||||
desc(case((ChatHistory.created_at.is_(None), 0), else_=1)),
|
||||
desc(ChatHistory.created_at),
|
||||
desc(ChatHistory.id),
|
||||
)
|
||||
)
|
||||
histories = s.exec(stmt).all()
|
||||
if not histories:
|
||||
return None
|
||||
|
||||
trigger_count_stmt = (
|
||||
select(func.count(Trigger.id)).where(Trigger.user_id == str(user_id)).where(Trigger.project_id == project_id)
|
||||
)
|
||||
trigger_count = s.exec(trigger_count_stmt).first() or 0
|
||||
trigger_count_map = {project_id: trigger_count}
|
||||
|
||||
projects = ChatService._build_project_data(histories, trigger_count_map, include_tasks, project_id_override=project_id)
|
||||
return projects[0] if projects else None
|
||||
15
server/app/domains/config/__init__.py
Normal file
15
server/app/domains/config/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Config domain: plans, config, providers."""
|
||||
|
|
@ -11,3 +11,4 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
93
server/app/domains/config/api/config_controller.py
Normal file
93
server/app/domains/config/api/config_controller.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
# STATUS: full-rewrite (uses ConfigService, self-managed session)
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from fastapi_babel import _
|
||||
|
||||
from app.model.config.config import ConfigCreate, ConfigUpdate, ConfigOut
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.domains.config.service.config_service import ConfigService
|
||||
|
||||
router = APIRouter(tags=["Config Management"])
|
||||
|
||||
|
||||
@router.get("/configs", name="list configs", response_model=list[ConfigOut])
|
||||
async def list_configs(
|
||||
config_group: Optional[str] = None,
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
return ConfigService.list_for_user(auth.id, config_group)
|
||||
|
||||
|
||||
@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut)
|
||||
async def get_config(config_id: int, auth: V1UserAuth = Depends(auth_must)):
|
||||
config = ConfigService.get(config_id, auth.id)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
return config
|
||||
|
||||
|
||||
@router.post("/configs", name="create config", response_model=ConfigOut)
|
||||
async def create_config(config: ConfigCreate, auth: V1UserAuth = Depends(auth_must)):
|
||||
result = ConfigService.create(
|
||||
user_id=auth.id,
|
||||
config_name=config.config_name,
|
||||
config_value=config.config_value,
|
||||
config_group=config.config_group,
|
||||
)
|
||||
if not result["success"]:
|
||||
error_map = {
|
||||
"CONFIG_INVALID_NAME": (400, _("Config Name is invalid")),
|
||||
"CONFIG_DUPLICATE": (400, _("Configuration already exists for this user")),
|
||||
}
|
||||
status, detail = error_map.get(result["error_code"], (500, "Config error"))
|
||||
raise HTTPException(status_code=status, detail=detail)
|
||||
return result["config"]
|
||||
|
||||
|
||||
@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut)
|
||||
async def update_config(config_id: int, config_update: ConfigUpdate, auth: V1UserAuth = Depends(auth_must)):
|
||||
result = ConfigService.update(
|
||||
config_id=config_id,
|
||||
user_id=auth.id,
|
||||
config_name=config_update.config_name,
|
||||
config_value=config_update.config_value,
|
||||
config_group=config_update.config_group,
|
||||
)
|
||||
if not result["success"]:
|
||||
error_map = {
|
||||
"CONFIG_NOT_FOUND": (404, _("Configuration not found")),
|
||||
"CONFIG_INVALID_NAME": (400, _("Invalid configuration group")),
|
||||
"CONFIG_DUPLICATE": (400, _("Configuration already exists for this user")),
|
||||
}
|
||||
status, detail = error_map.get(result["error_code"], (500, "Config error"))
|
||||
raise HTTPException(status_code=status, detail=detail)
|
||||
return result["config"]
|
||||
|
||||
|
||||
@router.delete("/configs/{config_id}", name="delete config")
|
||||
async def delete_config(config_id: int, auth: V1UserAuth = Depends(auth_must)):
|
||||
if not ConfigService.delete(config_id, auth.id):
|
||||
raise HTTPException(status_code=404, detail=_("Configuration not found"))
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.get("/config/info", name="get config info")
|
||||
async def get_config_info(
|
||||
show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"),
|
||||
):
|
||||
return ConfigService.get_config_info(show_all)
|
||||
14
server/app/domains/config/schema/__init__.py
Normal file
14
server/app/domains/config/schema/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
17
server/app/domains/config/service/__init__.py
Normal file
17
server/app/domains/config/service/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from .config_service import ConfigService
|
||||
|
||||
__all__ = ["ConfigService"]
|
||||
120
server/app/domains/config/service/config_service.py
Normal file
120
server/app/domains/config/service/config_service.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""ConfigService: user config CRUD with validation. Follows CreditsService pattern."""
|
||||
|
||||
from sqlmodel import select
|
||||
from loguru import logger
|
||||
|
||||
from app.core.database import session_make
|
||||
from app.model.config.config import Config, ConfigInfo
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""User configuration management - static methods, self-managed session."""
|
||||
|
||||
@staticmethod
|
||||
def list_for_user(user_id: int, config_group: str | None = None) -> list[Config]:
|
||||
"""List user configs, optionally filtered by group."""
|
||||
with session_make() as s:
|
||||
query = select(Config).where(Config.user_id == user_id)
|
||||
if config_group is not None:
|
||||
query = query.where(Config.config_group == config_group)
|
||||
return list(s.exec(query).all())
|
||||
|
||||
@staticmethod
|
||||
def get(config_id: int, user_id: int) -> Config | None:
|
||||
"""Get a single config by id, scoped to user."""
|
||||
with session_make() as s:
|
||||
return s.exec(
|
||||
select(Config).where(Config.id == config_id, Config.user_id == user_id)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def create(user_id: int, config_name: str, config_value: str, config_group: str | None = None) -> dict:
|
||||
"""Create config: env var name validation + duplicate check.
|
||||
Returns {"success": True, "config": Config} or {"success": False, "error_code": str}.
|
||||
"""
|
||||
if not ConfigInfo.is_valid_env_var(config_group, config_name):
|
||||
return {"success": False, "error_code": "CONFIG_INVALID_NAME"}
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
with session_make() as s:
|
||||
db_config = Config(
|
||||
user_id=user_id,
|
||||
config_name=config_name,
|
||||
config_value=config_value,
|
||||
config_group=config_group,
|
||||
)
|
||||
s.add(db_config)
|
||||
try:
|
||||
s.commit()
|
||||
except IntegrityError:
|
||||
s.rollback()
|
||||
return {"success": False, "error_code": "CONFIG_DUPLICATE"}
|
||||
s.refresh(db_config)
|
||||
return {"success": True, "config": db_config}
|
||||
|
||||
@staticmethod
|
||||
def update(config_id: int, user_id: int, config_name: str, config_value: str, config_group: str | None = None) -> dict:
|
||||
"""Update config: ownership + validation + duplicate check."""
|
||||
if not ConfigInfo.is_valid_env_var(config_group, config_name):
|
||||
return {"success": False, "error_code": "CONFIG_INVALID_NAME"}
|
||||
|
||||
with session_make() as s:
|
||||
db_config = s.exec(
|
||||
select(Config).where(Config.id == config_id, Config.user_id == user_id)
|
||||
).first()
|
||||
if not db_config:
|
||||
return {"success": False, "error_code": "CONFIG_NOT_FOUND"}
|
||||
|
||||
# Check for name conflict with other configs
|
||||
conflict = s.exec(
|
||||
select(Config).where(
|
||||
Config.user_id == user_id,
|
||||
Config.config_name == config_name,
|
||||
Config.id != config_id,
|
||||
)
|
||||
).first()
|
||||
if conflict:
|
||||
return {"success": False, "error_code": "CONFIG_DUPLICATE"}
|
||||
|
||||
db_config.config_name = config_name
|
||||
db_config.config_value = config_value
|
||||
s.add(db_config)
|
||||
s.commit()
|
||||
s.refresh(db_config)
|
||||
return {"success": True, "config": db_config}
|
||||
|
||||
@staticmethod
|
||||
def delete(config_id: int, user_id: int) -> bool:
|
||||
"""Delete config: ownership check."""
|
||||
with session_make() as s:
|
||||
db_config = s.exec(
|
||||
select(Config).where(Config.id == config_id, Config.user_id == user_id)
|
||||
).first()
|
||||
if not db_config:
|
||||
return False
|
||||
s.delete(db_config)
|
||||
s.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_config_info(show_all: bool = False) -> dict:
|
||||
"""Return available config metadata."""
|
||||
configs = ConfigInfo.getinfo()
|
||||
if show_all:
|
||||
return configs
|
||||
return {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0}
|
||||
15
server/app/domains/mcp/__init__.py
Normal file
15
server/app/domains/mcp/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""MCP domain: MCP components, categories, proxy, user mappings."""
|
||||
14
server/app/domains/mcp/api/__init__.py
Normal file
14
server/app/domains/mcp/api/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
|
|
@ -1,27 +1,27 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlmodel import Session, col, select
|
||||
|
||||
from app.component.database import session
|
||||
from app.core.database import session
|
||||
from app.model.mcp.category import Category, CategoryOut
|
||||
|
||||
router = APIRouter(prefix="/mcp", tags=["Mcp Category"])
|
||||
|
||||
|
||||
@router.get("/categories", name="category list", response_model=list[CategoryOut])
|
||||
def gets(session: Session = Depends(session)):
|
||||
def gets(db_session: Session = Depends(session)):
|
||||
stmt = select(Category).where(Category.no_delete()).order_by(col(Category.priority).asc())
|
||||
return session.exec(stmt)
|
||||
return db_session.exec(stmt)
|
||||
101
server/app/domains/mcp/api/mcp_controller.py
Normal file
101
server/app/domains/mcp/api/mcp_controller.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""MCP controller. Uses McpUserService for install and import."""
|
||||
|
||||
from fastapi import Depends, HTTPException, APIRouter
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy.orm import selectinload, with_loader_criteria
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.mcp.mcp import Mcp, McpOut
|
||||
from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus
|
||||
from app.model.mcp.mcp_user import McpImportType, McpUser
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.shared.middleware.rate_limit import install_rate_limiter
|
||||
from app.domains.mcp.service.mcp_user_service import McpUserService
|
||||
|
||||
router = APIRouter(tags=["Mcp Servers"])
|
||||
|
||||
_INSTALL_ERROR_MAP = {
|
||||
"MCP_NOT_FOUND": (404, "Mcp not found"),
|
||||
"MCP_ALREADY_INSTALLED": (400, "mcp is installed"),
|
||||
"MCP_INVALID_INSTALL_COMMAND": (400, "Install command is not valid json"),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/mcps", name="mcp list")
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
category_id: int | None = None,
|
||||
mine: int | None = None,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> Page[McpOut]:
|
||||
stmt = (
|
||||
select(Mcp)
|
||||
.where(Mcp.no_delete())
|
||||
.options(
|
||||
selectinload(Mcp.category),
|
||||
selectinload(Mcp.envs),
|
||||
with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use),
|
||||
)
|
||||
)
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%"))
|
||||
if category_id:
|
||||
stmt = stmt.where(Mcp.category_id == category_id)
|
||||
if mine and auth:
|
||||
stmt = (
|
||||
stmt.join(McpUser)
|
||||
.where(McpUser.user_id == auth.id)
|
||||
.options(
|
||||
selectinload(Mcp.mcp_user),
|
||||
with_loader_criteria(McpUser, col(McpUser.user_id) == auth.id),
|
||||
)
|
||||
)
|
||||
return paginate(db_session, stmt)
|
||||
|
||||
|
||||
@router.get("/mcp", name="mcp detail", response_model=McpOut)
|
||||
async def get(id: int, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)):
|
||||
stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs))
|
||||
model = db_session.exec(stmt).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=_("Mcp not found"))
|
||||
return model
|
||||
|
||||
|
||||
@router.post("/mcp/install", name="mcp install", dependencies=[install_rate_limiter])
|
||||
async def install(mcp_id: int, db_session: Session = Depends(session), auth: V1UserAuth = Depends(auth_must)):
|
||||
result = McpUserService.install(mcp_id, auth.id, db_session)
|
||||
if not result["success"]:
|
||||
status, msg = _INSTALL_ERROR_MAP.get(result["error_code"], (400, "Install failed"))
|
||||
raise HTTPException(status_code=status, detail=_(msg))
|
||||
return result["mcp_user"]
|
||||
|
||||
|
||||
@router.post("/mcp/import/{mcp_type}", name="mcp import", dependencies=[install_rate_limiter])
|
||||
async def import_mcp(
|
||||
mcp_type: McpImportType, mcp_data: dict, auth: V1UserAuth = Depends(auth_must)
|
||||
):
|
||||
result = McpUserService.import_mcp(mcp_type, mcp_data, auth.id)
|
||||
if not result["success"]:
|
||||
detail = result.get("detail", "Import failed")
|
||||
raise HTTPException(status_code=400, detail=detail)
|
||||
return result
|
||||
|
|
@ -12,19 +12,14 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends
|
||||
from exa_py import Exa
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from app.component.auth import key_must
|
||||
from app.component.environment import env_not_empty
|
||||
from loguru import logger
|
||||
from app.shared.auth.user_auth import key_must
|
||||
from app.core.environment import env_not_empty
|
||||
from app.model.mcp.proxy import ExaSearch
|
||||
|
||||
logger = logging.getLogger("server_proxy_controller")
|
||||
from typing import Any, cast
|
||||
import requests
|
||||
|
||||
from app.model.user.key import Key
|
||||
|
||||
|
|
@ -33,41 +28,25 @@ router = APIRouter(prefix="/proxy", tags=["Mcp Servers"])
|
|||
|
||||
@router.post("/exa")
|
||||
def exa_search(search: ExaSearch, key: Key = Depends(key_must)):
|
||||
"""Search using Exa API."""
|
||||
EXA_API_KEY = env_not_empty("EXA_API_KEY")
|
||||
try:
|
||||
# Validate input parameters
|
||||
if search.num_results is not None and not 0 < search.num_results <= 100:
|
||||
logger.warning("Invalid exa search parameter", extra={"param": "num_results", "value": search.num_results})
|
||||
raise ValueError("num_results must be between 1 and 100")
|
||||
|
||||
if search.include_text is not None and len(search.include_text) > 0:
|
||||
if len(search.include_text) > 1:
|
||||
logger.warning(
|
||||
"Invalid exa search parameter", extra={"param": "include_text", "reason": "more than 1 string"}
|
||||
)
|
||||
raise ValueError("include_text can only contain 1 string")
|
||||
if len(search.include_text[0].split()) > 5:
|
||||
logger.warning(
|
||||
"Invalid exa search parameter", extra={"param": "include_text", "reason": "exceeds 5 words"}
|
||||
)
|
||||
raise ValueError("include_text string cannot be longer than 5 words")
|
||||
|
||||
if search.exclude_text is not None and len(search.exclude_text) > 0:
|
||||
if len(search.exclude_text) > 1:
|
||||
logger.warning(
|
||||
"Invalid exa search parameter", extra={"param": "exclude_text", "reason": "more than 1 string"}
|
||||
)
|
||||
raise ValueError("exclude_text can only contain 1 string")
|
||||
if len(search.exclude_text[0].split()) > 5:
|
||||
logger.warning(
|
||||
"Invalid exa search parameter", extra={"param": "exclude_text", "reason": "exceeds 5 words"}
|
||||
)
|
||||
raise ValueError("exclude_text string cannot be longer than 5 words")
|
||||
|
||||
exa = Exa(EXA_API_KEY)
|
||||
|
||||
# Call Exa API with direct parameters
|
||||
if search.num_results is not None and not 0 < search.num_results <= 100:
|
||||
raise ValueError("num_results must be between 1 and 100")
|
||||
|
||||
if search.include_text is not None:
|
||||
if len(search.include_text) > 1:
|
||||
raise ValueError("include_text can only contain 1 string")
|
||||
if len(search.include_text[0].split()) > 5:
|
||||
raise ValueError("include_text string cannot be longer than 5 words")
|
||||
|
||||
if search.exclude_text is not None:
|
||||
if len(search.exclude_text) > 1:
|
||||
raise ValueError("exclude_text can only contain 1 string")
|
||||
if len(search.exclude_text[0].split()) > 5:
|
||||
raise ValueError("exclude_text string cannot be longer than 5 words")
|
||||
|
||||
if search.text:
|
||||
results = cast(
|
||||
dict[str, Any],
|
||||
|
|
@ -96,41 +75,23 @@ def exa_search(search: ExaSearch, key: Key = Depends(key_must)):
|
|||
),
|
||||
)
|
||||
|
||||
result_count = len(results.get("results", [])) if "results" in results else 0
|
||||
logger.info(
|
||||
"Exa search completed",
|
||||
extra={"query": search.query, "search_type": search.search_type, "result_count": result_count},
|
||||
)
|
||||
return results
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning("Exa search validation error", extra={"error": str(e)})
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
except Exception as e:
|
||||
logger.error("Exa search failed", extra={"query": search.query, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
return {"error": f"Exa search failed: {e!s}"}
|
||||
|
||||
|
||||
@router.get("/google")
|
||||
def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)):
|
||||
"""Search using Google Custom Search API."""
|
||||
# https://developers.google.com/custom-search/v1/overview
|
||||
GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY")
|
||||
# https://cse.google.com/cse/all
|
||||
SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID")
|
||||
|
||||
# Using the first page
|
||||
start_page_idx = 1
|
||||
# Different language may get different result
|
||||
search_language = "en"
|
||||
# How many pages to return
|
||||
num_result_pages = 10
|
||||
|
||||
# Constructing the URL
|
||||
# Doc: https://developers.google.com/custom-search/v1/using_rest
|
||||
base_url = (
|
||||
f"https://www.googleapis.com/customsearch/v1?"
|
||||
f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={quote_plus(query)}&start="
|
||||
f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start="
|
||||
f"{start_page_idx}&lr={search_language}&num={num_result_pages}"
|
||||
)
|
||||
|
||||
|
|
@ -140,32 +101,21 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m
|
|||
url = base_url
|
||||
|
||||
responses = []
|
||||
|
||||
try:
|
||||
# Make the GET request
|
||||
result = requests.get(url)
|
||||
data = result.json()
|
||||
|
||||
# Get the result items
|
||||
if "items" in data:
|
||||
search_items = data.get("items")
|
||||
|
||||
# Iterate over results found
|
||||
for i, search_item in enumerate(search_items, start=1):
|
||||
if search_type == "image":
|
||||
# Process image search results
|
||||
title = search_item.get("title")
|
||||
image_url = search_item.get("link")
|
||||
display_link = search_item.get("displayLink")
|
||||
|
||||
# Get context URL (page containing the image)
|
||||
image_info = search_item.get("image", {})
|
||||
context_url = image_info.get("contextLink", "")
|
||||
|
||||
# Get image dimensions if available
|
||||
width = image_info.get("width")
|
||||
height = image_info.get("height")
|
||||
|
||||
response = {
|
||||
"result_id": i,
|
||||
"title": title,
|
||||
|
|
@ -173,17 +123,12 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m
|
|||
"display_link": display_link,
|
||||
"context_url": context_url,
|
||||
}
|
||||
|
||||
# Add dimensions if available
|
||||
if width:
|
||||
response["width"] = int(width)
|
||||
if height:
|
||||
response["height"] = int(height)
|
||||
|
||||
responses.append(response)
|
||||
else:
|
||||
# Process web search results
|
||||
# Check metatags are present
|
||||
if "pagemap" not in search_item:
|
||||
continue
|
||||
if "metatags" not in search_item["pagemap"]:
|
||||
|
|
@ -192,12 +137,8 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m
|
|||
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
||||
else:
|
||||
long_description = "N/A"
|
||||
# Get the page title
|
||||
title = search_item.get("title")
|
||||
# Page snippet
|
||||
snippet = search_item.get("snippet")
|
||||
|
||||
# Extract the page url
|
||||
link = search_item.get("link")
|
||||
response = {
|
||||
"result_id": i,
|
||||
|
|
@ -207,20 +148,11 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m
|
|||
"url": link,
|
||||
}
|
||||
responses.append(response)
|
||||
|
||||
logger.info(
|
||||
"Google search completed",
|
||||
extra={"query": query, "search_type": search_type, "result_count": len(responses)},
|
||||
)
|
||||
else:
|
||||
error_info = data.get("error", {})
|
||||
logger.error("Google search API error", extra={"query": query, "api_error": error_info})
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
logger.error(f"Google search failed - API response: {error_info}")
|
||||
responses.append({"error": f"Google search failed - API response: {error_info}"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Google search failed", extra={"query": query, "search_type": search_type, "error": str(e)}, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
responses.append({"error": f"google search failed: {e!s}"})
|
||||
return responses
|
||||
99
server/app/domains/mcp/api/user_controller.py
Normal file
99
server/app/domains/mcp/api/user_controller.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""MCP User controller - H15 ownership on GET/DELETE."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate
|
||||
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.ownership import require_owner
|
||||
|
||||
router = APIRouter(tags=["V1 McpUser Management"])
|
||||
|
||||
|
||||
@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut])
|
||||
async def list_mcp_users(
|
||||
mcp_id: Optional[int] = None,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
query = select(McpUser).where(McpUser.user_id == auth.user.id)
|
||||
if mcp_id is not None:
|
||||
query = query.where(McpUser.mcp_id == mcp_id)
|
||||
return list(db_session.exec(query).all())
|
||||
|
||||
|
||||
@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut)
|
||||
async def get_mcp_user(
|
||||
mcp_user_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
mcp_user = db_session.get(McpUser, mcp_user_id)
|
||||
require_owner(mcp_user, auth.user.id)
|
||||
return mcp_user
|
||||
|
||||
|
||||
@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut)
|
||||
async def create_mcp_user(
|
||||
mcp_user: McpUserIn,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
exists = db_session.exec(
|
||||
select(McpUser).where(McpUser.mcp_id == mcp_user.mcp_id, McpUser.user_id == auth.user.id)
|
||||
).first()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail=_("mcp is installed"))
|
||||
db_mcp_user = McpUser(mcp_id=mcp_user.mcp_id, user_id=auth.user.id, env=mcp_user.env)
|
||||
db_session.add(db_mcp_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(db_mcp_user)
|
||||
return db_mcp_user
|
||||
|
||||
|
||||
@router.put("/mcp/users/{id}", name="update mcp user")
|
||||
async def update_mcp_user(
|
||||
id: int,
|
||||
update_item: McpUserUpdate,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
model = db_session.get(McpUser, id)
|
||||
require_owner(model, auth.user.id)
|
||||
update_data = update_item.model_dump(exclude_unset=True)
|
||||
model.update_fields(update_data)
|
||||
model.save(db_session)
|
||||
db_session.refresh(model)
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user")
|
||||
async def delete_mcp_user(
|
||||
mcp_user_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth=Depends(auth_must),
|
||||
):
|
||||
db_mcp_user = db_session.get(McpUser, mcp_user_id)
|
||||
require_owner(db_mcp_user, auth.user.id)
|
||||
db_session.delete(db_mcp_user)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
14
server/app/domains/mcp/schema/__init__.py
Normal file
14
server/app/domains/mcp/schema/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
17
server/app/domains/mcp/service/__init__.py
Normal file
17
server/app/domains/mcp/service/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from .mcp_user_service import McpUserService
|
||||
|
||||
__all__ = ["McpUserService"]
|
||||
175
server/app/domains/mcp/service/mcp_user_service.py
Normal file
175
server/app/domains/mcp/service/mcp_user_service.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""McpUserService: MCP install/uninstall/import with dedup."""
|
||||
|
||||
import json
|
||||
|
||||
from sqlmodel import select
|
||||
from loguru import logger
|
||||
|
||||
from app.core.database import session_make
|
||||
from app.model.mcp.mcp import Mcp, McpType
|
||||
from app.model.mcp.mcp_user import McpImportType, McpUser, Status
|
||||
from app.core.validator.McpServer import (
|
||||
McpRemoteServer,
|
||||
validate_mcp_remote_servers,
|
||||
validate_mcp_servers,
|
||||
)
|
||||
|
||||
|
||||
class McpUserService:
|
||||
"""MCP user installation management - static methods, self-managed session."""
|
||||
|
||||
@staticmethod
|
||||
def list_for_user(user_id: int, mcp_id: int | None = None) -> list[McpUser]:
|
||||
"""List user's MCP installations, optionally filtered by mcp_id."""
|
||||
with session_make() as s:
|
||||
query = select(McpUser).where(McpUser.user_id == user_id)
|
||||
if mcp_id is not None:
|
||||
query = query.where(McpUser.mcp_id == mcp_id)
|
||||
return list(s.exec(query).all())
|
||||
|
||||
@staticmethod
|
||||
def get(mcp_user_id: int, user_id: int) -> McpUser | None:
|
||||
"""Get a single MCP user installation, scoped to user."""
|
||||
with session_make() as s:
|
||||
model = s.get(McpUser, mcp_user_id)
|
||||
if model and model.user_id == user_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def install(mcp_id: int, user_id: int, s) -> dict:
|
||||
"""Install MCP from store: dedup check, parse install_command, create McpUser.
|
||||
Returns {"success": True, "mcp_user": McpUser} or {"success": False, "error_code": str}.
|
||||
"""
|
||||
mcp = s.get(Mcp, mcp_id)
|
||||
if not mcp:
|
||||
return {"success": False, "error_code": "MCP_NOT_FOUND"}
|
||||
|
||||
exists = s.exec(
|
||||
select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == user_id)
|
||||
).first()
|
||||
if exists:
|
||||
return {"success": False, "error_code": "MCP_ALREADY_INSTALLED"}
|
||||
|
||||
install_command_raw = mcp.install_command or "{}"
|
||||
try:
|
||||
install_command = (
|
||||
json.loads(install_command_raw) if isinstance(install_command_raw, str) else install_command_raw or {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
return {"success": False, "error_code": "MCP_INVALID_INSTALL_COMMAND"}
|
||||
if not isinstance(install_command, dict) or "command" not in install_command:
|
||||
return {"success": False, "error_code": "MCP_INVALID_INSTALL_COMMAND"}
|
||||
|
||||
mcp_user = McpUser(
|
||||
mcp_id=mcp.id,
|
||||
user_id=user_id,
|
||||
mcp_name=mcp.name,
|
||||
mcp_key=mcp.key,
|
||||
mcp_desc=mcp.description,
|
||||
type=mcp.type,
|
||||
status=Status.enable,
|
||||
command=install_command["command"],
|
||||
args=install_command.get("args"),
|
||||
env=install_command.get("env"),
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
return {"success": True, "mcp_user": mcp_user}
|
||||
|
||||
@staticmethod
|
||||
def import_mcp(mcp_type: McpImportType, mcp_data: dict, user_id: int) -> dict:
|
||||
"""Import MCP servers from user config. Returns result dict with imported/failed lists."""
|
||||
if mcp_type == McpImportType.Local:
|
||||
is_valid, res = validate_mcp_servers(mcp_data)
|
||||
if not is_valid:
|
||||
return {"success": False, "error_code": "MCP_INVALID_DATA", "detail": res}
|
||||
|
||||
mcp_data_parsed = getattr(res, "mcpServers", mcp_data)
|
||||
imported_names = []
|
||||
failed_names = []
|
||||
for name, data in mcp_data_parsed.items():
|
||||
command = data.command if hasattr(data, "command") else data.get("command")
|
||||
args = data.args if hasattr(data, "args") else data.get("args")
|
||||
env = data.env if hasattr(data, "env") else data.get("env")
|
||||
try:
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=user_id,
|
||||
mcp_name=name,
|
||||
mcp_key=name,
|
||||
mcp_desc=name,
|
||||
type=McpType.Local,
|
||||
status=Status.enable,
|
||||
command=command,
|
||||
args=args,
|
||||
env=env,
|
||||
server_url=None,
|
||||
)
|
||||
mcp_user.save()
|
||||
imported_names.append(name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to import local MCP", extra={"mcp_name": name, "error": str(e)})
|
||||
failed_names.append(name)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Local MCP servers imported successfully",
|
||||
"count": len(imported_names),
|
||||
"imported": imported_names,
|
||||
"failed": failed_names,
|
||||
}
|
||||
|
||||
elif mcp_type == McpImportType.Remote:
|
||||
is_valid, res = validate_mcp_remote_servers(mcp_data)
|
||||
if not is_valid:
|
||||
return {"success": False, "error_code": "MCP_INVALID_DATA", "detail": res}
|
||||
data: McpRemoteServer = res
|
||||
mcp_user = McpUser(
|
||||
mcp_id=0,
|
||||
user_id=user_id,
|
||||
type=McpType.Remote,
|
||||
status=Status.enable,
|
||||
mcp_name=data.server_name,
|
||||
server_url=data.server_url,
|
||||
)
|
||||
mcp_user.save()
|
||||
return {"success": True, "mcp_user": mcp_user}
|
||||
|
||||
return {"success": False, "error_code": "MCP_INVALID_TYPE"}
|
||||
|
||||
@staticmethod
|
||||
def update(mcp_user_id: int, user_id: int, data: dict) -> dict:
|
||||
"""Update MCP user: ownership check."""
|
||||
with session_make() as s:
|
||||
model = s.get(McpUser, mcp_user_id)
|
||||
if not model or model.user_id != user_id:
|
||||
return {"success": False, "error_code": "MCP_USER_NOT_FOUND"}
|
||||
model.update_fields(data)
|
||||
model.save(s)
|
||||
s.refresh(model)
|
||||
return {"success": True, "mcp_user": model}
|
||||
|
||||
@staticmethod
|
||||
def uninstall(mcp_user_id: int, user_id: int) -> bool:
|
||||
"""Uninstall MCP: ownership check + delete."""
|
||||
with session_make() as s:
|
||||
model = s.get(McpUser, mcp_user_id)
|
||||
if not model or model.user_id != user_id:
|
||||
return False
|
||||
s.delete(model)
|
||||
s.commit()
|
||||
return True
|
||||
15
server/app/domains/model_provider/__init__.py
Normal file
15
server/app/domains/model_provider/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Model Provider domain - API key / model supplier management."""
|
||||
14
server/app/domains/model_provider/api/__init__.py
Normal file
14
server/app/domains/model_provider/api/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
91
server/app/domains/model_provider/api/provider_controller.py
Normal file
91
server/app/domains/model_provider/api/provider_controller.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
# STATUS: full-rewrite (uses ProviderService, self-managed session)
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from fastapi_babel import _
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, select, col
|
||||
|
||||
from app.core.database import session
|
||||
from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.domains.model_provider.service.provider_service import ProviderService
|
||||
|
||||
router = APIRouter(tags=["Provider Management"])
|
||||
|
||||
|
||||
@router.get("/providers", name="list providers", response_model=Page[ProviderOut])
|
||||
async def gets(
|
||||
keyword: str | None = None,
|
||||
prefer: Optional[bool] = Query(None, description="Filter by prefer status"),
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> Page[ProviderOut]:
|
||||
# Pagination still needs session for paginate() — use session directly for this read-only query
|
||||
stmt = select(Provider).where(Provider.user_id == auth.id, Provider.no_delete())
|
||||
if keyword:
|
||||
stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%"))
|
||||
if prefer is not None:
|
||||
stmt = stmt.where(Provider.prefer == prefer)
|
||||
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc())
|
||||
return paginate(db_session, stmt)
|
||||
|
||||
|
||||
@router.get("/provider", name="get provider detail", response_model=ProviderOut)
|
||||
async def get(id: int, auth: V1UserAuth = Depends(auth_must)):
|
||||
model = ProviderService.get(id, auth.id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
return model
|
||||
|
||||
|
||||
@router.post("/provider", name="create provider", response_model=ProviderOut)
|
||||
async def post(data: ProviderIn, auth: V1UserAuth = Depends(auth_must)):
|
||||
result = ProviderService.create(auth.id, data.model_dump())
|
||||
return result["provider"]
|
||||
|
||||
|
||||
@router.put("/provider/{id}", name="update provider", response_model=ProviderOut)
|
||||
async def put(id: int, data: ProviderIn, auth: V1UserAuth = Depends(auth_must)):
|
||||
result = ProviderService.update(id, auth.id, data.model_dump())
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
return result["provider"]
|
||||
|
||||
|
||||
@router.delete("/provider/{id}", name="delete provider")
|
||||
async def delete(id: int, auth: V1UserAuth = Depends(auth_must)):
|
||||
if not ProviderService.delete(id, auth.id):
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch("/provider/{id}/invalidate", name="invalidate provider")
|
||||
async def invalidate(id: int, auth: V1UserAuth = Depends(auth_must)):
|
||||
result = ProviderService.invalidate(id, auth.id)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=404, detail=_("Provider not found"))
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/provider/prefer", name="set provider prefer")
|
||||
async def set_prefer(data: ProviderPreferIn, auth: V1UserAuth = Depends(auth_must)):
|
||||
result = ProviderService.set_prefer(data.provider_id, auth.id)
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail="Failed to set prefer")
|
||||
return {"success": True}
|
||||
14
server/app/domains/model_provider/schema/__init__.py
Normal file
14
server/app/domains/model_provider/schema/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
17
server/app/domains/model_provider/service/__init__.py
Normal file
17
server/app/domains/model_provider/service/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from .provider_service import ProviderService
|
||||
|
||||
__all__ = ["ProviderService"]
|
||||
121
server/app/domains/model_provider/service/provider_service.py
Normal file
121
server/app/domains/model_provider/service/provider_service.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""ProviderService: model provider CRUD with prefer/invalidate. Follows CreditsService pattern."""
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import select, col
|
||||
from loguru import logger
|
||||
|
||||
from app.core.database import session_make
|
||||
from app.model.provider.provider import Provider, VaildStatus
|
||||
|
||||
|
||||
class ProviderService:
|
||||
"""Model provider management - static methods, self-managed session."""
|
||||
|
||||
@staticmethod
|
||||
def list_for_user(user_id: int, keyword: str | None = None, prefer: bool | None = None) -> list[Provider]:
|
||||
"""List user providers, supports keyword search and prefer filter."""
|
||||
with session_make() as s:
|
||||
stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete())
|
||||
if keyword:
|
||||
safe_keyword = keyword.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
stmt = stmt.where(col(Provider.provider_name).like(f"%{safe_keyword}%"))
|
||||
if prefer is not None:
|
||||
stmt = stmt.where(Provider.prefer == prefer)
|
||||
stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc())
|
||||
return list(s.exec(stmt).all())
|
||||
|
||||
@staticmethod
|
||||
def get(provider_id: int, user_id: int) -> Provider | None:
|
||||
"""Get a single provider by id, scoped to user."""
|
||||
with session_make() as s:
|
||||
model = s.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
).first()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def create(user_id: int, data: dict) -> dict:
|
||||
"""Create provider. Returns {"success": True, "provider": Provider}."""
|
||||
with session_make() as s:
|
||||
model = Provider(**data, user_id=user_id)
|
||||
model.save(s)
|
||||
s.refresh(model)
|
||||
return {"success": True, "provider": model}
|
||||
|
||||
@staticmethod
|
||||
def update(provider_id: int, user_id: int, data: dict) -> dict:
|
||||
"""Update provider: ownership check."""
|
||||
with session_make() as s:
|
||||
model = s.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
).first()
|
||||
if not model:
|
||||
return {"success": False, "error_code": "PROVIDER_NOT_FOUND"}
|
||||
# H10: only allow updating safe fields
|
||||
_UPDATABLE_FIELDS = {"provider_name", "api_key", "api_base", "extra_config", "prefer", "is_vaild"}
|
||||
for key, value in data.items():
|
||||
if key in _UPDATABLE_FIELDS:
|
||||
setattr(model, key, value)
|
||||
model.save(s)
|
||||
s.refresh(model)
|
||||
return {"success": True, "provider": model}
|
||||
|
||||
@staticmethod
|
||||
def delete(provider_id: int, user_id: int) -> bool:
|
||||
"""Soft-delete provider."""
|
||||
with session_make() as s:
|
||||
model = s.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
).first()
|
||||
if not model:
|
||||
return False
|
||||
model.delete(s)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def invalidate(provider_id: int, user_id: int) -> dict:
|
||||
"""Mark provider as not_valid."""
|
||||
with session_make() as s:
|
||||
model = s.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
).first()
|
||||
if not model:
|
||||
return {"success": False, "error_code": "PROVIDER_NOT_FOUND"}
|
||||
model.is_vaild = VaildStatus.not_valid
|
||||
model.save(s)
|
||||
s.refresh(model)
|
||||
logger.info("Provider invalidated", extra={"user_id": user_id, "provider_id": provider_id})
|
||||
return {"success": True, "provider": model}
|
||||
|
||||
@staticmethod
|
||||
def set_prefer(provider_id: int, user_id: int) -> dict:
|
||||
"""Set preferred provider: lock user rows, clear all prefer flags, then set the specified one."""
|
||||
with session_make() as s:
|
||||
# Lock all provider rows for this user to prevent concurrent prefer changes
|
||||
s.exec(
|
||||
select(Provider).where(Provider.user_id == user_id, Provider.no_delete()).with_for_update()
|
||||
).all()
|
||||
# Clear all prefer flags for this user
|
||||
s.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False))
|
||||
# Set the specified provider as preferred
|
||||
s.exec(
|
||||
update(Provider)
|
||||
.where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == provider_id)
|
||||
.values(prefer=True)
|
||||
)
|
||||
s.commit()
|
||||
return {"success": True}
|
||||
14
server/app/domains/oauth/__init__.py
Normal file
14
server/app/domains/oauth/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
14
server/app/domains/oauth/api/__init__.py
Normal file
14
server/app/domains/oauth/api/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
97
server/app/domains/oauth/api/oauth_controller.py
Normal file
97
server/app/domains/oauth/api/oauth_controller.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
# STATUS: full-rewrite (uses OAuthService, H12 code/state XSS validation)
|
||||
import json
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from app.core.oauth_adapter import OauthCallbackPayload
|
||||
from app.domains.oauth.schema import OAuthAuthorizeReq, OAuthTokenReq
|
||||
from app.domains.oauth.service.oauth_service import OAuthService
|
||||
|
||||
router = APIRouter(prefix="/oauth", tags=["Oauth Servers"])
|
||||
|
||||
_OAUTH_CODE_MAX_LEN = 2048
|
||||
_OAUTH_STATE_MAX_LEN = 512
|
||||
_OAUTH_SAFE_CHARS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.")
|
||||
|
||||
|
||||
def _validate_oauth_param(value: Optional[str], name: str, max_len: int) -> str:
|
||||
"""Validate and sanitize OAuth callback params to prevent XSS injection (H12)."""
|
||||
if value is None:
|
||||
return ""
|
||||
s = str(value).strip()
|
||||
if len(s) > max_len:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {name}: too long")
|
||||
if s and not all(c in _OAUTH_SAFE_CHARS or c in " /+=" for c in s):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {name}: invalid characters")
|
||||
return s
|
||||
|
||||
|
||||
@router.get("/{app}/login", name="OAuth Login Redirect")
|
||||
def oauth_login(app: str, request: Request, state: Optional[str] = None):
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://"):]
|
||||
|
||||
result = OAuthService.get_authorize_url(
|
||||
OAuthAuthorizeReq(provider=app, redirect_uri=callback_url, state=state)
|
||||
)
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=400, detail="Failed to generate authorization URL")
|
||||
return RedirectResponse(str(result.authorize_url))
|
||||
|
||||
|
||||
@router.get("/{app}/callback", name="OAuth Callback")
|
||||
def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None):
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing code parameter")
|
||||
safe_code = _validate_oauth_param(code, "code", _OAUTH_CODE_MAX_LEN)
|
||||
safe_state = _validate_oauth_param(state, "state", _OAUTH_STATE_MAX_LEN)
|
||||
safe_app = _validate_oauth_param(app, "provider", 64) or ""
|
||||
query = f"provider={quote(safe_app, safe='')}&code={quote(safe_code, safe='')}&state={quote(safe_state, safe='')}"
|
||||
redirect_url = f"eigent://callback/oauth?{query}"
|
||||
html_content = f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>OAuth Callback</title>
|
||||
</head>
|
||||
<body>
|
||||
<script type='text/javascript'>
|
||||
window.location.href = {json.dumps(redirect_url)};
|
||||
</script>
|
||||
<p>Redirecting, please wait...</p>
|
||||
<button onclick='window.close()'>Close this window</button>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
@router.post("/{app}/token", name="OAuth Fetch Token")
|
||||
def fetch_token(app: str, request: Request, data: OauthCallbackPayload):
|
||||
callback_url = str(request.url_for("OAuth Callback", app=app))
|
||||
if callback_url.startswith("http://"):
|
||||
callback_url = "https://" + callback_url[len("http://"):]
|
||||
|
||||
result = OAuthService.exchange_token(
|
||||
OAuthTokenReq(provider=app, code=data.code, redirect_uri=callback_url)
|
||||
)
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=500, detail="Token exchange failed")
|
||||
return JSONResponse(result.token_data)
|
||||
27
server/app/domains/oauth/schema/__init__.py
Normal file
27
server/app/domains/oauth/schema/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""OAuth domain models - re-exports v1 schemas."""
|
||||
|
||||
from app.domains.oauth.schema.schemas import (
|
||||
OAuthAuthorizeReq,
|
||||
OAuthTokenReq,
|
||||
OAuthResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OAuthAuthorizeReq",
|
||||
"OAuthTokenReq",
|
||||
"OAuthResult",
|
||||
]
|
||||
36
server/app/domains/oauth/schema/schemas.py
Normal file
36
server/app/domains/oauth/schema/schemas.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""v1 OAuth request/response schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OAuthAuthorizeReq(BaseModel):
|
||||
provider: str
|
||||
redirect_uri: str | None = None
|
||||
state: str | None = None
|
||||
|
||||
|
||||
class OAuthTokenReq(BaseModel):
|
||||
provider: str
|
||||
code: str
|
||||
redirect_uri: str | None = None
|
||||
|
||||
|
||||
class OAuthResult(BaseModel):
|
||||
success: bool
|
||||
authorize_url: str | None = None
|
||||
token_data: dict | None = None
|
||||
error_code: str | None = None
|
||||
17
server/app/domains/oauth/service/__init__.py
Normal file
17
server/app/domains/oauth/service/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from .oauth_service import OAuthService
|
||||
|
||||
__all__ = ["OAuthService"]
|
||||
55
server/app/domains/oauth/service/oauth_service.py
Normal file
55
server/app/domains/oauth/service/oauth_service.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""OAuthService: adapter factory wrapper for unified OAuth flow. Follows CreditsService pattern."""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.core.oauth_adapter import get_oauth_adapter, OAUTH_ADAPTERS
|
||||
from app.domains.oauth.schema import OAuthAuthorizeReq, OAuthTokenReq, OAuthResult
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""OAuth operations - static methods, wraps adapter factory."""
|
||||
|
||||
@staticmethod
|
||||
def get_authorize_url(req: OAuthAuthorizeReq) -> OAuthResult:
|
||||
"""Get OAuth authorization URL via adapter factory."""
|
||||
try:
|
||||
adapter = get_oauth_adapter(req.provider, req.redirect_uri)
|
||||
url = adapter.get_authorize_url(req.state)
|
||||
if not url:
|
||||
return OAuthResult(success=False, error_code="OAUTH_URL_FAILED")
|
||||
return OAuthResult(success=True, authorize_url=url)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth authorize failed provider={req.provider}: {e}")
|
||||
return OAuthResult(success=False, error_code="OAUTH_PROVIDER_ERROR")
|
||||
|
||||
@staticmethod
|
||||
def exchange_token(req: OAuthTokenReq) -> OAuthResult:
|
||||
"""Exchange authorization code for token via adapter factory."""
|
||||
try:
|
||||
adapter = get_oauth_adapter(req.provider, req.redirect_uri)
|
||||
token_data = adapter.fetch_token(req.code)
|
||||
if not token_data:
|
||||
return OAuthResult(success=False, error_code="OAUTH_TOKEN_FAILED")
|
||||
return OAuthResult(success=True, token_data=token_data)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token exchange failed provider={req.provider}: {e}")
|
||||
return OAuthResult(success=False, error_code="OAUTH_PROVIDER_ERROR")
|
||||
|
||||
@staticmethod
|
||||
def list_providers() -> list[str]:
|
||||
"""Return list of supported OAuth providers."""
|
||||
return list(set(OAUTH_ADAPTERS.keys()))
|
||||
15
server/app/domains/trigger/__init__.py
Normal file
15
server/app/domains/trigger/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Trigger domain: triggers, executions, webhooks, schedules, slack."""
|
||||
14
server/app/domains/trigger/api/__init__.py
Normal file
14
server/app/domains/trigger/api/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
57
server/app/domains/trigger/api/slack_controller.py
Normal file
57
server/app/domains/trigger/api/slack_controller.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Slack integration controller. Uses SlackService."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.domains.trigger.service.slack_service import SlackService
|
||||
|
||||
|
||||
class SlackChannelOut(BaseModel):
|
||||
"""Output model for Slack channels."""
|
||||
id: str
|
||||
name: str
|
||||
is_private: bool = False
|
||||
is_member: bool = False
|
||||
num_members: Optional[int] = None
|
||||
|
||||
|
||||
class SlackChannelsResponse(BaseModel):
|
||||
"""Response model for Slack channels list."""
|
||||
channels: List[SlackChannelOut]
|
||||
has_credentials: bool
|
||||
|
||||
|
||||
router = APIRouter(prefix="/trigger/slack", tags=["Slack Integration"])
|
||||
|
||||
|
||||
@router.get("/channels", name="get slack channels")
|
||||
def get_slack_channels(auth: V1UserAuth = Depends(auth_must)) -> SlackChannelsResponse:
|
||||
result = SlackService.get_channels(auth.id)
|
||||
if not result["success"]:
|
||||
error_map = {
|
||||
"SLACK_SDK_NOT_INSTALLED": (500, "Slack SDK not installed on server"),
|
||||
"SLACK_API_ERROR": (400, f"Slack API error: {result.get('detail', 'Unknown error')}"),
|
||||
}
|
||||
status, msg = error_map.get(result["error_code"], (500, "Failed to fetch Slack channels"))
|
||||
raise HTTPException(status_code=status, detail=msg)
|
||||
return SlackChannelsResponse(
|
||||
channels=[SlackChannelOut(**ch) for ch in result["channels"]],
|
||||
has_credentials=result["has_credentials"],
|
||||
)
|
||||
225
server/app/domains/trigger/api/trigger_controller.py
Normal file
225
server/app/domains/trigger/api/trigger_controller.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Trigger controller. Uses TriggerCrudService for business logic."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, Query
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, select, desc, and_, delete
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
from app.model.trigger.trigger import Trigger, TriggerIn, TriggerOut, TriggerUpdate, TriggerConfigSchemaOut
|
||||
from app.model.trigger.trigger_execution import TriggerExecution, TriggerExecutionOut
|
||||
from app.model.trigger.app_configs import get_config_schema, has_config
|
||||
from app.shared.types.trigger_types import TriggerType, TriggerStatus
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.core.database import session
|
||||
from app.domains.trigger.service.trigger_crud_service import TriggerCrudService
|
||||
|
||||
router = APIRouter(prefix="/trigger", tags=["Triggers"])
|
||||
|
||||
|
||||
def _raise_on_error(result: dict) -> None:
|
||||
"""Convert service error dict to HTTPException."""
|
||||
if result["success"]:
|
||||
return
|
||||
status_code = result.get("status_code", 500)
|
||||
error = result.get("error", "Internal server error")
|
||||
raise HTTPException(status_code=status_code, detail=error)
|
||||
|
||||
|
||||
@router.post("/", name="create trigger", response_model=TriggerOut)
|
||||
def create_trigger(
|
||||
data: TriggerIn,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Create a new trigger."""
|
||||
try:
|
||||
result = TriggerCrudService.create(data, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["trigger_out"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error("Trigger creation failed", extra={"user_id": auth.id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/", name="list triggers")
|
||||
def list_triggers(
|
||||
trigger_type: Optional[TriggerType] = Query(None, description="Filter by trigger type"),
|
||||
status: Optional[TriggerStatus] = Query(None, description="Filter by status"),
|
||||
project_id: Optional[str] = Query(None, description="Filter by project ID"),
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> Page[TriggerOut]:
|
||||
"""List triggers for current user."""
|
||||
user_id = auth.id
|
||||
conditions = [Trigger.user_id == str(user_id)]
|
||||
if trigger_type:
|
||||
conditions.append(Trigger.trigger_type == trigger_type)
|
||||
if status is not None:
|
||||
conditions.append(Trigger.status == status)
|
||||
if project_id:
|
||||
conditions.append(Trigger.project_id == project_id)
|
||||
|
||||
stmt = select(Trigger).where(and_(*conditions)).order_by(desc(Trigger.created_at))
|
||||
result = paginate(db_session, stmt)
|
||||
|
||||
# Enrich with execution counts
|
||||
trigger_ids = [t.id for t in result.items]
|
||||
counts = TriggerCrudService.get_execution_counts(db_session, trigger_ids)
|
||||
result.items = [TriggerCrudService.trigger_to_out(t, counts.get(t.id, 0)) for t in result.items]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{trigger_id}", name="get trigger", response_model=TriggerOut)
|
||||
def get_trigger(
|
||||
trigger_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Get a specific trigger by ID."""
|
||||
trigger = db_session.exec(
|
||||
select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(auth.id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
counts = TriggerCrudService.get_execution_counts(db_session, [trigger_id])
|
||||
return TriggerCrudService.trigger_to_out(trigger, counts.get(trigger_id, 0))
|
||||
|
||||
|
||||
@router.put("/{trigger_id}", name="update trigger", response_model=TriggerOut)
|
||||
def update_trigger(
|
||||
trigger_id: int,
|
||||
data: TriggerUpdate,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Update a trigger."""
|
||||
try:
|
||||
result = TriggerCrudService.update(trigger_id, data, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["trigger_out"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error("Trigger update failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{trigger_id}", name="delete trigger")
|
||||
def delete_trigger(
|
||||
trigger_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Delete a trigger and its executions."""
|
||||
trigger = db_session.exec(
|
||||
select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(auth.id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
try:
|
||||
db_session.exec(delete(TriggerExecution).where(TriggerExecution.trigger_id == trigger_id))
|
||||
db_session.delete(trigger)
|
||||
db_session.commit()
|
||||
logger.info("Trigger deleted", extra={"user_id": auth.id, "trigger_id": trigger_id})
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error("Trigger deletion failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{trigger_id}/activate", name="activate trigger", response_model=TriggerOut)
|
||||
def activate_trigger(
|
||||
trigger_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Activate a trigger."""
|
||||
try:
|
||||
result = TriggerCrudService.activate(trigger_id, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["trigger_out"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error("Trigger activation failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{trigger_id}/deactivate", name="deactivate trigger", response_model=TriggerOut)
|
||||
def deactivate_trigger(
|
||||
trigger_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Deactivate a trigger."""
|
||||
try:
|
||||
result = TriggerCrudService.deactivate(trigger_id, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["trigger_out"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error("Trigger deactivation failed", extra={"user_id": auth.id, "trigger_id": trigger_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{trigger_id}/executions", name="list trigger executions")
|
||||
def list_trigger_executions(
|
||||
trigger_id: int,
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> Page[TriggerExecutionOut]:
|
||||
"""List executions for a specific trigger."""
|
||||
trigger = db_session.exec(
|
||||
select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(auth.id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
stmt = (
|
||||
select(TriggerExecution)
|
||||
.where(TriggerExecution.trigger_id == trigger_id)
|
||||
.order_by(desc(TriggerExecution.created_at))
|
||||
)
|
||||
return paginate(db_session, stmt)
|
||||
|
||||
|
||||
@router.get("/{trigger_type}/config", name="get trigger type config schema")
|
||||
def get_trigger_type_config(
|
||||
trigger_type: TriggerType,
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> TriggerConfigSchemaOut:
|
||||
"""Get the configuration schema for a specific trigger type."""
|
||||
schema = get_config_schema(trigger_type)
|
||||
return TriggerConfigSchemaOut(
|
||||
trigger_type=trigger_type.value,
|
||||
has_config=has_config(trigger_type),
|
||||
schema_=schema,
|
||||
)
|
||||
|
|
@ -12,111 +12,62 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""Trigger Execution controller. Uses TriggerCrudService for REST, WebSocket handled locally."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, WebSocket, WebSocketDisconnect
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.sqlmodel import paginate
|
||||
from sqlmodel import Session, select, desc, and_
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
import logging
|
||||
from loguru import logger
|
||||
import asyncio
|
||||
|
||||
from app.model.trigger.trigger_execution import (
|
||||
TriggerExecution,
|
||||
TriggerExecutionIn,
|
||||
TriggerExecutionOut,
|
||||
TriggerExecutionUpdate
|
||||
TriggerExecution,
|
||||
TriggerExecutionIn,
|
||||
TriggerExecutionOut,
|
||||
TriggerExecutionUpdate,
|
||||
)
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.user.user import User
|
||||
from app.type.trigger_types import ExecutionStatus, ExecutionType
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.redis_utils import get_redis_manager
|
||||
from app.service.trigger.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger("server_trigger_execution_controller")
|
||||
from app.shared.types.trigger_types import ExecutionStatus, ExecutionType
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.core.database import session
|
||||
from app.core.redis_utils import get_redis_manager
|
||||
from app.domains.trigger.service.trigger_crud_service import TriggerCrudService
|
||||
|
||||
# Store active WebSocket connections per session (WebSocket objects only, metadata in Redis)
|
||||
# Format: {session_id: WebSocket}
|
||||
# This is per-worker, and Redis pub/sub is used to broadcast across workers
|
||||
active_websockets: Dict[str, WebSocket] = {}
|
||||
|
||||
# Background task for Redis pub/sub
|
||||
_pubsub_task = None
|
||||
|
||||
router = APIRouter(prefix="/execution", tags=["Trigger Executions"])
|
||||
|
||||
|
||||
def _raise_on_error(result: dict) -> None:
|
||||
"""Convert service error dict to HTTPException."""
|
||||
if result["success"]:
|
||||
return
|
||||
raise HTTPException(status_code=result.get("status_code", 500), detail=result.get("error", "Internal server error"))
|
||||
|
||||
|
||||
@router.post("/", name="create trigger execution", response_model=TriggerExecutionOut)
|
||||
async def create_trigger_execution(
|
||||
data: TriggerExecutionIn,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Create a new trigger execution."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Verify the trigger exists and belongs to the user
|
||||
trigger = session.exec(
|
||||
select(Trigger).where(
|
||||
and_(Trigger.id == data.trigger_id, Trigger.user_id == str(user_id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not trigger:
|
||||
logger.warning("Trigger not found for execution creation", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": data.trigger_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
|
||||
try:
|
||||
execution_data = data.model_dump()
|
||||
execution = TriggerExecution(**execution_data)
|
||||
|
||||
session.add(execution)
|
||||
session.commit()
|
||||
session.refresh(execution)
|
||||
|
||||
# Update trigger last executed timestamp
|
||||
trigger.last_executed_at = datetime.now(timezone.utc)
|
||||
session.add(trigger)
|
||||
session.commit()
|
||||
|
||||
logger.info("Trigger execution created", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": data.trigger_id,
|
||||
"execution_id": execution.execution_id,
|
||||
"execution_type": data.execution_type.value
|
||||
})
|
||||
|
||||
# Publish to Redis pub/sub (broadcasts to all workers)
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "execution_created",
|
||||
"execution_id": execution.execution_id,
|
||||
"trigger_id": trigger.id,
|
||||
"trigger_type": trigger.trigger_type.value if trigger.trigger_type else "unknown",
|
||||
"task_prompt": trigger.task_prompt,
|
||||
"status": execution.status.value,
|
||||
"input_data": execution.input_data,
|
||||
"execution_type": data.execution_type.value,
|
||||
"user_id": str(user_id),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"project_id": str(trigger.project_id)
|
||||
})
|
||||
|
||||
return execution
|
||||
|
||||
result = TriggerCrudService.create_execution(data, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["execution"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Trigger execution creation failed", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": data.trigger_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
db_session.rollback()
|
||||
logger.error("Trigger execution creation failed", extra={"user_id": auth.id, "trigger_id": data.trigger_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
|
|
@ -125,90 +76,43 @@ def list_executions(
|
|||
trigger_id: Optional[int] = None,
|
||||
status: Optional[ExecutionStatus] = None,
|
||||
execution_type: Optional[ExecutionType] = None,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
) -> Page[TriggerExecutionOut]:
|
||||
"""List trigger executions for current user."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get all trigger IDs that belong to the user
|
||||
user_trigger_ids = session.exec(
|
||||
select(Trigger.id).where(Trigger.user_id == str(user_id))
|
||||
).all()
|
||||
|
||||
user_id = auth.id
|
||||
user_trigger_ids = db_session.exec(select(Trigger.id).where(Trigger.user_id == str(user_id))).all()
|
||||
if not user_trigger_ids:
|
||||
# User has no triggers, return empty result
|
||||
return Page(items=[], total=0, page=1, size=50, pages=0)
|
||||
|
||||
# Build conditions
|
||||
|
||||
conditions = [TriggerExecution.trigger_id.in_(user_trigger_ids)]
|
||||
|
||||
if trigger_id:
|
||||
if trigger_id not in user_trigger_ids:
|
||||
raise HTTPException(status_code=404, detail="Trigger not found")
|
||||
conditions.append(TriggerExecution.trigger_id == trigger_id)
|
||||
|
||||
if status is not None:
|
||||
conditions.append(TriggerExecution.status == status)
|
||||
|
||||
if execution_type:
|
||||
conditions.append(TriggerExecution.execution_type == execution_type)
|
||||
|
||||
stmt = (
|
||||
select(TriggerExecution)
|
||||
.where(and_(*conditions))
|
||||
.order_by(desc(TriggerExecution.created_at))
|
||||
)
|
||||
|
||||
result = paginate(session, stmt)
|
||||
total = result.total if hasattr(result, 'total') else 0
|
||||
|
||||
logger.debug("Executions listed", extra={
|
||||
"user_id": user_id,
|
||||
"total": total,
|
||||
"filters": {
|
||||
"trigger_id": trigger_id,
|
||||
"status": status.value if status is not None else None,
|
||||
"execution_type": execution_type.value if execution_type else None
|
||||
}
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
stmt = select(TriggerExecution).where(and_(*conditions)).order_by(desc(TriggerExecution.created_at))
|
||||
return paginate(db_session, stmt)
|
||||
|
||||
|
||||
@router.get("/{execution_id}", name="get execution", response_model=TriggerExecutionOut)
|
||||
def get_execution(
|
||||
execution_id: str,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Get a specific execution by execution ID."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get the execution and verify ownership through trigger
|
||||
execution = session.exec(
|
||||
select(TriggerExecution)
|
||||
.join(Trigger)
|
||||
.where(
|
||||
and_(
|
||||
TriggerExecution.execution_id == execution_id,
|
||||
Trigger.user_id == str(user_id)
|
||||
)
|
||||
execution = db_session.exec(
|
||||
select(TriggerExecution).join(Trigger).where(
|
||||
and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(auth.id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not execution:
|
||||
logger.warning("Execution not found", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
logger.debug("Execution retrieved", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id
|
||||
})
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
|
|
@ -216,245 +120,67 @@ def get_execution(
|
|||
async def update_execution(
|
||||
execution_id: str,
|
||||
data: TriggerExecutionUpdate,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Update a trigger execution."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get the execution and verify ownership through trigger
|
||||
execution = session.exec(
|
||||
select(TriggerExecution)
|
||||
.join(Trigger)
|
||||
.where(
|
||||
and_(
|
||||
TriggerExecution.execution_id == execution_id,
|
||||
Trigger.user_id == str(user_id)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not execution:
|
||||
logger.warning("Execution not found for update", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
try:
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# Check if status is being updated - use TriggerService for proper failure tracking
|
||||
if "status" in update_data:
|
||||
trigger_service = TriggerService(session)
|
||||
# Convert status string back to enum for TriggerService
|
||||
status_value = ExecutionStatus(update_data["status"]) if isinstance(update_data["status"], str) else update_data["status"]
|
||||
trigger_service.update_execution_status(
|
||||
execution=execution,
|
||||
status=status_value,
|
||||
output_data=update_data.get("output_data"),
|
||||
error_message=update_data.get("error_message"),
|
||||
tokens_used=update_data.get("tokens_used"),
|
||||
tools_executed=update_data.get("tools_executed")
|
||||
)
|
||||
# Remove status-related fields from update_data since TriggerService handled them
|
||||
for key in ["status", "output_data", "error_message", "tokens_used", "tools_executed"]:
|
||||
update_data.pop(key, None)
|
||||
|
||||
# Update remaining fields
|
||||
if update_data:
|
||||
# Auto-calculate duration if both started_at and completed_at are set
|
||||
if ("started_at" in update_data or "completed_at" in update_data) and execution.started_at:
|
||||
completed_at = update_data.get("completed_at") or execution.completed_at
|
||||
if completed_at:
|
||||
# Ensure both datetimes are timezone-aware for subtraction
|
||||
started_at = execution.started_at
|
||||
if started_at.tzinfo is None:
|
||||
started_at = started_at.replace(tzinfo=timezone.utc)
|
||||
if completed_at.tzinfo is None:
|
||||
completed_at = completed_at.replace(tzinfo=timezone.utc)
|
||||
duration = (completed_at - started_at).total_seconds()
|
||||
update_data["duration_seconds"] = duration
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(execution, key, value)
|
||||
|
||||
session.add(execution)
|
||||
session.commit()
|
||||
|
||||
session.refresh(execution)
|
||||
|
||||
# Get trigger for event publishing
|
||||
trigger = session.get(Trigger, execution.trigger_id)
|
||||
|
||||
logger.info("Execution updated", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id,
|
||||
"fields_updated": list(data.model_dump(exclude_unset=True).keys())
|
||||
})
|
||||
|
||||
# Publish to Redis pub/sub (broadcasts to all workers)
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "execution_updated",
|
||||
"execution_id": execution_id,
|
||||
"trigger_id": execution.trigger_id,
|
||||
"status": execution.status.value,
|
||||
"updated_fields": list(update_data.keys()),
|
||||
"user_id": str(user_id),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"project_id": str(trigger.project_id) if trigger else None
|
||||
})
|
||||
|
||||
return execution
|
||||
|
||||
result = TriggerCrudService.update_execution(execution_id, data, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["execution"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Execution update failed", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
db_session.rollback()
|
||||
logger.error("Execution update failed", extra={"user_id": auth.id, "execution_id": execution_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{execution_id}", name="delete execution")
|
||||
def delete_execution(
|
||||
execution_id: str,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Delete a trigger execution."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get the execution and verify ownership through trigger
|
||||
execution = session.exec(
|
||||
select(TriggerExecution)
|
||||
.join(Trigger)
|
||||
.where(
|
||||
and_(
|
||||
TriggerExecution.execution_id == execution_id,
|
||||
Trigger.user_id == str(user_id)
|
||||
)
|
||||
execution = db_session.exec(
|
||||
select(TriggerExecution).join(Trigger).where(
|
||||
and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(auth.id))
|
||||
)
|
||||
).first()
|
||||
|
||||
if not execution:
|
||||
logger.warning("Execution not found for deletion", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
try:
|
||||
session.delete(execution)
|
||||
session.commit()
|
||||
|
||||
logger.info("Execution deleted", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id
|
||||
})
|
||||
|
||||
db_session.delete(execution)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Execution deletion failed", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
db_session.rollback()
|
||||
logger.error("Execution deletion failed", extra={"user_id": auth.id, "execution_id": execution_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/{execution_id}/retry", name="retry execution", response_model=TriggerExecutionOut)
|
||||
def retry_execution(
|
||||
execution_id: str,
|
||||
session: Session = Depends(session),
|
||||
auth: Auth = Depends(auth_must)
|
||||
db_session: Session = Depends(session),
|
||||
auth: V1UserAuth = Depends(auth_must),
|
||||
):
|
||||
"""Retry a failed execution."""
|
||||
user_id = auth.user.id
|
||||
|
||||
# Get the execution and verify ownership through trigger
|
||||
execution = session.exec(
|
||||
select(TriggerExecution)
|
||||
.join(Trigger)
|
||||
.where(
|
||||
and_(
|
||||
TriggerExecution.execution_id == execution_id,
|
||||
Trigger.user_id == str(user_id)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not execution:
|
||||
logger.warning("Execution not found for retry", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id
|
||||
})
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
if execution.status != ExecutionStatus.failed:
|
||||
raise HTTPException(status_code=400, detail="Only failed executions can be retried")
|
||||
|
||||
if execution.attempts >= execution.max_retries:
|
||||
raise HTTPException(status_code=400, detail="Maximum retry attempts exceeded")
|
||||
|
||||
try:
|
||||
# Create a new execution for the retry
|
||||
new_execution_id = str(uuid4())
|
||||
new_execution = TriggerExecution(
|
||||
trigger_id=execution.trigger_id,
|
||||
execution_id=new_execution_id,
|
||||
execution_type=execution.execution_type,
|
||||
input_data=execution.input_data,
|
||||
attempts=execution.attempts + 1,
|
||||
max_retries=execution.max_retries
|
||||
)
|
||||
|
||||
session.add(new_execution)
|
||||
session.commit()
|
||||
session.refresh(new_execution)
|
||||
|
||||
# Get trigger for event publishing
|
||||
trigger = session.get(Trigger, execution.trigger_id)
|
||||
|
||||
logger.info("Execution retry created", extra={
|
||||
"user_id": user_id,
|
||||
"original_execution_id": execution_id,
|
||||
"new_execution_id": new_execution_id,
|
||||
"attempts": new_execution.attempts
|
||||
})
|
||||
|
||||
# Publish to Redis pub/sub (broadcasts to all workers)
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "execution_created",
|
||||
"execution_id": new_execution.execution_id,
|
||||
"trigger_id": trigger.id if trigger else execution.trigger_id,
|
||||
"trigger_type": trigger.trigger_type.value if trigger and trigger.trigger_type else "unknown",
|
||||
"task_prompt": trigger.task_prompt if trigger else None,
|
||||
"status": new_execution.status.value,
|
||||
"input_data": new_execution.input_data,
|
||||
"execution_type": new_execution.execution_type.value,
|
||||
"user_id": str(user_id),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"project_id": str(trigger.project_id) if trigger else None
|
||||
})
|
||||
|
||||
return new_execution
|
||||
|
||||
result = TriggerCrudService.retry_execution(execution_id, auth.id, db_session)
|
||||
_raise_on_error(result)
|
||||
return result["execution"]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error("Execution retry failed", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}, exc_info=True)
|
||||
db_session.rollback()
|
||||
logger.error("Execution retry failed", extra={"user_id": auth.id, "execution_id": execution_id, "error": str(e)}, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
# ---- WebSocket (kept in controller due to process-level state) ----
|
||||
|
||||
@router.websocket("/subscribe")
|
||||
async def subscribe_executions(websocket: WebSocket):
|
||||
"""Subscribe to trigger execution events via WebSocket.
|
||||
|
|
@ -471,12 +197,12 @@ async def subscribe_executions(websocket: WebSocket):
|
|||
await websocket.accept()
|
||||
session_id = None
|
||||
user_id = None
|
||||
db_session = None
|
||||
|
||||
ws_db_session = None
|
||||
|
||||
try:
|
||||
# Create database session manually for WebSocket
|
||||
from app.component.database import session_make
|
||||
db_session = session_make()
|
||||
from app.core.database import session_make
|
||||
ws_db_session = session_make()
|
||||
# Wait for subscription message
|
||||
data = await websocket.receive_json()
|
||||
|
||||
|
|
@ -501,14 +227,18 @@ async def subscribe_executions(websocket: WebSocket):
|
|||
return
|
||||
|
||||
try:
|
||||
from app.component.auth import Auth
|
||||
# Decode token and fetch user
|
||||
auth = Auth.decode_token(auth_token)
|
||||
user = db_session.get(User, auth.id)
|
||||
from app.shared.auth.user_auth import V1UserAuth, _get_jti
|
||||
from app.shared.auth.token_blacklist import is_blacklisted as _is_blacklisted
|
||||
# Decode token and check blacklist
|
||||
auth = V1UserAuth.decode_token(auth_token)
|
||||
jti = _get_jti(auth_token)
|
||||
if jti and await _is_blacklisted(jti):
|
||||
raise Exception("Token has been revoked")
|
||||
user = ws_db_session.get(User, auth.id)
|
||||
if not user:
|
||||
raise Exception("User not found")
|
||||
auth._user = user
|
||||
user_id = auth.user.id
|
||||
user_id = auth.id
|
||||
logger.info(f"User authenticated for WebSocket {user_id} and {session_id}", extra={
|
||||
"user_id": user_id,
|
||||
"session_id": session_id
|
||||
|
|
@ -562,7 +292,7 @@ async def subscribe_executions(websocket: WebSocket):
|
|||
redis_manager.remove_pending_execution(session_id, execution_id)
|
||||
|
||||
# Update execution status to running
|
||||
execution = db_session.exec(
|
||||
execution = ws_db_session.exec(
|
||||
select(TriggerExecution).where(
|
||||
TriggerExecution.execution_id == execution_id
|
||||
)
|
||||
|
|
@ -571,8 +301,8 @@ async def subscribe_executions(websocket: WebSocket):
|
|||
if execution and execution.status == ExecutionStatus.pending:
|
||||
execution.status = ExecutionStatus.running
|
||||
execution.started_at = datetime.now(timezone.utc)
|
||||
db_session.add(execution)
|
||||
db_session.commit()
|
||||
ws_db_session.add(execution)
|
||||
ws_db_session.commit()
|
||||
|
||||
logger.info("Execution acknowledged and started", extra={
|
||||
"session_id": session_id,
|
||||
|
|
@ -636,8 +366,8 @@ async def subscribe_executions(websocket: WebSocket):
|
|||
logger.info("Session cleaned up", extra={"session_id": session_id})
|
||||
|
||||
# Close database session
|
||||
if db_session:
|
||||
db_session.close()
|
||||
if ws_db_session:
|
||||
ws_db_session.close()
|
||||
|
||||
|
||||
async def handle_pubsub_message(event_data: Dict[str, Any]):
|
||||
|
|
@ -23,17 +23,15 @@ from sqlmodel import Session, select, and_, or_
|
|||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import logging
|
||||
from loguru import logger
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.type.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus
|
||||
from app.component.database import session
|
||||
from app.component.trigger_utils import check_rate_limits
|
||||
from app.service.trigger.app_handler_service import get_app_handler
|
||||
|
||||
logger = logging.getLogger("server_webhook_controller")
|
||||
from app.shared.types.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus
|
||||
from app.core.database import session
|
||||
from app.core.trigger_utils import check_rate_limits
|
||||
from app.domains.trigger.service.app_handler_service import get_app_handler
|
||||
|
||||
router = APIRouter(prefix="/webhook", tags=["Webhook"])
|
||||
|
||||
|
|
@ -58,13 +56,14 @@ async def webhook_trigger(
|
|||
input_data = {"raw_body": body.decode()}
|
||||
|
||||
headers = dict(request.headers)
|
||||
webhook_url = f"/webhook/trigger/{webhook_uuid}"
|
||||
|
||||
# Find the trigger (allow active and pending_verification for verification flows)
|
||||
webhook_url_new = f"/v1/webhook/trigger/{webhook_uuid}"
|
||||
webhook_url_old = f"/webhook/trigger/{webhook_uuid}"
|
||||
|
||||
# Find the trigger (match both old and new URL formats)
|
||||
trigger = db_session.exec(
|
||||
select(Trigger).where(
|
||||
and_(
|
||||
Trigger.webhook_url == webhook_url,
|
||||
or_(Trigger.webhook_url == webhook_url_new, Trigger.webhook_url == webhook_url_old),
|
||||
Trigger.trigger_type.in_(WEBHOOK_TRIGGER_TYPES),
|
||||
Trigger.status.in_([TriggerStatus.active, TriggerStatus.pending_verification])
|
||||
)
|
||||
|
|
@ -111,7 +110,7 @@ async def webhook_trigger(
|
|||
|
||||
# Notify Redis subscribers of successful activation
|
||||
try:
|
||||
from app.component.redis_utils import get_redis_manager
|
||||
from app.core.redis_utils import get_redis_manager
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "trigger_activated",
|
||||
|
|
@ -233,7 +232,7 @@ async def webhook_trigger(
|
|||
|
||||
# Notify WebSocket subscribers and wait for delivery confirmation
|
||||
try:
|
||||
from app.component.redis_utils import get_redis_manager
|
||||
from app.core.redis_utils import get_redis_manager
|
||||
redis_manager = get_redis_manager()
|
||||
|
||||
# Check if user has any active WebSocket sessions
|
||||
|
|
@ -302,7 +301,7 @@ async def webhook_trigger(
|
|||
"execution_id": execution_id,
|
||||
"message": "Webhook trigger processed but WebSocket notification failed",
|
||||
"delivered": False,
|
||||
"reason": "websocket_notification_error"
|
||||
"reason": str(e)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -321,12 +320,13 @@ def get_webhook_info(
|
|||
db_session: Session = Depends(session)
|
||||
):
|
||||
"""Get information about a webhook trigger (public endpoint)."""
|
||||
webhook_url = f"/webhook/trigger/{webhook_uuid}"
|
||||
|
||||
webhook_url_new = f"/v1/webhook/trigger/{webhook_uuid}"
|
||||
webhook_url_old = f"/webhook/trigger/{webhook_uuid}"
|
||||
|
||||
trigger = db_session.exec(
|
||||
select(Trigger).where(
|
||||
and_(
|
||||
Trigger.webhook_url == webhook_url,
|
||||
or_(Trigger.webhook_url == webhook_url_new, Trigger.webhook_url == webhook_url_old),
|
||||
Trigger.trigger_type.in_(WEBHOOK_TRIGGER_TYPES)
|
||||
)
|
||||
)
|
||||
|
|
@ -335,14 +335,11 @@ def get_webhook_info(
|
|||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
|
||||
# Only expose minimal info on public endpoint to avoid information disclosure
|
||||
return {
|
||||
"name": trigger.name,
|
||||
"description": trigger.description,
|
||||
"status": trigger.status.value,
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
"is_active": trigger.status == TriggerStatus.active,
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
"webhook_method": trigger.webhook_method.value if trigger.webhook_method else None,
|
||||
"last_executed_at": trigger.last_executed_at.isoformat() if trigger.last_executed_at else None,
|
||||
}
|
||||
|
||||
|
||||
14
server/app/domains/trigger/schema/__init__.py
Normal file
14
server/app/domains/trigger/schema/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
|
|
@ -12,18 +12,11 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""
|
||||
Trigger Service Package
|
||||
"""Trigger domain services."""
|
||||
|
||||
Contains services for managing triggers including:
|
||||
- TriggerService: Main service for trigger operations
|
||||
- TriggerScheduleService: Service for scheduled trigger operations
|
||||
- App Handlers: Handlers for different trigger types (Slack, Webhook, Schedule)
|
||||
"""
|
||||
|
||||
from app.service.trigger.trigger_service import TriggerService, get_trigger_service
|
||||
from app.service.trigger.trigger_schedule_service import TriggerScheduleService
|
||||
from app.service.trigger.app_handler_service import (
|
||||
from app.domains.trigger.service.trigger_service import TriggerService, get_trigger_service
|
||||
from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService
|
||||
from app.domains.trigger.service.app_handler_service import (
|
||||
BaseAppHandler,
|
||||
SlackAppHandler,
|
||||
DefaultWebhookHandler,
|
||||
|
|
@ -36,19 +29,16 @@ from app.service.trigger.app_handler_service import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
# Services
|
||||
"TriggerService",
|
||||
"get_trigger_service",
|
||||
"TriggerScheduleService",
|
||||
# Handlers
|
||||
"BaseAppHandler",
|
||||
"SlackAppHandler",
|
||||
"DefaultWebhookHandler",
|
||||
"ScheduleAppHandler",
|
||||
"AppHandlerResult",
|
||||
# Handler functions
|
||||
"get_app_handler",
|
||||
"get_schedule_handler",
|
||||
"register_app_handler",
|
||||
"get_supported_trigger_types",
|
||||
]
|
||||
]
|
||||
|
|
@ -24,13 +24,13 @@ from typing import Optional
|
|||
from dataclasses import dataclass
|
||||
from fastapi import Request
|
||||
from sqlmodel import Session, select, and_
|
||||
import logging
|
||||
from loguru import logger
|
||||
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.config.config import Config
|
||||
from app.model.trigger.app_configs import SlackTriggerConfig, WebhookTriggerConfig, ScheduleTriggerConfig
|
||||
from app.type.trigger_types import TriggerType, ExecutionType, TriggerStatus
|
||||
from app.type.config_group import ConfigGroup
|
||||
from app.shared.types.trigger_types import TriggerType, ExecutionType, TriggerStatus
|
||||
from app.shared.types.config_group import ConfigGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
83
server/app/domains/trigger/service/slack_service.py
Normal file
83
server/app/domains/trigger/service/slack_service.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""SlackService: Slack integration helpers (channel listing)."""
|
||||
|
||||
from loguru import logger
|
||||
from sqlmodel import select, and_
|
||||
|
||||
from app.core.database import session_make
|
||||
from app.model.config.config import Config
|
||||
from app.shared.types.config_group import ConfigGroup
|
||||
|
||||
|
||||
class SlackService:
|
||||
"""Slack integration - static methods."""
|
||||
|
||||
@staticmethod
|
||||
def get_channels(user_id: int) -> dict:
|
||||
"""Fetch Slack channels for user. Returns {"success": True, "channels": [...], "has_credentials": True} or error."""
|
||||
with session_make() as s:
|
||||
configs = s.exec(
|
||||
select(Config).where(
|
||||
and_(
|
||||
Config.user_id == int(user_id),
|
||||
Config.config_group == ConfigGroup.SLACK.value,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
credentials = {config.config_name: config.config_value for config in configs}
|
||||
bot_token = credentials.get("SLACK_BOT_TOKEN")
|
||||
|
||||
if not bot_token:
|
||||
logger.warning("Slack credentials not found", extra={"user_id": user_id})
|
||||
return {"success": True, "channels": [], "has_credentials": False}
|
||||
|
||||
try:
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
client = WebClient(token=bot_token)
|
||||
channels = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
cursor=cursor,
|
||||
limit=200,
|
||||
)
|
||||
for channel in response.get("channels", []):
|
||||
channels.append({
|
||||
"id": channel.get("id"),
|
||||
"name": channel.get("name"),
|
||||
"is_private": channel.get("is_private", False),
|
||||
"is_member": channel.get("is_member", False),
|
||||
"num_members": channel.get("num_members"),
|
||||
})
|
||||
cursor = response.get("response_metadata", {}).get("next_cursor")
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
logger.info("Slack channels fetched", extra={"user_id": user_id, "channel_count": len(channels)})
|
||||
return {"success": True, "channels": channels, "has_credentials": True}
|
||||
|
||||
except ImportError:
|
||||
logger.error("slack_sdk not installed")
|
||||
return {"success": False, "error_code": "SLACK_SDK_NOT_INSTALLED"}
|
||||
except Exception as e:
|
||||
error_detail = getattr(e, "response", {}).get("error", str(e)) if hasattr(e, "response") else str(e)
|
||||
logger.error("Slack API error", extra={"user_id": user_id, "error": error_detail})
|
||||
return {"success": False, "error_code": "SLACK_API_ERROR", "detail": error_detail}
|
||||
594
server/app/domains/trigger/service/trigger_crud_service.py
Normal file
594
server/app/domains/trigger/service/trigger_crud_service.py
Normal file
|
|
@ -0,0 +1,594 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""TriggerCrudService: trigger CRUD + execution business logic."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, select, and_
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.model.trigger.trigger import Trigger, TriggerIn, TriggerOut, TriggerUpdate
|
||||
from app.model.trigger.trigger_execution import TriggerExecution, TriggerExecutionIn, TriggerExecutionUpdate
|
||||
from app.shared.types.trigger_types import ExecutionStatus, ExecutionType
|
||||
from app.model.trigger.app_configs import (
|
||||
get_config_schema,
|
||||
validate_config,
|
||||
has_config,
|
||||
validate_activation,
|
||||
ActivationError,
|
||||
)
|
||||
from app.model.trigger.app_configs.config_registry import requires_authentication
|
||||
from app.model.chat.chat_history import ChatHistory
|
||||
from app.shared.types.trigger_types import TriggerType, TriggerStatus
|
||||
from app.core.redis_utils import get_redis_manager
|
||||
from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService
|
||||
from app.domains.trigger.service.trigger_service import TriggerService
|
||||
|
||||
|
||||
ACTIVE_STATUSES = (TriggerStatus.active, TriggerStatus.pending_verification)
|
||||
MAX_ACTIVE_PER_USER = 25
|
||||
MAX_ACTIVE_PER_PROJECT = 5
|
||||
|
||||
|
||||
class TriggerCrudService:
|
||||
"""Trigger CRUD business logic - static methods, caller-managed session."""
|
||||
|
||||
@staticmethod
|
||||
def get_active_trigger_counts(s: Session, user_id: str, project_id: str | None = None) -> tuple[int, int]:
|
||||
"""Return (user_active_count, project_active_count) for active/pending triggers."""
|
||||
user_count = s.exec(
|
||||
select(func.count(Trigger.id)).where(
|
||||
and_(
|
||||
Trigger.user_id == user_id,
|
||||
Trigger.status.in_(ACTIVE_STATUSES),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
project_count = 0
|
||||
if project_id:
|
||||
project_count = s.exec(
|
||||
select(func.count(Trigger.id)).where(
|
||||
and_(
|
||||
Trigger.user_id == user_id,
|
||||
Trigger.project_id == project_id,
|
||||
Trigger.status.in_(ACTIVE_STATUSES),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
return user_count, project_count
|
||||
|
||||
@staticmethod
|
||||
def get_execution_counts(s: Session, trigger_ids: list[int]) -> dict[int, int]:
|
||||
"""Get execution counts for multiple triggers in a single query."""
|
||||
if not trigger_ids:
|
||||
return {}
|
||||
result = s.exec(
|
||||
select(TriggerExecution.trigger_id, func.count(TriggerExecution.id))
|
||||
.where(TriggerExecution.trigger_id.in_(trigger_ids))
|
||||
.group_by(TriggerExecution.trigger_id)
|
||||
).all()
|
||||
return {trigger_id: count for trigger_id, count in result}
|
||||
|
||||
@staticmethod
|
||||
def trigger_to_out(trigger: Trigger, execution_count: int = 0) -> TriggerOut:
|
||||
"""Convert Trigger model to TriggerOut with execution count."""
|
||||
return TriggerOut(
|
||||
id=trigger.id,
|
||||
user_id=trigger.user_id,
|
||||
project_id=trigger.project_id,
|
||||
name=trigger.name,
|
||||
description=trigger.description,
|
||||
trigger_type=trigger.trigger_type,
|
||||
status=trigger.status,
|
||||
execution_count=execution_count,
|
||||
webhook_url=trigger.webhook_url,
|
||||
webhook_method=trigger.webhook_method,
|
||||
custom_cron_expression=trigger.custom_cron_expression,
|
||||
listener_type=trigger.listener_type,
|
||||
agent_model=trigger.agent_model,
|
||||
task_prompt=trigger.task_prompt,
|
||||
config=trigger.config,
|
||||
max_executions_per_hour=trigger.max_executions_per_hour,
|
||||
max_executions_per_day=trigger.max_executions_per_day,
|
||||
is_single_execution=trigger.is_single_execution,
|
||||
last_executed_at=trigger.last_executed_at,
|
||||
next_run_at=trigger.next_run_at,
|
||||
last_execution_status=trigger.last_execution_status,
|
||||
created_at=trigger.created_at,
|
||||
updated_at=trigger.updated_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_project_chat_history(data: TriggerIn, user_id: int, s: Session) -> None:
|
||||
"""Create placeholder ChatHistory for a new project if it doesn't exist, and notify via WebSocket."""
|
||||
if not data.project_id:
|
||||
return
|
||||
|
||||
existing_chat = s.exec(
|
||||
select(ChatHistory).where(ChatHistory.project_id == data.project_id)
|
||||
).first()
|
||||
if existing_chat:
|
||||
return
|
||||
|
||||
chat_history = ChatHistory(
|
||||
user_id=user_id,
|
||||
task_id=data.project_id,
|
||||
project_id=data.project_id,
|
||||
question=f"Project created via trigger: {data.name}",
|
||||
language="en",
|
||||
model_platform=data.agent_model or "none",
|
||||
model_type=data.agent_model or "none",
|
||||
installed_mcp="none",
|
||||
api_key="",
|
||||
api_url="",
|
||||
max_retries=3,
|
||||
project_name=data.name,
|
||||
summary=data.description or "",
|
||||
tokens=0,
|
||||
spend=0,
|
||||
status=2,
|
||||
)
|
||||
s.add(chat_history)
|
||||
s.commit()
|
||||
s.refresh(chat_history)
|
||||
|
||||
logger.info("Chat history created for new project", extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id,
|
||||
"chat_history_id": chat_history.id,
|
||||
})
|
||||
|
||||
# WebSocket notification (best effort)
|
||||
try:
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "project_created",
|
||||
"project_id": data.project_id,
|
||||
"project_name": data.name,
|
||||
"chat_history_id": chat_history.id,
|
||||
"trigger_name": data.name,
|
||||
"user_id": str(user_id),
|
||||
"created_at": chat_history.created_at.isoformat() if chat_history.created_at else None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send WebSocket notification for new project", extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id,
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _validate_trigger_config(trigger_type: TriggerType, config: dict | None) -> None:
|
||||
"""Validate trigger-type specific config. Raises HTTPException-compatible ValueError on failure."""
|
||||
if config and has_config(trigger_type):
|
||||
try:
|
||||
validate_config(trigger_type, config)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid config for {trigger_type.value}: {e.errors()}")
|
||||
|
||||
@staticmethod
|
||||
def _determine_initial_status(
|
||||
data: TriggerIn, user_id: int, s: Session
|
||||
) -> TriggerStatus:
|
||||
"""Determine initial trigger status based on auth requirements and concurrency limits."""
|
||||
# Desired status from auth requirements
|
||||
if has_config(data.trigger_type) and data.config and requires_authentication(data.trigger_type, data.config):
|
||||
desired_status = TriggerStatus.pending_verification
|
||||
else:
|
||||
desired_status = TriggerStatus.active
|
||||
|
||||
# Check concurrency limits
|
||||
user_active, project_active = TriggerCrudService.get_active_trigger_counts(
|
||||
s, str(user_id), data.project_id
|
||||
)
|
||||
if user_active >= MAX_ACTIVE_PER_USER or (
|
||||
data.project_id and project_active >= MAX_ACTIVE_PER_PROJECT
|
||||
):
|
||||
logger.info(
|
||||
"Active trigger limit reached — new trigger created as inactive",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"project_id": data.project_id,
|
||||
"user_active": user_active,
|
||||
"project_active": project_active,
|
||||
},
|
||||
)
|
||||
return TriggerStatus.inactive
|
||||
|
||||
return desired_status
|
||||
|
||||
@staticmethod
|
||||
def create(data: TriggerIn, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Create a new trigger with all business rules.
|
||||
Returns {"success": True, "trigger_out": TriggerOut} or {"success": False, "error": str, "status_code": int}.
|
||||
"""
|
||||
# 1. Ensure project ChatHistory exists
|
||||
TriggerCrudService._ensure_project_chat_history(data, user_id, s)
|
||||
|
||||
# 2. Generate webhook URL
|
||||
webhook_url = None
|
||||
if data.trigger_type in (TriggerType.webhook, TriggerType.slack_trigger):
|
||||
webhook_url = f"/v1/webhook/trigger/{uuid4()}"
|
||||
|
||||
# 3. Validate config
|
||||
try:
|
||||
TriggerCrudService._validate_trigger_config(data.trigger_type, data.config)
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e), "status_code": 400}
|
||||
|
||||
# 4. Determine initial status
|
||||
initial_status = TriggerCrudService._determine_initial_status(data, user_id, s)
|
||||
|
||||
# 5. Create trigger
|
||||
trigger_data = data.model_dump()
|
||||
trigger_data["user_id"] = str(user_id)
|
||||
trigger_data["webhook_url"] = webhook_url
|
||||
trigger_data["status"] = initial_status
|
||||
|
||||
trigger = Trigger(**trigger_data)
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
s.refresh(trigger)
|
||||
|
||||
# 6. Calculate next_run_at for scheduled triggers
|
||||
if trigger.trigger_type == TriggerType.schedule and trigger.custom_cron_expression:
|
||||
schedule_service = TriggerScheduleService(s)
|
||||
trigger.next_run_at = schedule_service.calculate_next_run_at(trigger)
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
s.refresh(trigger)
|
||||
|
||||
logger.info("Trigger created", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger.id,
|
||||
"trigger_type": data.trigger_type.value,
|
||||
"next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None,
|
||||
})
|
||||
|
||||
return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, 0)}
|
||||
|
||||
@staticmethod
|
||||
def update(trigger_id: int, data: TriggerUpdate, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Update a trigger with config validation and schedule recalculation.
|
||||
Returns {"success": True, "trigger_out": TriggerOut} or {"success": False, "error": str, "status_code": int}.
|
||||
"""
|
||||
trigger = s.exec(
|
||||
select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
return {"success": False, "error": "Trigger not found", "status_code": 404}
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate config if being updated
|
||||
if "config" in update_data and update_data["config"] is not None:
|
||||
try:
|
||||
TriggerCrudService._validate_trigger_config(trigger.trigger_type, update_data["config"])
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e), "status_code": 400}
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(trigger, key, value)
|
||||
|
||||
# Recalculate next_run_at if cron expression or status changed for scheduled triggers
|
||||
if trigger.trigger_type == TriggerType.schedule:
|
||||
if "custom_cron_expression" in update_data or "status" in update_data:
|
||||
if trigger.status == TriggerStatus.active and trigger.custom_cron_expression:
|
||||
schedule_service = TriggerScheduleService(s)
|
||||
trigger.next_run_at = schedule_service.calculate_next_run_at(trigger)
|
||||
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
s.refresh(trigger)
|
||||
|
||||
counts = TriggerCrudService.get_execution_counts(s, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.info("Trigger updated", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"fields_updated": list(update_data.keys()),
|
||||
"next_run_at": trigger.next_run_at.isoformat() if trigger.next_run_at else None,
|
||||
})
|
||||
|
||||
return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, execution_count)}
|
||||
|
||||
@staticmethod
|
||||
def activate(trigger_id: int, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Activate a trigger with limit checks, auth requirements, and activation validation.
|
||||
Returns {"success": True, "trigger_out": TriggerOut}
|
||||
or {"success": False, "error": str/dict, "status_code": int}.
|
||||
"""
|
||||
trigger = s.exec(
|
||||
select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
return {"success": False, "error": "Trigger not found", "status_code": 404}
|
||||
|
||||
# 1. Check concurrency limits
|
||||
user_active, project_active = TriggerCrudService.get_active_trigger_counts(
|
||||
s, str(user_id), trigger.project_id
|
||||
)
|
||||
if user_active >= MAX_ACTIVE_PER_USER:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_USER}) reached for this user",
|
||||
"status_code": 400,
|
||||
}
|
||||
if trigger.project_id and project_active >= MAX_ACTIVE_PER_PROJECT:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Maximum number of concurrent active triggers ({MAX_ACTIVE_PER_PROJECT}) reached for this project",
|
||||
"status_code": 400,
|
||||
}
|
||||
|
||||
# 2. Check if authentication is required — go straight to pending_verification
|
||||
if has_config(trigger.trigger_type) and requires_authentication(trigger.trigger_type, trigger.config):
|
||||
trigger.status = TriggerStatus.pending_verification
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
s.refresh(trigger)
|
||||
return {
|
||||
"success": False,
|
||||
"error": {
|
||||
"message": "Authentication required for this trigger type",
|
||||
"missing_requirements": ["authentication"],
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
},
|
||||
"status_code": 401,
|
||||
}
|
||||
|
||||
# 3. Validate activation requirements (non-auth triggers)
|
||||
if has_config(trigger.trigger_type):
|
||||
try:
|
||||
validate_activation(
|
||||
trigger_type=trigger.trigger_type,
|
||||
config_data=trigger.config,
|
||||
user_id=int(user_id),
|
||||
session=s,
|
||||
)
|
||||
except ActivationError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": {
|
||||
"message": e.message,
|
||||
"missing_requirements": e.missing_requirements,
|
||||
"trigger_type": trigger.trigger_type.value,
|
||||
},
|
||||
"status_code": 400,
|
||||
}
|
||||
|
||||
# 4. Activate
|
||||
trigger.status = TriggerStatus.active
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
s.refresh(trigger)
|
||||
|
||||
counts = TriggerCrudService.get_execution_counts(s, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.info("Trigger activated", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
"status": trigger.status.value,
|
||||
})
|
||||
|
||||
return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, execution_count)}
|
||||
|
||||
@staticmethod
|
||||
def deactivate(trigger_id: int, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Deactivate a trigger.
|
||||
Returns {"success": True, "trigger_out": TriggerOut} or {"success": False, ...}.
|
||||
"""
|
||||
trigger = s.exec(
|
||||
select(Trigger).where(and_(Trigger.id == trigger_id, Trigger.user_id == str(user_id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
return {"success": False, "error": "Trigger not found", "status_code": 404}
|
||||
|
||||
trigger.status = TriggerStatus.inactive
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
s.refresh(trigger)
|
||||
|
||||
counts = TriggerCrudService.get_execution_counts(s, [trigger_id])
|
||||
execution_count = counts.get(trigger_id, 0)
|
||||
|
||||
logger.info("Trigger deactivated", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": trigger_id,
|
||||
})
|
||||
|
||||
return {"success": True, "trigger_out": TriggerCrudService.trigger_to_out(trigger, execution_count)}
|
||||
|
||||
# ---- Execution CRUD ----
|
||||
|
||||
@staticmethod
|
||||
def _publish_execution_event(event_type: str, execution: TriggerExecution, trigger: Trigger, user_id: int, extra: dict | None = None) -> None:
|
||||
"""Publish execution event to Redis pub/sub (best effort)."""
|
||||
try:
|
||||
payload = {
|
||||
"type": event_type,
|
||||
"execution_id": execution.execution_id,
|
||||
"trigger_id": trigger.id,
|
||||
"trigger_type": trigger.trigger_type.value if trigger.trigger_type else "unknown",
|
||||
"task_prompt": trigger.task_prompt,
|
||||
"status": execution.status.value,
|
||||
"input_data": execution.input_data,
|
||||
"execution_type": execution.execution_type.value,
|
||||
"user_id": str(user_id),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"project_id": str(trigger.project_id),
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
get_redis_manager().publish_execution_event(payload)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish execution event", extra={"execution_id": execution.execution_id, "error": str(e)})
|
||||
|
||||
@staticmethod
|
||||
def create_execution(data: TriggerExecutionIn, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Create a trigger execution: verify ownership, create record, update trigger timestamp, publish event.
|
||||
Returns {"success": True, "execution": TriggerExecution} or {"success": False, ...}.
|
||||
"""
|
||||
trigger = s.exec(
|
||||
select(Trigger).where(and_(Trigger.id == data.trigger_id, Trigger.user_id == str(user_id)))
|
||||
).first()
|
||||
if not trigger:
|
||||
return {"success": False, "error": "Trigger not found", "status_code": 404}
|
||||
|
||||
execution = TriggerExecution(**data.model_dump())
|
||||
s.add(execution)
|
||||
s.commit()
|
||||
s.refresh(execution)
|
||||
|
||||
# Update trigger timestamp
|
||||
trigger.last_executed_at = datetime.now(timezone.utc)
|
||||
s.add(trigger)
|
||||
s.commit()
|
||||
|
||||
logger.info("Trigger execution created", extra={
|
||||
"user_id": user_id,
|
||||
"trigger_id": data.trigger_id,
|
||||
"execution_id": execution.execution_id,
|
||||
"execution_type": data.execution_type.value,
|
||||
})
|
||||
|
||||
TriggerCrudService._publish_execution_event("execution_created", execution, trigger, user_id)
|
||||
|
||||
return {"success": True, "execution": execution}
|
||||
|
||||
@staticmethod
|
||||
def update_execution(execution_id: str, data: TriggerExecutionUpdate, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Update a trigger execution: ownership check, status via TriggerService, duration calc, publish event.
|
||||
Returns {"success": True, "execution": TriggerExecution} or {"success": False, ...}.
|
||||
"""
|
||||
execution = s.exec(
|
||||
select(TriggerExecution)
|
||||
.join(Trigger)
|
||||
.where(and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(user_id)))
|
||||
).first()
|
||||
if not execution:
|
||||
return {"success": False, "error": "Execution not found", "status_code": 404}
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# Delegate status update to TriggerService for proper failure tracking
|
||||
if "status" in update_data:
|
||||
trigger_service = TriggerService(s)
|
||||
status_value = ExecutionStatus(update_data["status"]) if isinstance(update_data["status"], str) else update_data["status"]
|
||||
trigger_service.update_execution_status(
|
||||
execution=execution,
|
||||
status=status_value,
|
||||
output_data=update_data.get("output_data"),
|
||||
error_message=update_data.get("error_message"),
|
||||
tokens_used=update_data.get("tokens_used"),
|
||||
tools_executed=update_data.get("tools_executed"),
|
||||
)
|
||||
for key in ["status", "output_data", "error_message", "tokens_used", "tools_executed"]:
|
||||
update_data.pop(key, None)
|
||||
|
||||
# Update remaining fields + auto-calculate duration
|
||||
if update_data:
|
||||
if ("started_at" in update_data or "completed_at" in update_data) and execution.started_at:
|
||||
completed_at = update_data.get("completed_at") or execution.completed_at
|
||||
if completed_at:
|
||||
started_at = execution.started_at
|
||||
if started_at.tzinfo is None:
|
||||
started_at = started_at.replace(tzinfo=timezone.utc)
|
||||
if completed_at.tzinfo is None:
|
||||
completed_at = completed_at.replace(tzinfo=timezone.utc)
|
||||
update_data["duration_seconds"] = (completed_at - started_at).total_seconds()
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(execution, key, value)
|
||||
s.add(execution)
|
||||
s.commit()
|
||||
|
||||
s.refresh(execution)
|
||||
|
||||
# Publish event
|
||||
trigger = s.get(Trigger, execution.trigger_id)
|
||||
logger.info("Execution updated", extra={
|
||||
"user_id": user_id,
|
||||
"execution_id": execution_id,
|
||||
"fields_updated": list(data.model_dump(exclude_unset=True).keys()),
|
||||
})
|
||||
|
||||
if trigger:
|
||||
TriggerCrudService._publish_execution_event(
|
||||
"execution_updated", execution, trigger, user_id,
|
||||
extra={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
return {"success": True, "execution": execution}
|
||||
|
||||
@staticmethod
|
||||
def retry_execution(execution_id: str, user_id: int, s: Session) -> dict:
|
||||
"""
|
||||
Retry a failed execution: validate status, create new execution, publish event.
|
||||
Returns {"success": True, "execution": TriggerExecution} or {"success": False, ...}.
|
||||
"""
|
||||
execution = s.exec(
|
||||
select(TriggerExecution)
|
||||
.join(Trigger)
|
||||
.where(and_(TriggerExecution.execution_id == execution_id, Trigger.user_id == str(user_id)))
|
||||
).first()
|
||||
if not execution:
|
||||
return {"success": False, "error": "Execution not found", "status_code": 404}
|
||||
|
||||
if execution.status != ExecutionStatus.failed:
|
||||
return {"success": False, "error": "Only failed executions can be retried", "status_code": 400}
|
||||
|
||||
if execution.attempts >= execution.max_retries:
|
||||
return {"success": False, "error": "Maximum retry attempts exceeded", "status_code": 400}
|
||||
|
||||
new_execution = TriggerExecution(
|
||||
trigger_id=execution.trigger_id,
|
||||
execution_id=str(uuid4()),
|
||||
execution_type=execution.execution_type,
|
||||
input_data=execution.input_data,
|
||||
attempts=execution.attempts + 1,
|
||||
max_retries=execution.max_retries,
|
||||
)
|
||||
s.add(new_execution)
|
||||
s.commit()
|
||||
s.refresh(new_execution)
|
||||
|
||||
trigger = s.get(Trigger, execution.trigger_id)
|
||||
|
||||
logger.info("Execution retry created", extra={
|
||||
"user_id": user_id,
|
||||
"original_execution_id": execution_id,
|
||||
"new_execution_id": new_execution.execution_id,
|
||||
"attempts": new_execution.attempts,
|
||||
})
|
||||
|
||||
if trigger:
|
||||
TriggerCrudService._publish_execution_event("execution_created", new_execution, trigger, user_id)
|
||||
|
||||
return {"success": True, "execution": new_execution}
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Tuple, Optional
|
||||
import logging
|
||||
from loguru import logger
|
||||
from croniter import croniter
|
||||
from uuid import uuid4
|
||||
import asyncio
|
||||
|
|
@ -22,8 +22,8 @@ from sqlmodel import select
|
|||
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.type.trigger_types import TriggerStatus, ExecutionType, ExecutionStatus, TriggerType
|
||||
from app.component.trigger_utils import check_rate_limits, MAX_DISPATCH_PER_TICK
|
||||
from app.shared.types.trigger_types import TriggerStatus, ExecutionType, ExecutionStatus, TriggerType
|
||||
from app.core.trigger_utils import check_rate_limits, MAX_DISPATCH_PER_TICK
|
||||
from app.model.trigger.app_configs import ScheduleTriggerConfig
|
||||
|
||||
|
||||
|
|
@ -206,7 +206,7 @@ class TriggerScheduleService:
|
|||
# Using asyncio.run() to run async code from sync Celery worker context
|
||||
try:
|
||||
# Notify WebSocket subscribers via Redis pub/sub
|
||||
from app.component.redis_utils import get_redis_manager
|
||||
from app.core.redis_utils import get_redis_manager
|
||||
redis_manager = get_redis_manager()
|
||||
redis_manager.publish_execution_event({
|
||||
"type": "execution_created",
|
||||
|
|
@ -12,44 +12,40 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
from celery import shared_task
|
||||
"""
|
||||
Celery tasks for trigger scheduling: poll due triggers and check execution timeouts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from celery import shared_task
|
||||
from sqlmodel import select, or_
|
||||
|
||||
from app.component.database import session_make
|
||||
from app.component.environment import env
|
||||
from app.service.trigger.trigger_schedule_service import TriggerScheduleService
|
||||
from app.service.trigger.trigger_service import TriggerService
|
||||
from app.component.trigger_utils import MAX_DISPATCH_PER_TICK
|
||||
from app.component.redis_utils import get_redis_manager
|
||||
from app.core.database import session_make
|
||||
from app.core.environment import env
|
||||
from app.core.trigger_utils import MAX_DISPATCH_PER_TICK
|
||||
from app.core.redis_utils import get_redis_manager
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.type.trigger_types import ExecutionStatus
|
||||
from app.shared.types.trigger_types import ExecutionStatus
|
||||
from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService
|
||||
from app.domains.trigger.service.trigger_service import TriggerService
|
||||
|
||||
# Timeout configuration from environment variables
|
||||
EXECUTION_PENDING_TIMEOUT_SECONDS = int(env("EXECUTION_PENDING_TIMEOUT_SECONDS", "60"))
|
||||
EXECUTION_RUNNING_TIMEOUT_SECONDS = int(env("EXECUTION_RUNNING_TIMEOUT_SECONDS", "600")) # 10 minutes
|
||||
EXECUTION_RUNNING_TIMEOUT_SECONDS = int(env("EXECUTION_RUNNING_TIMEOUT_SECONDS", "600"))
|
||||
|
||||
logger = logging.getLogger("server_trigger_schedule_task")
|
||||
|
||||
|
||||
@shared_task(queue="poll_trigger_schedules")
|
||||
def poll_trigger_schedules() -> None:
|
||||
"""
|
||||
Celery task to poll and execute scheduled triggers.
|
||||
This runs periodically to check for triggers that are due for execution.
|
||||
|
||||
This is a lightweight wrapper around TriggerScheduleService that handles
|
||||
session management and delegates all business logic to the service layer.
|
||||
"""
|
||||
"""Poll and execute scheduled triggers."""
|
||||
logger.info("Starting poll_trigger_schedules task")
|
||||
|
||||
|
||||
session = session_make()
|
||||
try:
|
||||
# Create service instance with session
|
||||
schedule_service = TriggerScheduleService(session)
|
||||
|
||||
# Delegate all logic to the service
|
||||
schedule_service.poll_and_execute_due_triggers(
|
||||
max_dispatch_per_tick=MAX_DISPATCH_PER_TICK
|
||||
)
|
||||
|
|
@ -59,28 +55,19 @@ def poll_trigger_schedules() -> None:
|
|||
|
||||
@shared_task(queue="check_execution_timeouts")
|
||||
def check_execution_timeouts() -> None:
|
||||
"""
|
||||
Celery task to check for timed-out pending and running executions.
|
||||
|
||||
This runs periodically to find:
|
||||
- Pending executions that haven't been acknowledged within EXECUTION_PENDING_TIMEOUT_SECONDS
|
||||
- Running executions that have been stuck for more than EXECUTION_RUNNING_TIMEOUT_SECONDS
|
||||
|
||||
These are marked as missed/failed respectively.
|
||||
"""
|
||||
"""Check for timed-out pending and running executions."""
|
||||
logger.info("Starting check_execution_timeouts task", extra={
|
||||
"pending_timeout": EXECUTION_PENDING_TIMEOUT_SECONDS,
|
||||
"running_timeout": EXECUTION_RUNNING_TIMEOUT_SECONDS
|
||||
})
|
||||
|
||||
|
||||
session = session_make()
|
||||
redis_manager = get_redis_manager()
|
||||
trigger_service = TriggerService(session)
|
||||
|
||||
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Find all pending and running executions
|
||||
|
||||
executions = session.exec(
|
||||
select(TriggerExecution).where(
|
||||
or_(
|
||||
|
|
@ -89,28 +76,26 @@ def check_execution_timeouts() -> None:
|
|||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
timed_out_pending_count = 0
|
||||
timed_out_running_count = 0
|
||||
|
||||
|
||||
for execution in executions:
|
||||
is_pending = execution.status == ExecutionStatus.pending
|
||||
is_running = execution.status == ExecutionStatus.running
|
||||
|
||||
# Determine the reference time and timeout based on status
|
||||
|
||||
if is_pending:
|
||||
reference_time = execution.created_at
|
||||
timeout_seconds = EXECUTION_PENDING_TIMEOUT_SECONDS
|
||||
else: # running
|
||||
else:
|
||||
reference_time = execution.started_at or execution.created_at
|
||||
timeout_seconds = EXECUTION_RUNNING_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
if reference_time.tzinfo is None:
|
||||
reference_time = reference_time.replace(tzinfo=timezone.utc)
|
||||
time_elapsed = (now - reference_time).total_seconds()
|
||||
|
||||
|
||||
if time_elapsed > timeout_seconds:
|
||||
# Determine the new status and error message
|
||||
if is_pending:
|
||||
new_status = ExecutionStatus.missed
|
||||
error_message = f"Execution acknowledgment timeout ({timeout_seconds} seconds)"
|
||||
|
|
@ -119,34 +104,26 @@ def check_execution_timeouts() -> None:
|
|||
new_status = ExecutionStatus.failed
|
||||
error_message = f"Execution running timeout ({timeout_seconds} seconds) - no completion received"
|
||||
timed_out_running_count += 1
|
||||
|
||||
# Use TriggerService.update_execution_status for proper failure tracking
|
||||
|
||||
trigger_service.update_execution_status(
|
||||
execution=execution,
|
||||
status=new_status,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
# Remove from Redis pending list (best effort, may not exist)
|
||||
|
||||
try:
|
||||
# Get all sessions for this execution's user
|
||||
trigger = session.get(Trigger, execution.trigger_id)
|
||||
if trigger and trigger.user_id:
|
||||
user_session_ids = redis_manager.get_user_sessions(trigger.user_id)
|
||||
for session_id in user_session_ids:
|
||||
redis_manager.remove_pending_execution(session_id, execution.execution_id)
|
||||
elif not trigger:
|
||||
logger.warning("Trigger not found for execution", extra={
|
||||
"execution_id": execution.execution_id,
|
||||
"trigger_id": execution.trigger_id
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove execution from Redis", extra={
|
||||
"execution_id": execution.execution_id,
|
||||
"trigger_id": execution.trigger_id,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
|
||||
logger.info("Execution timed out", extra={
|
||||
"execution_id": execution.execution_id,
|
||||
"trigger_id": execution.trigger_id,
|
||||
|
|
@ -154,7 +131,7 @@ def check_execution_timeouts() -> None:
|
|||
"new_status": new_status.value,
|
||||
"time_elapsed": time_elapsed
|
||||
})
|
||||
|
||||
|
||||
total_timed_out = timed_out_pending_count + timed_out_running_count
|
||||
if total_timed_out > 0:
|
||||
logger.info("Marked executions as timed out", extra={
|
||||
|
|
@ -162,15 +139,13 @@ def check_execution_timeouts() -> None:
|
|||
"timed_out_running_count": timed_out_running_count,
|
||||
"total_timed_out": total_timed_out
|
||||
})
|
||||
else:
|
||||
logger.debug("No timed-out executions found")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error checking execution timeouts", extra={
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
}, exc_info=True)
|
||||
session.rollback()
|
||||
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
session.close()
|
||||
|
|
@ -16,18 +16,17 @@ from datetime import datetime, timedelta, timezone
|
|||
from typing import Optional, List, Dict, Any
|
||||
from sqlmodel import select, and_, or_
|
||||
from uuid import uuid4
|
||||
import logging
|
||||
from loguru import logger
|
||||
|
||||
from app.model.trigger.trigger import Trigger
|
||||
from app.model.trigger.trigger_execution import TriggerExecution
|
||||
from app.type.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus
|
||||
from app.component.database import session_make
|
||||
from app.service.trigger.trigger_schedule_service import TriggerScheduleService
|
||||
from app.component.trigger_utils import SCHEDULED_FETCH_BATCH_SIZE, check_rate_limits
|
||||
from app.shared.types.trigger_types import TriggerType, TriggerStatus, ExecutionType, ExecutionStatus
|
||||
from app.core.database import session_make
|
||||
from app.domains.trigger.service.trigger_schedule_service import TriggerScheduleService
|
||||
from app.core.trigger_utils import SCHEDULED_FETCH_BATCH_SIZE, check_rate_limits
|
||||
from app.model.trigger.app_configs import ScheduleTriggerConfig, WebhookTriggerConfig
|
||||
from app.model.trigger.app_configs.base_config import BaseTriggerConfig
|
||||
|
||||
logger = logging.getLogger("server_trigger_service")
|
||||
|
||||
|
||||
class TriggerService:
|
||||
15
server/app/domains/user/__init__.py
Normal file
15
server/app/domains/user/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""User domain: users, admins, auth, invites, keys."""
|
||||
15
server/app/domains/user/api/__init__.py
Normal file
15
server/app/domains/user/api/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""User domain API."""
|
||||
129
server/app/domains/user/api/login_controller.py
Normal file
129
server/app/domains/user/api/login_controller.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""v1 Login - 1h access token, refresh token, rate limit."""
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi_babel import _
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.core import code
|
||||
from app.core.database import session
|
||||
from app.core.encrypt import password_verify
|
||||
from app.core.environment import env
|
||||
from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User
|
||||
from app.shared.auth import create_access_token, create_refresh_token
|
||||
from app.shared.auth.token_blacklist import blacklist_token
|
||||
from app.shared.auth.user_auth import decode_refresh_token
|
||||
from app.shared.exception import TokenException, UserException
|
||||
from app.shared.middleware.rate_limit import login_rate_limiter
|
||||
|
||||
router = APIRouter(prefix="/user", tags=["V1 Login"])
|
||||
|
||||
|
||||
@router.post("/dev_login", name="dev login (Swagger only)", include_in_schema=True)
|
||||
async def dev_login(username: str | None = Form(default=None), password: str | None = Form(default=None)):
|
||||
"""Debug-only login for Swagger Authorize. Accepts OAuth2 password form."""
|
||||
if env("debug", "") != "on":
|
||||
raise HTTPException(status_code=404)
|
||||
return {"access_token": create_access_token(1), "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/auto-login", name="auto login for local mode")
|
||||
async def auto_login(db_session: Session = Depends(session)) -> LoginResponse:
|
||||
"""Auto login for fully local mode. Returns most recently active user or creates default."""
|
||||
user = User.by(
|
||||
User.status == Status.Normal,
|
||||
order_by=User.updated_at.desc(),
|
||||
limit=1,
|
||||
s=db_session,
|
||||
).one_or_none()
|
||||
|
||||
if not user:
|
||||
with db_session as s:
|
||||
try:
|
||||
user = User(
|
||||
email="admin@local.eigent.ai",
|
||||
username="admin",
|
||||
nickname="Admin",
|
||||
avatar="",
|
||||
fullname="",
|
||||
work_desc="",
|
||||
)
|
||||
s.add(user)
|
||||
s.commit()
|
||||
s.refresh(user)
|
||||
logger.info("Default admin user created", extra={"user_id": user.id})
|
||||
except Exception as e:
|
||||
s.rollback()
|
||||
logger.error("Failed to create default admin user", extra={"error": str(e)}, exc_info=True)
|
||||
raise UserException(code.error, _("Failed to create default user"))
|
||||
|
||||
logger.info("Auto login successful", extra={"user_id": user.id, "email": user.email})
|
||||
return LoginResponse(token=create_access_token(user.id), email=user.email)
|
||||
|
||||
|
||||
class RefreshTokenIn(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
@router.post("/login", name="login by email or password", dependencies=[login_rate_limiter])
|
||||
async def by_password(data: LoginByPasswordIn, db_session: Session = Depends(session)) -> dict:
|
||||
"""User login with email and password. Returns access_token (1h) and refresh_token (30d)."""
|
||||
user = User.by(User.email == data.email, s=db_session).one_or_none()
|
||||
if not user or not password_verify(data.password, user.password):
|
||||
raise UserException(code.password, _("Account or password error"))
|
||||
if user.status == Status.Block:
|
||||
raise UserException(code.error, _("Your account has been blocked."))
|
||||
if not user.is_active:
|
||||
raise UserException(code.error, _("Please activate your account via the email link."))
|
||||
|
||||
access_token = create_access_token(user.id)
|
||||
refresh_token = create_refresh_token(user.id)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"email": user.email,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh", name="refresh tokens", dependencies=[login_rate_limiter])
|
||||
async def refresh(data: RefreshTokenIn, db_session: Session = Depends(session)) -> dict:
|
||||
"""Exchange valid refresh_token for new access_token and refresh_token."""
|
||||
if not data.refresh_token:
|
||||
raise TokenException(code.token_need, _("Refresh token required"))
|
||||
user_id, jti, exp_ts = await decode_refresh_token(data.refresh_token)
|
||||
user = db_session.get(User, user_id)
|
||||
if not user:
|
||||
raise TokenException(code.token_invalid, _("User not found"))
|
||||
if user.status == Status.Block:
|
||||
raise UserException(code.error, _("Your account has been blocked."))
|
||||
if not user.is_active:
|
||||
raise UserException(code.error, _("Please activate your account via the email link."))
|
||||
if jti:
|
||||
ttl = max(0, exp_ts - int(time.time()))
|
||||
await blacklist_token(jti, ttl)
|
||||
access_token = create_access_token(user.id)
|
||||
refresh_token = create_refresh_token(user.id)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"email": user.email,
|
||||
}
|
||||
51
server/app/domains/user/api/logout_controller.py
Normal file
51
server/app/domains/user/api/logout_controller.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
"""v1 Logout - token blacklist, audit log."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger
|
||||
|
||||
from app.core.environment import env_not_empty
|
||||
from app.shared.auth.token_blacklist import blacklist_token
|
||||
from app.shared.auth.user_auth import _get_jti, oauth2_scheme
|
||||
|
||||
router = APIRouter(prefix="/user", tags=["V1 Auth"])
|
||||
|
||||
|
||||
@router.post("/logout", name="logout")
|
||||
async def logout(token: str = Depends(oauth2_scheme)):
|
||||
"""Revoke current token. Requires Bearer token."""
|
||||
if not token:
|
||||
logger.info("logout: no token provided")
|
||||
return {"success": True, "message": "No token to revoke"}
|
||||
jti = _get_jti(token)
|
||||
user_id = None
|
||||
if jti:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, env_not_empty("secret_key"), algorithms=["HS256"], options={"verify_exp": False}
|
||||
)
|
||||
user_id = payload.get("id")
|
||||
exp = payload.get("exp")
|
||||
ttl = max(0, int(exp) - int(datetime.utcnow().timestamp())) if exp else 3600
|
||||
await blacklist_token(jti, ttl)
|
||||
except Exception as e:
|
||||
logger.warning("logout: token decode/blacklist failed", extra={"error": str(e)})
|
||||
jti_safe = (jti[:8] + "...") if jti and len(jti) >= 8 else (jti or None)
|
||||
logger.info("logout", extra={"user_id": user_id, "jti_preview": jti_safe})
|
||||
return {"success": True, "message": "Logged out"}
|
||||
|
|
@ -12,39 +12,30 @@
|
|||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi_babel import _
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.component import code
|
||||
from app.component.auth import Auth, auth_must
|
||||
from app.component.database import session
|
||||
from app.component.encrypt import password_hash, password_verify
|
||||
from app.exception.exception import UserException
|
||||
from app.core import code
|
||||
from app.core.database import session
|
||||
from app.core.encrypt import password_hash, password_verify
|
||||
from app.model.user.user import UpdatePassword, UserOut
|
||||
|
||||
logger = logging.getLogger("server_password_controller")
|
||||
from app.shared.auth import auth_must
|
||||
from app.shared.auth.user_auth import V1UserAuth
|
||||
from app.shared.exception import UserException
|
||||
|
||||
router = APIRouter(tags=["User"])
|
||||
|
||||
|
||||
@router.put("/user/update-password", name="update password", response_model=UserOut)
|
||||
def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)):
|
||||
"""Update user password after verifying current password."""
|
||||
user_id = auth.user.id
|
||||
def update_password(
|
||||
data: UpdatePassword, auth: V1UserAuth = Depends(auth_must), db_session: Session = Depends(session)
|
||||
):
|
||||
model = auth.user
|
||||
|
||||
if not password_verify(data.password, model.password):
|
||||
logger.warning("Password update failed: incorrect current password", extra={"user_id": user_id})
|
||||
raise UserException(code.error, _("Password is incorrect"))
|
||||
|
||||
if data.new_password != data.re_new_password:
|
||||
logger.warning("Password update failed: new passwords do not match", extra={"user_id": user_id})
|
||||
raise UserException(code.error, _("The two passwords do not match"))
|
||||
|
||||
model.password = password_hash(data.new_password)
|
||||
model.save(session)
|
||||
logger.info("Password updated successfully", extra={"user_id": user_id})
|
||||
model.save(db_session)
|
||||
return model
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue