diff --git a/README.md b/README.md index f96ffbab7..fb0cfa86c 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,9 @@ npm install npm run dev ``` +#### 3. Local Development (Use the version that is completely separated from the cloud service) +[server/README_EN.md](./server/README_EN.md) + ### 🏢 Enterprise For organizations requiring maximum security, customization, and control: diff --git a/README_CN.md b/README_CN.md index 9269770b1..374a5e7f3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -118,6 +118,9 @@ npm install npm run dev ``` +#### 3. 本地开发(使用完全和云端服务分离的版本) +[server/README_CN.md](./server/README_CN.md) + ### 🏢 企业版 适合需要最高安全性、定制化和控制的组织: diff --git a/server/.env.example b/server/.env.example new file mode 100644 index 000000000..f066c5385 --- /dev/null +++ b/server/.env.example @@ -0,0 +1,7 @@ +debug=false +url_prefix=/api +secret_key=put-your-secret-key-here +database_url=postgresql://postgres:postgres@localhost:5432/postgres +# Chat Share Secret Key +CHAT_SHARE_SECRET_KEY=put-your-secret-key-here +CHAT_SHARE_SALT=put-your-encode-salt-here diff --git a/server/.gitignore b/server/.gitignore new file mode 100644 index 000000000..bb0a4bf44 --- /dev/null +++ b/server/.gitignore @@ -0,0 +1,18 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info +*.mo + +# Virtual environments +.venv +.env + +runtime + +app/public/upload/ + +uv.lock \ No newline at end of file diff --git a/server/Dockerfile b/server/Dockerfile new file mode 100644 index 000000000..1b4eff273 --- /dev/null +++ b/server/Dockerfile @@ -0,0 +1,53 @@ +# Use a Python image with uv pre-installed +FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim + +# Install the project into `/app` +WORKDIR /app + +# Enable bytecode compilation +ENV UV_COMPILE_BYTECODE=1 + +# Copy from the cache instead of linking since it's a mounted volume +ENV UV_LINK_MODE=copy + +ENV UV_PYTHON_INSTALL_MIRROR=https://registry.npmmirror.com/-/binary/python-build-standalone + +ARG database_url +ENV database_url=$database_url + + +# Copy dependency files first +COPY pyproject.toml uv.lock ./ + +# Install the project's dependencies +RUN --mount=type=cache,target=/root/.cache/uv \ + uv sync --no-install-project --no-dev + +# Then, add the rest of the project source code and install it +# Installing separately from its dependencies allows optimal layer caching +COPY . /app +RUN --mount=type=cache,target=/root/.cache/uv \ + uv sync --no-dev + +RUN uv run pybabel extract -F babel.cfg -o messages.pot . && \ + uv run pybabel init -i messages.pot -d lang -l zh_CN && \ + uv run pybabel compile -d lang -l zh_CN + + +# Install netcat for database connectivity check +RUN apt-get update && apt-get install -y curl netcat-openbsd && rm -rf /var/lib/apt/lists/* + +# Place executables in the environment at the front of the path +ENV PATH="/app/.venv/bin:$PATH" + +# Copy and make the start script executable +COPY start.sh /app/start.sh +RUN chmod +x /app/start.sh + +# Reset the entrypoint, don't invoke `uv` +ENTRYPOINT [] + +EXPOSE 5678 + +# Use the start script +CMD ["/app/start.sh"] \ No newline at end of file diff --git a/server/README_CN.md b/server/README_CN.md new file mode 100644 index 000000000..bf939f812 --- /dev/null +++ b/server/README_CN.md @@ -0,0 +1,106 @@ +### 背景与目标 +本目录 `server/` 是在客户端本地下放的后端服务(FastAPI + PostgreSQL)。目标:实现本地与云端完全数据分离。部署该服务后,用户的注册信息、模型提供商配置、工具配置、聊天历史等敏感数据均保存在本机的数据库中,不会上传到我们的云端,除非你主动配置了外部服务(如云端模型提供商或远程 MCP 服务器)。 + +### 本地下放的服务范围(主要模块) +- 用户与账号 + - `POST /register`:邮箱 + 密码注册(仅本地 DB) + - `POST /login`:邮箱 + 密码登录,返回本地签发的 Token + - `GET/PUT /user`、`/user/profile`、`/user/privacy`、`/user/current_credits`、`/user/stat` 等 +- 模型提供商 Provider(保存你的模型访问配置) + - `GET /providers`、`POST /provider`、`PUT /provider/{id}`、`DELETE /provider/{id}` + - `POST /provider/prefer`:设置首选 Provider(前端与后端将优先使用) +- 配置中心 Config(保存各类工具/能力所需的密钥或参数) + - `GET /configs`、`POST /configs`、`PUT /configs/{id}`、`DELETE /configs/{id}`、`GET /config/info` +- 聊天与数据 + - 历史、快照、分享等接口位于 `app/controller/chat/`,数据全部落在本地数据库 +- MCP 服务管理(导入本地/远程 MCP 服务器) + - `GET /mcps`、`POST /mcp/install`、`POST /mcp/import/{Local|Remote}` 等 + +说明:上述数据均保存在 Docker 中的本地 PostgreSQL 卷中(见“数据持久化”),不经我们云端。若你配置了外部模型或远程 MCP,则相应请求会发往你指定的第三方服务。 + +--- + +### 快速开始(Docker 推荐) +前置要求:已安装 Docker Desktop。 + +1) 启动服务 +```bash +cd server +docker compose up -d +``` + +2) 启动前端(本地模式) +- 在项目根目录创建或修改 `.env.development`,开启本地模式并指向本地后端: +```bash +VITE_USE_LOCAL_PROXY=true +VITE_PROXY_URL=http://localhost:3001 +``` +- 启动前端应用: +```bash +npm install +npm run dev +``` + +### 访问 API 文档 +- 浏览器打开 `http://localhost:3001/docs`(Swagger UI) + +### 容器与端口 +- API 服务:本机 `3001` → 容器 `5678` +- PostgreSQL:本机 `5432` → 容器 `5432` + +### 数据持久化 +- 数据库数据存放在 Docker 卷 `server_postgres_data`,容器路径 `/var/lib/postgresql/data` +- 容器启动时会自动执行数据库迁移(见 `start.sh` 中的 `alembic upgrade head`) + +### 常用命令 +```bash +# 查看运行中的容器 +docker ps + +# 停止/启动 API 容器(保留数据库) +docker stop eigent_api +docker start eigent_api + +# 停止/启动全部(API + DB) +docker compose stop +docker compose start + +# 查看日志 +docker logs -f eigent_api | cat +docker logs -f eigent_postgres | cat +``` +提示:若拉取镜像缓慢,可在 Docker Desktop 配置国内镜像加速后重试。 + +--- + +### 开发模式(可选) +如果希望在本地以热重载方式开发 API(数据库仍用 Docker 中的 Postgres): +```bash +# 1) 停止容器中的 API 服务,仅保留数据库 + docker stop eigent_api + +# 2) 本地启动(需提供数据库连接串) + cd server + # 方式 A:在当前 shell 导出环境变量 + export database_url=postgresql://postgres:123456@localhost:5432/eigent + uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0 + + # 方式 B:在 server/.env 中写入(示例) + # database_url=postgresql://postgres:123456@localhost:5432/eigent + # 然后直接运行同样的 uvicorn 命令 +uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0 +``` + +--- + +### 其它 +- API 文档:`http://localhost:3001/docs` +- 运行时日志:容器内 `/app/runtime/log/app.log` +- i18n 相关(仅开发者使用) +```bash +uv run pybabel extract -F babel.cfg -o messages.pot . +uv run pybabel init -i messages.pot -d lang -l zh_CN +uv run pybabel compile -d lang -l zh_CN +``` + +如需完全离线环境,请仅使用本地模型与本地 MCP 服务器,并避免配置任何外部 Provider 与远程 MCP 地址。 \ No newline at end of file diff --git a/server/README_EN.md b/server/README_EN.md new file mode 100644 index 000000000..dbfdad24b --- /dev/null +++ b/server/README_EN.md @@ -0,0 +1,101 @@ +### Purpose +`server/` provides a local backend (FastAPI + PostgreSQL) to achieve complete separation between local and cloud environments. After deploying this service, sensitive data such as user registration, model provider configurations, tool settings, and chat history are stored on your machine and are not uploaded to our cloud unless you explicitly configure external services (e.g., cloud model providers or remote MCP servers). + +### Services Provided (Main Modules) +- Users & Accounts + - `POST /register`: Email + password registration (local DB only) + - `POST /login`: Email + password login; returns a locally issued token + - `GET/PUT /user`, `/user/profile`, `/user/privacy`, `/user/current_credits`, `/user/stat`, etc. +- Model Providers (store local/cloud model access configurations) + - `GET /providers`, `POST /provider`, `PUT /provider/{id}`, `DELETE /provider/{id}` + - `POST /provider/prefer`: Set a preferred provider (frontend/backend will prioritize it) +- Config Center (store secrets/params required by tools/capabilities) + - `GET /configs`, `POST /configs`, `PUT /configs/{id}`, `DELETE /configs/{id}`, `GET /config/info` +- Chat & Data + - History, snapshots, sharing, etc. in `app/controller/chat/`, all persisted to local DB +- MCP Management (import local/remote MCP servers) + - `GET /mcps`, `POST /mcp/install`, `POST /mcp/import/{Local|Remote}`, etc. + +Note: All the above data is stored in the local PostgreSQL volume in Docker (see “Data Persistence” below). If you configure external models or remote MCP, requests go to the third-party services you specify. + +--- + +### Quick Start (Docker) +Prerequisite: Docker Desktop installed. + +1) Start services +```bash +cd server +docker compose up -d +``` + +2) Start Frontend (Local Mode) +- In the project root directory, create or modify `.env.development` to enable local mode and point to the local backend: +```bash +VITE_USE_LOCAL_PROXY=true +VITE_PROXY_URL=http://localhost:3001 +``` +- Start the frontend application: +```bash +npm install +npm run dev +``` + +### Open API docs +- `http://localhost:3001/docs` (Swagger UI) + +### Ports +- API: Host `3001` → Container `5678` +- PostgreSQL: Host `5432` → Container `5432` + +### Data Persistence +- DB data is stored in Docker volume `server_postgres_data` at `/var/lib/postgresql/data` inside the container +- Database migrations run automatically on container startup (see `start.sh` → `alembic upgrade head`) + +### Common Commands +```bash +# List running containers +docker ps + +# Stop/Start API container (keep DB) +docker stop eigent_api +docker start eigent_api + +# Stop/Start all (API + DB) +docker compose stop +docker compose start + +# View logs +docker logs -f eigent_api | cat +docker logs -f eigent_postgres | cat +``` + +--- + +### Developer Mode (Optional) +You can run the API locally with hot-reload while keeping the database in Docker: +```bash +# Stop API in container, keep DB +docker stop eigent_api + +# Run locally (provide DB connection string) +cd server +export database_url=postgresql://postgres:123456@localhost:5432/eigent +uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0 +``` + +--- + +### Others +- API docs: `http://localhost:3001/docs` +- Runtime logs: `/app/runtime/log/app.log` in the container +- i18n (for developers) +```bash +uv run pybabel extract -F babel.cfg -o messages.pot . +uv run pybabel init -i messages.pot -d lang -l zh_CN +uv run pybabel compile -d lang -l zh_CN +``` + +For a fully offline environment, only use local models and local MCP servers, and avoid configuring any external Providers or remote MCP addresses. + + diff --git a/server/alembic.ini b/server/alembic.ini new file mode 100644 index 000000000..25dcac1b2 --- /dev/null +++ b/server/alembic.ini @@ -0,0 +1,119 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = postgresql://postgres:123456@localhost:5432/eigent + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/server/alembic/env.py b/server/alembic/env.py new file mode 100644 index 000000000..4dccfeb6a --- /dev/null +++ b/server/alembic/env.py @@ -0,0 +1,129 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool +from alembic import context +from sqlmodel import SQLModel +from app.component.environment import auto_import, env_not_empty + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +auto_import("app.model.mcp") +auto_import("app.model.user") +auto_import("app.model.config") +auto_import("app.model.chat") +auto_import("app.model.provider") + + +# target_metadata = mymodel.Base.metadata +target_metadata = SQLModel.metadata + + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +# https://alembic.sqlalchemy.org/en/latest/autogenerate.html#affecting-the-rendering-of-types-themselves +def render_item(type_, obj, autogen_context): + """Apply rendering for custom sqlalchemy types""" + if type_ == "type": + module_name = obj.__class__.__module__ + if module_name.startswith("sqlalchemy_utils."): + return render_sqlalchemy_utils_type(obj, autogen_context) + + # render default + return False + + +def render_sqlalchemy_utils_type(obj, autogen_context): + class_name = obj.__class__.__name__ + import_statement = f"from sqlalchemy_utils.types import {class_name}" + autogen_context.imports.add(import_statement) + if class_name == "ChoiceType": + return render_choice_type(obj, autogen_context) + return f"{class_name}()" + + +def render_choice_type(obj, autogen_context): + choices = obj.choices + if obj.type_impl.__class__.__name__ == "EnumTypeImpl": + choices = obj.type_impl.enum_class.__name__ + import_statement = f"from {obj.type_impl.enum_class.__module__} import {choices}" + autogen_context.imports.add(import_statement) + impl_stmt = f"sa.{obj.impl.__class__.__name__}()" + return f"{obj.__class__.__name__}(choices={choices}, impl={impl_stmt})" + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def include_object(object, name, type_, reflected, compare_to): + # ignore all foreign key constraints + if type_ == "foreign_key_constraint": + return False + return True + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + options = config.get_section(config.config_ini_section, {}) + options["sqlalchemy.url"] = env_not_empty("database_url") + connectable = engine_from_config( + options, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + render_item=render_item, + include_object=include_object, + target_metadata=target_metadata, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/server/alembic/script.py.mako b/server/alembic/script.py.mako new file mode 100644 index 000000000..8c3031b81 --- /dev/null +++ b/server/alembic/script.py.mako @@ -0,0 +1,29 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/server/alembic/versions/2025_08_11_2107-0001_init_initial_schema.py b/server/alembic/versions/2025_08_11_2107-0001_init_initial_schema.py new file mode 100644 index 000000000..31c88c262 --- /dev/null +++ b/server/alembic/versions/2025_08_11_2107-0001_init_initial_schema.py @@ -0,0 +1,400 @@ +"""initial schema + +Revision ID: 0001_init +Revises: +Create Date: 2025-08-11 21:07:03.701363 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from app.model.chat.chat_history import ChatStatus +from app.model.mcp.mcp import McpType +from app.model.mcp.mcp import Status +from app.model.mcp.mcp_env import Status +from app.model.mcp.mcp_user import McpType +from app.model.mcp.mcp_user import Status +from app.model.provider.provider import VaildStatus +from app.model.user.admin import Status +from app.model.user.key import KeyStatus +from app.model.user.role import RoleType +from app.model.user.user import Status +from app.model.user.user_credits_record import CreditsChannel +from sqlalchemy_utils.types import ChoiceType + +# revision identifiers, used by Alembic. +revision: str = "0001_init" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "admin", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("email", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("status", ChoiceType(choices=Status, impl=sa.SmallInteger()), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "admin_role", + sa.Column("admin_id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("admin_id", "role_id"), + ) + op.create_table( + "category", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("priority", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "chat_history", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("task_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("question", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("language", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("model_platform", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("model_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("api_url", sa.String(length=500), nullable=True), + sa.Column("max_retries", sa.Integer(), nullable=False), + sa.Column("file_save_path", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("installed_mcp", sa.JSON(), nullable=False), + sa.Column("project_name", sa.String(length=128), nullable=True), + sa.Column("summary", sa.String(length=1024), nullable=True), + sa.Column("tokens", sa.Integer(), server_default="0", nullable=True), + sa.Column("spend", sa.Float(), server_default="0", nullable=True), + sa.Column("status", ChoiceType(choices=ChatStatus, impl=sa.SmallInteger()), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_chat_history_task_id"), "chat_history", ["task_id"], unique=True) + op.create_index(op.f("ix_chat_history_user_id"), "chat_history", ["user_id"], unique=False) + op.create_table( + "chat_snapshot", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), server_default=sa.text("0"), nullable=True), + sa.Column("api_task_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("camel_task_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("browser_url", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("image_path", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_chat_snapshot_api_task_id"), "chat_snapshot", ["api_task_id"], unique=False) + op.create_index(op.f("ix_chat_snapshot_camel_task_id"), "chat_snapshot", ["camel_task_id"], unique=False) + op.create_table( + "chat_step", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("step", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("data", sa.JSON(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_chat_step_task_id"), "chat_step", ["task_id"], unique=False) + op.create_table( + "config", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("config_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("config_value", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("config_group", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "config_name", name="uix_user_id_config_name"), + ) + op.create_index(op.f("ix_config_config_group"), "config", ["config_group"], unique=False) + op.create_index(op.f("ix_config_user_id"), "config", ["user_id"], unique=False) + op.create_table( + "plan", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("plan_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("price_month", sa.Float(), nullable=False), + sa.Column("price_year", sa.Float(), nullable=False), + sa.Column("daily_credits", sa.Integer(), nullable=False), + sa.Column("monthly_credits", sa.Integer(), nullable=False), + sa.Column("storage_limit", sa.Integer(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("extra_config", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_plan_name"), "plan", ["name"], unique=True) + op.create_index(op.f("ix_plan_plan_key"), "plan", ["plan_key"], unique=True) + op.create_table( + "provider", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("provider_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("model_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("endpoint_url", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("encrypted_config", sa.JSON(), nullable=True), + sa.Column("prefer", sa.Boolean(), server_default=sa.text("false"), nullable=True), + sa.Column( + "is_vaild", + ChoiceType(choices=VaildStatus, impl=sa.SmallInteger()), + server_default=sa.text("1"), + nullable=True, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_provider_user_id"), "provider", ["user_id"], unique=False) + op.create_table( + "role", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("type", ChoiceType(choices=RoleType, impl=sa.SmallInteger()), nullable=True), + sa.Column("permissions", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("stack_id", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column("username", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=True), + sa.Column("email", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("password", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=True), + sa.Column("avatar", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False), + sa.Column("nickname", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False), + sa.Column("fullname", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("work_desc", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column("credits", sa.Integer(), server_default=sa.text("0"), nullable=True), + sa.Column("last_daily_credit_date", sa.Date(), nullable=True), + sa.Column("last_monthly_credit_date", sa.Date(), nullable=True), + sa.Column("inviter_user_id", sa.Integer(), nullable=True), + sa.Column("status", ChoiceType(choices=Status, impl=sa.SmallInteger()), nullable=True), + sa.ForeignKeyConstraint( + ["inviter_user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + sa.UniqueConstraint("stack_id"), + sa.UniqueConstraint("username"), + ) + op.create_table( + "key", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("value", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column("inner_key", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column("status", ChoiceType(choices=KeyStatus, impl=sa.SmallInteger()), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_key_user_id"), "key", ["user_id"], unique=False) + op.create_index(op.f("ix_key_value"), "key", ["value"], unique=False) + op.create_table( + "mcp", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("category_id", sa.Integer(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("key", sa.String(length=128), nullable=True), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("home_page", sa.String(length=1024), nullable=True), + sa.Column("type", ChoiceType(choices=McpType, impl=sa.SmallInteger()), nullable=True), + sa.Column("status", ChoiceType(choices=Status, impl=sa.SmallInteger()), nullable=True), + sa.Column("sort", sa.SmallInteger(), nullable=True), + sa.Column("server_name", sa.String(length=128), nullable=True), + sa.Column("install_command", sa.JSON(), nullable=True), + sa.ForeignKeyConstraint( + ["category_id"], + ["category.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user_credits_record", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("invite_by", sa.Integer(), nullable=True), + sa.Column("invite_code", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column("amount", sa.Integer(), nullable=False), + sa.Column("balance", sa.Integer(), nullable=False), + sa.Column("channel", ChoiceType(choices=CreditsChannel, impl=sa.SmallInteger()), nullable=True), + sa.Column("source_id", sa.Integer(), nullable=False), + sa.Column("remark", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column("expire_at", sa.DateTime(), nullable=True), + sa.Column("used", sa.Boolean(), server_default=sa.text("false"), nullable=True), + sa.Column("used_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_user_credits_record_user_id"), "user_credits_record", ["user_id"], unique=False) + op.create_table( + "user_privacy", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("pricacy_setting", sa.JSON(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id"), + ) + op.create_table( + "user_stat", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("model_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("download_count", sa.Integer(), nullable=False), + sa.Column("register_count", sa.Integer(), nullable=False), + sa.Column("task_complete_count", sa.Integer(), nullable=False), + sa.Column("task_failed_count", sa.Integer(), nullable=False), + sa.Column("file_download_count", sa.Integer(), nullable=False), + sa.Column("file_generate_count", sa.Integer(), nullable=False), + sa.Column("paid_amount_on_avg_task", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_user_stat_user_id"), "user_stat", ["user_id"], unique=False) + op.create_table( + "mcp_env", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("mcp_id", sa.Integer(), nullable=False), + sa.Column("env_name", sa.String(length=128), nullable=True), + sa.Column("env_description", sa.TEXT(), nullable=True), + sa.Column("env_key", sa.String(length=128), nullable=True), + sa.Column("env_default_value", sa.String(length=1024), nullable=True), + sa.Column("env_required", sa.SmallInteger(), nullable=True), + sa.Column("status", ChoiceType(choices=Status, impl=sa.SmallInteger()), nullable=True), + sa.ForeignKeyConstraint( + ["mcp_id"], + ["mcp.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "mcp_user", + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("mcp_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("mcp_name", sa.String(length=128), nullable=True), + sa.Column("mcp_key", sa.String(length=128), nullable=True), + sa.Column("mcp_desc", sa.String(length=1024), nullable=True), + sa.Column("command", sa.String(length=1024), nullable=True), + sa.Column("args", sa.String(length=1024), nullable=True), + sa.Column("env", sa.JSON(), nullable=True), + sa.Column("type", ChoiceType(choices=McpType, impl=sa.SmallInteger()), nullable=True), + sa.Column("status", ChoiceType(choices=Status, impl=sa.SmallInteger()), nullable=True), + sa.Column("server_url", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["mcp_id"], + ["mcp.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("mcp_user") + op.drop_table("mcp_env") + op.drop_index(op.f("ix_user_stat_user_id"), table_name="user_stat") + op.drop_table("user_stat") + op.drop_table("user_privacy") + op.drop_index(op.f("ix_user_credits_record_user_id"), table_name="user_credits_record") + op.drop_table("user_credits_record") + op.drop_table("mcp") + op.drop_index(op.f("ix_key_value"), table_name="key") + op.drop_index(op.f("ix_key_user_id"), table_name="key") + op.drop_table("key") + op.drop_table("user") + op.drop_table("role") + op.drop_index(op.f("ix_provider_user_id"), table_name="provider") + op.drop_table("provider") + op.drop_index(op.f("ix_plan_plan_key"), table_name="plan") + op.drop_index(op.f("ix_plan_name"), table_name="plan") + op.drop_table("plan") + op.drop_index(op.f("ix_config_user_id"), table_name="config") + op.drop_index(op.f("ix_config_config_group"), table_name="config") + op.drop_table("config") + op.drop_index(op.f("ix_chat_step_task_id"), table_name="chat_step") + op.drop_table("chat_step") + op.drop_index(op.f("ix_chat_snapshot_camel_task_id"), table_name="chat_snapshot") + op.drop_index(op.f("ix_chat_snapshot_api_task_id"), table_name="chat_snapshot") + op.drop_table("chat_snapshot") + op.drop_index(op.f("ix_chat_history_user_id"), table_name="chat_history") + op.drop_index(op.f("ix_chat_history_task_id"), table_name="chat_history") + op.drop_table("chat_history") + op.drop_table("category") + op.drop_table("admin_role") + op.drop_table("admin") + # ### end Alembic commands ### diff --git a/server/app/__init__.py b/server/app/__init__.py new file mode 100644 index 000000000..101f7685b --- /dev/null +++ b/server/app/__init__.py @@ -0,0 +1,5 @@ +from fastapi import FastAPI +from fastapi_pagination import add_pagination + +api = FastAPI(swagger_ui_parameters={"persistAuthorization": True}) +add_pagination(api) diff --git a/server/app/component/auth.py b/server/app/component/auth.py new file mode 100644 index 000000000..4a74d410f --- /dev/null +++ b/server/app/component/auth.py @@ -0,0 +1,90 @@ +from fastapi import Depends, Header +from fastapi_babel import _ +from sqlmodel import Session, select +from app.component import code +from fastapi.security import OAuth2PasswordBearer +from app.component.database import session +from app.component.environment import env, env_not_empty +from datetime import timedelta, datetime +import jwt +from jwt.exceptions import InvalidTokenError +from app.model.mcp.proxy import ApiKey +from app.model.user.key import Key +from app.model.user.user import User + +from app.exception.exception import ( + NoPermissionException, + TokenException, +) + + +class Auth: + SECRET_KEY = env_not_empty("secret_key") + + def __init__(self, id: int, expired_at: datetime): + self.id = id + self.expired_at = expired_at + self._user: User | None = None + + @property + def user(self): + if self._user is None: + raise NoPermissionException("未查询到登录用户") + return self._user + + @classmethod + def decode_token(cls, token: str): + try: + payload = jwt.decode(token, Auth.SECRET_KEY, algorithms=["HS256"]) + id = payload["id"] + if payload["exp"] < int(datetime.now().timestamp()): + raise TokenException(code.token_expired, _("Validate credentials expired")) + except InvalidTokenError: + raise TokenException(code.token_invalid, _("Could not validate credentials")) + return Auth(id, payload["exp"]) + + @classmethod + def create_access_token(cls, user_id: int, expires_delta: timedelta | None = None): + to_encode: dict = {"id": user_id} + if expires_delta: + expire = datetime.now() + expires_delta + else: + expire = datetime.now() + timedelta(days=30) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, Auth.SECRET_KEY, algorithm="HS256") + return encoded_jwt + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{env('url_prefix', '')}/dev_login", auto_error=False) + + +async def auth( + token: str | None = Depends(oauth2_scheme), + session: Session = Depends(session), +) -> Auth | None: + if token is None: + return None + try: + model = Auth.decode_token(token) + user = session.get(User, model.id) + model._user = user + return model + except Exception: + return None + + +async def auth_must( + token: str = Depends(oauth2_scheme), + session: Session = Depends(session), +) -> Auth: + model = Auth.decode_token(token) + user = session.get(User, model.id) + model._user = user + return model + + +async def key_must(headers: ApiKey = Header(), session: Session = Depends(session)): + model = session.exec(select(Key).where(Key.value == headers.api_key)).one_or_none() + if model is None: + raise TokenException(code.token_invalid, _(f"Could not validate key credentials: {headers.api_key}")) + return model diff --git a/server/app/component/babel.py b/server/app/component/babel.py new file mode 100644 index 000000000..83382a269 --- /dev/null +++ b/server/app/component/babel.py @@ -0,0 +1,10 @@ +from fastapi_babel import BabelConfigs, Babel +from pathlib import Path + +babel_configs = BabelConfigs( + ROOT_DIR=Path(__file__).parent.parent, + BABEL_DEFAULT_LOCALE="en_US", + BABEL_TRANSLATION_DIRECTORY="lang", +) + +babel = Babel(configs=babel_configs) diff --git a/server/app/component/code.py b/server/app/component/code.py new file mode 100644 index 000000000..f62936656 --- /dev/null +++ b/server/app/component/code.py @@ -0,0 +1,13 @@ +success = 0 # success response code +error = 1 # common error response code +not_found = 4 # can't find route or data + +password = 10 # account or password error +token_need = 11 +token_expired = 12 +token_invalid = 13 +token_blocked = 14 + +form_error = 100 # form pydantic validate error + +no_permission_error = 300 # admin no permission error diff --git a/server/app/component/database.py b/server/app/component/database.py new file mode 100644 index 000000000..1a9cfb9c9 --- /dev/null +++ b/server/app/component/database.py @@ -0,0 +1,18 @@ +from sqlmodel import Session, create_engine +from app.component.environment import env, env_or_fail + + +engine = create_engine( + env_or_fail("database_url"), + echo=True if env("debug") == "on" else False, + pool_size=36, +) + + +def session_make(): + return Session(engine) + + +def session(): + with Session(engine) as session: + yield session diff --git a/server/app/component/encrypt.py b/server/app/component/encrypt.py new file mode 100644 index 000000000..a6ceee781 --- /dev/null +++ b/server/app/component/encrypt.py @@ -0,0 +1,13 @@ +from passlib.context import CryptContext + +password = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def password_hash(password_value: str): + return password.hash(password_value) + + +def password_verify(password_value: str, password_hash: str | None): + if not password_hash: + return False + return password.verify(password_value, password_hash) diff --git a/server/app/component/environment.py b/server/app/component/environment.py new file mode 100644 index 000000000..626a79f89 --- /dev/null +++ b/server/app/component/environment.py @@ -0,0 +1,97 @@ +import importlib.util +import os +from pathlib import Path +from fastapi import APIRouter, FastAPI +from dotenv import load_dotenv +import importlib +from typing import Any, overload + + +load_dotenv() + + +@overload +def env(key: str) -> str | None: ... + + +@overload +def env(key: str, default: str) -> str: ... + + +@overload +def env(key: str, default: Any) -> Any: ... + + +def env(key: str, default=None): + return os.getenv(key, default) + + +def env_or_fail(key: str): + value = env(key) + if value is None: + raise Exception("can't get env config value.") + return value + + +def env_not_empty(key: str): + value = env(key) + if not value: + raise Exception("env config value can't be empty.") + return value + + +def base_path(): + return Path(__file__).parent.parent.parent + + +def to_path(path: str): + return base_path() / path + + +def auto_import(package: str): + """ + 自动导入指定目录下的全部py文件 + """ + # 获取文件夹下的所有文件名 + folder = package.replace(".", "/") + files = os.listdir(folder) + + # 导入文件夹下的所有.py文件 + for file in files: + if file.endswith(".py") and not file.startswith("__"): + module_name = file[:-3] # 去掉文件名的扩展名.py + importlib.import_module(package + "." + module_name) + + +def auto_include_routers(api: FastAPI, prefix: str, directory: str): + """ + 自动扫描指定目录下的所有模块并注册路由 + + :param api: FastAPI 实例 + :param prefix: 路由前缀 + :param directory: 要扫描的目录路径 + """ + # 将目录转换为绝对路径 + dir_path = Path(directory).resolve() + + # 遍历目录下所有.py文件 + for root, _, files in os.walk(dir_path): + for file_name in files: + if file_name.endswith("_controller.py") and not file_name.startswith("__"): + # 构造完整文件路径 + file_path = Path(root) / file_name + + # 生成模块名称 + module_name = file_path.stem + + # 使用importlib加载模块 + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + continue + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # 检查模块中是否存在router属性且是APIRouter实例 + router = getattr(module, "router", None) + if isinstance(router, APIRouter): + api.include_router(router, prefix=prefix) diff --git a/server/app/component/oauth_adapter.py b/server/app/component/oauth_adapter.py new file mode 100644 index 000000000..5264327fe --- /dev/null +++ b/server/app/component/oauth_adapter.py @@ -0,0 +1,205 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +import os +import httpx +from pydantic import BaseModel +import base64 +import json +from app.component.environment import env + + +class OAuthAdapter(ABC): + @abstractmethod + def get_authorize_url(self, state: Optional[str] = None) -> Optional[str]: + pass + + @abstractmethod + def fetch_token(self, code: Optional[str]) -> Optional[Dict[str, Any]]: + pass + + +class SlackOAuthAdapter(OAuthAdapter): + def __init__(self, redirect_uri: Optional[str] = None): + self.client_id = env("SLACK_CLIENT_ID", "your_client_id") + self.client_secret = env("SLACK_CLIENT_SECRET", "your_client_secret") + self.redirect_uri = redirect_uri or env("SLACK_REDIRECT_URI", "https://localhost/api/oauth/slack/callback") + self.scope = env("SLACK_SCOPE", "chat:write,channels:read,channels:join,groups:read,im:write") + + def get_authorize_url(self, state: Optional[str] = None) -> Optional[str]: + url = ( + f"https://slack.com/oauth/v2/authorize?client_id={self.client_id}" + f"&scope={self.scope}" + f"&redirect_uri={self.redirect_uri}" + ) + if state: + url += f"&state={state}" + return url + + def fetch_token(self, code: Optional[str]) -> Optional[Dict[str, Any]]: + if not code: + return None + token_url = "https://slack.com/api/oauth.v2.access" + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, + } + with httpx.Client() as client: + resp = client.post(token_url, data=data) + return resp.json() + + +class NotionOAuthAdapter(OAuthAdapter): + def __init__(self, redirect_uri: Optional[str] = None): + self.client_id = env("NOTION_CLIENT_ID", "your_notion_client_id") + self.client_secret = env("NOTION_CLIENT_SECRET", "your_notion_client_secret") + self.redirect_uri = redirect_uri or env("NOTION_REDIRECT_URI", "https://localhost/api/oauth/notion/callback") + self.scope = env("NOTION_SCOPE", "") # Notion目前scope可为空 + + def get_authorize_url(self, state: Optional[str] = None) -> Optional[str]: + url = ( + f"https://api.notion.com/v1/oauth/authorize?client_id={self.client_id}" + f"&owner=user" + f"&response_type=code" + f"&redirect_uri={self.redirect_uri}" + ) + if state: + url += f"&state={state}" + return url + + def fetch_token(self, code: Optional[str]) -> Optional[Dict[str, Any]]: + if not code: + return None + token_url = "https://api.notion.com/v1/oauth/token" + + basic_auth = base64.b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode() + headers = { + "Authorization": f"Basic {basic_auth}", + "Content-Type": "application/json", + "Accept": "application/json", + } + data = {"grant_type": "authorization_code", "code": code, "redirect_uri": self.redirect_uri} + with httpx.Client() as client: + resp = client.post(token_url, headers=headers, json=data) + return resp.json() + + +class XOAuthAdapter(OAuthAdapter): + def __init__(self, redirect_uri: Optional[str] = None): + self.client_id = env("X_CLIENT_ID", "your_x_client_id") + self.client_secret = env("X_CLIENT_SECRET", "your_x_client_secret") + self.redirect_uri = redirect_uri or env("X_REDIRECT_URI", "https://localhost/api/oauth/x/callback") + self.scope = env("X_SCOPE", "tweet.read users.read offline.access") + + def get_authorize_url( + self, state: Optional[str] = None, code_challenge: Optional[str] = None, code_challenge_method: str = "plain" + ) -> Optional[str]: + # code_challenge建议由外部生成并传入,PKCE安全 + url = ( + f"https://twitter.com/i/oauth2/authorize?response_type=code" + f"&client_id={self.client_id}" + f"&redirect_uri={self.redirect_uri}" + f"&scope={self.scope.replace(' ', '%20')}" + f"&state={state or ''}" + ) + if code_challenge: + url += f"&code_challenge={code_challenge}&code_challenge_method={code_challenge_method}" + return url + + def fetch_token(self, code: Optional[str], code_verifier: Optional[str] = None) -> Optional[Dict[str, Any]]: + if not code: + return None + token_url = "https://api.twitter.com/2/oauth2/token" + headers = {"Content-Type": "application/x-www-form-urlencoded"} + data = { + "grant_type": "authorization_code", + "code": code, + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + } + if code_verifier: + data["code_verifier"] = code_verifier + with httpx.Client() as client: + resp = client.post(token_url, headers=headers, data=data) + return resp.json() + + +class GoogleSuiteOAuthAdapter(OAuthAdapter): + def __init__(self, redirect_uri: Optional[str] = None): + self.client_id = env("GOOGLE_SUITE_CLIENT_ID", "your_google_suite_client_id") + self.client_secret = env("GOOGLE_SUITE_CLIENT_SECRET", "your_google_suite_client_secret") + self.redirect_uri = redirect_uri or env( + "GOOGLE_SUITE_REDIRECT_URI", "https://localhost/api/oauth/google_suite/callback" + ) + self.scope = env( + "GOOGLE_SUITE_SCOPE", "openid email profile https://www.googleapis.com/auth/drive.metadata.readonly" + ) + + def get_authorize_url(self, state: Optional[str] = None) -> Optional[str]: + url = ( + f"https://accounts.google.com/o/oauth2/v2/auth?" + f"client_id={self.client_id}" + f"&redirect_uri={self.redirect_uri}" + f"&response_type=code" + f"&scope={self.scope.replace(' ', '%20')}" + f"&access_type=offline" + f"&include_granted_scopes=true" + ) + if state: + url += f"&state={state}" + return url + + def fetch_token(self, code: Optional[str]) -> Optional[Dict[str, Any]]: + if not code: + return None + token_url = "https://oauth2.googleapis.com/token" + headers = {"Content-Type": "application/x-www-form-urlencoded"} + data = { + "code": code, + "client_id": self.client_id, + "client_secret": self.client_secret, + "redirect_uri": self.redirect_uri, + "grant_type": "authorization_code", + } + with httpx.Client() as client: + resp = client.post(token_url, headers=headers, data=data) + return resp.json() + + +class EXAOAuthAdapter(OAuthAdapter): + def get_authorize_url(self, state: Optional[str] = None) -> Optional[str]: + # TODO: 实现EXA search授权URL生成 + return None + + def fetch_token(self, code: Optional[str]) -> Optional[Dict[str, Any]]: + # TODO: 实现EXA search用code换token + return None + + +# 工厂方法 +OAUTH_ADAPTERS = { + "slack": SlackOAuthAdapter, + "notion": NotionOAuthAdapter, + "x": XOAuthAdapter, + "twitter": XOAuthAdapter, + "googlesuite": GoogleSuiteOAuthAdapter, +} + + +def get_oauth_adapter(app_name: str, redirect_uri: Optional[str] = None) -> OAuthAdapter: + adapter_cls = OAUTH_ADAPTERS.get(app_name.lower()) + if not adapter_cls: + raise ValueError(f"不支持的OAuth应用: {app_name}") + if app_name.lower() == "slack": + return adapter_cls(redirect_uri=redirect_uri) + if app_name.lower() == "notion": + return adapter_cls(redirect_uri=redirect_uri) + if app_name.lower() == "x" or app_name.lower() == "twitter": + return adapter_cls(redirect_uri=redirect_uri) + return adapter_cls() + + +class OauthCallbackPayload(BaseModel): + code: str + state: Optional[str] = None diff --git a/server/app/component/permission.py b/server/app/component/permission.py new file mode 100644 index 000000000..1cae0db80 --- /dev/null +++ b/server/app/component/permission.py @@ -0,0 +1,75 @@ +from fastapi_babel import _ + +""" +权限定义: +当存在子权限的时候,父权限则不生效,应该全部放至子权限中定义处理 +""" + + +def permissions(): + return [ + { + "name": _("User"), + "description": _("User manger"), + "children": [ + { + "identity": "user:view", + "name": _("User Manage"), + "description": _("View users"), + }, + { + "identity": "user:edit", + "name": _("User Edit"), + "description": _("Manage users"), # 修改用户信息,邀请用户(限本组织下) + }, + ], + }, + { + "name": _("Admin"), + "description": _("Admin manger"), + "children": [ + { + "identity": "admin:view", + "name": _("Admin View"), + "description": _("View admins"), # 修改项目,工作区,角色,用户 + }, + { + "identity": "admin:edit", + "name": _("Admin Edit"), + "description": _("Edit admins"), + }, + ], + }, + { + "name": _("Role"), + "description": _("Role manger"), + "children": [ + { + "identity": "role:view", + "name": _("Role View"), + "description": _("View roles"), # 修改项目和工作区中的角色,创建新的角色 + }, + { + "identity": "role:edit", + "name": _("Role Edit"), + "description": _("Edit roles"), # 修改角色 + }, + ], + }, + { + "name": _("Mcp"), + "description": _("Mcp manger"), + "children": [ + { + "identity": "mcp:edit", + "name": _("Mcp Edit"), + "description": _("Edit mcp service"), + }, + { + "identity": "mcp-category:edit", + "name": _("Mcp Category Edit"), + "description": _("Edit mcp category"), + }, + ], + }, + ] diff --git a/server/app/component/pydantic/i18n.py b/server/app/component/pydantic/i18n.py new file mode 100644 index 000000000..67d42d487 --- /dev/null +++ b/server/app/component/pydantic/i18n.py @@ -0,0 +1,48 @@ +from pathlib import Path +from app.component.babel import babel_configs, babel +import re +import os +from fastapi_babel.middleware import LANGUAGES_PATTERN +from pydantic_i18n import JsonLoader, PydanticI18n + + +def get_language(lang_code: str | None = None): + """移植于fastapi_babel.middleware.BabelMiddleware.get_language + Applies an available language. + + To apply an available language it will be searched in the language folder for an available one + and will also priotize the one with the highest quality value. The Fallback language will be the + taken from the BABEL_DEFAULT_LOCALE var. + + Args: + babel (Babel): Request scoped Babel instance + lang_code (str): The Value of the Accept-Language Header. + + Returns: + str: The language that should be used. + """ + + if not lang_code: + return babel.config.BABEL_DEFAULT_LOCALE + + matches = re.finditer(LANGUAGES_PATTERN, lang_code) + languages = [(f"{m.group(1)}{f'_{m.group(2)}' if m.group(2) else ''}", m.group(3) or "") for m in matches] + languages = sorted(languages, key=lambda x: x[1], reverse=True) # sort the priority, no priority comes last + translation_directory = Path(babel.config.BABEL_TRANSLATION_DIRECTORY) + translation_files = [i.name for i in translation_directory.iterdir()] + explicit_priority = None + + for lang, quality in languages: + if lang in translation_files: + if not quality: # languages without quality value having the highest priority 1 + return lang + + elif not explicit_priority: # set language with explicit priority <= priority 1 + explicit_priority = lang + + # Return language with explicit priority or default value + return explicit_priority if explicit_priority else babel_configs.BABEL_DEFAULT_LOCALE + + +loader = JsonLoader(os.path.dirname(__file__) + "/translations") +trans = PydanticI18n(loader) diff --git a/server/app/component/pydantic/translations/en_US.json b/server/app/component/pydantic/translations/en_US.json new file mode 100644 index 000000000..f2f5f0d1b --- /dev/null +++ b/server/app/component/pydantic/translations/en_US.json @@ -0,0 +1,98 @@ +{ + "Object has no attribute '{}'": "Object has no attribute '{}'", + "Invalid JSON: {}": "Invalid JSON: {}", + "JSON input should be string, bytes or bytearray": "JSON input should be string, bytes or bytearray", + "Cannot check `{}` when validating from json, use a JsonOrPython validator instead": "Cannot check `{}` when validating from json, use a JsonOrPython validator instead", + "Recursion error - cyclic reference detected": "Recursion error - cyclic reference detected", + "Field required": "Field required", + "Field is frozen": "Field is frozen", + "Instance is frozen": "Instance is frozen", + "Extra inputs are not permitted": "Extra inputs are not permitted", + "Keys should be strings": "Keys should be strings", + "Error extracting attribute: {}": "Error extracting attribute: {}", + "Input should be a valid dictionary or instance of {}": "Input should be a valid dictionary or instance of {}", + "Input should be a valid dictionary or object to extract fields from": "Input should be a valid dictionary or object to extract fields from", + "Input should be a dictionary or an instance of {}": "Input should be a dictionary or an instance of {}", + "Input should be an instance of {}": "Input should be an instance of {}", + "Input should be None": "Input should be None", + "Input should be greater than {}": "Input should be greater than {}", + "Input should be greater than or equal to {}": "Input should be greater than or equal to {}", + "Input should be less than {}": "Input should be less than {}", + "Input should be less than or equal to {}": "Input should be less than or equal to {}", + "Input should be a multiple of {}": "Input should be a multiple of {}", + "Input should be a finite number": "Input should be a finite number", + "Input should be iterable": "Input should be iterable", + "Error iterating over object, error: {}": "Error iterating over object, error: {}", + "Input should be a valid string": "Input should be a valid string", + "Input should be a string, not an instance of a subclass of str": "Input should be a string, not an instance of a subclass of str", + "Input should be a valid string, unable to parse raw data as a unicode string": "Input should be a valid string, unable to parse raw data as a unicode string", + "String should have at least {}": "String should have at least {}", + "String should have at most {}": "String should have at most {}", + "String should match pattern '{}'": "String should match pattern '{}'", + "Input should be {}": "Input should be {}", + "Input should be a valid dictionary": "Input should be a valid dictionary", + "Input should be a valid mapping, error: {}": "Input should be a valid mapping, error: {}", + "Input should be a valid list": "Input should be a valid list", + "Input should be a valid tuple": "Input should be a valid tuple", + "Input should be a valid set": "Input should be a valid set", + "Input should be a valid boolean": "Input should be a valid boolean", + "Input should be a valid boolean, unable to interpret input": "Input should be a valid boolean, unable to interpret input", + "Input should be a valid integer": "Input should be a valid integer", + "Input should be a valid integer, unable to parse string as an integer": "Input should be a valid integer, unable to parse string as an integer", + "Unable to parse input string as an integer, exceeded maximum size": "Unable to parse input string as an integer, exceeded maximum size", + "Input should be a valid integer, got a number with a fractional part": "Input should be a valid integer, got a number with a fractional part", + "Input should be a valid number": "Input should be a valid number", + "Input should be a valid number, unable to parse string as a number": "Input should be a valid number, unable to parse string as a number", + "Input should be a valid bytes": "Input should be a valid bytes", + "Data should have at least {}": "Data should have at least {}", + "Data should have at most {}": "Data should have at most {}", + "Data should be valid {}": "Data should be valid {}", + "Value error, {}": "Value error, {}", + "Assertion failed, {}": "Assertion failed, {}", + "Input should be a valid date": "Input should be a valid date", + "Input should be a valid date in the format YYYY-MM-DD, {}": "Input should be a valid date in the format YYYY-MM-DD, {}", + "Input should be a valid date or datetime, {}": "Input should be a valid date or datetime, {}", + "Datetimes provided to dates should have zero time - e.g. be exact dates": "Datetimes provided to dates should have zero time - e.g. be exact dates", + "Date should be in the past": "Date should be in the past", + "Date should be in the future": "Date should be in the future", + "Input should be a valid time": "Input should be a valid time", + "Input should be in a valid time format, {}": "Input should be in a valid time format, {}", + "Input should be a valid datetime": "Input should be a valid datetime", + "Input should be a valid datetime, {}": "Input should be a valid datetime, {}", + "Invalid datetime object, got {}": "Invalid datetime object, got {}", + "Input should be a valid datetime or date, {}": "Input should be a valid datetime or date, {}", + "Input should be in the past": "Input should be in the past", + "Input should be in the future": "Input should be in the future", + "Input should not have timezone info": "Input should not have timezone info", + "Input should have timezone info": "Input should have timezone info", + "Timezone offset of {}": "Timezone offset of {}", + "Input should be a valid timedelta": "Input should be a valid timedelta", + "Input should be a valid timedelta, {}": "Input should be a valid timedelta, {}", + "Input should be a valid frozenset": "Input should be a valid frozenset", + "Input should be a subclass of {}": "Input should be a subclass of {}", + "Input should be callable": "Input should be callable", + "Input tag '{}": "Input tag '{}", + "Unable to extract tag using discriminator {}": "Unable to extract tag using discriminator {}", + "Arguments must be a tuple, list or a dictionary": "Arguments must be a tuple, list or a dictionary", + "Missing required argument": "Missing required argument", + "Unexpected keyword argument": "Unexpected keyword argument", + "Missing required keyword only argument": "Missing required keyword only argument", + "Unexpected positional argument": "Unexpected positional argument", + "Missing required positional only argument": "Missing required positional only argument", + "Got multiple values for argument": "Got multiple values for argument", + "URL input should be a string or URL": "URL input should be a string or URL", + "Input should be a valid URL, {}": "Input should be a valid URL, {}", + "Input violated strict URL syntax rules, {}": "Input violated strict URL syntax rules, {}", + "URL should have at most {}": "URL should have at most {}", + "URL scheme should be {}": "URL scheme should be {}", + "UUID input should be a string, bytes or UUID object": "UUID input should be a string, bytes or UUID object", + "Input should be a valid UUID, {}": "Input should be a valid UUID, {}", + "UUID version {} expected": "UUID version {} expected", + "Decimal input should be an integer, float, string or Decimal object": "Decimal input should be an integer, float, string or Decimal object", + "Input should be a valid decimal": "Input should be a valid decimal", + "Decimal input should have no more than {} in total": "Decimal input should have no more than {} in total", + "Decimal input should have no more than {}": "Decimal input should have no more than {}", + "Decimal input should have no more than {} before the decimal point": "Decimal input should have no more than {} before the decimal point", + "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex": "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", + "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex": "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex" +} diff --git a/server/app/component/pydantic/translations/zh_CN.json b/server/app/component/pydantic/translations/zh_CN.json new file mode 100644 index 000000000..b06af4928 --- /dev/null +++ b/server/app/component/pydantic/translations/zh_CN.json @@ -0,0 +1,98 @@ +{ + "Object has no attribute '{}'": "对象没有属性'{}'", + "Invalid JSON: {}": "无效的 JSON:{}", + "JSON input should be string, bytes or bytearray": "JSON 输入应为字符串、字节或字节数组", + "Cannot check `{}` when validating from json, use a JsonOrPython validator instead": "在从 JSON 验证时,无法检查`{}`,请改用 JsonOrPython 验证器", + "Recursion error - cyclic reference detected": "递归错误 - 检测到循环引用", + "Field required": "字段必填", + "Field is frozen": "字段已冻结", + "Instance is frozen": "实例已冻结", + "Extra inputs are not permitted": "不允许额外输入", + "Keys should be strings": "键应为字符串", + "Error extracting attribute: {}": "提取属性时出错:{}", + "Input should be a valid dictionary or instance of {}": "输入应为有效的字典或{}的实例", + "Input should be a valid dictionary or object to extract fields from": "输入应为有效的字典或可用于提取字段的对象", + "Input should be a dictionary or an instance of {}": "输入应为字典或{}的实例", + "Input should be an instance of {}": "输入应为{}的实例", + "Input should be None": "输入应为 None", + "Input should be greater than {}": "输入应大于{}", + "Input should be greater than or equal to {}": "输入应大于或等于{}", + "Input should be less than {}": "输入应小于{}", + "Input should be less than or equal to {}": "输入应小于或等于{}", + "Input should be a multiple of {}": "输入应为{}的倍数", + "Input should be a finite number": "输入应为有限数字", + "Input should be iterable": "输入应为可迭代对象", + "Error iterating over object, error: {}": "迭代对象时出错,错误:{}", + "Input should be a valid string": "输入应为有效字符串", + "Input should be a string, not an instance of a subclass of str": "输入应为字符串,而不是 str 的子类实例", + "Input should be a valid string, unable to parse raw data as a unicode string": "输入应为有效字符串,无法将原始数据解析为 Unicode 字符串", + "String should have at least {}": "字符串应至少有{}", + "String should have at most {}": "字符串应最多有{}", + "String should match pattern '{}'": "字符串应匹配模式'{}'", + "Input should be {}": "输入应为{}", + "Input should be a valid dictionary": "输入应为有效的字典", + "Input should be a valid mapping, error: {}": "输入应为有效的映射,错误:{}", + "Input should be a valid list": "输入应为有效的列表", + "Input should be a valid tuple": "输入应为有效的元组", + "Input should be a valid set": "输入应为有效的集合", + "Input should be a valid boolean": "输入应为有效的布尔值", + "Input should be a valid boolean, unable to interpret input": "输入应为有效的布尔值,无法解析输入", + "Input should be a valid integer": "输入应为有效的整数", + "Input should be a valid integer, unable to parse string as an integer": "输入应为有效的整数,无法将字符串解析为整数", + "Unable to parse input string as an integer, exceeded maximum size": "无法将输入字符串解析为整数,超出最大尺寸", + "Input should be a valid integer, got a number with a fractional part": "输入应为有效的整数,但输入的数字有小数部分", + "Input should be a valid number": "输入应为有效的数字", + "Input should be a valid number, unable to parse string as a number": "输入应为有效的数字,无法将字符串解析为数字", + "Input should be a valid bytes": "输入应为有效的字节", + "Data should have at least {}": "数据应至少有{}", + "Data should have at most {}": "数据应最多有{}", + "Data should be valid {}": "数据应为有效的{}", + "Value error, {}": "值错误,{}", + "Assertion failed, {}": "断言失败,{}", + "Input should be a valid date": "输入应为有效的日期", + "Input should be a valid date in the format YYYY-MM-DD, {}": "输入应为有效的日期,格式为 YYYY-MM-DD,{}", + "Input should be a valid date or datetime, {}": "输入应为有效的日期或日期时间,{}", + "Datetimes provided to dates should have zero time - e.g. be exact dates": "提供给日期的日期时间应为零时间,即为精确日期", + "Date should be in the past": "日期应为过去的日期", + "Date should be in the future": "日期应为未来的日期", + "Input should be a valid time": "输入应为有效的时间", + "Input should be in a valid time format, {}": "输入应为有效的时间格式,{}", + "Input should be a valid datetime": "输入应为有效的日期时间", + "Input should be a valid datetime, {}": "输入应为有效的日期时间,{}", + "Invalid datetime object, got {}": "无效的日期时间对象,得到{}", + "Input should be a valid datetime or date, {}": "输入应为有效的日期时间或日期,{}", + "Input should be in the past": "输入应为过去的日期/时间", + "Input should be in the future": "输入应为未来的日期/时间", + "Input should not have timezone info": "输入不应包含时区信息", + "Input should have timezone info": "输入应包含时区信息", + "Timezone offset of {}": "时区偏移量为{}", + "Input should be a valid timedelta": "输入应为有效的 timedelta", + "Input should be a valid timedelta, {}": "输入应为有效的 timedelta,{}", + "Input should be a valid frozenset": "输入应为有效的 frozenset", + "Input should be a subclass of {}": "输入应为{}的子类", + "Input should be callable": "输入应为可调用对象", + "Input tag '{}": "输入标签'{}'", + "Unable to extract tag using discriminator {}": "无法使用区分符{}提取标签", + "Arguments must be a tuple, list or a dictionary": "参数必须是元组、列表或字典", + "Missing required argument": "缺少必需的参数", + "Unexpected keyword argument": "意外的关键字参数", + "Missing required keyword only argument": "缺少必需的关键字参数", + "Unexpected positional argument": "意外的位置参数", + "Missing required positional only argument": "缺少必需的位置参数", + "Got multiple values for argument": "为参数提供了多个值", + "URL input should be a string or URL": "URL 输入应为字符串或 URL", + "Input should be a valid URL, {}": "输入应为有效的 URL,{}", + "Input violated strict URL syntax rules, {}": "输入违反了严格的 URL 语法规则,{}", + "URL should have at most {}": "URL 应最多有{}", + "URL scheme should be {}": "URL 的协议应为{}", + "UUID input should be a string, bytes or UUID object": "UUID 输入应为字符串、字节或 UUID 对象", + "Input should be a valid UUID, {}": "输入应为有效的 UUID,{}", + "UUID version {} expected": "期望的 UUID 版本为{}", + "Decimal input should be an integer, float, string or Decimal object": "十进制输入应为整数、浮点数、字符串或 Decimal 对象", + "Input should be a valid decimal": "输入应为有效的十进制数", + "Decimal input should have no more than {} in total": "十进制输入的总位数不应超过{}", + "Decimal input should have no more than {}": "十进制输入不应超过{}", + "Decimal input should have no more than {} before the decimal point": "十进制输入的小数点前不应超过{}位", + "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex": "输入应为有效的 Python 复杂对象、数字,或遵循 https://docs.python.org/3/library/functions.html#complex 规则的有效复杂字符串", + "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex": "输入应为遵循 https://docs.python.org/3/library/functions.html#complex 规则的有效复杂字符串" +} diff --git a/server/app/component/sqids.py b/server/app/component/sqids.py new file mode 100644 index 000000000..7b7e02433 --- /dev/null +++ b/server/app/component/sqids.py @@ -0,0 +1,11 @@ +from sqids import Sqids + +sqids = Sqids(min_length=10) + + +def encode_user_id(user_id: int) -> str: + return sqids.encode([user_id]) + + +def decode_user_id(user_id: str) -> int: + return sqids.decode(user_id) diff --git a/server/app/component/stack_auth.py b/server/app/component/stack_auth.py new file mode 100644 index 000000000..086b0f36f --- /dev/null +++ b/server/app/component/stack_auth.py @@ -0,0 +1,58 @@ +import asyncio +import httpx +from app.component.environment import env_not_empty +import jwt +from app.exception.exception import UserException +from app.component import code + + +class StackAuth: + _signing_key_cache = {} + + @staticmethod + async def user_id(token: str): + header = jwt.get_unverified_header(token) + kid = header.get("kid") + if not kid: + raise jwt.InvalidTokenError("Token is missing 'kid' in header") + + signed = await StackAuth.stack_signing_key(kid) + payload = jwt.decode( + token, + signed.key, + algorithms=["ES256"], + audience=env_not_empty("stack_project_id"), + # issuer="https://access-token.jwt-signature.stack-auth.com", + ) + return payload["sub"] + + @staticmethod + async def user_info(token: str): + headers = { + "X-Stack-Access-Type": "server", + "X-Stack-Project-Id": env_not_empty("stack_project_id"), + "X-Stack-Secret-Server-Key": env_not_empty("stack_secret_server_key"), + "X-Stack-Access-Token": token, + } + url = "https://api.stack-auth.com/api/v1/users/me" + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers) + return response.json() + + @staticmethod + async def stack_signing_key(kid: str): + if kid in StackAuth._signing_key_cache: + return StackAuth._signing_key_cache[kid] + + jwks_endpoint = ( + f"https://api.stack-auth.com/api/v1/projects/{env_not_empty('stack_project_id')}/.well-known/jwks.json" + ) + loop = asyncio.get_running_loop() + jwks_client = jwt.PyJWKClient(jwks_endpoint) + + try: + signing_key = await loop.run_in_executor(None, jwks_client.get_signing_key, kid) + StackAuth._signing_key_cache[kid] = signing_key + return signing_key + except jwt.exceptions.PyJWKClientError as e: + raise UserException(code.token_invalid, str(e)) diff --git a/server/app/component/time_friendly.py b/server/app/component/time_friendly.py new file mode 100644 index 000000000..3be88729b --- /dev/null +++ b/server/app/component/time_friendly.py @@ -0,0 +1,24 @@ +from datetime import datetime, timedelta + +import arrow + + +def to_date(time: str, format: str | None = None): + try: + if format: + return arrow.get(time, format).date() + else: + return arrow.get(time).date() + except Exception as e: + return None + + +def monday_start_time() -> datetime: + # 获取当前时间 + now = datetime.now() + # 计算今天是本周的第几天(星期一是0,星期天是6) + weekday = now.weekday() + # 计算本周一的日期 + monday = now - timedelta(days=weekday) + # 设置时间为 0 点 + return monday.replace(hour=0, minute=0, second=0, microsecond=0) diff --git a/server/app/component/validator/McpServer.py b/server/app/component/validator/McpServer.py new file mode 100644 index 000000000..d23beb020 --- /dev/null +++ b/server/app/component/validator/McpServer.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel, ValidationError, field_validator +from typing import Dict, List, Optional + + +class McpServerItem(BaseModel): + command: str + args: List[str] + env: Optional[Dict[str, str]] = None + + +class McpServersModel(BaseModel): + mcpServers: Dict[str, McpServerItem] + + +class McpRemoteServer(BaseModel): + server_name: str + server_url: str + + +def validate_mcp_servers(data: dict): + try: + model = McpServersModel.model_validate(data) + return True, model + except ValidationError as e: + return False, e.errors() + + +def validate_mcp_remote_servers(data: dict): + try: + model = McpRemoteServer.model_validate(data) + return True, model + except ValidationError as e: + return False, e.errors() diff --git a/server/app/controller/__init__.py b/server/app/controller/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/app/controller/chat/history_controller.py b/server/app/controller/chat/history_controller.py new file mode 100644 index 000000000..a3ab6a220 --- /dev/null +++ b/server/app/controller/chat/history_controller.py @@ -0,0 +1,52 @@ +from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from app.model.chat.chat_history import ChatHistoryOut, ChatHistoryIn, ChatHistory, ChatHistoryUpdate +from fastapi_babel import _ +from sqlmodel import Session, select, desc +from app.component.auth import Auth, auth_must +from app.component.database import session + +router = APIRouter(prefix="/chat", tags=["Chat History"]) + + +@router.post("/history", name="save chat history", response_model=ChatHistoryOut) +def create_chat_history(data: ChatHistoryIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + data.user_id = auth.user.id + chat_history = ChatHistory(**data.model_dump()) + session.add(chat_history) + session.commit() + session.refresh(chat_history) + return chat_history + + +@router.get("/histories", name="get chat history") +def list_chat_history(session: Session = Depends(session), auth: Auth = Depends(auth_must)) -> Page[ChatHistoryOut]: + stmt = select(ChatHistory).where(ChatHistory.user_id == auth.user.id).order_by(desc(ChatHistory.created_at)) + return paginate(session, stmt) + + +@router.delete("/history/{history_id}", name="delete chat history") +def delete_chat_history(history_id: str, session: Session = Depends(session)): + history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first() + if not history: + raise HTTPException(status_code=404, detail="Caht History not found") + session.delete(history) + session.commit() + return Response(status_code=204) + + +@router.put("/history/{history_id}", name="update chat history", response_model=ChatHistoryOut) +def update_chat_history( + history_id: int, data: ChatHistoryUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + history = session.exec(select(ChatHistory).where(ChatHistory.id == history_id)).first() + if not history: + raise HTTPException(status_code=404, detail="Chat History not found") + if history.user_id != auth.user.id: + raise HTTPException(status_code=403, detail="You are not allowed to update this chat history") + update_data = data.model_dump(exclude_unset=True) + history.update_fields(update_data) + history.save(session) + session.refresh(history) + return history diff --git a/server/app/controller/chat/share_controller.py b/server/app/controller/chat/share_controller.py new file mode 100644 index 000000000..17a41ae96 --- /dev/null +++ b/server/app/controller/chat/share_controller.py @@ -0,0 +1,78 @@ +from fastapi import APIRouter, Depends, HTTPException, Response +from sqlmodel import Session, asc, select +from app.component.database import session +import json +import asyncio +from itsdangerous import SignatureExpired, BadTimeSignature +from starlette.responses import StreamingResponse +from app.model.chat.chat_share import ChatHistoryShareOut, ChatShare, ChatShareIn +from app.model.chat.chat_step import ChatStep +from app.model.chat.chat_history import ChatHistory + +router = APIRouter(prefix="/chat", tags=["Chat Share"]) + + +@router.get("/share/info/{token}", name="Get shared chat info", response_model=ChatHistoryShareOut) +def get_share_info(token: str, session: Session = Depends(session)): + """ + Get shared chat history info by token, excluding sensitive data. + """ + try: + task_id = ChatShare.verify_token(token, False) + except (SignatureExpired, BadTimeSignature): + raise HTTPException(status_code=400, detail="Share link is invalid or has expired.") + + stmt = select(ChatHistory).where(ChatHistory.task_id == task_id) + history = session.exec(stmt).one_or_none() + + if not history: + raise HTTPException(status_code=404, detail="Chat history not found.") + + return history + + +@router.get("/share/playback/{token}", name="Playback shared chat via SSE") +async def share_playback(token: str, session: Session = Depends(session), delay_time: float = 0): + """ + Playbacks the chat history via a sharing token (SSE). + delay_time: control sse interval, max 5 seconds + """ + if delay_time > 5: + delay_time = 5 + try: + task_id = ChatShare.verify_token(token, False) + except SignatureExpired: + raise HTTPException(status_code=400, detail="Share link has expired.") + except BadTimeSignature: + raise HTTPException(status_code=400, detail="Share link is invalid.") + + async def event_generator(): + stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id)) + steps = session.exec(stmt).all() + + if not steps: + yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" + return + + for step in steps: + step_data = { + "id": step.id, + "task_id": step.task_id, + "step": step.step, + "data": step.data, + "created_at": step.created_at.isoformat() if step.created_at else None, + } + yield f"data: {json.dumps(step_data)}\n\n" + if delay_time > 0 and step.step != "create_agent": + await asyncio.sleep(delay_time) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@router.post("/share", name="Generate sharable link for a task(1 day expiration)") +def create_share_link(data: ChatShareIn): + """ + Generates a sharing token with an expiration time for the specified task_id. + """ + share_token = ChatShare.generate_token(data.task_id) + return {"share_token": share_token} diff --git a/server/app/controller/chat/snapshot_controller.py b/server/app/controller/chat/snapshot_controller.py new file mode 100644 index 000000000..5767746bd --- /dev/null +++ b/server/app/controller/chat/snapshot_controller.py @@ -0,0 +1,81 @@ +from app.model.chat.chat_snpshot import ChatSnapshot, ChatSnapshotIn +from typing import List, Optional +from fastapi import Depends, HTTPException, Response, APIRouter +from sqlmodel import Session, select +from app.component.database import session +from app.component.auth import Auth, auth_must +from fastapi_babel import _ + +router = APIRouter(prefix="/chat", tags=["Chat Snapshot Management"]) + + +@router.get("/snapshots", name="list chat snapshots", response_model=List[ChatSnapshot]) +async def list_chat_snapshots( + api_task_id: Optional[str] = None, + camel_task_id: Optional[str] = None, + browser_url: Optional[str] = None, + session: Session = Depends(session), +): + query = select(ChatSnapshot) + if api_task_id is not None: + query = query.where(ChatSnapshot.api_task_id == api_task_id) + if camel_task_id is not None: + query = query.where(ChatSnapshot.camel_task_id == camel_task_id) + if browser_url is not None: + query = query.where(ChatSnapshot.browser_url == browser_url) + snapshots = session.exec(query).all() + return snapshots + + +@router.get("/snapshots/{snapshot_id}", name="get chat snapshot", response_model=ChatSnapshot) +async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + snapshot = session.get(ChatSnapshot, snapshot_id) + if not snapshot: + raise HTTPException(status_code=404, detail=_("Chat snapshot not found")) + return snapshot + + +@router.post("/snapshots", name="create chat snapshot", response_model=ChatSnapshot) +async def create_chat_snapshot( + snapshot: ChatSnapshotIn, auth: Auth = Depends(auth_must), session: Session = Depends(session) +): + image_path = ChatSnapshotIn.save_image(auth.user.id, snapshot.api_task_id, snapshot.image_base64) + chat_snapshot = ChatSnapshot( + user_id=auth.user.id, + api_task_id=snapshot.api_task_id, + camel_task_id=snapshot.camel_task_id, + browser_url=snapshot.browser_url, + image_path=image_path, + ) + session.add(chat_snapshot) + session.commit() + session.refresh(chat_snapshot) + return Response(status_code=200) + + +@router.put("/snapshots/{snapshot_id}", name="update chat snapshot", response_model=ChatSnapshot) +async def update_chat_snapshot( + snapshot_id: int, + snapshot_update: ChatSnapshot, + session: Session = Depends(session), + auth: Auth = Depends(auth_must), +): + db_snapshot = session.get(ChatSnapshot, snapshot_id) + if not db_snapshot: + raise HTTPException(status_code=404, detail=_("Chat snapshot not found")) + for key, value in snapshot_update.dict(exclude_unset=True).items(): + setattr(db_snapshot, key, value) + session.add(db_snapshot) + session.commit() + session.refresh(db_snapshot) + return db_snapshot + + +@router.delete("/snapshots/{snapshot_id}", name="delete chat snapshot") +async def delete_chat_snapshot(snapshot_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + db_snapshot = session.get(ChatSnapshot, snapshot_id) + if not db_snapshot: + raise HTTPException(status_code=404, detail=_("Chat snapshot not found")) + session.delete(db_snapshot) + session.commit() + return Response(status_code=204) diff --git a/server/app/controller/chat/step_controller.py b/server/app/controller/chat/step_controller.py new file mode 100644 index 000000000..c33639112 --- /dev/null +++ b/server/app/controller/chat/step_controller.py @@ -0,0 +1,105 @@ +import asyncio +import json +from typing import List, Optional +from fastapi import Depends, HTTPException, Query, Response, APIRouter +from fastapi.responses import StreamingResponse +from sqlmodel import Session, asc, select +from app.component.database import session +from app.component.auth import Auth, auth_must +from fastapi_babel import _ +from app.model.chat.chat_step import ChatStep, ChatStepOut, ChatStepIn + +router = APIRouter(prefix="/chat", tags=["Chat Step Management"]) + + +@router.get("/steps", name="list chat steps", response_model=List[ChatStepOut]) +async def list_chat_steps( + task_id: str, step: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + query = select(ChatStep) + if task_id is not None: + query = query.where(ChatStep.task_id == task_id) + if step is not None: + query = query.where(ChatStep.step == step) + chat_steps = session.exec(query).all() + return chat_steps + + +@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE") +async def share_playback( + task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + """ + Playbacks the chat steps (SSE). + """ + if delay_time > 5: + delay_time = 5 + + async def event_generator(): + stmt = select(ChatStep).where(ChatStep.task_id == task_id).order_by(asc(ChatStep.id)) + steps = session.exec(stmt).all() + + if not steps: + yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" + return + + for step in steps: + step_data = { + "id": step.id, + "task_id": step.task_id, + "step": step.step, + "data": step.data, + "created_at": step.created_at.isoformat() if step.created_at else None, + } + yield f"data: {json.dumps(step_data)}\n\n" + if delay_time > 0: + await asyncio.sleep(delay_time) + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut) +async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + chat_step = session.get(ChatStep, step_id) + if not chat_step: + raise HTTPException(status_code=404, detail=_("Chat step not found")) + return chat_step + + +@router.post("/steps", name="create chat step") +# TODO Limit request sources +async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)): + chat_step = ChatStep( + task_id=step.task_id, + step=step.step, + data=step.data, + ) + session.add(chat_step) + session.commit() + session.refresh(chat_step) + return {"code": 200, "msg": "success"} + + +@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut) +async def update_chat_step( + step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + db_chat_step = session.get(ChatStep, step_id) + if not db_chat_step: + raise HTTPException(status_code=404, detail=_("Chat step not found")) + for key, value in chat_step_update.dict(exclude_unset=True).items(): + setattr(db_chat_step, key, value) + session.add(db_chat_step) + session.commit() + session.refresh(db_chat_step) + return db_chat_step + + +@router.delete("/steps/{step_id}", name="delete chat step") +async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + db_chat_step = session.get(ChatStep, step_id) + if not db_chat_step: + raise HTTPException(status_code=404, detail=_("Chat step not found")) + session.delete(db_chat_step) + session.commit() + return Response(status_code=204) diff --git a/server/app/controller/config/config_controller.py b/server/app/controller/config/config_controller.py new file mode 100644 index 000000000..51ee87d8d --- /dev/null +++ b/server/app/controller/config/config_controller.py @@ -0,0 +1,121 @@ +from typing import List, Optional +from fastapi import Depends, HTTPException, Query, Response, APIRouter +from sqlmodel import Session, select, or_ +from app.component.database import session +from app.component.auth import Auth, auth_must +from fastapi_babel import _ +from app.model.config.config import Config, ConfigCreate, ConfigUpdate, ConfigInfo, ConfigOut + +router = APIRouter(tags=["Config Management"]) + + +@router.get("/configs", name="list configs", response_model=list[ConfigOut]) +async def list_configs( + config_group: Optional[str] = None, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + query = select(Config) + user_id = auth.user.id + if user_id is not None: + query = query.where(Config.user_id == user_id) + if config_group is not None: + query = query.where(Config.config_group == config_group) + configs = session.exec(query).all() + return configs + + +@router.get("/configs/{config_id}", name="get config", response_model=ConfigOut) +async def get_config( + config_id: int, + session: Session = Depends(session), + auth: Auth = Depends(auth_must), +): + query = select(Config).where(Config.user_id == auth.user.id) + + if config_id is not None: + query = query.where(Config.id == config_id) + + config = session.exec(query).first() + + if not config: + raise HTTPException(status_code=404, detail=_("Configuration not found")) + return config + + +@router.post("/configs", name="create config", response_model=ConfigOut) +async def create_config(config: ConfigCreate, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + if not ConfigInfo.is_valid_env_var(config.config_group, config.config_name): + raise HTTPException(status_code=400, detail=_("Config Name is valid")) + + # Check if configuration already exists + existing_config = session.exec( + select(Config).where(Config.user_id == auth.user.id, Config.config_name == config.config_name) + ).first() + + if existing_config: + raise HTTPException(status_code=400, detail=_("Configuration already exists for this user")) + + db_config = Config( + user_id=auth.user.id, + config_name=config.config_name, + config_value=config.config_value, + config_group=config.config_group, + ) + session.add(db_config) + session.commit() + session.refresh(db_config) + return db_config + + +@router.put("/configs/{config_id}", name="update config", response_model=ConfigOut) +async def update_config( + config_id: int, config_update: ConfigUpdate, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == auth.user.id)).first() + + if not db_config: + raise HTTPException(status_code=404, detail=_("Configuration not found")) + + # Check if configuration group is valid + if not ConfigInfo.is_valid_env_var(config_update.config_group, config_update.config_name): + raise HTTPException(status_code=400, detail=_("Invalid configuration group")) + + # Check for conflicts with other configurations + existing_config = session.exec( + select(Config).where( + Config.user_id == auth.user.id, + Config.config_name == config_update.config_name, + Config.id != config_id, + ) + ).first() + + if existing_config: + raise HTTPException(status_code=400, detail=_("Configuration already exists for this user")) + + db_config.config_name = config_update.config_name + db_config.config_value = config_update.config_value + + session.add(db_config) + session.commit() + session.refresh(db_config) + return db_config + + +@router.delete("/configs/{config_id}", name="delete config") +async def delete_config(config_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + db_config = session.exec(select(Config).where(Config.id == config_id, Config.user_id == auth.user.id)).first() + + if not db_config: + raise HTTPException(status_code=404, detail=_("Configuration not found")) + session.delete(db_config) + session.commit() + return Response(status_code=204) + + +@router.get("/config/info", name="get config info") +async def get_config_info( + show_all: bool = Query(False, description="Show all config info, including those with empty env_vars"), +): + configs = ConfigInfo.getinfo() + if show_all: + return configs + return {k: v for k, v in configs.items() if v.get("env_vars") and len(v["env_vars"]) > 0} diff --git a/server/app/controller/mcp/category_controller.py b/server/app/controller/mcp/category_controller.py new file mode 100644 index 000000000..8029f84c3 --- /dev/null +++ b/server/app/controller/mcp/category_controller.py @@ -0,0 +1,15 @@ +from typing import Annotated +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session, col, select + +from app.component.database import session +from app.model.mcp.category import Category, CategoryOut + + +router = APIRouter(prefix="/mcp", tags=["Mcp Category"]) + + +@router.get("/categories", name="category list", response_model=list[CategoryOut]) +def gets(session: Session = Depends(session)): + stmt = select(Category).where(Category.no_delete()).order_by(col(Category.priority).asc()) + return session.exec(stmt) diff --git a/server/app/controller/mcp/mcp_controller.py b/server/app/controller/mcp/mcp_controller.py new file mode 100644 index 000000000..1dd5ac5ef --- /dev/null +++ b/server/app/controller/mcp/mcp_controller.py @@ -0,0 +1,132 @@ +from typing import Dict +from fastapi import Depends, HTTPException, APIRouter +from fastapi_babel import _ +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from sqlmodel import Session, col, select +from sqlalchemy.orm import selectinload, with_loader_criteria +from app.component.auth import Auth, auth_must +from app.component.database import session +from app.model.mcp.mcp import Mcp, McpOut, McpType +from app.model.mcp.mcp_env import McpEnv, Status as McpEnvStatus +from app.model.mcp.mcp_user import McpImportType, McpUser, Status +from loguru import logger + +from app.component.validator.McpServer import ( + McpRemoteServer, + McpServerItem, + validate_mcp_remote_servers, + validate_mcp_servers, +) + +router = APIRouter(tags=["Mcp Servers"]) + + +@router.get("/mcps", name="mcp list") +async def gets( + keyword: str | None = None, + category_id: int | None = None, + mine: int | None = None, + session: Session = Depends(session), + auth: Auth = Depends(auth_must), +) -> Page[McpOut]: + stmt = ( + select(Mcp) + .where(Mcp.no_delete()) + .options( + selectinload(Mcp.category), + selectinload(Mcp.envs), + with_loader_criteria(McpEnv, col(McpEnv.status) == McpEnvStatus.in_use), + ) + # .order_by(col(Mcp.sort).desc()) + ) + if keyword: + stmt = stmt.where(col(Mcp.key).like(f"%{keyword.lower()}%")) + if category_id: + stmt = stmt.where(Mcp.category_id == category_id) + if mine and auth: + stmt = ( + stmt.join(McpUser) + .where(McpUser.user_id == auth.user.id) + .options( + selectinload(Mcp.mcp_user), + with_loader_criteria(McpUser, col(McpUser.user_id) == auth.user.id), + ) + ) + return paginate(session, stmt) + + +@router.get("/mcp", name="mcp detail", response_model=McpOut) +async def get(id: int, session: Session = Depends(session)): + stmt = select(Mcp).where(Mcp.no_delete(), Mcp.id == id).options(selectinload(Mcp.category), selectinload(Mcp.envs)) + model = session.exec(stmt).one() + return model + + +@router.post("/mcp/install", name="mcp install") +async def install(mcp_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + mcp = session.get_one(Mcp, mcp_id) + if not mcp: + raise HTTPException(status_code=404, detail=_("Mcp not found")) + exists = session.exec(select(McpUser).where(McpUser.mcp_id == mcp.id, McpUser.user_id == auth.user.id)).first() + if exists: + raise HTTPException(status_code=400, detail=_("mcp is installed")) + install_command: dict = mcp.install_command + mcp_user = McpUser( + mcp_id=mcp.id, + user_id=auth.user.id, + mcp_name=mcp.name, + mcp_key=mcp.key, + mcp_desc=mcp.description, + type=mcp.type, + status=Status.enable, + command=install_command["command"], + args=install_command["args"], + env=install_command["env"], + server_url=None, + ) + mcp_user.save() + return mcp_user + + +@router.post("/mcp/import/{mcp_type}", name="mcp import") +async def import_mcp( + mcp_type: McpImportType, mcp_data: dict, session: Session = Depends(session), auth: Auth = Depends(auth_must) +): + logger.debug(mcp_type, mcp_type.value) + + if mcp_type == McpImportType.Local: + is_valid, res = validate_mcp_servers(mcp_data) + if not is_valid: + raise HTTPException(status_code=400, detail=res) + mcp_data: Dict[str, McpServerItem] = res.mcpServers + for name, data in mcp_data.items(): + mcp_user = McpUser( + mcp_id=0, + user_id=auth.user.id, + mcp_name=name, + mcp_key=name, + mcp_desc=name, + type=McpType.Local, + status=Status.enable, + command=data.command, + args=data.args, + env=data.env, + server_url=None, + ) + break + elif mcp_type == McpImportType.Remote: + is_valid, res = validate_mcp_remote_servers(mcp_data) + if not is_valid: + raise HTTPException(status_code=400, detail=res) + data: McpRemoteServer = res + mcp_user = McpUser( + mcp_id=0, + user_id=auth.user.id, + type=McpType.Remote, + status=Status.enable, + mcp_name=data.server_name, + server_url=data.server_url, + ) + mcp_user.save() + return mcp_user diff --git a/server/app/controller/mcp/proxy_controller.py b/server/app/controller/mcp/proxy_controller.py new file mode 100644 index 000000000..aa008229a --- /dev/null +++ b/server/app/controller/mcp/proxy_controller.py @@ -0,0 +1,173 @@ +from fastapi import APIRouter, Depends +from exa_py import Exa +from loguru import logger +from app.component.auth import key_must +from app.component.environment import env_not_empty +from app.model.mcp.proxy import ExaSearch +from typing import Any, cast +import requests + +from app.model.user.key import Key + + +router = APIRouter(prefix="/proxy", tags=["Mcp Servers"]) + + +@router.post("/exa") +def exa_search(search: ExaSearch, key: Key = Depends(key_must)): + EXA_API_KEY = env_not_empty("EXA_API_KEY") + try: + exa = Exa(EXA_API_KEY) + + if search.num_results is not None and not 0 < search.num_results <= 100: + raise ValueError("num_results must be between 1 and 100") + + if search.include_text is not None: + if len(search.include_text) > 1: + raise ValueError("include_text can only contain 1 string") + if len(search.include_text[0].split()) > 5: + raise ValueError("include_text string cannot be longer than 5 words") + + if search.exclude_text is not None: + if len(search.exclude_text) > 1: + raise ValueError("exclude_text can only contain 1 string") + if len(search.exclude_text[0].split()) > 5: + raise ValueError("exclude_text string cannot be longer than 5 words") + + # Call Exa API with direct parameters + if search.text: + results = cast( + dict[str, Any], + exa.search_and_contents( + query=search.query, + type=search.search_type, + category=search.category, + num_results=search.num_results, + include_text=search.include_text, + exclude_text=search.exclude_text, + use_autoprompt=search.use_autoprompt, + text=True, + ), + ) + else: + results = cast( + dict[str, Any], + exa.search( + query=search.query, + type=search.search_type, + category=search.category, + num_results=search.num_results, + include_text=search.include_text, + exclude_text=search.exclude_text, + use_autoprompt=search.use_autoprompt, + ), + ) + + return results + + except Exception as e: + return {"error": f"Exa search failed: {e!s}"} + + +@router.get("/google") +def google_search(query: str, search_type: str = "web", key: Key = Depends(key_must)): + # https://developers.google.com/custom-search/v1/overview + GOOGLE_API_KEY = env_not_empty("GOOGLE_API_KEY") + # https://cse.google.com/cse/all + SEARCH_ENGINE_ID = env_not_empty("SEARCH_ENGINE_ID") + + # Using the first page + start_page_idx = 1 + # Different language may get different result + search_language = "en" + # How many pages to return + num_result_pages = 10 + # Constructing the URL + # Doc: https://developers.google.com/custom-search/v1/using_rest + base_url = ( + f"https://www.googleapis.com/customsearch/v1?" + f"key={GOOGLE_API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}&start=" + f"{start_page_idx}&lr={search_language}&num={num_result_pages}" + ) + + if search_type == "image": + url = base_url + "&searchType=image" + else: + url = base_url + + responses = [] + # Fetch the results given the URL + try: + # Make the get + result = requests.get(url) + data = result.json() + + # Get the result items + if "items" in data: + search_items = data.get("items") + + # Iterate over results found + for i, search_item in enumerate(search_items, start=1): + if search_type == "image": + # Process image search results + title = search_item.get("title") + image_url = search_item.get("link") + display_link = search_item.get("displayLink") + + # Get context URL (page containing the image) + image_info = search_item.get("image", {}) + context_url = image_info.get("contextLink", "") + + # Get image dimensions if available + width = image_info.get("width") + height = image_info.get("height") + + response = { + "result_id": i, + "title": title, + "image_url": image_url, + "display_link": display_link, + "context_url": context_url, + } + + # Add dimensions if available + if width: + response["width"] = int(width) + if height: + response["height"] = int(height) + + responses.append(response) + else: + # Process web search results (existing logic) + # Check metatags are present + if "pagemap" not in search_item: + continue + if "metatags" not in search_item["pagemap"]: + continue + if "og:description" in search_item["pagemap"]["metatags"][0]: + long_description = search_item["pagemap"]["metatags"][0]["og:description"] + else: + long_description = "N/A" + # Get the page title + title = search_item.get("title") + # Page snippet + snippet = search_item.get("snippet") + + # Extract the page url + link = search_item.get("link") + response = { + "result_id": i, + "title": title, + "description": snippet, + "long_description": long_description, + "url": link, + } + responses.append(response) + else: + error_info = data.get("error", {}) + logger.error(f"Google search failed - API response: {error_info}") + responses.append({"error": f"Google search failed - API response: {error_info}"}) + + except Exception as e: + responses.append({"error": f"google search failed: {e!s}"}) + return responses diff --git a/server/app/controller/mcp/user_controller.py b/server/app/controller/mcp/user_controller.py new file mode 100644 index 000000000..2beda01f7 --- /dev/null +++ b/server/app/controller/mcp/user_controller.py @@ -0,0 +1,78 @@ +from typing import List, Optional +from fastapi import Depends, HTTPException, Query, Response, APIRouter +from sqlmodel import Session, select +from app.component.database import session +from app.component.auth import Auth, auth_must +from fastapi_babel import _ +from app.model.mcp.mcp_user import McpUser, McpUserIn, McpUserOut, McpUserUpdate, Status +from loguru import logger + +router = APIRouter(tags=["McpUser Management"]) + + +@router.get("/mcp/users", name="list mcp users", response_model=List[McpUserOut]) +async def list_mcp_users( + mcp_id: Optional[int] = None, + session: Session = Depends(session), + auth: Auth = Depends(auth_must), +): + user_id = auth.user.id + query = select(McpUser) + if mcp_id is not None: + query = query.where(McpUser.mcp_id == mcp_id) + if user_id is not None: + query = query.where(McpUser.user_id == user_id) + mcp_users = session.exec(query).all() + return mcp_users + + +@router.get("/mcp/users/{mcp_user_id}", name="get mcp user", response_model=McpUserOut) +async def get_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + query = select(McpUser).where(McpUser.id == mcp_user_id) + mcp_user = session.exec(query).first() + if not mcp_user: + raise HTTPException(status_code=404, detail=_("McpUser not found")) + return mcp_user + + +@router.post("/mcp/users", name="create mcp user", response_model=McpUserOut) +async def create_mcp_user(mcp_user: McpUserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + exists = session.exec( + select(McpUser).where(McpUser.mcp_id == mcp_user.mcp_id, McpUser.user_id == auth.user.id) + ).first() + if exists: + raise HTTPException(status_code=400, detail=_("mcp is installed")) + db_mcp_user = McpUser(mcp_id=mcp_user.mcp_id, user_id=auth.user.id, env=mcp_user.env) + session.add(db_mcp_user) + session.commit() + session.refresh(db_mcp_user) + return db_mcp_user + + +@router.put("/mcp/users/{id}", name="update mcp user") +async def update_mcp_user( + id: int, + update_item: McpUserUpdate, + session: Session = Depends(session), + auth: Auth = Depends(auth_must), +): + model = session.get(McpUser, id) + if not model: + raise HTTPException(status_code=404, detail=_("Mcp Info not found")) + if model.user_id != auth.user.id: + raise HTTPException(status_code=400, detail=_("current user have no permission to modify")) + update_data = update_item.model_dump(exclude_unset=True) + model.update_fields(update_data) + model.save(session) + session.refresh(model) + return model + + +@router.delete("/mcp/users/{mcp_user_id}", name="delete mcp user") +async def delete_mcp_user(mcp_user_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + db_mcp_user = session.get(McpUser, mcp_user_id) + if not db_mcp_user: + raise HTTPException(status_code=404, detail=_("Mcp Info not found")) + session.delete(db_mcp_user) + session.commit() + return Response(status_code=204) diff --git a/server/app/controller/oauth/oauth_controller.py b/server/app/controller/oauth/oauth_controller.py new file mode 100644 index 000000000..be14905ce --- /dev/null +++ b/server/app/controller/oauth/oauth_controller.py @@ -0,0 +1,58 @@ +from fastapi import APIRouter, Request, HTTPException +from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse +from app.component.environment import env +from app.component.oauth_adapter import OauthCallbackPayload, get_oauth_adapter +from typing import Optional + +router = APIRouter(prefix="/oauth", tags=["Oauth Servers"]) + + +@router.get("/{app}/login", name="OAuth Login Redirect") +def oauth_login(app: str, request: Request, state: Optional[str] = None): + try: + callback_url = str(request.url_for("OAuth Callback", app=app)) + if callback_url.startswith("http://"): + callback_url = "https://" + callback_url[len("http://") :] + adapter = get_oauth_adapter(app, callback_url) + url = adapter.get_authorize_url(state) + if not url: + raise HTTPException(status_code=400, detail="Failed to generate authorization URL") + return RedirectResponse(str(url)) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get("/{app}/callback", name="OAuth Callback") +def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None): + if not code: + raise HTTPException(status_code=400, detail="Missing code parameter") + redirect_url = f"eigent://callback/oauth?provider={app}&code={code}&state={state}" + html_content = f""" + + + OAuth Callback + + + +

Redirecting, please wait...

+ + + + """ + return HTMLResponse(content=html_content) + + +@router.post("/{app}/token", name="OAuth Fetch Token") +def fetch_token(app: str, request: Request, data: OauthCallbackPayload): + try: + callback_url = str(request.url_for("OAuth Callback", app=app)) + if callback_url.startswith("http://"): + callback_url = "https://" + callback_url[len("http://") :] + + adapter = get_oauth_adapter(app, callback_url) + token_data = adapter.fetch_token(data.code) + return JSONResponse(token_data) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/server/app/controller/provider/provider_controller.py b/server/app/controller/provider/provider_controller.py new file mode 100644 index 000000000..410f246fc --- /dev/null +++ b/server/app/controller/provider/provider_controller.py @@ -0,0 +1,100 @@ +from typing import List, Optional +from fastapi import Depends, HTTPException, Query, Response, APIRouter +from fastapi_babel import _ +from fastapi_pagination import Page +from fastapi_pagination.ext.sqlmodel import paginate +from sqlalchemy import update +from sqlmodel import Session, select, col +from sqlalchemy.exc import SQLAlchemyError + +from app.component.database import session +from app.component.auth import Auth, auth_must +from app.model.provider.provider import Provider, ProviderIn, ProviderOut, ProviderPreferIn + + +router = APIRouter(tags=["Provider Management"]) + + +@router.get("/providers", name="list providers", response_model=Page[ProviderOut]) +async def gets( + keyword: str | None = None, + prefer: Optional[bool] = Query(None, description="Filter by prefer status"), + session: Session = Depends(session), + auth: Auth = Depends(auth_must), +) -> Page[ProviderOut]: + user_id = auth.user.id + stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete()) + if keyword: + stmt = stmt.where(col(Provider.provider_name).like(f"%{keyword}%")) + if prefer is not None: + stmt = stmt.where(Provider.prefer == prefer) + stmt = stmt.order_by(col(Provider.created_at).desc(), col(Provider.id).desc()) # Added for consistent pagination + return paginate(session, stmt) + + +@router.get("/provider", name="get provider detail", response_model=ProviderOut) +async def get(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + stmt = select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id) + model = session.exec(stmt).one_or_none() + if not model: + raise HTTPException(status_code=404, detail=_("Provider not found")) + return model + + +@router.post("/provider", name="create provider", response_model=ProviderOut) +async def post(data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + model = Provider(**data.model_dump(), user_id=user_id) + model.save(session) + return model + + +@router.put("/provider/{id}", name="update provider", response_model=ProviderOut) +async def put(id: int, data: ProviderIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + model = session.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id) + ).one_or_none() + if not model: + raise HTTPException(status_code=404, detail=_("Provider not found")) + model.model_type = data.model_type + model.provider_name = data.provider_name + model.api_key = data.api_key + model.endpoint_url = data.endpoint_url + model.encrypted_config = data.encrypted_config + model.is_vaild = data.is_vaild + model.save(session) + session.refresh(model) + return model + + +@router.delete("/provider/{id}", name="delete provider") +async def delete(id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + model = session.exec( + select(Provider).where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == id) + ).one_or_none() + if not model: + raise HTTPException(status_code=404, detail=_("Provider not found")) + model.delete(session) + return Response(status_code=204) + + +@router.post("/provider/prefer", name="set provider prefer") +async def set_prefer(data: ProviderPreferIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + try: + # 1. current user's all provider prefer set to false + session.exec(update(Provider).where(Provider.user_id == user_id, Provider.no_delete()).values(prefer=False)) + # 2. set the prefer of the specified provider_id to true + session.exec( + update(Provider) + .where(Provider.user_id == user_id, Provider.no_delete(), Provider.id == data.provider_id) + .values(prefer=True) + ) + session.commit() + return {"success": True} + except SQLAlchemyError as e: + session.rollback() + raise HTTPException(status_code=500, detail=str(e)) diff --git a/server/app/controller/redirect_controller.py b/server/app/controller/redirect_controller.py new file mode 100644 index 000000000..3695a8fb4 --- /dev/null +++ b/server/app/controller/redirect_controller.py @@ -0,0 +1,72 @@ +import json +from fastapi import APIRouter, Depends, Request +from fastapi_babel import _ +from fastapi.responses import HTMLResponse + + +router = APIRouter(tags=["Redirect"]) + + +@router.get("/redirect/callback") +def redirect_callback(code: str, request: Request): + cookies = request.cookies + cookies_json = json.dumps(cookies) + + html_content = f""" + + + + + + Authorization successful + + + +
+

Authorization Successful

+

Redirecting to application...

+
Please wait...
+
+ + + + """ + return HTMLResponse(content=html_content) diff --git a/server/app/controller/user/login_controller.py b/server/app/controller/user/login_controller.py new file mode 100644 index 000000000..908e63da9 --- /dev/null +++ b/server/app/controller/user/login_controller.py @@ -0,0 +1,90 @@ +from fastapi import APIRouter, Depends, HTTPException +from fastapi_babel import _ +from sqlmodel import Session +from app.component import code +from app.component.auth import Auth +from app.component.database import session +from app.component.encrypt import password_verify +from app.component.stack_auth import StackAuth +from app.exception.exception import UserException +from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User, RegisterIn +from loguru import logger +from app.component.environment import env + + +router = APIRouter(tags=["Login/Registration"]) + + +@router.post("/login", name="login by email or password") +async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse: + """ + User login with email and password + """ + user = User.by(User.email == data.email, s=session).one_or_none() + if not user or not password_verify(data.password, user.password): + raise UserException(code.password, _("Account or password error")) + return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) + + +@router.post("/login-by_stack", name="login by stack") +async def by_stack_auth( + token: str, + type: str = "signup", + invite_code: str | None = None, + session: Session = Depends(session), +): + try: + stack_id = await StackAuth.user_id(token) + info = await StackAuth.user_info(token) + except Exception as e: + logger.error(e) + raise HTTPException(500, detail=_(f"{e}")) + user = User.by(User.stack_id == stack_id, s=session).one_or_none() + + if not user: + # Only signup can create user + if type != "signup": + raise UserException(code.error, _("User not found")) + with session as s: + try: + user = User( + username=info["username"] if "username" in info else None, + nickname=info["display_name"], + email=info["primary_email"], + avatar=info["profile_image_url"], + stack_id=stack_id, + ) + s.add(user) + s.commit() + session.refresh(user) + return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) + except Exception as e: + s.rollback() + logger.error(f"Failed to register: {e}") + raise UserException(code.error, _("Failed to register")) + else: + if user.status == Status.Block: + raise UserException(code.error, _("Your account has been blocked.")) + return LoginResponse(token=Auth.create_access_token(user.id), email=user.email) + + +@router.post("/register", name="register by email/password") +async def register(data: RegisterIn, session: Session = Depends(session)): + # Check if email is already registered + if User.by(User.email == data.email, s=session).one_or_none(): + raise UserException(code.error, _("Email already registered")) + + with session as s: + try: + user = User( + email=data.email, + password=data.password, + ) + s.add(user) + s.commit() + s.refresh(user) + except Exception as e: + s.rollback() + logger.error(f"Failed to register: {e}") + raise UserException(code.error, _("Failed to register")) + return {"status": "success"} diff --git a/server/app/controller/user/user_controller.py b/server/app/controller/user/user_controller.py new file mode 100644 index 000000000..cd8ecc3b9 --- /dev/null +++ b/server/app/controller/user/user_controller.py @@ -0,0 +1,115 @@ +from fastapi import APIRouter, Depends +from sqlalchemy import func +from sqlmodel import Session, select +from app.component.auth import Auth, auth_must +from app.component.database import session +from app.model.user.privacy import UserPrivacy, UserPrivacySettings +from app.model.user.user import User, UserIn, UserOut, UserProfile +from app.model.user.user_stat import UserStat, UserStatActionIn, UserStatOut +from app.model.chat.chat_history import ChatHistory +from app.model.mcp.mcp_user import McpUser +from app.model.config.config import Config +from app.model.chat.chat_snpshot import ChatSnapshot +from app.model.user.user_credits_record import UserCreditsRecord + + +router = APIRouter(tags=["User"]) + + +@router.get("/user", name="user info", response_model=UserOut) +def get(auth: Auth = Depends(auth_must), session: Session = Depends(session)): + # 获取用户信息时触发积分刷新 + user: User = auth.user + user.refresh_credits_on_active(session) + return user + + +@router.put("/user", name="update user info", response_model=UserOut) +def put(data: UserIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + model = auth.user + model.username = data.username + model.save(session) + return model + + +@router.put("/user/profile", name="update user profile", response_model=UserProfile) +def put_profile(data: UserProfile, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + model = auth.user + model.nickname = data.nickname + model.fullname = data.fullname + model.work_desc = data.work_desc + model.save(session) + return model + + +@router.get("/user/privacy", name="get user privacy") +def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id) + model = session.exec(stmt).one_or_none() + + if not model: + return UserPrivacySettings.default_settings() + return model.pricacy_setting + + +@router.put("/user/privacy", name="update user privacy") +def put_privacy(data: UserPrivacySettings, session: Session = Depends(session), auth: Auth = Depends(auth_must)): + user_id = auth.user.id + stmt = select(UserPrivacy).where(UserPrivacy.user_id == user_id) + model = session.exec(stmt).one_or_none() + default_settings = UserPrivacySettings.default_settings() + + if model: + model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()} + model.save(session) + else: + model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()}) + model.save(session) + + return model.pricacy_setting + + +@router.get("/user/current_credits", name="get user current credits") +def get_user_credits(auth: Auth = Depends(auth_must), session: Session = Depends(session)): + user = auth.user + user.refresh_credits_on_active(session) + credits = user.credits + daily_credits: UserCreditsRecord | None = UserCreditsRecord.get_daily_balance(user.id) + current_daily_credits = 0 + if daily_credits: + current_daily_credits = daily_credits.amount - daily_credits.balance + credits += current_daily_credits if current_daily_credits > 0 else 0 + return {"credits": credits, "daily_credits": current_daily_credits} + + +@router.get("/user/stat", name="get user stat", response_model=UserStatOut) +def get_user_stat(auth: Auth = Depends(auth_must), session: Session = Depends(session)): + """Get current user's operation statistics.""" + stat = session.exec(select(UserStat).where(UserStat.user_id == auth.user.id)).first() + data = UserStatOut() + if stat: + data = UserStatOut(**stat.model_dump()) + else: + data = UserStatOut(user_id=auth.user.id) + data.task_queries = ChatHistory.count(ChatHistory.user_id == auth.user.id, s=session) + mcp = McpUser.count(McpUser.user_id == auth.user.id, s=session) + tool: list = session.exec( + select(func.count("*")).where(Config.user_id == auth.user.id).group_by(Config.config_group) + ).all() + tool = tool.__len__() + data.mcp_install_count = mcp + tool + data.storage_used = ChatSnapshot.caclDir(ChatSnapshot.get_user_dir(auth.user.id)) + return data + + +@router.post("/user/stat", name="record user stat") +def record_user_stat( + data: UserStatActionIn, + auth: Auth = Depends(auth_must), + session: Session = Depends(session), +): + """Record or update current user's operation statistics.""" + data.user_id = auth.user.id + stat = UserStat.record_action(session, data) + return stat diff --git a/server/app/controller/user/user_password_controller.py b/server/app/controller/user/user_password_controller.py new file mode 100644 index 000000000..3efd19866 --- /dev/null +++ b/server/app/controller/user/user_password_controller.py @@ -0,0 +1,24 @@ +from fastapi import APIRouter, Depends +from sqlmodel import Session + +from app.component import code +from app.component.auth import Auth, auth_must +from app.component.database import session +from app.component.encrypt import password_hash, password_verify +from app.exception.exception import UserException +from app.model.user.user import UpdatePassword, UserOut +from fastapi_babel import _ + +router = APIRouter(tags=["User"]) + + +@router.put("/user/update-password", name="update password", response_model=UserOut) +def update_password(data: UpdatePassword, auth: Auth = Depends(auth_must), session: Session = Depends(session)): + model = auth.user + if not password_verify(data.password, model.password): + raise UserException(code.error, _("Password is incorrect")) + if data.new_password != data.re_new_password: + raise UserException(code.error, _("The two passwords do not match")) + model.password = password_hash(data.new_password) + model.save(session) + return model diff --git a/server/app/exception/exception.py b/server/app/exception/exception.py new file mode 100644 index 000000000..a67378a4f --- /dev/null +++ b/server/app/exception/exception.py @@ -0,0 +1,20 @@ +class UserException(Exception): + def __init__(self, code: int, description: str): + self.code = code + self.description = description + + +class TokenException(Exception): + def __init__(self, code: int, text: str): + self.code = code + self.text = text + + +class NoPermissionException(Exception): + def __init__(self, text: str): + self.text = text + + +class ProgramException(Exception): + def __init__(self, text: str): + self.text = text diff --git a/server/app/exception/handler.py b/server/app/exception/handler.py new file mode 100644 index 000000000..696ac7a98 --- /dev/null +++ b/server/app/exception/handler.py @@ -0,0 +1,49 @@ +import json +from fastapi import Request +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from app import api +from app.component import code +from app.exception.exception import NoPermissionException, TokenException +from app.component.pydantic.i18n import trans, get_language +from app.exception.exception import UserException +from sqlalchemy.exc import NoResultFound + + +@api.exception_handler(RequestValidationError) +async def request_exception(request: Request, e: RequestValidationError): + + if (lang := get_language(request.headers.get("Accept-Language"))) is None: + lang = "en_US" + return JSONResponse( + content={ + "code": code.form_error, + "error": jsonable_encoder(trans.translate(list(e.errors()), locale=lang)), + } + ) + + +@api.exception_handler(TokenException) +async def token_exception(request: Request, e: TokenException): + return JSONResponse(content={"code": e.code, "text": e.text}) + + +@api.exception_handler(UserException) +async def user_exception(request: Request, e: UserException): + return JSONResponse(content={"code": e.code, "text": e.description}) + + +@api.exception_handler(NoPermissionException) +async def no_permission(request: Request, exception: NoPermissionException): + return JSONResponse( + status_code=200, + content={"code": code.no_permission_error, "text": exception.text}, + ) + + +async def no_results(request: Request, exception: NoResultFound): + return JSONResponse( + status_code=200, + content={"code": code.not_found, "text": exception._message()}, + ) diff --git a/server/app/middleware/__init__.py b/server/app/middleware/__init__.py new file mode 100644 index 000000000..8476e5e2a --- /dev/null +++ b/server/app/middleware/__init__.py @@ -0,0 +1,6 @@ +from app import api +from app.component.babel import babel_configs +from fastapi_babel import BabelMiddleware + + +api.add_middleware(BabelMiddleware, babel_configs=babel_configs) diff --git a/server/app/model/abstract/model.py b/server/app/model/abstract/model.py new file mode 100644 index 000000000..1e63fa7e7 --- /dev/null +++ b/server/app/model/abstract/model.py @@ -0,0 +1,119 @@ +from datetime import datetime +from typing import Any +from sqlalchemy import delete +from sqlmodel import Field, SQLModel, Session, col, func, TIMESTAMP, select, text +from app.component import code +from sqlalchemy.sql.expression import ColumnExpressionArgument +from sqlalchemy.sql.base import ExecutableOption +from sqlalchemy.orm import declared_attr +from fastapi_babel import _ +from app.exception.exception import UserException +from app.component.database import engine +from convert_case import snake_case + + +class AbstractModel(SQLModel): + @declared_attr # type: ignore + def __tablename__(cls) -> str: + return snake_case(cls.__name__) + + @classmethod + def by( + cls, + *whereclause: ColumnExpressionArgument[bool] | bool, + order_by: Any | None = None, + limit: int | None = None, + offset: int | None = None, + options: ExecutableOption | list[ExecutableOption] | None = None, + s: Session, + ): + stmt = select(cls).where(*whereclause) + if order_by is not None: + stmt = stmt.order_by(order_by) + if limit is not None: + stmt = stmt.limit(limit) + if offset is not None: + stmt = stmt.offset(offset) + if options is not None: + stmt = stmt.options(*(options if isinstance(options, list) else [options])) + return s.exec(stmt, execution_options={"prebuffer_rows": True}) + + @classmethod + def exists( + cls, + *whereclause: ColumnExpressionArgument[bool] | bool, + s: Session, + ) -> bool: + res = s.exec(select(func.count("*")).where(*whereclause)).first() + return res is not None and res > 0 + + @classmethod + def count( + cls, + *whereclause: ColumnExpressionArgument[bool] | bool, + s: Session, + ) -> int: + res = s.exec(select(func.count("*")).where(*whereclause)).first() + return res if res is not None else 0 + + @classmethod + def exists_must( + cls, + *whereclause: ColumnExpressionArgument[bool] | bool, + s: Session, + ): + if not cls.exists(*whereclause, s=s): + raise UserException(code.not_found, _("There is no data that meets the conditions")) + + @classmethod + def delete_by( + cls, + *whereclause: ColumnExpressionArgument[bool], + s: Session, + ): + stmt = delete(cls).where(*whereclause) + s.connection().execute(stmt) + s.commit() + + def save(self, s: Session | None = None): + if s is None: + with Session(engine, expire_on_commit=False) as s: + s.add(self) + s.commit() + else: + s.add(self) + s.commit() + + def delete(self, s: Session): + if isinstance(self, DefaultTimes): + self.deleted_at = datetime.now() + self.save(s) + else: + s.delete(self) + s.commit() + + def update_fields(self, update_dict: dict): + for k, v in update_dict.items(): + setattr(self, k, v) + + +class DefaultTimes: + deleted_at: datetime | None = Field(default=None) + created_at: datetime | None = Field( + # 兼容mysql,如果只有数据库的保存的话,保存后,created_at为None,无法立即调用 + default_factory=datetime.now, + sa_type=TIMESTAMP, + sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")}, + ) + updated_at: datetime | None = Field( + default_factory=datetime.now, + sa_type=TIMESTAMP, + sa_column_kwargs={ + "server_default": text("CURRENT_TIMESTAMP"), + "onupdate": func.now(), + }, + ) + + @classmethod + def no_delete(cls): + return col(cls.deleted_at).is_(None) diff --git a/server/app/model/chat/chat_history.py b/server/app/model/chat/chat_history.py new file mode 100644 index 000000000..e351e34ae --- /dev/null +++ b/server/app/model/chat/chat_history.py @@ -0,0 +1,76 @@ +from sqlalchemy import Float, Integer +from sqlmodel import Field, SmallInteger, Column, JSON, String +from typing import Optional +from enum import IntEnum +from sqlalchemy_utils import ChoiceType +from app.model.abstract.model import AbstractModel, DefaultTimes +from pydantic import BaseModel + + +class ChatStatus(IntEnum): + ongoing = 1 + done = 2 + + +class ChatHistory(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(index=True) + task_id: str = Field(index=True, unique=True) + question: str + language: str + model_platform: str + model_type: str + api_key: str + api_url: str = Field(sa_column=Column(String(500))) + max_retries: int = Field(default=3) + file_save_path: Optional[str] = None + installed_mcp: str = Field(sa_type=JSON, default={}) + project_name: str = Field(default="", sa_column=Column(String(128))) + summary: str = Field(default="", sa_column=Column(String(1024))) + tokens: int = Field(default=0, sa_column=(Column(Integer, server_default="0"))) + spend: float = Field(default=0, sa_column=(Column(Float, server_default="0"))) + status: int = Field(default=1, sa_column=Column(ChoiceType(ChatStatus, SmallInteger()))) + + +class ChatHistoryIn(BaseModel): + task_id: str + user_id: int | None = None + question: str + language: str + model_platform: str + model_type: str + api_key: str | None = "" + api_url: str | None = None + max_retries: int = 3 + file_save_path: Optional[str] = None + installed_mcp: Optional[str] = None + project_name: str | None = None + summary: str | None = None + tokens: int = 0 + spend: float = 0 + status: int = ChatStatus.ongoing.value + + +class ChatHistoryOut(BaseModel): + id: int + task_id: str + question: str + language: str + model_platform: str + model_type: str + api_key: Optional[str] = None + api_url: Optional[str] = None + max_retries: int + file_save_path: Optional[str] = None + installed_mcp: Optional[str] = None + project_name: str | None = None + summary: str | None = None + tokens: int + status: int + + +class ChatHistoryUpdate(BaseModel): + project_name: str | None = None + summary: str | None = None + tokens: int | None = None + status: int | None = None diff --git a/server/app/model/chat/chat_share.py b/server/app/model/chat/chat_share.py new file mode 100644 index 000000000..55170245c --- /dev/null +++ b/server/app/model/chat/chat_share.py @@ -0,0 +1,56 @@ +import os +from itsdangerous import URLSafeTimedSerializer +from pydantic import BaseModel + + +class ChatShare: + SECRET_KEY = os.getenv("CHAT_SHARE_SECRET_KEY", "EGB1WRC9xMUVgNoIPH8tLw") + SALT = os.getenv("CHAT_SHARE_SALT", "r4U2M") + # Set expiration to 1 day + EXPIRATION_SECONDS = int(os.getenv("CHAT_SHARE_EXPIRATION_SECONDS", 60 * 60 * 24)) + + @classmethod + def generate_token(cls, task_id: str) -> str: + serializer = URLSafeTimedSerializer(cls.SECRET_KEY) + return serializer.dumps(task_id, salt=cls.SALT) + + @classmethod + def verify_token(cls, token: str, check_expiration: bool = True) -> str: + """ + Verify token and return task_id + + Args: + token: The token to verify + check_expiration: Whether to check token expiration (default: True) + + Returns: + str: The task_id from the token + + Raises: + Exception: If token is invalid or expired (when check_expiration=True) + """ + serializer = URLSafeTimedSerializer(cls.SECRET_KEY) + + if check_expiration: + # Check expiration time + return serializer.loads(token, salt=cls.SALT, max_age=cls.EXPIRATION_SECONDS) + else: + # Don't check expiration time + return serializer.loads(token, salt=cls.SALT) + + +class ChatShareIn(BaseModel): + task_id: str + + +class ChatHistoryShareOut(BaseModel): + question: str + language: str + model_platform: str + model_type: str + max_retries: int + project_name: str | None = None + summary: str | None = None + + class Config: + from_attributes = True diff --git a/server/app/model/chat/chat_snpshot.py b/server/app/model/chat/chat_snpshot.py new file mode 100644 index 000000000..a1cb3a98a --- /dev/null +++ b/server/app/model/chat/chat_snpshot.py @@ -0,0 +1,56 @@ +from typing import Optional +from sqlalchemy import Column, Integer, text +from sqlmodel import Field +from app.model.abstract.model import AbstractModel, DefaultTimes +from pydantic import BaseModel +import os +import base64 +import time + +from app.component.sqids import encode_user_id + + +class ChatSnapshot(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(sa_column=(Column(Integer, server_default=text("0")))) + api_task_id: str = Field(index=True) + camel_task_id: str = Field(index=True) + browser_url: str + image_path: str + + @classmethod + def get_user_dir(cls, user_id: int) -> str: + return os.path.join("app", "public", "upload", encode_user_id(user_id)) + + @classmethod + def caclDir(cls, path: str) -> float: + """Return disk usage of path directory (in MB, rounded to 2 decimal places)""" + total_size = 0 + for dirpath, dirnames, filenames in os.walk(path): + for f in filenames: + fp = os.path.join(dirpath, f) + if os.path.isfile(fp): + total_size += os.path.getsize(fp) + size_mb = total_size / (1024 * 1024) + return round(size_mb, 2) + + +class ChatSnapshotIn(BaseModel): + api_task_id: str + user_id: Optional[int] = None + camel_task_id: str + browser_url: str + image_base64: str + + @staticmethod + def save_image(user_id: int, api_task_id: str, image_base64: str) -> str: + if "," in image_base64: + image_base64 = image_base64.split(",", 1)[1] + user_dir = encode_user_id(user_id) + folder = os.path.join("app", "public", "upload", user_dir, api_task_id) + os.makedirs(folder, exist_ok=True) + filename = f"{int(time.time() * 1000)}.jpg" + file_path = os.path.join(folder, filename) + with open(file_path, "wb") as f: + f.write(base64.b64decode(image_base64)) + return f"/public/upload/{user_dir}/{api_task_id}/{filename}" diff --git a/server/app/model/chat/chat_step.py b/server/app/model/chat/chat_step.py new file mode 100644 index 000000000..7b567d924 --- /dev/null +++ b/server/app/model/chat/chat_step.py @@ -0,0 +1,43 @@ +from sqlmodel import SQLModel, Field, JSON +from app.model.abstract.model import AbstractModel, DefaultTimes +from pydantic import BaseModel +from typing import Any +from pydantic import field_validator +import json + + +class ChatStep(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + task_id: str = Field(index=True) + step: str + data: str = Field(sa_type=JSON) + + @field_validator("data", mode="before") + @classmethod + def serialize_data(cls, v): + if isinstance(v, (dict, list)): + return json.dumps(v, ensure_ascii=False) + return v + + @field_validator("data", mode="after") + @classmethod + def deserialize_data(cls, v): + if isinstance(v, str): + try: + return json.loads(v) + except Exception: + return v + return v + + +class ChatStepIn(BaseModel): + task_id: str + step: str + data: Any + + +class ChatStepOut(BaseModel): + id: int + task_id: str + step: str + data: Any diff --git a/server/app/model/config/config.py b/server/app/model/config/config.py new file mode 100644 index 000000000..72e213ee7 --- /dev/null +++ b/server/app/model/config/config.py @@ -0,0 +1,181 @@ +from sqlmodel import Field, SQLModel, UniqueConstraint +from app.model.abstract.model import AbstractModel, DefaultTimes +from app.type.config_group import ConfigGroup + + +class Config(AbstractModel, DefaultTimes, table=True): + __table_args__ = (UniqueConstraint("user_id", "config_name", name="uix_user_id_config_name"),) + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(nullable=False, index=True) + config_name: str = Field(nullable=False) + config_value: str = Field(nullable=False, default=None) + config_group: str = Field(nullable=False, index=True) + + +class ConfigCreate(SQLModel): + config_name: str + config_value: str + config_group: ConfigGroup + + +class ConfigUpdate(SQLModel): + config_name: str + config_value: str + config_group: ConfigGroup + + +class ConfigOut(SQLModel): + id: int + user_id: int + config_name: str + config_value: str + config_group: ConfigGroup + + +class ConfigInfo: + configs: dict = { + # "model_platform": {"env_vars": ["api_key"]}, + # ConfigGroup.GOOGLE_SUITE.value: { + # "env_vars": [], + # "toolkit": "", + # }, + ConfigGroup.SLACK.value: { + "env_vars": ["SLACK_BOT_TOKEN"], + "toolkit": "slack_toolkit", + }, + ConfigGroup.NOTION.value: { + "env_vars": ["MCP_REMOTE_CONFIG_DIR"], + "toolkit": "notion_toolkit", + }, + ConfigGroup.TWITTER.value: { + "env_vars": [ + "TWITTER_CONSUMER_KEY", + "TWITTER_CONSUMER_SECRET", + "TWITTER_ACCESS_TOKEN", + "TWITTER_ACCESS_TOKEN_SECRET", + ], + "toolkit": "twitter_toolkit", + }, + ConfigGroup.WHATSAPP.value: { + "env_vars": ["WHATSAPP_ACCESS_TOKEN", "WHATSAPP_PHONE_NUMBER_ID"], + "toolkit": "whatsapp_toolkit", + }, + ConfigGroup.LINKEDIN.value: { + "env_vars": ["LINKEDIN_ACCESS_TOKEN"], + "toolkit": "linkedin_toolkit", + }, + ConfigGroup.REDDIT.value: { + "env_vars": [ + "REDDIT_CLIENT_ID", + "REDDIT_CLIENT_SECRET", + "REDDIT_USER_AGENT", + ], + "toolkit": "reddit_toolkit", + }, + # ConfigGroup.DISCORD.value: { + # "env_vars": ["DISCORD_BOT_TOKEN"], + # "toolkit": "discord_toolkit", + # }, + ConfigGroup.SEARCH.value: { + "env_vars": ["GOOGLE_API_KEY", "SEARCH_ENGINE_ID", "EXA_API_KEY"], + "toolkit": "search_toolkit", + }, + ConfigGroup.AUDIO_ANALYSIS.value: { + "env_vars": [], + "toolkit": "audio_analysis_toolkit", + }, + ConfigGroup.CODE_EXECUTION.value: { + "env_vars": [], + "toolkit": "code_execution_toolkit", + }, + ConfigGroup.CRAW4AI.value: { + "env_vars": [], + "toolkit": "craw4ai_toolkit", + }, + ConfigGroup.DALLE.value: { + "env_vars": [], + "toolkit": "dalle_toolkit", + }, + ConfigGroup.EDGEONE_PAGES_MCP.value: { + "env_vars": [], + "toolkit": "edgeone_pages_mcp_toolkit", + }, + ConfigGroup.EXCEL.value: { + "env_vars": [], + "toolkit": "excel_toolkit", + }, + ConfigGroup.FILE_WRITE.value: { + "env_vars": [], + "toolkit": "file_write_toolkit", + }, + ConfigGroup.GITHUB.value: { + "env_vars": ["GITHUB_TOKEN"], + "toolkit": "github_toolkit", + }, + ConfigGroup.GOOGLE_CALENDAR.value: { + "env_vars": [ + "GOOGLE_CLIENT_ID", + "GOOGLE_CLIENT_SECRET", + "GOOGLE_REFRESH_TOKEN", + ], + "toolkit": "google_calendar_toolkit", + }, + ConfigGroup.GOOGLE_DRIVE_MCP.value: { + "env_vars": [], + "toolkit": "google_drive_mcp_toolkit", + }, + # ConfigGroup.GOOGLE_GMAIL_MCP.value: { + # "env_vars": [], + # "toolkit": "google_gmail_mcp_toolkit", + # }, + ConfigGroup.IMAGE_ANALYSIS.value: { + "env_vars": [], + "toolkit": "image_analysis_toolkit", + }, + ConfigGroup.MCP_SEARCH.value: { + "env_vars": [], + "toolkit": "mcp_search_toolkit", + }, + ConfigGroup.PPTX.value: { + "env_vars": [], + "toolkit": "pptx_toolkit", + }, + ConfigGroup.REDDIT.value: { + "env_vars": [ + "REDDIT_CLIENT_ID", + "REDDIT_CLIENT_SECRET", + "REDDIT_USER_AGENT", + ], + "toolkit": "reddit_toolkit", + }, + } + + @classmethod + def getinfo(cls): + return cls.configs + + @classmethod + def is_valid_group(cls, group: str) -> bool: + return group in cls.configs + + @classmethod + def get_group_env_vars(cls, group: str) -> list[str]: + if not cls.is_valid_group(group): + raise KeyError(f"Invalid group: {group}") + return cls.configs[group]["env_vars"] + + @classmethod + def is_valid_env_var(cls, group: str, env_var: str) -> bool: + if not cls.is_valid_group(group): + return False + return env_var in cls.configs[group]["env_vars"] + + @classmethod + def validate_env_vars(cls, group: str, env_vars: list[str]) -> tuple[bool, list[str]]: + if not cls.is_valid_group(group): + return False, env_vars + + valid_vars = cls.configs[group]["env_vars"] + invalid_vars = [var for var in env_vars if var not in valid_vars] + + return len(invalid_vars) == 0, invalid_vars diff --git a/server/app/model/config/plan.py b/server/app/model/config/plan.py new file mode 100644 index 000000000..7d181687d --- /dev/null +++ b/server/app/model/config/plan.py @@ -0,0 +1,16 @@ +from sqlmodel import SQLModel, Field, Column +from sqlalchemy import JSON + + +class Plan(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) + plan_key: str = Field(index=True, unique=True, description="Unique plan key") + name: str = Field(index=True, unique=True, description="Plan name") + price_month: float = Field(default=0, description="Monthly price") + price_year: float = Field(default=0, description="Yearly price") + daily_credits: int = Field(default=0, description="Daily credits") + monthly_credits: int = Field(default=0, description="Monthly credits") + storage_limit: int = Field(default=0, description="Cloud storage space (MB)") + description: str = Field(default="", description="Plan description") + is_active: bool = Field(default=True, description="Is the plan active") + extra_config: dict = Field(default_factory=dict, sa_column=Column(JSON), description="Flexible extra config") diff --git a/server/app/model/mcp/category.py b/server/app/model/mcp/category.py new file mode 100644 index 000000000..75662fd89 --- /dev/null +++ b/server/app/model/mcp/category.py @@ -0,0 +1,36 @@ +from typing import ClassVar +from pydantic import BaseModel +from sqlalchemy import func +from sqlmodel import Field, select +from sqlalchemy.orm import Mapped, query_expression +from app.model.abstract.model import AbstractModel, DefaultTimes + + +class Category(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + name: str = Field(default="", max_length=64) + description: str = Field(default="", max_length=128) + priority: int = Field(default=100) + + mcp_num: ClassVar[Mapped[int | None]] = query_expression() + + @staticmethod + def expr_mcp_num(): + from app.model.mcp.mcp import Mcp + + return select(func.count("*")).where(Category.id == Mcp.category_id).scalar_subquery() + + +class CategoryOut(BaseModel): + id: int + name: str + description: str + priority: int + + mcp_num: int | None + + +class CategoryIn(BaseModel): + name: str + description: str + priority: int diff --git a/server/app/model/mcp/mcp.py b/server/app/model/mcp/mcp.py new file mode 100644 index 000000000..fa4c69d93 --- /dev/null +++ b/server/app/model/mcp/mcp.py @@ -0,0 +1,75 @@ +from enum import IntEnum +from typing import List +from pydantic import BaseModel +from sqlalchemy import Column, SmallInteger, String +from sqlalchemy.orm import Mapped +from sqlmodel import Field, Relationship, JSON +from sqlalchemy_utils import ChoiceType +from app.model.abstract.model import AbstractModel, DefaultTimes +from app.model.mcp.mcp_env import McpEnv, McpEnvOut +from app.type.pydantic import HttpUrlStr +from app.model.mcp.category import Category, CategoryOut +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.model.mcp.mcp_user import McpUser + + +class Status(IntEnum): + Online = 1 + Offline = -1 + + +class McpType(IntEnum): + Local = 1 + Remote = 2 + + +class Mcp(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + category_id: int = Field(foreign_key="category.id") + name: str + key: str = Field(sa_column=Column(String(128))) + description: str = "" + home_page: str = Field(default="", sa_column=Column(String(1024))) + type: McpType = Field(default=McpType.Local, sa_column=Column(ChoiceType(McpType, SmallInteger()))) + status: Status = Field(default=Status.Online, sa_column=Column(ChoiceType(Status, SmallInteger()))) + sort: int = Field(default=0, sa_column=Column(SmallInteger)) + server_name: str = Field(default="", sa_column=Column(String(128))) + install_command: dict = Field(default="{}", sa_column=Column(JSON)) + """{ + "command": "uvx", + "args": ["mcp-server-everything-search"], + "env": { + "EVERYTHING_SDK_PATH": "path/to/Everything-SDK/dll/Everything64.dll" + } + }""" + + category: Mapped[Category] = Relationship() + envs: Mapped[list[McpEnv]] = Relationship() + # user_env: Mapped[McpUser] = Relationship() + + mcp_user: List["McpUser"] = Relationship(back_populates="mcp") + + +class McpIn(BaseModel): + category_id: int + name: str + key: str + description: str + home_page: HttpUrlStr + type: McpType + status: Status + install_command: dict + + +class McpOut(McpIn): + id: int + category: CategoryOut | None = None + # envs: list[McpEnvOut] = [] + + +class McpInfo(BaseModel): + id: int + name: str + key: str diff --git a/server/app/model/mcp/mcp_env.py b/server/app/model/mcp/mcp_env.py new file mode 100644 index 000000000..c7a10ef2e --- /dev/null +++ b/server/app/model/mcp/mcp_env.py @@ -0,0 +1,33 @@ +from enum import IntEnum + +from pydantic import BaseModel +from app.model.abstract.model import AbstractModel, DefaultTimes +from sqlalchemy_utils import ChoiceType +from sqlmodel import Field, Column, String, TEXT, SmallInteger + + +class Status(IntEnum): + in_use = 1 + deprecated = 2 + no_use = 3 + + +class McpEnv(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + mcp_id: int = Field(foreign_key="mcp.id") + env_name: str = Field(default="", sa_column=Column(String(128))) + env_description: str = Field(default="", sa_column=Column(TEXT)) + env_key: str = Field(sa_column=Column(String(128))) + env_default_value: str = Field(default="", sa_column=Column(String(1024))) + env_required: int = Field(default=1, sa_column=Column(SmallInteger)) + status: Status = Field(default=Status.in_use, sa_column=Column(ChoiceType(Status, SmallInteger()))) + + +class McpEnvOut(BaseModel): + id: int + mcp_id: int + env_name: str + env_description: str + env_key: str + env_default_value: str + env_required: int diff --git a/server/app/model/mcp/mcp_user.py b/server/app/model/mcp/mcp_user.py new file mode 100644 index 000000000..04e202a92 --- /dev/null +++ b/server/app/model/mcp/mcp_user.py @@ -0,0 +1,100 @@ +from enum import Enum, IntEnum +from pydantic import BaseModel +from sqlalchemy import String +from sqlmodel import Field, Column, JSON, SQLModel, UniqueConstraint, Relationship, SmallInteger +from app.model.abstract.model import DefaultTimes, AbstractModel +from sqlalchemy.orm import Mapped +from typing import Optional +from sqlalchemy_utils import ChoiceType +from app.model.mcp.mcp import McpInfo, Mcp + + +class Status(IntEnum): + enable = 1 + disable = 2 + + +class McpType(IntEnum): + Local = 1 + Remote = 2 + + +class McpImportType(str, Enum): + Local = "local" + Remote = "remote" + + +class McpUser(AbstractModel, DefaultTimes, table=True): + id: int | None = Field(default=None, primary_key=True) + mcp_id: int = Field(default=0, foreign_key="mcp.id") + user_id: int = Field(foreign_key="user.id") + mcp_name: str = Field(sa_column=Column(String(128))) + mcp_key: str = Field(sa_column=Column(String(128))) + mcp_desc: str | None = Field(default=None, sa_column=Column(String(1024))) + command: str | None = Field(default=None, sa_column=Column(String(1024))) + args: str | None = Field(default=None, sa_column=Column(String(1024))) + env: dict | None = Field(default=None, sa_column=Column(JSON)) + type: McpType = Field(default=McpType.Local, sa_column=Column(ChoiceType(McpType, SmallInteger()))) + status: Status = Field(default=Status.enable, sa_column=Column(ChoiceType(Status, SmallInteger()))) + server_url: str | None + + mcp: Mcp = Relationship(back_populates="mcp_user") + + +class McpUserIn(SQLModel): + mcp_id: int + env: Optional[dict] = None + status: Status = Status.enable + mcp_key: Optional[str] = None + + +class McpUserOut(SQLModel): + id: int + mcp_id: int + mcp_name: str | None = None + mcp_desc: str | None = None + command: Optional[str] = None + args: Optional[str] = None + env: Optional[dict] = None + status: int + type: int + server_url: Optional[str] = None + mcp_key: str + + +class McpUserUpdate(BaseModel): + mcp_name: str | None = None + mcp_desc: str | None = None + status: Optional[int] = None + type: McpType | None = None + env: Optional[dict] = None + server_url: Optional[str] = None + command: Optional[str] = None + args: Optional[str] = None + env: Optional[dict] = None + mcp_key: Optional[str] = None + + +class McpUserImport(BaseModel): + mcp_id: int = 0 + command: Optional[str] = None + args: Optional[str] = None + env: Optional[dict] = None + status: int = Status.enable + type: int = McpType.Local + server_url: Optional[str] = None + mcp_key: Optional[str] = None + + +class McpLocalImport(BaseModel): + type: int = McpType.Local + status: int = Status.enable + command: Optional[str] = None + args: Optional[str] = None + env: Optional[dict] = None + + +class McpRemoteImport(BaseModel): + type: int = McpType.Remote + status: int = Status.enable + server_url: Optional[str] = None diff --git a/server/app/model/mcp/proxy.py b/server/app/model/mcp/proxy.py new file mode 100644 index 000000000..a42eedb47 --- /dev/null +++ b/server/app/model/mcp/proxy.py @@ -0,0 +1,30 @@ +from typing import Literal +from pydantic import BaseModel + + +class ApiKey(BaseModel): + api_key: str + + +class ExaSearch(BaseModel): + query: str + search_type: Literal["auto", "neural", "keyword"] = "auto" + category: ( + Literal[ + "company", + "research paper", + "news", + "pdf", + "github", + "tweet", + "personal site", + "linkedin profile", + "financial report", + ] + | None + ) = None + num_results: int = 10 + include_text: list[str] | None = None + exclude_text: list[str] | None = None + use_autoprompt: bool | None = True + text: bool | None = False diff --git a/server/app/model/pay/order.py b/server/app/model/pay/order.py new file mode 100644 index 000000000..00ff0b0d8 --- /dev/null +++ b/server/app/model/pay/order.py @@ -0,0 +1,131 @@ +from enum import Enum +from sqlmodel import Field, Column, SQLModel, SmallInteger, Session +from sqlalchemy import JSON +from app.model.abstract.model import AbstractModel, DefaultTimes +from pydantic import BaseModel + + +class OrderType(int, Enum): + single = 1 # 单次/加量包 + plan = 2 # 套餐订阅 + addon = 3 # 加量包 + other = 99 # 其他 + + +class OrderAddonPrice(str, Enum): + addon_200: 20 + addon_500: 50 + + +class OrderStatus(int, Enum): + pending = 1 # 等待支付 + success = 2 # 支付成功 + cancel = -1 # 放弃支付 + refund = 3 + + +class Order(AbstractModel, DefaultTimes, table=True): + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(index=True, foreign_key="user.id") + order_type: OrderType = Field(sa_column=Column(SmallInteger)) + price: int = 0 + status: OrderStatus = Field(OrderStatus.pending, sa_column=Column(SmallInteger)) + payment_method: str = Field(default="", max_length=32) + stripe_id: str = Field(max_length=1024) + third_party_id: str | None = Field(default=None, max_length=1024) + plan_id: int | None = Field(default=None, foreign_key="plan.id") + period: str | None = Field(default=None, max_length=16) + buy_type: int | None = Field(default=None) # 仅加量包/次数包 + use_num: int | None = Field(default=None) + left_num: int | None = Field(default=None) + extra: dict = Field(default_factory=dict, sa_column=Column(JSON)) + + def mark_success(self, third_party_id: str | None = None, s: Session = None): + self.status = OrderStatus.success + if third_party_id: + self.third_party_id = third_party_id + if s: + s.add(self) + s.commit() + + def mark_failed(self, s: Session = None): + self.status = OrderStatus.cancel + if s: + s.add(self) + s.commit() + + def mark_pending(self, s: Session = None): + self.status = OrderStatus.pending + if s: + s.add(self) + s.commit() + + @classmethod + def create_addon_order( + cls, user_id: int, buy_type: int, price: int, payment_method: str, stripe_id: str, s: Session + ): + order = cls( + user_id=user_id, + order_type=OrderType.addon, + buy_type=buy_type, + use_num=buy_type, + left_num=buy_type, + price=price, + status=OrderStatus.pending, + payment_method=payment_method, + stripe_id=stripe_id, + ) + s.add(order) + s.commit() + return order + + @classmethod + def create_plan_order( + cls, + user_id: int, + plan_id: int, + period: str, + price: int, + payment_method: str, + stripe_id: str, + plan_name: str, + s: Session, + ): + order = cls( + user_id=user_id, + order_type=OrderType.plan, + plan_id=plan_id, + period=period, + price=price, + status=OrderStatus.pending, + payment_method=payment_method, + stripe_id=stripe_id, + extra={"plan_name": plan_name}, + ) + s.add(order) + s.commit() + return order + + +class PlanPeriod(str, Enum): + monthly = "monthly" + yearly = "yearly" + + +class PlanKey(str, Enum): + plus = "plus" + pro = "pro" + + +class PlanOrderCreate(BaseModel): + plan_key: PlanKey + period: PlanPeriod + + +class AddonPlanKey(str, Enum): + addon_200 = "addon_29.90" + addon_500 = "addon_69.90" + + +class OrderAddonCreate(BaseModel): + plan_key: AddonPlanKey diff --git a/server/app/model/provider/provider.py b/server/app/model/provider/provider.py new file mode 100644 index 000000000..2a1da3cb5 --- /dev/null +++ b/server/app/model/provider/provider.py @@ -0,0 +1,50 @@ +from enum import IntEnum +from typing import Optional +from pydantic import BaseModel +from sqlalchemy import Boolean, Column, SmallInteger, String +from sqlalchemy.orm import Mapped +from sqlmodel import Field, JSON +from sqlalchemy_utils import ChoiceType +from sqlalchemy import text +from app.model.abstract.model import AbstractModel, DefaultTimes + + +class VaildStatus(IntEnum): + not_valid = 1 + is_valid = 2 + + +class Provider(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(index=True) + provider_name: str + model_type: str + api_key: str + endpoint_url: str = "" + encrypted_config: dict | None = Field(default=None, sa_column=Column(JSON)) + prefer: bool = Field(default=False, sa_column=Column(Boolean, server_default=text("false"))) + is_vaild: VaildStatus = Field( + default=VaildStatus.not_valid, + sa_column=Column(ChoiceType(VaildStatus, SmallInteger()), server_default=text("1")), + ) + + +class ProviderIn(BaseModel): + provider_name: str + model_type: str + api_key: str + endpoint_url: str + encrypted_config: dict | None = None + is_vaild: VaildStatus = VaildStatus.not_valid + prefer: bool = False + + +class ProviderPreferIn(BaseModel): + provider_id: int + + +class ProviderOut(ProviderIn): + id: int + user_id: int + prefer: bool + model_type: Optional[str] = None diff --git a/server/app/model/user/admin.py b/server/app/model/user/admin.py new file mode 100644 index 000000000..06b22720d --- /dev/null +++ b/server/app/model/user/admin.py @@ -0,0 +1,77 @@ +from datetime import datetime +from pydantic import BaseModel, EmailStr, computed_field +from pydash import chain +from sqlmodel import Field, Column, Relationship, SmallInteger +from sqlalchemy.orm import Mapped +from sqlalchemy_utils import ChoiceType +from app.model.abstract.model import AbstractModel, DefaultTimes +from enum import IntEnum +from app.model.user.role import Role, RoleOut +from app.model.user.admin_role import AdminRole + + +class Status(IntEnum): + Normal = 1 + Disable = -1 + + +class Admin(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + email: EmailStr + password: str + name: str + user_id: int = 0 + status: int = Field(default=1, sa_column=Column(ChoiceType(Status, SmallInteger()))) + + roles: Mapped[list[Role]] = Relationship( + link_model=AdminRole, + sa_relationship_kwargs={ + "primaryjoin": "Admin.id == AdminRole.admin_id", + "secondaryjoin": "AdminRole.role_id == Role.id", + # "collection_class": Collection, + }, + ) + + +class LoginByPasswordIn(BaseModel): + email: EmailStr + password: str + + +class LoginResponse(BaseModel): + token: str + user_id: int + permissions: list[str] + + +class AdminIn(BaseModel): + email: EmailStr + name: str + status: Status + + +class AdminCreate(AdminIn): + password: str + + +class AdminOut(BaseModel): + id: int + email: EmailStr + name: str + status: Status + created_at: datetime + roles: list[RoleOut] + + @computed_field(return_type=list[str]) + def permissions(self): + return chain(self.roles).flat_map(lambda role: role.permissions).value() + + +class UpdatePassword(BaseModel): + password: str + new_password: str + re_new_password: str + + +class SetPassword(BaseModel): + password: str diff --git a/server/app/model/user/admin_role.py b/server/app/model/user/admin_role.py new file mode 100644 index 000000000..90960776c --- /dev/null +++ b/server/app/model/user/admin_role.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel +from sqlmodel import Field +from app.model.abstract.model import AbstractModel + + +class AdminRole(AbstractModel, table=True): + admin_id: int = Field(primary_key=True) + role_id: int = Field(primary_key=True) + + +class AdminRoleIn(BaseModel): + admin_id: int + role_ids: list[int] diff --git a/server/app/model/user/key.py b/server/app/model/user/key.py new file mode 100644 index 000000000..55f9f8323 --- /dev/null +++ b/server/app/model/user/key.py @@ -0,0 +1,37 @@ +from enum import IntEnum, StrEnum +from typing import Optional +from pydantic import BaseModel, computed_field +from sqlmodel import Field, Column, SmallInteger +from sqlalchemy_utils import ChoiceType +from app.component.environment import env_not_empty +from app.model.abstract.model import AbstractModel, DefaultTimes + + +class ModelType(StrEnum): + gpt4_1 = "gpt-4.1" + gpt4_mini = "gpt-4.1-mini" + gemini_2_5_pro = "gemini/gemini-2.5-pro" + gemini_2_5_flash = "gemini-2.5-flash" + + +class KeyStatus(IntEnum): + active = 1 + disabled = -1 + + +class Key(AbstractModel, DefaultTimes, table=True): + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="user.id", index=True) + value: str = Field(max_length=255, index=True) + inner_key: str = Field(default="", max_length=255) # litellm内部存储的key + status: KeyStatus = Field(sa_column=Column(ChoiceType(KeyStatus, SmallInteger()))) + + +class KeyOut(BaseModel): + warning_code: Optional[str] = None + warning_text: Optional[str] = None + value: str + + @computed_field(return_type=str) + def api_url(self): + return env_not_empty("litellm_url") diff --git a/server/app/model/user/privacy.py b/server/app/model/user/privacy.py new file mode 100644 index 000000000..c749bc526 --- /dev/null +++ b/server/app/model/user/privacy.py @@ -0,0 +1,24 @@ +from datetime import datetime +from enum import IntEnum +from sqlalchemy import JSON, SmallInteger +from sqlalchemy_utils import ChoiceType +from pydantic import BaseModel, EmailStr +from sqlmodel import Field, Column +from app.model.abstract.model import AbstractModel, DefaultTimes + + +class UserPrivacy(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(unique=True, foreign_key="user.id") + pricacy_setting: dict = Field(default="{}", sa_column=Column(JSON)) + + +class UserPrivacySettings(BaseModel): + take_screenshot: bool | None = False + access_local_software: bool | None = False + access_your_address: bool | None = False + password_storage: bool | None = False + + @classmethod + def default_settings(cls) -> dict: + return cls().model_dump() diff --git a/server/app/model/user/role.py b/server/app/model/user/role.py new file mode 100644 index 000000000..2088a65ed --- /dev/null +++ b/server/app/model/user/role.py @@ -0,0 +1,34 @@ +from datetime import datetime +from enum import Enum +from pydantic import BaseModel +from sqlmodel import Field, Column, SmallInteger, JSON +from app.model.abstract.model import AbstractModel, DefaultTimes +from sqlalchemy_utils import ChoiceType + + +class RoleType(int, Enum): + System = 1 + Custom = 2 + + +class Role(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + name: str + description: str = "" + type: RoleType = Field(default=RoleType.Custom.value, sa_column=Column(ChoiceType(RoleType, SmallInteger()))) + permissions: list[str] = Field(sa_column=Column(JSON)) + + +class RoleIn(BaseModel): + name: str + description: str = "" + permissions: list[str] + + +class RoleOut(BaseModel): + id: int + name: str + description: str + type: RoleType + permissions: list[str] + created_at: datetime diff --git a/server/app/model/user/user.py b/server/app/model/user/user.py new file mode 100644 index 000000000..3764c31ea --- /dev/null +++ b/server/app/model/user/user.py @@ -0,0 +1,92 @@ +from datetime import datetime, date +from enum import IntEnum +from sqlalchemy import Integer, SmallInteger, text +from sqlalchemy_utils import ChoiceType +from pydantic import BaseModel, EmailStr, field_validator +from sqlmodel import Field, Column +from app.model.abstract.model import AbstractModel, DefaultTimes +from typing import Optional +from app.component.encrypt import password_hash + + +class Status(IntEnum): + Normal = 1 + Block = -1 + + +class User(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + stack_id: str | None = Field(default=None, unique=True, max_length=255) + username: str | None = Field(default=None, unique=True, max_length=128) + email: EmailStr = Field(unique=True, max_length=128) + password: str | None = Field(default=None, max_length=256) + avatar: str = Field(default="", max_length=256) + nickname: str = Field(default="", max_length=64) + fullname: str = Field(default="", max_length=128) + work_desc: str = Field(default="", max_length=255) + credits: int = Field(default=0, description="credits", sa_column=Column(Integer, server_default=text("0"))) + last_daily_credit_date: date | None = Field(default=None, description="Last date daily credits were granted") + last_monthly_credit_date: date | None = Field(default=None, description="Last month monthly credits were granted") + inviter_user_id: int | None = Field(default=None, foreign_key="user.id", description="Inviter user ID") + status: Status = Field(default=Status.Normal.value, sa_column=Column(ChoiceType(Status, SmallInteger()))) + + +class UserProfile(BaseModel): + fullname: str = "" + nickname: str = "" + work_desc: str = "" + + +class LoginByPasswordIn(BaseModel): + email: EmailStr + password: str + + +class LoginResponse(BaseModel): + token: str + email: EmailStr + + +class UserIn(BaseModel): + username: str + + +class UserCreate(UserIn): + password: str + + +class UserOut(BaseModel): + email: EmailStr + avatar: Optional[str] = "" + username: Optional[str] = "" + nickname: Optional[str] = "" + fullname: Optional[str] = "" + work_desc: Optional[str] = "" + credits: int + status: Status + created_at: datetime + + +class UpdatePassword(BaseModel): + password: str + new_password: str + re_new_password: str + + +class RegisterIn(BaseModel): + email: EmailStr + password: str + invite_code: Optional[str] = None + + @field_validator("password", mode="before") + def password_strength(cls, v): + # At least 8 chars, must contain letters and numbers + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + if not any(c.isdigit() for c in v) or not any(c.isalpha() for c in v): + raise ValueError("Password must contain both letters and numbers") + return v + + @field_validator("password", mode="after") + def password_hash(cls, v): + return password_hash(v) diff --git a/server/app/model/user/user_credits.py b/server/app/model/user/user_credits.py new file mode 100644 index 000000000..16cd80dcb --- /dev/null +++ b/server/app/model/user/user_credits.py @@ -0,0 +1,8 @@ +from enum import IntEnum + + +class Channel(IntEnum): + free = 1 + paid = 2 + gift = 3 + top_up = 4 diff --git a/server/app/model/user/user_credits_record.py b/server/app/model/user/user_credits_record.py new file mode 100644 index 000000000..bcdb7750d --- /dev/null +++ b/server/app/model/user/user_credits_record.py @@ -0,0 +1,381 @@ +from enum import IntEnum +from typing import Optional +from pydantic import BaseModel +from sqlmodel import Relationship, SQLModel, Field, Column, col, select, Session +from sqlalchemy_utils import ChoiceType +from sqlalchemy import Boolean, SmallInteger, text +from app.model.abstract.model import AbstractModel, DefaultTimes +from datetime import date, datetime, timedelta +from app.model.user.key import ModelType +from app.component.database import session_make +from loguru import logger + + +class CreditsChannel(IntEnum): + register = 1 # 注册赠送 + invite = 2 # 邀请赠送 + daily = 3 # 每日刷新积分 + monthly = 4 # 每月刷新积分 + paid = 5 # 付费积分 + addon = 6 # 加量包 + consume = 7 # 任务消费 + + +class CreditsPriority(IntEnum): + daily = 1 # 每日刷新积分 + monthly = 2 # 每月刷新积分 + paid = 3 # 付费积分 + addon = 4 # 加量包 + + +class CreditsPoint(IntEnum): + register = 1000 + invite = 500 + special_register = 1500 # 1000 register + 500 invite credit + + +class UserCreditsRecord(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(index=True, foreign_key="user.id") + invite_by: int = Field(default=None, nullable=True, description="invite by user id") + invite_code: str = Field(default="", max_length=255) + amount: int = Field(default=0) + balance: int = Field(default=0) + channel: CreditsChannel = Field( + default=CreditsChannel.register.value, sa_column=Column(ChoiceType(CreditsChannel, SmallInteger())) + ) + source_id: int = Field(default=0, description="source id") + remark: str = Field(default="", max_length=255) + expire_at: datetime = Field(default=None, nullable=True, description="Expiration time") + used: bool = Field( + default=False, + sa_column=Column(Boolean, server_default=text("false")), + description="Is this record used/expired", + ) + used_at: datetime = Field(default=None, nullable=True, description="Time when this record was used/expired") + + @classmethod + def get_permanent_credits(cls, user_id: int) -> int: + """ + 获取可用的token总量,直接用SQL聚合sum + Returns: + int: 可用的token总量 + """ + session = session_make() + from sqlalchemy import func + + statement = ( + select(func.sum(UserCreditsRecord.amount)) + .where(UserCreditsRecord.user_id == user_id) + .where( + UserCreditsRecord.channel.in_( + [ + CreditsChannel.register, + CreditsChannel.invite, + CreditsChannel.paid, + CreditsChannel.addon, + CreditsChannel.monthly, + ] + ) + ) + .where(UserCreditsRecord.used == False) + .where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > datetime.now())) + ) + result = session.exec(statement).first() + return result or 0 + + @classmethod + def get_temp_credits(cls, user_id: int) -> tuple[int, date]: + """ + 1. 获取可用的临时token总量,需要通过credits 然后根据model_type来计算 + 2. 每天只允许赠送一次临时的量 + + Returns: + int: 可用的临时token总量 + """ + session = session_make() + statement = ( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where(UserCreditsRecord.channel == CreditsChannel.daily) + .where(UserCreditsRecord.used == False) + .where(UserCreditsRecord.expire_at.is_not(None)) + .where(col(UserCreditsRecord.expire_at) > datetime.now()) + ) + record: UserCreditsRecord = session.exec(statement).first() + if record is None: + return 0, None + return record.amount - record.balance, record.expire_at + + @classmethod + def consume_credits(cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = ""): + """ + 消耗积分,优先消耗每日积分(daily),再消耗monthly、paid、addon等。 + 消耗时更新UserCreditsRecord的balance字段,记录已消耗积分数。 + 同时生成积分消耗记录,更新用户积分credits字段(不包括每日积分)。 + 避免重复生成积分消耗记录和重复扣减积分。 + """ + + # 检查是否已有积分消耗记录 + existing_consume_record = None + if source_id > 0: + existing_consume_record = session.exec( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where(UserCreditsRecord.channel == CreditsChannel.consume) + .where(UserCreditsRecord.source_id == source_id) + ).first() + + if existing_consume_record: + # 如果新amount更大,需要额外消耗积分 + if amount > 0: + existing_consume_record.amount -= amount + session.add(existing_consume_record) + # 直接处理额外的积分消耗,不生成新的消耗记录 + cls._consume_credits_internal_update(user_id, amount, session, source_id, remark) + # 如果新amount更小,需要退还积分(这里可以根据业务需求决定是否实现) + else: + # 暂时不实现退还逻辑,可以根据需要添加 + pass + + session.commit() + return + + # 没有现有记录,执行正常的积分消耗流程 + cls._consume_credits_internal(user_id, amount, session, source_id, remark) + + @classmethod + def _consume_credits_internal( + cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = "" + ): + """ + 内部积分消耗逻辑,处理实际的积分扣减 + """ + from app.model.user.user import User + + remain = amount + now = datetime.now() + consumed_from_daily = 0 + consumed_from_other = 0 + + # 优先消耗daily + statement = ( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where(UserCreditsRecord.channel == CreditsChannel.daily) + .where(UserCreditsRecord.used == False) + .where(UserCreditsRecord.expire_at.is_not(None)) + .where(col(UserCreditsRecord.expire_at) > now) + .order_by(UserCreditsRecord.expire_at) + ) + daily_records = session.exec(statement).first() + if daily_records: + can_consume = daily_records.amount - daily_records.balance + use = min(remain, can_consume) + daily_records.balance += use + session.add(daily_records) + remain -= use + consumed_from_daily = use + if remain == 0: + # 生成积分消耗记录 + consume_record = UserCreditsRecord( + user_id=user_id, + amount=-amount, + channel=CreditsChannel.consume, + source_id=source_id, + remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily})", + ) + session.add(consume_record) + session.commit() + return + + # 若daily不够,继续消耗monthly/paid/addon + if remain > 0: + statement = ( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where( + UserCreditsRecord.channel.in_( + [ + CreditsChannel.monthly, + CreditsChannel.paid, + CreditsChannel.addon, + CreditsChannel.register, + CreditsChannel.invite, + ] + ) + ) + .where(UserCreditsRecord.used == False) + .where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now)) + .order_by(UserCreditsRecord.expire_at) + ) + other_records = session.exec(statement).all() + for record in other_records: + can_consume = record.amount - record.balance + if can_consume <= 0: + continue + use = min(remain, can_consume) + record.balance += use + session.add(record) + remain -= use + consumed_from_other += use + if remain == 0: + break + + # 更新用户积分字段(只扣除非每日积分消耗的部分) + if consumed_from_other > 0: + user = session.exec(select(User).where(User.id == user_id)).first() + if user: + user.credits -= consumed_from_other + session.add(user) + + # 生成积分消耗记录 + consume_record = UserCreditsRecord( + user_id=user_id, + amount=-amount, + channel=CreditsChannel.consume, + source_id=source_id, + remark=remark or f"Consumed {amount} credits (daily: {consumed_from_daily}, other: {consumed_from_other})", + ) + session.add(consume_record) + session.commit() + + if remain > 0: + raise Exception(f"Insufficient credits: need {amount}, remain {remain}") + + @classmethod + def _consume_credits_internal_update( + cls, user_id: int, amount: int, session: Session, source_id: int = 0, remark: str = "" + ): + """ + 内部积分消耗逻辑(更新模式),处理实际的积分扣减但不生成新的消耗记录 + 用于更新现有消耗记录时的额外积分消耗 + """ + from app.model.user.user import User + + remain = amount + now = datetime.now() + consumed_from_daily = 0 + consumed_from_other = 0 + + # 优先消耗daily + statement = ( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where(UserCreditsRecord.channel == CreditsChannel.daily) + .where(UserCreditsRecord.used == False) + .where(UserCreditsRecord.expire_at.is_not(None)) + .where(col(UserCreditsRecord.expire_at) > now) + .order_by(UserCreditsRecord.expire_at) + ) + daily_records = session.exec(statement).first() + if daily_records: + can_consume = daily_records.amount - daily_records.balance + use = min(remain, can_consume) + daily_records.balance += use + session.add(daily_records) + remain -= use + consumed_from_daily = use + if remain == 0: + # 不生成新的消耗记录,只更新现有记录 + return + + # 若daily不够,继续消耗monthly/paid/addon + if remain > 0: + statement = ( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where( + UserCreditsRecord.channel.in_( + [ + CreditsChannel.monthly, + CreditsChannel.paid, + CreditsChannel.addon, + CreditsChannel.register, + CreditsChannel.invite, + ] + ) + ) + .where(UserCreditsRecord.used == False) + .where((UserCreditsRecord.expire_at.is_(None)) | (col(UserCreditsRecord.expire_at) > now)) + .order_by(UserCreditsRecord.expire_at) + ) + other_records = session.exec(statement).all() + for record in other_records: + can_consume = record.amount - record.balance + if can_consume <= 0: + continue + use = min(remain, can_consume) + record.balance += use + session.add(record) + remain -= use + consumed_from_other += use + if remain == 0: + break + logger.info(f"consumed_from_other: {consumed_from_other}") + # 更新用户积分字段(只扣除非每日积分消耗的部分) + if consumed_from_other > 0: + user = session.exec(select(User).where(User.id == user_id)).first() + if user: + user.credits -= consumed_from_other + session.add(user) + + # 不生成新的消耗记录,因为现有记录已经在主函数中更新了 + + if remain > 0: + raise Exception(f"Insufficient credits: need {amount}, remain {remain}") + + @classmethod + def get_daily_balance_sum(cls, user_id: int) -> int: + """ + 获取用户所有每日积分(daily channel)的balance字段之和 + """ + session = session_make() + statement = ( + select(UserCreditsRecord.balance) + .where(UserCreditsRecord.user_id == user_id) + .where(UserCreditsRecord.channel == CreditsChannel.daily) + ) + balances = session.exec(statement).all() + return sum(balances) if balances else 0 + + @classmethod + def get_daily_balance(cls, user_id: int) -> int: + """ + 获取用户当前的每日积分数据 + """ + session = session_make() + statement = ( + select(UserCreditsRecord) + .where(UserCreditsRecord.user_id == user_id) + .where(UserCreditsRecord.channel == CreditsChannel.daily) + .where(UserCreditsRecord.used == False) + ) + record = session.exec(statement).first() + return record + + +class UserCreditsRecordWithChatOut(BaseModel): + """扩展的积分记录输出模型,包含聊天历史信息""" + + amount: int + balance: int + channel: CreditsChannel + source_id: int + expire_at: Optional[datetime] = None + created_at: datetime + updated_at: Optional[datetime] = None + # 聊天历史相关字段(当channel为consume且source_id有效时) + chat_project_name: Optional[str] = None + chat_tokens: Optional[int] = None + + +class UserCreditsRecordOut(BaseModel): + amount: int + balance: int + channel: CreditsChannel + source_id: int + remark: str + expire_at: datetime | None + created_at: datetime + updated_at: datetime | None diff --git a/server/app/model/user/user_stat.py b/server/app/model/user/user_stat.py new file mode 100644 index 000000000..2ec5f7794 --- /dev/null +++ b/server/app/model/user/user_stat.py @@ -0,0 +1,87 @@ +from datetime import datetime +from typing import Optional +from sqlmodel import SQLModel, Field, Column, select +from pydantic import BaseModel +from enum import Enum + +from app.model.abstract.model import AbstractModel, DefaultTimes + + +class UserStatActionEnum(str, Enum): + download_count = "download_count" + register_count = "register_count" + task_complete_count = "task_complete_count" + task_failed_count = "task_failed_count" + file_download_count = "file_download_count" + file_generate_count = "file_generate_count" + paid_amount_on_avg_task = "paid_amount_on_avg_task" + + +class UserStatActionIn(BaseModel): + user_id: int | None = None + action: UserStatActionEnum + value: int = 1 + model_type: str | None = None + + +class UserStat(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="user.id", index=True, description="User ID") + # Model usage type: 'cloud' or 'local' + model_type: str = Field(default="unused", description="Model usage type: 'cloud' or 'local'") + # Product page statistics + download_count: int = Field(default=0, description="Number of downloads by the user") + register_count: int = Field(default=0, description="Number of registrations (for product page)") + task_complete_count: int = Field(default=0, description="Number of tasks completed by the user") + task_failed_count: int = Field(default=0, description="Number of tasks failed by the user") + file_download_count: int = Field(default=0, description="Number of files downloaded by the user") + file_generate_count: int = Field(default=0, description="Number of files generated by the user") + # Payment statistics + paid_amount_on_avg_task: int = Field(default=0, description="Total paid amount on average task completion") + + @classmethod + def record_action(cls, session, action_in: UserStatActionIn): + """ + Record or update user operation statistics using a Pydantic model. + If no record exists for the user, create one. Otherwise, update the corresponding field. + Supported actions: download_count, register_count, task_complete_count, task_failed_count, file_download_count, file_generate_count, paid_amount_on_avg_task. + If model_type is provided, update it as well. + """ + stat = session.exec(select(cls).where(cls.user_id == action_in.user_id)).first() + if not stat: + stat = cls(user_id=action_in.user_id) + session.add(stat) + if action_in.action in [ + UserStatActionEnum.download_count, + UserStatActionEnum.register_count, + UserStatActionEnum.task_complete_count, + UserStatActionEnum.task_failed_count, + UserStatActionEnum.file_download_count, + UserStatActionEnum.file_generate_count, + ]: + setattr(stat, action_in.action.value, getattr(stat, action_in.action.value, 0) + action_in.value) + elif action_in.action == UserStatActionEnum.paid_amount_on_avg_task: + stat.paid_amount_on_avg_task += action_in.value + else: + raise ValueError(f"Unsupported action: {action_in.action}") + if action_in.model_type is not None: + stat.model_type = action_in.model_type + session.add(stat) + session.commit() + session.refresh(stat) + return stat + + +class UserStatOut(BaseModel): + model_type: str | None = None + download_count: int = 0 + register_count: int = 0 + task_complete_count: int = 0 + task_failed_count: int = 0 + file_download_count: int = 0 + file_generate_count: int = 0 + paid_amount_on_avg_task: int = 0 + # cusotmer + task_queries: int = 0 + mcp_install_count: int = 0 + storage_used: float = 0 diff --git a/server/app/type/config_group.py b/server/app/type/config_group.py new file mode 100644 index 000000000..ba7b66050 --- /dev/null +++ b/server/app/type/config_group.py @@ -0,0 +1,40 @@ +from enum import Enum + + +class ConfigGroup(str, Enum): + WHATSAPP = "WhatsApp" + TWITTER = "X(Twitter)" + LINKEDIN = "LinkedIn" + REDDIT = "Reddit" + SLACK = "Slack" + NOTION = "Notion" + GOOGLE_SUITE = "GoogleSuite" + DISCORD = "Discord" + SEARCH = "Search" + AUDIO_ANALYSIS = "Audio Analysis" + CODE_EXECUTION = "Code Execution" + CRAW4AI = "Craw4ai" + DALLE = "Dalle" + EDGEONE_PAGES_MCP = "Edgeone Pages MCP" + EXCEL = "Excel" + FILE_WRITE = "File Write" + GITHUB = "Github" + GOOGLE_CALENDAR = "Google Calendar" + GOOGLE_DRIVE_MCP = "Google Drive MCP" + GOOGLE_GMAIL_MCP = "Google Gmail MCP" + IMAGE_ANALYSIS = "Image Analysis" + MCP_SEARCH = "MCP Search" + PPTX = "PPTX" + TERMINAL = "Terminal" + + @classmethod + def get_all_values(cls) -> list[str]: + return [group.value for group in cls] + + @classmethod + def is_valid_group(cls, group: str) -> bool: + try: + cls(group) + return True + except ValueError: + return False diff --git a/server/app/type/model_providers.py b/server/app/type/model_providers.py new file mode 100644 index 000000000..2dbd424b4 --- /dev/null +++ b/server/app/type/model_providers.py @@ -0,0 +1,63 @@ +from enum import Enum +from typing import List + + +class ModelProviders(Enum): + OPENAI = "openai" + AWS_BEDROCK = "aws-bedrock" + AZURE = "azure" + ANTHROPIC = "anthropic" + GROQ = "groq" + OPENROUTER = "openrouter" + OLLAMA = "ollama" + LITELLM = "litellm" + LMSTUDIO = "lmstudio" + ZHIPU = "zhipuai" + GEMINI = "gemini" + VLLM = "vllm" + MISTRAL = "mistral" + REKA = "reka" + TOGETHER = "together" + STUB = "stub" + OPENAI_COMPATIBLE_MODEL = "openai-compatible-model" + SAMBA = "samba-nova" + COHERE = "cohere" + YI = "lingyiwanwu" + QWEN = "tongyi-qianwen" + NVIDIA = "nvidia" + DEEPSEEK = "deepseek" + PPIO = "ppio" + SGLANG = "sglang" + INTERNLM = "internlm" + MOONSHOT = "moonshot" + MODELSCOPE = "modelscope" + SILICONFLOW = "siliconflow" + AIML = "aiml" + VOLCANO = "volcano" + NETMIND = "netmind" + NOVITA = "novita" + WATSONX = "watsonx" + + @classmethod + def get_all_values(cls) -> List[str]: + return [platform.value for platform in cls] + + @classmethod + def get_all_names(cls) -> List[str]: + return [platform.name for platform in cls] + + @classmethod + def get_all_items(cls) -> List[dict]: + return [{"name": platform.name, "value": platform.value} for platform in cls] + + @classmethod + def is_valid_platform(cls, platform_name: str) -> bool: + try: + cls(platform_name) + return True + except ValueError: + return False + + @classmethod + def get_platform_by_name(cls, platform_name: str) -> "ModelPlatformType": + return cls(platform_name) diff --git a/server/app/type/model_type.py b/server/app/type/model_type.py new file mode 100644 index 000000000..775661a7d --- /dev/null +++ b/server/app/type/model_type.py @@ -0,0 +1,343 @@ +from enum import Enum + + +class ModelType(Enum): + GPT_3_5_TURBO = "gpt-3.5-turbo" + GPT_4 = "gpt-4" + GPT_4_TURBO = "gpt-4-turbo" + GPT_4O = "gpt-4o" + GPT_4O_MINI = "gpt-4o-mini" + GPT_4_5_PREVIEW = "gpt-4.5-preview" + O1 = "o1" + O1_PREVIEW = "o1-preview" + O1_MINI = "o1-mini" + O3_MINI = "o3-mini" + GPT_4_1 = "gpt-4.1-2025-04-14" + GPT_4_1_MINI = "gpt-4.1-mini-2025-04-14" + GPT_4_1_NANO = "gpt-4.1-nano-2025-04-14" + O4_MINI = "o4-mini" + O3 = "o3" + O3_PRO = "o3-pro" + + AWS_CLAUDE_3_7_SONNET = "anthropic.claude-3-7-sonnet-20250219-v1:0" + AWS_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20241022-v2:0" + AWS_CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0" + AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0" + AWS_DEEPSEEK_R1 = "us.deepseek.r1-v1:0" + AWS_LLAMA_3_3_70B_INSTRUCT = "us.meta.llama3-3-70b-instruct-v1:0" + AWS_LLAMA_3_2_90B_INSTRUCT = "us.meta.llama3-2-90b-instruct-v1:0" + AWS_LLAMA_3_2_11B_INSTRUCT = "us.meta.llama3-2-11b-instruct-v1:0" + AWS_CLAUDE_SONNET_4 = "anthropic.claude-sonnet-4-20250514-v1:0" + AWS_CLAUDE_OPUS_4 = "anthropic.claude-opus-4-20250514-v1:0" + + GLM_4 = "glm-4" + GLM_4V = "glm-4v" + GLM_4V_FLASH = "glm-4v-flash" + GLM_4V_PLUS_0111 = "glm-4v-plus-0111" + GLM_4_PLUS = "glm-4-plus" + GLM_4_AIR = "glm-4-air" + GLM_4_AIR_0111 = "glm-4-air-0111" + GLM_4_AIRX = "glm-4-airx" + GLM_4_LONG = "glm-4-long" + GLM_4_FLASHX = "glm-4-flashx" + GLM_4_FLASH = "glm-4-flash" + GLM_ZERO_PREVIEW = "glm-zero-preview" + GLM_3_TURBO = "glm-3-turbo" + + # Groq platform models + GROQ_LLAMA_3_1_8B = "llama-3.1-8b-instant" + GROQ_LLAMA_3_3_70B = "llama-3.3-70b-versatile" + GROQ_LLAMA_3_3_70B_PREVIEW = "llama-3.3-70b-specdec" + GROQ_LLAMA_3_8B = "llama3-8b-8192" + GROQ_LLAMA_3_70B = "llama3-70b-8192" + GROQ_MIXTRAL_8_7B = "mixtral-8x7b-32768" + GROQ_GEMMA_2_9B_IT = "gemma2-9b-it" + + # OpenRouter models + OPENROUTER_LLAMA_3_1_405B = "meta-llama/llama-3.1-405b-instruct" + OPENROUTER_LLAMA_3_1_70B = "meta-llama/llama-3.1-70b-instruct" + OPENROUTER_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick" + OPENROUTER_LLAMA_4_MAVERICK_FREE = "meta-llama/llama-4-maverick:free" + OPENROUTER_LLAMA_4_SCOUT = "meta-llama/llama-4-scout" + OPENROUTER_LLAMA_4_SCOUT_FREE = "meta-llama/llama-4-scout:free" + OPENROUTER_OLYMPICODER_7B = "open-r1/olympiccoder-7b:free" + + # LMStudio models + LMSTUDIO_GEMMA_3_1B = "gemma-3-1b" + LMSTUDIO_GEMMA_3_4B = "gemma-3-4b" + LMSTUDIO_GEMMA_3_12B = "gemma-3-12b" + LMSTUDIO_GEMMA_3_27B = "gemma-3-27b" + + # TogetherAI platform models support tool calling + TOGETHER_LLAMA_3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" + TOGETHER_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" + TOGETHER_LLAMA_3_1_405B = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo" + TOGETHER_LLAMA_3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo" + TOGETHER_MIXTRAL_8_7B = "mistralai/Mixtral-8x7B-Instruct-v0.1" + TOGETHER_MISTRAL_7B = "mistralai/Mistral-7B-Instruct-v0.1" + TOGETHER_LLAMA_4_MAVERICK = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" + TOGETHER_LLAMA_4_SCOUT = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + + # PPIO platform models support tool calling + PPIO_DEEPSEEK_PROVER_V2_671B = "deepseek/deepseek-prover-v2-671b" + PPIO_DEEPSEEK_R1_TURBO = "deepseek/deepseek-r1-turbo" + PPIO_DEEPSEEK_V3_TURBO = "deepseek/deepseek-v3-turbo" + PPIO_DEEPSEEK_R1_COMMUNITY = "deepseek/deepseek-r1/community" + PPIO_DEEPSEEK_V3_COMMUNITY = "deepseek/deepseek-v3/community" + PPIO_DEEPSEEK_R1 = "deepseek/deepseek-r1" + PPIO_DEEPSEEK_V3 = "deepseek/deepseek-v3" + PPIO_QWEN_2_5_72B = "qwen/qwen-2.5-72b-instruct" + PPIO_BAICHUAN_2_13B_CHAT = "baichuan/baichuan2-13b-chat" + PPIO_LLAMA_3_3_70B = "meta-llama/llama-3.3-70b-instruct" + PPIO_LLAMA_3_1_70B = "meta-llama/llama-3.1-70b-instruct" + PPIO_YI_1_5_34B_CHAT = "01-ai/yi-1.5-34b-chat" + + # SambaNova Cloud platform models support tool calling + SAMBA_LLAMA_3_1_8B = "Meta-Llama-3.1-8B-Instruct" + SAMBA_LLAMA_3_1_70B = "Meta-Llama-3.1-70B-Instruct" + SAMBA_LLAMA_3_1_405B = "Meta-Llama-3.1-405B-Instruct" + + # SGLang models support tool calling + SGLANG_LLAMA_3_1_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct" + SGLANG_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct" + SGLANG_LLAMA_3_1_405B = "meta-llama/Meta-Llama-3.1-405B-Instruct" + SGLANG_LLAMA_3_2_1B = "meta-llama/Llama-3.2-1B-Instruct" + SGLANG_MIXTRAL_NEMO = "mistralai/Mistral-Nemo-Instruct-2407" + SGLANG_MISTRAL_7B = "mistralai/Mistral-7B-Instruct-v0.3" + SGLANG_QWEN_2_5_7B = "Qwen/Qwen2.5-7B-Instruct" + SGLANG_QWEN_2_5_32B = "Qwen/Qwen2.5-32B-Instruct" + SGLANG_QWEN_2_5_72B = "Qwen/Qwen2.5-72B-Instruct" + + STUB = "stub" + + # Legacy anthropic models + # NOTE: anthropic legacy models only Claude 2.1 has system prompt support + CLAUDE_2_1 = "claude-2.1" + CLAUDE_2_0 = "claude-2.0" + CLAUDE_INSTANT_1_2 = "claude-instant-1.2" + + # Claude models + CLAUDE_3_OPUS = "claude-3-opus-latest" + CLAUDE_3_SONNET = "claude-3-sonnet-20240229" + CLAUDE_3_HAIKU = "claude-3-haiku-20240307" + CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest" + CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest" + CLAUDE_3_7_SONNET = "claude-3-7-sonnet-latest" + CLAUDE_SONNET_4 = "claude-sonnet-4-20250514" + CLAUDE_OPUS_4 = "claude-opus-4-20250514" + + # Netmind models + NETMIND_LLAMA_4_MAVERICK_17B_128E_INSTRUCT = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" + NETMIND_LLAMA_4_SCOUT_17B_16E_INSTRUCT = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + NETMIND_DEEPSEEK_R1 = "deepseek-ai/DeepSeek-R1" + NETMIND_DEEPSEEK_V3 = "deepseek-ai/DeepSeek-V3-0324" + NETMIND_DOUBAO_1_5_PRO = "doubao/Doubao-1.5-pro" + NETMIND_QWQ_32B = "Qwen/QwQ-32B" + + # Nvidia models + NVIDIA_NEMOTRON_340B_INSTRUCT = "nvidia/nemotron-4-340b-instruct" + NVIDIA_NEMOTRON_340B_REWARD = "nvidia/nemotron-4-340b-reward" + NVIDIA_YI_LARGE = "01-ai/yi-large" + NVIDIA_MISTRAL_LARGE = "mistralai/mistral-large" + NVIDIA_MIXTRAL_8X7B = "mistralai/mixtral-8x7b-instruct" + NVIDIA_LLAMA3_70B = "meta/llama3-70b" + NVIDIA_LLAMA3_1_8B_INSTRUCT = "meta/llama-3.1-8b-instruct" + NVIDIA_LLAMA3_1_70B_INSTRUCT = "meta/llama-3.1-70b-instruct" + NVIDIA_LLAMA3_1_405B_INSTRUCT = "meta/llama-3.1-405b-instruct" + NVIDIA_LLAMA3_2_1B_INSTRUCT = "meta/llama-3.2-1b-instruct" + NVIDIA_LLAMA3_2_3B_INSTRUCT = "meta/llama-3.2-3b-instruct" + NVIDIA_LLAMA3_3_70B_INSTRUCT = "meta/llama-3.3-70b-instruct" + + # Gemini models + GEMINI_2_5_FLASH_PREVIEW = "gemini-2.5-flash-preview-04-17" + GEMINI_2_5_PRO_PREVIEW = "gemini-2.5-pro-preview-06-05" + GEMINI_2_0_FLASH = "gemini-2.0-flash" + GEMINI_2_0_FLASH_EXP = "gemini-2.0-flash-exp" + GEMINI_2_0_FLASH_THINKING = "gemini-2.0-flash-thinking-exp" + GEMINI_2_0_PRO_EXP = "gemini-2.0-pro-exp-02-05" + GEMINI_2_0_FLASH_LITE = "gemini-2.0-flash-lite" + GEMINI_2_0_FLASH_LITE_PREVIEW = "gemini-2.0-flash-lite-preview-02-05" + GEMINI_1_5_FLASH = "gemini-1.5-flash" + GEMINI_1_5_PRO = "gemini-1.5-pro" + + # Mistral AI models + MISTRAL_3B = "ministral-3b-latest" + MISTRAL_7B = "open-mistral-7b" + MISTRAL_8B = "ministral-8b-latest" + MISTRAL_CODESTRAL = "codestral-latest" + MISTRAL_CODESTRAL_MAMBA = "open-codestral-mamba" + MISTRAL_LARGE = "mistral-large-latest" + MISTRAL_MIXTRAL_8x7B = "open-mixtral-8x7b" + MISTRAL_MIXTRAL_8x22B = "open-mixtral-8x22b" + MISTRAL_NEMO = "open-mistral-nemo" + MISTRAL_PIXTRAL_12B = "pixtral-12b-2409" + MISTRAL_MEDIUM_3 = "mistral-medium-latest" + MAGISTRAL_MEDIUM = "magistral-medium-2506" + + # Reka models + REKA_CORE = "reka-core" + REKA_FLASH = "reka-flash" + REKA_EDGE = "reka-edge" + + # Cohere models + COHERE_COMMAND_R_PLUS = "command-r-plus" + COHERE_COMMAND_R = "command-r" + COHERE_COMMAND_LIGHT = "command-light" + COHERE_COMMAND = "command" + COHERE_COMMAND_NIGHTLY = "command-nightly" + + # Qwen models (Aliyun) + QWEN_MAX = "qwen-max" + QWEN_PLUS = "qwen-plus" + QWEN_TURBO = "qwen-turbo" + QWEN_PLUS_LATEST = "qwen-plus-latest" + QWEN_PLUS_2025_04_28 = "qwen-plus-2025-04-28" + QWEN_TURBO_LATEST = "qwen-turbo-latest" + QWEN_TURBO_2025_04_28 = "qwen-turbo-2025-04-28" + QWEN_LONG = "qwen-long" + QWEN_VL_MAX = "qwen-vl-max" + QWEN_VL_PLUS = "qwen-vl-plus" + QWEN_MATH_PLUS = "qwen-math-plus" + QWEN_MATH_TURBO = "qwen-math-turbo" + QWEN_CODER_TURBO = "qwen-coder-turbo" + QWEN_2_5_CODER_32B = "qwen2.5-coder-32b-instruct" + QWEN_2_5_VL_72B = "qwen2.5-vl-72b-instruct" + QWEN_2_5_72B = "qwen2.5-72b-instruct" + QWEN_2_5_32B = "qwen2.5-32b-instruct" + QWEN_2_5_14B = "qwen2.5-14b-instruct" + QWEN_QWQ_32B = "qwq-32b-preview" + QWEN_QVQ_72B = "qvq-72b-preview" + QWEN_QWQ_PLUS = "qwq-plus" + + # Yi models (01-ai) + YI_LIGHTNING = "yi-lightning" + YI_LARGE = "yi-large" + YI_MEDIUM = "yi-medium" + YI_LARGE_TURBO = "yi-large-turbo" + YI_VISION = "yi-vision" + YI_MEDIUM_200K = "yi-medium-200k" + YI_SPARK = "yi-spark" + YI_LARGE_RAG = "yi-large-rag" + YI_LARGE_FC = "yi-large-fc" + + # DeepSeek models + DEEPSEEK_CHAT = "deepseek-chat" + DEEPSEEK_REASONER = "deepseek-reasoner" + # InternLM models + INTERNLM3_LATEST = "internlm3-latest" + INTERNLM3_8B_INSTRUCT = "internlm3-8b-instruct" + INTERNLM2_5_LATEST = "internlm2.5-latest" + INTERNLM2_PRO_CHAT = "internlm2-pro-chat" + + # Moonshot models + MOONSHOT_V1_8K = "moonshot-v1-8k" + MOONSHOT_V1_32K = "moonshot-v1-32k" + MOONSHOT_V1_128K = "moonshot-v1-128k" + + # SiliconFlow models support tool calling + SILICONFLOW_DEEPSEEK_V2_5 = "deepseek-ai/DeepSeek-V2.5" + SILICONFLOW_DEEPSEEK_V3 = "deepseek-ai/DeepSeek-V3" + SILICONFLOW_INTERN_LM2_5_20B_CHAT = "internlm/internlm2_5-20b-chat" + SILICONFLOW_INTERN_LM2_5_7B_CHAT = "internlm/internlm2_5-7b-chat" + SILICONFLOW_PRO_INTERN_LM2_5_7B_CHAT = "Pro/internlm/internlm2_5-7b-chat" + SILICONFLOW_QWEN2_5_72B_INSTRUCT = "Qwen/Qwen2.5-72B-Instruct" + SILICONFLOW_QWEN2_5_32B_INSTRUCT = "Qwen/Qwen2.5-32B-Instruct" + SILICONFLOW_QWEN2_5_14B_INSTRUCT = "Qwen/Qwen2.5-14B-Instruct" + SILICONFLOW_QWEN2_5_7B_INSTRUCT = "Qwen/Qwen2.5-7B-Instruct" + SILICONFLOW_PRO_QWEN2_5_7B_INSTRUCT = "Pro/Qwen/Qwen2.5-7B-Instruct" + SILICONFLOW_THUDM_GLM_4_9B_CHAT = "THUDM/glm-4-9b-chat" + SILICONFLOW_PRO_THUDM_GLM_4_9B_CHAT = "Pro/THUDM/glm-4-9b-chat" + + # AIML models support tool calling + AIML_MIXTRAL_8X7B = "mistralai/Mixtral-8x7B-Instruct-v0.1" + AIML_MISTRAL_7B_INSTRUCT = "mistralai/Mistral-7B-Instruct-v0.1" + + # Novita platform models support tool calling + NOVITA_LLAMA_4_MAVERICK_17B = "meta-llama/llama-4-maverick-17b-128e-instruct-fp8" + NOVITA_LLAMA_4_SCOUT_17B = "meta-llama/llama-4-scout-17b-16e-instruct" + NOVITA_DEEPSEEK_V3_0324 = "deepseek/deepseek-v3-0324" + NOVITA_QWEN_2_5_V1_72B = "qwen/qwen2.5-vl-72b-instruct" + NOVITA_DEEPSEEK_V3_TURBO = "deepseek/deepseek-v3-turbo" + NOVITA_DEEPSEEK_R1_TURBO = "deepseek/deepseek-r1-turbo" + NOVITA_GEMMA_3_27B_IT = "google/gemma-3-27b-it" + NOVITA_QWEN_32B = "qwen/qwq-32b" + NOVITA_L3_8B_STHENO_V3_2 = "Sao10K/L3-8B-Stheno-v3.2" + NOVITA_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b" + NOVITA_DEEPSEEK_R1_DISTILL_LLAMA_8B = "deepseek/deepseek-r1-distill-llama-8b" + NOVITA_DEEPSEEK_V3 = "deepseek/deepseek_v3" + NOVITA_LLAMA_3_1_8B = "meta-llama/llama-3.1-8b-instruct" + NOVITA_DEEPSEEK_R1_DISTILL_QWEN_14B = "deepseek/deepseek-r1-distill-qwen-14b" + NOVITA_LLAMA_3_3_70B = "meta-llama/llama-3.3-70b-instruct" + NOVITA_QWEN_2_5_72B = "qwen/qwen-2.5-72b-instruct" + NOVITA_MISTRAL_NEMO = "mistralai/mistral-nemo" + NOVITA_DEEPSEEK_R1_DISTILL_QWEN_32B = "deepseek/deepseek-r1-distill-qwen-32b" + NOVITA_LLAMA_3_8B = "meta-llama/llama-3-8b-instruct" + NOVITA_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b" + NOVITA_DEEPSEEK_R1_DISTILL_LLAMA_70B = "deepseek/deepseek-r1-distill-llama-70b" + NOVITA_LLAMA_3_1_70B = "meta-llama/llama-3.1-70b-instruct" + NOVITA_GEMMA_2_9B_IT = "google/gemma-2-9b-it" + NOVITA_MISTRAL_7B = "mistralai/mistral-7b-instruct" + NOVITA_LLAMA_3_70B = "meta-llama/llama-3-70b-instruct" + NOVITA_DEEPSEEK_R1 = "deepseek/deepseek-r1" + NOVITA_HERMES_2_PRO_LLAMA_3_8B = "nousresearch/hermes-2-pro-llama-3-8b" + NOVITA_L3_70B_EURYALE_V2_1 = "sao10k/l3-70b-euryale-v2.1" + NOVITA_DOLPHIN_MIXTRAL_8X22B = "cognitivecomputations/dolphin-mixtral-8x22b" + NOVITA_AIROBOROS_L2_70B = "jondurbin/airoboros-l2-70b" + NOVITA_MIDNIGHT_ROSE_70B = "sophosympatheia/midnight-rose-70b" + NOVITA_L3_8B_LUNARIS = "sao10k/l3-8b-lunaris" + NOVITA_GLM_4_9B_0414 = "thudm/glm-4-9b-0414" + NOVITA_GLM_Z1_9B_0414 = "thudm/glm-z1-9b-0414" + NOVITA_GLM_Z1_32B_0414 = "thudm/glm-z1-32b-0414" + NOVITA_GLM_4_32B_0414 = "thudm/glm-4-32b-0414" + NOVITA_GLM_Z1_RUMINATION_32B_0414 = "thudm/glm-z1-rumination-32b-0414" + NOVITA_QWEN_2_5_7B = "qwen/qwen2.5-7b-instruct" + NOVITA_LLAMA_3_2_1B = "meta-llama/llama-3.2-1b-instruct" + NOVITA_LLAMA_3_2_11B_VISION = "meta-llama/llama-3.2-11b-vision-instruct" + NOVITA_LLAMA_3_2_3B = "meta-llama/llama-3.2-3b-instruct" + NOVITA_LLAMA_3_1_8B_BF16 = "meta-llama/llama-3.1-8b-instruct-bf16" + NOVITA_L31_70B_EURYALE_V2_2 = "sao10k/l31-70b-euryale-v2.2" + + # ModelScope models support tool calling + MODELSCOPE_QWEN_2_5_7B_INSTRUCT = "Qwen/Qwen2.5-7B-Instruct" + MODELSCOPE_QWEN_2_5_14B_INSTRUCT = "Qwen/Qwen2.5-14B-Instruct" + MODELSCOPE_QWEN_2_5_32B_INSTRUCT = "Qwen/Qwen2.5-32B-Instruct" + MODELSCOPE_QWEN_2_5_72B_INSTRUCT = "Qwen/Qwen2.5-72B-Instruct" + MODELSCOPE_QWEN_2_5_CODER_7B_INSTRUCT = "Qwen/Qwen2.5-Coder-7B-Instruct" + MODELSCOPE_QWEN_2_5_CODER_14B_INSTRUCT = "Qwen/Qwen2.5-Coder-14B-Instruct" + MODELSCOPE_QWEN_2_5_CODER_32B_INSTRUCT = "Qwen/Qwen2.5-Coder-32B-Instruct" + MODELSCOPE_QWEN_3_235B_A22B = "Qwen/Qwen3-235B-A22B" + MODELSCOPE_QWEN_3_32B = "Qwen/Qwen3-32B" + MODELSCOPE_QWQ_32B = "Qwen/QwQ-32B" + MODELSCOPE_QWQ_32B_PREVIEW = "Qwen/QwQ-32B-Preview" + MODELSCOPE_LLAMA_3_1_8B_INSTRUCT = "LLM-Research/Meta-Llama-3.1-8B-Instruct" + MODELSCOPE_LLAMA_3_1_70B_INSTRUCT = "LLM-Research/Meta-Llama-3.1-70B-Instruct" + MODELSCOPE_LLAMA_3_1_405B_INSTRUCT = "LLM-Research/Meta-Llama-3.1-405B-Instruct" + MODELSCOPE_LLAMA_3_3_70B_INSTRUCT = "LLM-Research/Llama-3.3-70B-Instruct" + MODELSCOPE_MINISTRAL_8B_INSTRUCT = "mistralai/Ministral-8B-Instruct-2410" + MODELSCOPE_DEEPSEEK_V3_0324 = "deepseek-ai/DeepSeek-V3-0324" + + # WatsonX models + WATSONX_GRANITE_3_8B_INSTRUCT = "ibm/granite-3-8b-instruct" + WATSONX_LLAMA_3_3_70B_INSTRUCT = "meta-llama/llama-3-3-70b-instruct" + WATSONX_LLAMA_3_2_1B_INSTRUCT = "meta-llama/llama-3-2-1b-instruct" + WATSONX_LLAMA_3_2_3B_INSTRUCT = "meta-llama/llama-3-2-3b-instruct" + WATSONX_LLAMA_3_2_11B_VISION_INSTRUCT = "meta-llama/llama-3-2-11b-vision-instruct" + WATSONX_LLAMA_3_2_90B_VISION_INSTRUCT = "meta-llama/llama-3-2-90b-vision-instruct" + WATSONX_LLAMA_GUARD_3_11B_VISION_INSTRUCT = "meta-llama/llama-guard-3-11b-vision-instruct" + WATSONX_MISTRAL_LARGE = "mistralai/mistral-large" + + # Crynux models + CRYNUX_DEEPSEEK_R1_DISTILL_QWEN_1_5B = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + CRYNUX_DEEPSEEK_R1_DISTILL_QWEN_7B = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" + CRYNUX_DEEPSEEK_R1_DISTILL_LLAMA_8B = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" + + CRYNUX_QWEN_3_4_B = "Qwen/Qwen3-4B" + CRYNUX_QWEN_3_8_B = "Qwen/Qwen3-8B" + CRYNUX_QWEN_2_5_7B = "Qwen/Qwen2.5-7B" + CRYNUX_QWEN_2_5_7B_INSTRUCT = "Qwen/Qwen2.5-7B-Instruct" + + CRYNUX_NOUS_HERMES_3_LLAMA_3_1_8B = "NousResearch/Hermes-3-Llama-3.1-8B" + CRYNUX_NOUS_HERMES_3_LLAMA_3_2_3B = "NousResearch/Hermes-3-Llama-3.2-3B" + + def __str__(self): + return self.value diff --git a/server/app/type/pydantic.py b/server/app/type/pydantic.py new file mode 100644 index 000000000..b1023c8c8 --- /dev/null +++ b/server/app/type/pydantic.py @@ -0,0 +1,5 @@ +from typing import Annotated, Literal +from pydantic import HttpUrl +from pydantic.functional_serializers import PlainSerializer + +HttpUrlStr = Annotated[HttpUrl | Literal[""], PlainSerializer(str)] diff --git a/server/babel.cfg b/server/babel.cfg new file mode 100644 index 000000000..1d15bb366 --- /dev/null +++ b/server/babel.cfg @@ -0,0 +1 @@ +[python: **.py] \ No newline at end of file diff --git a/server/cli.py b/server/cli.py new file mode 100644 index 000000000..63255ffda --- /dev/null +++ b/server/cli.py @@ -0,0 +1,9 @@ +from app.component.environment import auto_import +from app.command import cli +from app.model.mcp.mcp_user import McpUser + +auto_import("app.command") + + +if __name__ == "__main__": + cli() diff --git a/server/docker-compose.yml b/server/docker-compose.yml new file mode 100644 index 000000000..124d3d93f --- /dev/null +++ b/server/docker-compose.yml @@ -0,0 +1,65 @@ +version: '3.8' + +services: + # PostgreSQL Database + postgres: + image: postgres:15 + container_name: eigent_postgres + restart: unless-stopped + environment: + POSTGRES_DB: eigent + POSTGRES_USER: postgres + POSTGRES_PASSWORD: 123456 + POSTGRES_INITDB_ARGS: "--encoding=UTF-8 --lc-collate=C --lc-ctype=C" + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + #- ./init-db.sql:/docker-entrypoint-initdb.d/init-db.sql:ro + networks: + - eigent_network + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U postgres -d eigent" ] + interval: 10s + timeout: 5s + retries: 5 + + # FastAPI Application + api: + build: + context: . + dockerfile: Dockerfile + args: + database_url: postgresql://postgres:123456@postgres:5432/eigent + container_name: eigent_api + restart: unless-stopped + ports: + - "3001:5678" + environment: + - DATABASE_URL=postgresql://postgres:123456@postgres:5432/eigent + - ENVIRONMENT=production + - DEBUG=false + # volumes: + # - ./app:/app/app + # - ./alembic:/app/alembic + # - ./lang:/app/lang + # - ./public:/app/public + depends_on: + postgres: + condition: service_healthy + networks: + - eigent_network + healthcheck: + test: [ "CMD", "curl", "-f", "http://localhost:5678/health" ] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + +volumes: + postgres_data: + driver: local + +networks: + eigent_network: + driver: bridge diff --git a/server/lang/zh_CN/LC_MESSAGES/messages.po b/server/lang/zh_CN/LC_MESSAGES/messages.po new file mode 100644 index 000000000..9c5f8ba14 --- /dev/null +++ b/server/lang/zh_CN/LC_MESSAGES/messages.po @@ -0,0 +1,216 @@ +# Chinese (Simplified, China) translations for PROJECT. +# Copyright (C) 2025 ORGANIZATION +# This file is distributed under the same license as the PROJECT project. +# FIRST AUTHOR , 2025. +# +msgid "" +msgstr "" +"Project-Id-Version: PROJECT VERSION\n" +"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" +"POT-Creation-Date: 2025-08-06 09:56+0800\n" +"PO-Revision-Date: 2025-08-06 09:56+0800\n" +"Last-Translator: FULL NAME \n" +"Language: zh_Hans_CN\n" +"Language-Team: zh_Hans_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.17.0\n" + +#: app/component/auth.py:41 +msgid "Validate credentials expired" +msgstr "" + +#: app/component/auth.py:43 +msgid "Could not validate credentials" +msgstr "" + +#: app/component/permission.py:12 +msgid "User" +msgstr "" + +#: app/component/permission.py:13 +msgid "User manger" +msgstr "" + +#: app/component/permission.py:17 +msgid "User Manage" +msgstr "" + +#: app/component/permission.py:18 +msgid "View users" +msgstr "" + +#: app/component/permission.py:22 +msgid "User Edit" +msgstr "" + +#: app/component/permission.py:23 +msgid "Manage users" +msgstr "" + +#: app/component/permission.py:28 +msgid "Admin" +msgstr "" + +#: app/component/permission.py:29 +msgid "Admin manger" +msgstr "" + +#: app/component/permission.py:33 +msgid "Admin View" +msgstr "" + +#: app/component/permission.py:34 +msgid "View admins" +msgstr "" + +#: app/component/permission.py:38 +msgid "Admin Edit" +msgstr "" + +#: app/component/permission.py:39 +msgid "Edit admins" +msgstr "" + +#: app/component/permission.py:44 +msgid "Role" +msgstr "" + +#: app/component/permission.py:45 +msgid "Role manger" +msgstr "" + +#: app/component/permission.py:49 +msgid "Role View" +msgstr "" + +#: app/component/permission.py:50 +msgid "View roles" +msgstr "" + +#: app/component/permission.py:54 +msgid "Role Edit" +msgstr "" + +#: app/component/permission.py:55 +msgid "Edit roles" +msgstr "" + +#: app/component/permission.py:60 +msgid "Mcp" +msgstr "" + +#: app/component/permission.py:61 +msgid "Mcp manger" +msgstr "" + +#: app/component/permission.py:65 +msgid "Mcp Edit" +msgstr "" + +#: app/component/permission.py:66 +msgid "Edit mcp service" +msgstr "" + +#: app/component/permission.py:70 +msgid "Mcp Category Edit" +msgstr "" + +#: app/component/permission.py:71 +msgid "Edit mcp category" +msgstr "" + +#: app/controller/chat/snapshot_controller.py:34 +#: app/controller/chat/snapshot_controller.py:68 +#: app/controller/chat/snapshot_controller.py:81 +msgid "Chat snapshot not found" +msgstr "" + +#: app/controller/chat/step_controller.py:65 +#: app/controller/chat/step_controller.py:89 +#: app/controller/chat/step_controller.py:102 +msgid "Chat step not found" +msgstr "" + +#: app/controller/config/config_controller.py:40 +#: app/controller/config/config_controller.py:76 +#: app/controller/config/config_controller.py:108 +msgid "Configuration not found" +msgstr "" + +#: app/controller/config/config_controller.py:47 +msgid "Config Name is valid" +msgstr "" + +#: app/controller/config/config_controller.py:55 +#: app/controller/config/config_controller.py:92 +msgid "Configuration already exists for this user" +msgstr "" + +#: app/controller/config/config_controller.py:80 +msgid "Invalid configuration group" +msgstr "" + +#: app/controller/mcp/mcp_controller.py:70 +msgid "Mcp not found" +msgstr "" + +#: app/controller/mcp/mcp_controller.py:73 +#: app/controller/mcp/user_controller.py:44 +msgid "mcp is installed" +msgstr "" + +#: app/controller/mcp/user_controller.py:34 +msgid "McpUser not found" +msgstr "" + +#: app/controller/mcp/user_controller.py:61 +#: app/controller/mcp/user_controller.py:75 +msgid "Mcp Info not found" +msgstr "" + +#: app/controller/mcp/user_controller.py:63 +msgid "current user have no permission to modify" +msgstr "" + +#: app/controller/provider/provider_controller.py:41 +#: app/controller/provider/provider_controller.py:60 +#: app/controller/provider/provider_controller.py:79 +msgid "Provider not found" +msgstr "" + +#: app/controller/user/login_controller.py:25 +msgid "Account or password error" +msgstr "" + +#: app/controller/user/login_controller.py:47 +msgid "User not found" +msgstr "" + +#: app/controller/user/login_controller.py:64 +#: app/controller/user/login_controller.py:89 +msgid "Failed to register" +msgstr "" + +#: app/controller/user/login_controller.py:67 +msgid "Your account has been blocked." +msgstr "" + +#: app/controller/user/login_controller.py:75 +msgid "Email already registered" +msgstr "" + +#: app/controller/user/user_password_controller.py:19 +msgid "Password is incorrect" +msgstr "" + +#: app/controller/user/user_password_controller.py:21 +msgid "The two passwords do not match" +msgstr "" + +#: app/model/abstract/model.py:66 +msgid "There is no data that meets the conditions" +msgstr "" + diff --git a/server/main.py b/server/main.py new file mode 100644 index 000000000..7a0c1ceba --- /dev/null +++ b/server/main.py @@ -0,0 +1,18 @@ +from app import api +from app.component.environment import auto_include_routers, env +from loguru import logger +import os +from fastapi.staticfiles import StaticFiles + +prefix = env("url_prefix", "") +auto_include_routers(api, prefix, "app/controller") +public_dir = os.environ.get("PUBLIC_DIR") or os.path.join(os.path.dirname(__file__), "app", "public") +api.mount("/public", StaticFiles(directory=public_dir), name="public") + +logger.add( + "runtime/log/app.log", + rotation="10 MB", + retention="10 days", + level="DEBUG", + enqueue=True, +) diff --git a/server/messages.pot b/server/messages.pot new file mode 100644 index 000000000..d86ee0bed --- /dev/null +++ b/server/messages.pot @@ -0,0 +1,215 @@ +# Translations template for PROJECT. +# Copyright (C) 2025 ORGANIZATION +# This file is distributed under the same license as the PROJECT project. +# FIRST AUTHOR , 2025. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PROJECT VERSION\n" +"Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" +"POT-Creation-Date: 2025-08-06 09:56+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language-Team: LANGUAGE \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.17.0\n" + +#: app/component/auth.py:41 +msgid "Validate credentials expired" +msgstr "" + +#: app/component/auth.py:43 +msgid "Could not validate credentials" +msgstr "" + +#: app/component/permission.py:12 +msgid "User" +msgstr "" + +#: app/component/permission.py:13 +msgid "User manger" +msgstr "" + +#: app/component/permission.py:17 +msgid "User Manage" +msgstr "" + +#: app/component/permission.py:18 +msgid "View users" +msgstr "" + +#: app/component/permission.py:22 +msgid "User Edit" +msgstr "" + +#: app/component/permission.py:23 +msgid "Manage users" +msgstr "" + +#: app/component/permission.py:28 +msgid "Admin" +msgstr "" + +#: app/component/permission.py:29 +msgid "Admin manger" +msgstr "" + +#: app/component/permission.py:33 +msgid "Admin View" +msgstr "" + +#: app/component/permission.py:34 +msgid "View admins" +msgstr "" + +#: app/component/permission.py:38 +msgid "Admin Edit" +msgstr "" + +#: app/component/permission.py:39 +msgid "Edit admins" +msgstr "" + +#: app/component/permission.py:44 +msgid "Role" +msgstr "" + +#: app/component/permission.py:45 +msgid "Role manger" +msgstr "" + +#: app/component/permission.py:49 +msgid "Role View" +msgstr "" + +#: app/component/permission.py:50 +msgid "View roles" +msgstr "" + +#: app/component/permission.py:54 +msgid "Role Edit" +msgstr "" + +#: app/component/permission.py:55 +msgid "Edit roles" +msgstr "" + +#: app/component/permission.py:60 +msgid "Mcp" +msgstr "" + +#: app/component/permission.py:61 +msgid "Mcp manger" +msgstr "" + +#: app/component/permission.py:65 +msgid "Mcp Edit" +msgstr "" + +#: app/component/permission.py:66 +msgid "Edit mcp service" +msgstr "" + +#: app/component/permission.py:70 +msgid "Mcp Category Edit" +msgstr "" + +#: app/component/permission.py:71 +msgid "Edit mcp category" +msgstr "" + +#: app/controller/chat/snapshot_controller.py:34 +#: app/controller/chat/snapshot_controller.py:68 +#: app/controller/chat/snapshot_controller.py:81 +msgid "Chat snapshot not found" +msgstr "" + +#: app/controller/chat/step_controller.py:65 +#: app/controller/chat/step_controller.py:89 +#: app/controller/chat/step_controller.py:102 +msgid "Chat step not found" +msgstr "" + +#: app/controller/config/config_controller.py:40 +#: app/controller/config/config_controller.py:76 +#: app/controller/config/config_controller.py:108 +msgid "Configuration not found" +msgstr "" + +#: app/controller/config/config_controller.py:47 +msgid "Config Name is valid" +msgstr "" + +#: app/controller/config/config_controller.py:55 +#: app/controller/config/config_controller.py:92 +msgid "Configuration already exists for this user" +msgstr "" + +#: app/controller/config/config_controller.py:80 +msgid "Invalid configuration group" +msgstr "" + +#: app/controller/mcp/mcp_controller.py:70 +msgid "Mcp not found" +msgstr "" + +#: app/controller/mcp/mcp_controller.py:73 +#: app/controller/mcp/user_controller.py:44 +msgid "mcp is installed" +msgstr "" + +#: app/controller/mcp/user_controller.py:34 +msgid "McpUser not found" +msgstr "" + +#: app/controller/mcp/user_controller.py:61 +#: app/controller/mcp/user_controller.py:75 +msgid "Mcp Info not found" +msgstr "" + +#: app/controller/mcp/user_controller.py:63 +msgid "current user have no permission to modify" +msgstr "" + +#: app/controller/provider/provider_controller.py:41 +#: app/controller/provider/provider_controller.py:60 +#: app/controller/provider/provider_controller.py:79 +msgid "Provider not found" +msgstr "" + +#: app/controller/user/login_controller.py:25 +msgid "Account or password error" +msgstr "" + +#: app/controller/user/login_controller.py:47 +msgid "User not found" +msgstr "" + +#: app/controller/user/login_controller.py:64 +#: app/controller/user/login_controller.py:89 +msgid "Failed to register" +msgstr "" + +#: app/controller/user/login_controller.py:67 +msgid "Your account has been blocked." +msgstr "" + +#: app/controller/user/login_controller.py:75 +msgid "Email already registered" +msgstr "" + +#: app/controller/user/user_password_controller.py:19 +msgid "Password is incorrect" +msgstr "" + +#: app/controller/user/user_password_controller.py:21 +msgid "The two passwords do not match" +msgstr "" + +#: app/model/abstract/model.py:66 +msgid "There is no data that meets the conditions" +msgstr "" + diff --git a/server/pyproject.toml b/server/pyproject.toml new file mode 100644 index 000000000..9f6ee2358 --- /dev/null +++ b/server/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "Eigent" +version = "0.1.0" +description = "Eigent" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [ + "alembic>=1.15.2", + "click>=8.1.8", + "fastapi>=0.115.12", + "fastapi-babel>=1.0.0", + "fastapi-pagination>=0.12.34", + "passlib[bcrypt]>=1.7.4", + "bcrypt==4.0.1", + "pydantic-i18n>=0.4.5", + "pydantic[email]>=2.11.1", + "pyjwt>=2.10.1", + "python-dotenv>=1.1.0", + "sqlalchemy-utils>=0.41.2", + "sqlmodel>=0.0.24", + "pandas>=2.2.3", + "openpyxl>=3.1.5", + "pandas>=2.2.3", + "arrow>=1.3.0", + "fastapi-filter>=2.0.1", + "psycopg2-binary>=2.9.10", + "convert-case>=1.2.3", + "python-multipart>=0.0.20", + "loguru>=0.7.3", + "httpx>=0.28.1", + "pydash>=8.0.5", + "requests>=2.32.4", + "itsdangerous>=2.2.0", + "cryptography>=45.0.4", + "sqids>=0.5.2", + "exa-py>=1.14.16", +] + +[tool.ruff] +line-length = 120 diff --git a/server/start.sh b/server/start.sh new file mode 100644 index 000000000..31d4795e4 --- /dev/null +++ b/server/start.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# 等待数据库启动 +echo "Waiting for database to be ready..." +while ! nc -z postgres 5432; do + sleep 1 +done +echo "Database is ready!" + +# 运行数据库迁移 +echo "Running database migrations..." +uv run alembic upgrade head + +# 启动应用 +echo "Starting application..." +exec uv run uvicorn main:api --host 0.0.0.0 --port 5678 \ No newline at end of file diff --git a/server/start_server.bat b/server/start_server.bat new file mode 100644 index 000000000..61e9eedf1 --- /dev/null +++ b/server/start_server.bat @@ -0,0 +1,113 @@ +@echo off +chcp 65001 >nul +setlocal enabledelayedexpansion + +echo ======================================== +echo Auto Start Server Service +echo ======================================== + +:: Check if running in correct directory +if not exist "pyproject.toml" ( + echo Error: Please run this script in the server directory + pause + exit /b 1 +) + +:: Check if uv is installed +echo [1/5] Checking if uv is installed... +uv --version >nul 2>&1 +if %errorlevel% neq 0 ( + echo uv is not installed, attempting to install... + echo Downloading and installing uv... + + :: Try to install uv using PowerShell + powershell -Command "irm https://astral.sh/uv/install.ps1 | iex" 2>nul + if %errorlevel% neq 0 ( + echo Auto installation failed, please install uv manually: + echo 1. Visit https://docs.astral.sh/uv/getting-started/installation/ + echo 2. Or run: curl -LsSf https://astral.sh/uv/install.sh | sh + pause + exit /b 1 + ) + + :: Refresh environment variables + call refreshenv 2>nul + if %errorlevel% neq 0 ( + echo Please reopen command prompt or manually refresh environment variables + echo Then run this script again + pause + exit /b 1 + ) + + echo uv installation completed +) else ( + echo uv is already installed +) + +:: Install project dependencies +echo [2/5] Installing project dependencies... +uv sync +if %errorlevel% neq 0 ( + echo Dependency installation failed + pause + exit /b 1 +) +echo Dependencies installed successfully + +:: Execute babel internationalization +echo [3/5] Executing babel internationalization... +uv run pybabel extract -F babel.cfg -o messages.pot . +if %errorlevel% neq 0 ( + echo babel extract failed + pause + exit /b 1 +) + +:: Check if Chinese translation files exist +if not exist "lang\zh_CN\LC_MESSAGES\messages.po" ( + echo Initializing Chinese translation files... + uv run pybabel init -i messages.pot -d lang -l zh_CN + if %errorlevel% neq 0 ( + echo babel init failed + pause + exit /b 1 + ) +) else ( + echo Updating Chinese translation files... + uv run pybabel update -i messages.pot -d lang -l zh_CN + if %errorlevel% neq 0 ( + echo babel update failed + pause + exit /b 1 + ) +) + +:: Compile translation files +uv run pybabel compile -d lang -l zh_CN +if %errorlevel% neq 0 ( + echo babel compile failed + pause + exit /b 1 +) +echo babel processing completed + +:: Execute alembic database migration +echo [4/5] Executing alembic database migration... +uv run alembic upgrade head +if %errorlevel% neq 0 ( + echo alembic migration failed + echo Please check database connection configuration + pause + exit /b 1 +) +echo alembic migration completed + +:: Start service +echo [5/5] Starting FastAPI service... +echo Service will start at http://localhost:3001 +echo Press Ctrl+C to stop the service +echo ======================================== + +uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0 + +pause \ No newline at end of file diff --git a/server/start_server.ps1 b/server/start_server.ps1 new file mode 100644 index 000000000..08f20834c --- /dev/null +++ b/server/start_server.ps1 @@ -0,0 +1,117 @@ +# Set console encoding to UTF-8 +[Console]::OutputEncoding = [System.Text.Encoding]::UTF8 + +Write-Host "========================================" -ForegroundColor Green +Write-Host "Auto Start Server Service" -ForegroundColor Green +Write-Host "========================================" -ForegroundColor Green + +# Check if running in correct directory +if (-not (Test-Path "pyproject.toml")) { + Write-Host "Error: Please run this script in the server directory" -ForegroundColor Red + Read-Host "Press Enter to exit" + exit 1 +} + +# Check if uv is installed +Write-Host "[1/5] Checking if uv is installed..." -ForegroundColor Yellow +try { + $uvVersion = uv --version 2>$null + if ($LASTEXITCODE -eq 0) { + Write-Host "uv is installed: $uvVersion" -ForegroundColor Green + } else { + throw "uv not found" + } +} catch { + Write-Host "uv is not installed, attempting to install..." -ForegroundColor Yellow + Write-Host "Downloading and installing uv..." -ForegroundColor Yellow + + try { + Invoke-RestMethod -Uri "https://astral.sh/uv/install.ps1" | Invoke-Expression + if ($LASTEXITCODE -eq 0) { + Write-Host "uv installation completed" -ForegroundColor Green + # Refresh environment variables + $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + } else { + throw "Installation failed" + } + } catch { + Write-Host "Auto installation failed, please install uv manually:" -ForegroundColor Red + Write-Host "1. Visit https://docs.astral.sh/uv/getting-started/installation/" -ForegroundColor Cyan + Write-Host "2. Or run: curl -LsSf https://astral.sh/uv/install.sh | sh" -ForegroundColor Cyan + Read-Host "Press Enter to exit" + exit 1 + } +} + +# Install project dependencies +Write-Host "[2/5] Installing project dependencies..." -ForegroundColor Yellow +try { + uv sync + if ($LASTEXITCODE -eq 0) { + Write-Host "Dependencies installed successfully" -ForegroundColor Green + } else { + throw "Dependency installation failed" + } +} catch { + Write-Host "Dependency installation failed" -ForegroundColor Red + Read-Host "Press Enter to exit" + exit 1 +} + +# Execute babel internationalization +Write-Host "[3/5] Executing babel internationalization..." -ForegroundColor Yellow +try { + uv run pybabel extract -F babel.cfg -o messages.pot . + if ($LASTEXITCODE -ne 0) { throw "babel extract failed" } + + # Check if Chinese translation files exist + if (-not (Test-Path "lang\zh_CN\LC_MESSAGES\messages.po")) { + Write-Host "Initializing Chinese translation files..." -ForegroundColor Yellow + uv run pybabel init -i messages.pot -d lang -l zh_CN + if ($LASTEXITCODE -ne 0) { throw "babel init failed" } + } else { + Write-Host "Updating Chinese translation files..." -ForegroundColor Yellow + uv run pybabel update -i messages.pot -d lang -l zh_CN + if ($LASTEXITCODE -ne 0) { throw "babel update failed" } + } + + # Compile translation files + uv run pybabel compile -d lang -l zh_CN + if ($LASTEXITCODE -ne 0) { throw "babel compile failed" } + + Write-Host "babel processing completed" -ForegroundColor Green +} catch { + Write-Host "babel processing failed: $($_.Exception.Message)" -ForegroundColor Red + Read-Host "Press Enter to exit" + exit 1 +} + +# Execute alembic database migration +Write-Host "[4/5] Executing alembic database migration..." -ForegroundColor Yellow +try { + uv run alembic upgrade head + if ($LASTEXITCODE -eq 0) { + Write-Host "alembic migration completed" -ForegroundColor Green + } else { + throw "alembic migration failed" + } +} catch { + Write-Host "alembic migration failed" -ForegroundColor Red + Write-Host "Please check database connection configuration" -ForegroundColor Yellow + Read-Host "Press Enter to exit" + exit 1 +} + +# Start service +Write-Host "[5/5] Starting FastAPI service..." -ForegroundColor Yellow +Write-Host "Service will start at http://localhost:3001" -ForegroundColor Cyan +Write-Host "Press Ctrl+C to stop the service" -ForegroundColor Cyan +Write-Host "========================================" -ForegroundColor Green + +try { + uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0 +} catch { + Write-Host "Service startup failed: $($_.Exception.Message)" -ForegroundColor Red + Read-Host "Press Enter to exit" + exit 1 +} \ No newline at end of file diff --git a/server/start_server.sh b/server/start_server.sh new file mode 100644 index 000000000..af834bfdd --- /dev/null +++ b/server/start_server.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +# Set text colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color + +echo -e "${GREEN}========================================" +echo -e "Auto Start Server Service" +echo -e "========================================${NC}" + +# Check if running in correct directory +if [ ! -f "pyproject.toml" ]; then + echo -e "${RED}Error: Please run this script in the server directory${NC}" + read -p "Press Enter to exit" + exit 1 +fi + +# Check if uv is installed +echo -e "${YELLOW}[1/5] Checking if uv is installed...${NC}" +if ! command -v uv &> /dev/null; then + echo -e "${YELLOW}uv is not installed, attempting to install...${NC}" + echo -e "${YELLOW}Downloading and installing uv...${NC}" + + # Try to install uv + if curl -LsSf https://astral.sh/uv/install.sh | sh; then + echo -e "${GREEN}uv installation completed${NC}" + # Refresh shell environment + export PATH="$HOME/.cargo/bin:$PATH" + source ~/.bashrc 2>/dev/null || source ~/.zshrc 2>/dev/null || true + else + echo -e "${RED}Auto installation failed, please install uv manually:${NC}" + echo -e "${CYAN}1. Visit https://docs.astral.sh/uv/getting-started/installation/${NC}" + echo -e "${CYAN}2. Or run: curl -LsSf https://astral.sh/uv/install.sh | sh${NC}" + read -p "Press Enter to exit" + exit 1 + fi +else + echo -e "${GREEN}uv is already installed${NC}" +fi + +# Install project dependencies +echo -e "${YELLOW}[2/5] Installing project dependencies...${NC}" +if uv sync; then + echo -e "${GREEN}Dependencies installed successfully${NC}" +else + echo -e "${RED}Dependency installation failed${NC}" + read -p "Press Enter to exit" + exit 1 +fi + +# Execute babel internationalization +echo -e "${YELLOW}[3/5] Executing babel internationalization...${NC}" +if ! uv run pybabel extract -F babel.cfg -o messages.pot .; then + echo -e "${RED}babel extract failed${NC}" + read -p "Press Enter to exit" + exit 1 +fi + +# Check if Chinese translation files exist +if [ ! -f "lang/zh_CN/LC_MESSAGES/messages.po" ]; then + echo -e "${YELLOW}Initializing Chinese translation files...${NC}" + if ! uv run pybabel init -i messages.pot -d lang -l zh_CN; then + echo -e "${RED}babel init failed${NC}" + read -p "Press Enter to exit" + exit 1 + fi +else + echo -e "${YELLOW}Updating Chinese translation files...${NC}" + if ! uv run pybabel update -i messages.pot -d lang -l zh_CN; then + echo -e "${RED}babel update failed${NC}" + read -p "Press Enter to exit" + exit 1 + fi +fi + +# Compile translation files +if ! uv run pybabel compile -d lang -l zh_CN; then + echo -e "${RED}babel compile failed${NC}" + read -p "Press Enter to exit" + exit 1 +fi + +echo -e "${GREEN}babel processing completed${NC}" + +# Execute alembic database migration +echo -e "${YELLOW}[4/5] Executing alembic database migration...${NC}" +if uv run alembic upgrade head; then + echo -e "${GREEN}alembic migration completed${NC}" +else + echo -e "${RED}alembic migration failed${NC}" + echo -e "${YELLOW}Please check database connection configuration${NC}" + read -p "Press Enter to exit" + exit 1 +fi + +# Start service +echo -e "${YELLOW}[5/5] Starting FastAPI service...${NC}" +echo -e "${CYAN}Service will start at http://localhost:3001${NC}" +echo -e "${CYAN}Press Ctrl+C to stop the service${NC}" +echo -e "${GREEN}========================================${NC}" + +if ! uv run uvicorn main:api --reload --port 3001 --host 0.0.0.0; then + echo -e "${RED}Service startup failed${NC}" + read -p "Press Enter to exit" + exit 1 +fi \ No newline at end of file diff --git a/src/components/ChatBox/index.tsx b/src/components/ChatBox/index.tsx index b205e1c34..634e71e75 100644 --- a/src/components/ChatBox/index.tsx +++ b/src/components/ChatBox/index.tsx @@ -1,17 +1,10 @@ import { useState, useRef, useEffect, useCallback } from "react"; import { fetchPost } from "@/api/http"; -import { Button } from "@/components/ui/button"; import { BottomInput } from "./BottomInput"; import { TaskCard } from "./TaskCard"; import { MessageCard } from "./MessageCard"; import { TypeCardSkeleton } from "./TypeCardSkeleton"; -import { - Smartphone, - Workflow, - CircleDollarSign, - FileText, - TriangleAlert, -} from "lucide-react"; +import { FileText, TriangleAlert } from "lucide-react"; import { generateUniqueId } from "@/lib"; import { useChatStore } from "@/store/chatStore"; import { proxyFetchGet } from "@/api/http"; @@ -31,7 +24,6 @@ export default function ChatBox(): JSX.Element { const [privacyDialogOpen, setPrivacyDialogOpen] = useState(false); const { modelType } = useAuthStore(); const [useCloudModelInDev, setUseCloudModelInDev] = useState(false); - useEffect(() => { if ( import.meta.env.VITE_USE_LOCAL_PROXY === "true" && @@ -200,26 +192,23 @@ export default function ChatBox(): JSX.Element { }, [scrollContainerRef.current?.scrollHeight]); const [loading, setLoading] = useState(false); - const handleConfirmTask = async () => { - const taskId = chatStore.activeTaskId; - if (!taskId) return; + const handleConfirmTask = async (taskId?: string) => { + const _taskId = taskId || chatStore.activeTaskId; + if (!_taskId) return; setLoading(true); - await chatStore.handleConfirmTask(taskId); + await chatStore.handleConfirmTask(_taskId); setLoading(false); }; const [hasSubTask, setHasSubTask] = useState(false); useEffect(() => { - setHasSubTask( - chatStore.tasks[chatStore.activeTaskId as string]?.messages?.find( - (message) => { - return message.step === "to_sub_tasks"; - } - ) - ? true - : false - ); + const _hasSubTask = chatStore.tasks[ + chatStore.activeTaskId as string + ]?.messages?.find((message) => message.step === "to_sub_tasks") + ? true + : false; + setHasSubTask(_hasSubTask); }, [chatStore?.tasks[chatStore.activeTaskId as string]?.messages]); useEffect(() => { @@ -437,13 +426,18 @@ export default function ChatBox(): JSX.Element { chatStore.tasks[chatStore.activeTaskId].summaryTask || "" } - onAddTask={() => chatStore.addTaskInfo()} - onUpdateTask={(taskIndex, content) => - chatStore.updateTaskInfo(taskIndex, content) - } - onDeleteTask={(taskIndex) => - chatStore.deleteTaskInfo(taskIndex) - } + onAddTask={() => { + chatStore.setIsTaskEdit(chatStore.activeTaskId as string, true); + chatStore.addTaskInfo(); + }} + onUpdateTask={(taskIndex, content) => { + chatStore.setIsTaskEdit(chatStore.activeTaskId as string, true); + chatStore.updateTaskInfo(taskIndex, content); + }} + onDeleteTask={(taskIndex) => { + chatStore.setIsTaskEdit(chatStore.activeTaskId as string, true); + chatStore.deleteTaskInfo(taskIndex); + }} /> ); } @@ -582,37 +576,39 @@ export default function ChatBox(): JSX.Element { ) )} - {!useCloudModelInDev && privacy && (hasSearchKey || modelType === "cloud") && ( -
- {[ - { - label: "Palm Springs Tennis Trip Planner", - message: - "I am two tennis fans and want to go see the tennis tournament in palm springs. l live in SF - please prepare a detailed itinerary with flights, hotels, things to do for 3 days - around the time semifinal/finals are happening. We like hiking, vegan food and spas. Our budget is $5K. The itinerary should be a detailed timeline of time, activity, cost, other details and if applicable a link to buy tickets/make reservations etc. for the item. Some preferences 1.Spa access would be nice but not necessary 2. When you finnish this task, please generate a html report about this trip.", - }, - { - label: "Bank Transfer CSV Analysis and Visualization", - message: - "Create a mock bank transfer CSV file include 10 columns and 10 rows. Read the generated CSV file and summarize the data, generate a chart to visualize relevant trends or insights from the data.", - }, - { - label: "Find Duplicate Files in Downloads Folder", - message: - "Help me find duplicate files by content, size, and format in my downloads folder.", - }, - ].map(({ label, message }) => ( -
{ - setMessage(message); - }} - > - {label} -
- ))} -
- )} + {!useCloudModelInDev && + privacy && + (hasSearchKey || modelType === "cloud") && ( +
+ {[ + { + label: "Palm Springs Tennis Trip Planner", + message: + "I am two tennis fans and want to go see the tennis tournament in palm springs. l live in SF - please prepare a detailed itinerary with flights, hotels, things to do for 3 days - around the time semifinal/finals are happening. We like hiking, vegan food and spas. Our budget is $5K. The itinerary should be a detailed timeline of time, activity, cost, other details and if applicable a link to buy tickets/make reservations etc. for the item. Some preferences 1.Spa access would be nice but not necessary 2. When you finnish this task, please generate a html report about this trip.", + }, + { + label: "Bank Transfer CSV Analysis and Visualization", + message: + "Create a mock bank transfer CSV file include 10 columns and 10 rows. Read the generated CSV file and summarize the data, generate a chart to visualize relevant trends or insights from the data.", + }, + { + label: "Find Duplicate Files in Downloads Folder", + message: + "Help me find duplicate files by content, size, and format in my downloads folder.", + }, + ].map(({ label, message }) => ( +
{ + setMessage(message); + }} + > + {label} +
+ ))} +
+ )} diff --git a/src/pages/Setting.tsx b/src/pages/Setting.tsx index 95fd365ed..0378dc04f 100644 --- a/src/pages/Setting.tsx +++ b/src/pages/Setting.tsx @@ -2,6 +2,7 @@ import { useState, useEffect } from "react"; import { useNavigate, useLocation, Outlet } from "react-router-dom"; import { Button } from "@/components/ui/button"; import useAppVersion from "@/hooks/use-app-version"; +import vsersionLogo from "@/assets/version-logo.png" import { X, CircleCheck, @@ -118,9 +119,9 @@ export default function Setting() { })} -
-
- Eigent +
+
+ version-logo
diff --git a/src/store/chatStore.ts b/src/store/chatStore.ts index cf1e62ba8..4939b3ca3 100644 --- a/src/store/chatStore.ts +++ b/src/store/chatStore.ts @@ -38,6 +38,7 @@ interface Task { snapshots: any[]; snapshotsTemp: any[]; isTakeControl: boolean; + isTaskEdit: boolean; } interface ChatStore { @@ -91,6 +92,7 @@ interface ChatStore { setSnapshots: (taskId: string, snapshots: any[]) => void, setIsTakeControl: (taskId: string, isTakeControl: boolean) => void, setSnapshotsTemp: (taskId: string, snapshot: any) => void, + setIsTaskEdit: (taskId: string, isTaskEdit: boolean) => void, } @@ -137,7 +139,8 @@ const chatStore = create()( selectedFile: null, snapshots: [], snapshotsTemp: [], - isTakeControl: false + isTakeControl: false, + isTaskEdit: false }, } })) @@ -326,9 +329,24 @@ const chatStore = create()( // if (tasks[taskId].status === 'finished') return if (agentMessages.step === "to_sub_tasks") { + + const messages = [...tasks[taskId].messages] const toSubTaskIndex = messages.findLastIndex((message: Message) => message.step === 'to_sub_tasks') if (toSubTaskIndex === -1) { + // 30 seconds auto confirm + setTimeout(() => { + const { tasks, handleConfirmTask, setIsTaskEdit } = get(); + const message = tasks[taskId].messages.findLast((item) => item.step === "to_sub_tasks"); + const isConfirm = message?.isConfirm || false; + const isTakeControl = + tasks[taskId].isTakeControl; + if (!isConfirm && !isTakeControl && !tasks[taskId].isTaskEdit) { + handleConfirmTask(taskId); + } + setIsTaskEdit(taskId, false); + }, 30000); + const newNoticeMessage: Message = { id: generateUniqueId(), role: "agent", @@ -1586,6 +1604,18 @@ const chatStore = create()( } }) }, + setIsTaskEdit(taskId: string, isTaskEdit: boolean) { + set((state) => ({ + ...state, + tasks: { + ...state.tasks, + [taskId]: { + ...state.tasks[taskId], + isTaskEdit + }, + }, + })) + }, }) );