server enhance

This commit is contained in:
Wendong-Fan 2025-08-21 00:43:48 +08:00
parent e7a5c5536f
commit c108179fbb
13 changed files with 191 additions and 31 deletions

View file

@ -1,7 +1,35 @@
# Environment Configuration Example
# Copy this file to .env and update with your own values
# Application Settings
debug=false
url_prefix=/api
secret_key=put-your-secret-key-here
database_url=postgresql://postgres:postgres@localhost:5432/postgres
# Chat Share Secret Key
CHAT_SHARE_SECRET_KEY=put-your-secret-key-here
CHAT_SHARE_SALT=put-your-encode-salt-here
# Security Configuration
# Generate with: openssl rand -hex 32
secret_key=CHANGE_THIS_TO_A_RANDOM_SECRET_KEY_USE_OPENSSL_RAND_HEX_32
# Database Configuration
# Use a strong password in production
database_url=postgresql://postgres:CHANGE_THIS_STRONG_PASSWORD@localhost:5432/eigent
# Docker Compose Database Settings (if using docker-compose)
POSTGRES_PASSWORD=CHANGE_THIS_STRONG_PASSWORD
POSTGRES_USER=postgres
POSTGRES_DB=eigent
# JWT Configuration
# Token expiration in seconds (3600 = 1 hour, recommended for production)
JWT_EXPIRATION=3600
# Chat Share Security
# Generate with: openssl rand -hex 32
CHAT_SHARE_SECRET_KEY=CHANGE_THIS_TO_A_RANDOM_SECRET_KEY
# Generate with: openssl rand -hex 16
CHAT_SHARE_SALT=CHANGE_THIS_TO_A_RANDOM_SALT
# Stack Auth Configuration (Optional)
# Leave empty if not using Stack Auth
STACK_AUTH_PROJECT_ID=
STACK_AUTH_API_KEY=
STACK_AUTH_BASE_URL=

View file

@ -52,6 +52,8 @@ def upgrade() -> None:
"admin_role",
sa.Column("admin_id", sa.Integer(), nullable=False),
sa.Column("role_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["admin_id"], ["admin.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["role_id"], ["role.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("admin_id", "role_id"),
)
op.create_table(
@ -283,7 +285,7 @@ def upgrade() -> None:
sa.Column("updated_at", sa.TIMESTAMP(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("pricacy_setting", sa.JSON(), nullable=True),
sa.Column("privacy_setting", sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],

View file

@ -39,17 +39,36 @@ class Auth:
id = payload["id"]
if payload["exp"] < int(datetime.now().timestamp()):
raise TokenException(code.token_expired, _("Validate credentials expired"))
# Accept both old tokens (without type) and new tokens (with type)
# Old tokens are treated as access tokens for backward compatibility
token_type = payload.get("type", "access")
if token_type not in ["access", "refresh"]:
raise TokenException(code.token_invalid, _("Invalid token type"))
except InvalidTokenError:
raise TokenException(code.token_invalid, _("Could not validate credentials"))
return Auth(id, payload["exp"])
@classmethod
def create_access_token(cls, user_id: int, expires_delta: timedelta | None = None):
to_encode: dict = {"id": user_id}
to_encode: dict = {"id": user_id, "type": "access"}
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(days=30)
# Get expiration from environment or default to 1 hour
expiration_seconds = int(env("JWT_EXPIRATION", "3600"))
expire = datetime.now() + timedelta(seconds=expiration_seconds)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, Auth.SECRET_KEY, algorithm="HS256")
return encoded_jwt
@classmethod
def create_refresh_token(cls, user_id: int, expires_delta: timedelta | None = None):
to_encode: dict = {"id": user_id, "type": "refresh"}
if expires_delta:
expire = datetime.now() + expires_delta
else:
# Refresh tokens last 7 days by default
expire = datetime.now() + timedelta(days=7)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, Auth.SECRET_KEY, algorithm="HS256")
return encoded_jwt

View file

@ -1,11 +1,35 @@
from pydantic import BaseModel, ValidationError, field_validator
from pydantic import BaseModel, ValidationError, field_validator, validator
from typing import Dict, List, Optional
import re
import os
class McpServerItem(BaseModel):
command: str
args: List[str]
env: Optional[Dict[str, str]] = None
@validator('command')
def validate_command(cls, v):
# Only allow alphanumeric, dash, underscore, forward slash, and dot
if not re.match(r'^[a-zA-Z0-9_\-./]+$', v):
raise ValueError('Command contains invalid characters')
# Prevent directory traversal
if '..' in v:
raise ValueError('Directory traversal not allowed')
# Check if it's an absolute path or a command name
if '/' in v and not os.path.isabs(v):
raise ValueError('Relative paths not allowed')
return v
@validator('args', each_item=True)
def validate_args(cls, v):
# Prevent shell metacharacters that could lead to command injection
dangerous_chars = ['&', '|', ';', '$', '`', '(', ')', '<', '>', '\n', '\r']
for char in dangerous_chars:
if char in v:
raise ValueError(f'Argument contains dangerous character: {char}')
return v
class McpServersModel(BaseModel):
@ -15,6 +39,21 @@ class McpServersModel(BaseModel):
class McpRemoteServer(BaseModel):
server_name: str
server_url: str
@validator('server_url')
def validate_server_url(cls, v):
# Only allow http/https URLs
if not v.startswith(('http://', 'https://')):
raise ValueError('Only HTTP/HTTPS URLs are allowed')
# Basic URL validation to prevent SSRF
# In production, you should use a proper URL validation library
# and implement domain allowlisting
forbidden_hosts = ['localhost', '127.0.0.1', '0.0.0.0', '169.254.169.254']
from urllib.parse import urlparse
parsed = urlparse(v)
if parsed.hostname in forbidden_hosts:
raise ValueError('Access to this host is forbidden')
return v
def validate_mcp_servers(data: dict):

View file

@ -67,8 +67,7 @@ async def get_chat_step(step_id: int, session: Session = Depends(session), auth:
@router.post("/steps", name="create chat step")
# TODO Limit request sources
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
chat_step = ChatStep(
task_id=step.task_id,
step=step.step,

View file

@ -63,7 +63,7 @@ async def put(id: int, data: ProviderIn, session: Session = Depends(session), au
model.api_key = data.api_key
model.endpoint_url = data.endpoint_url
model.encrypted_config = data.encrypted_config
model.is_vaild = data.is_vaild
model.is_valid = data.is_valid
model.save(session)
session.refresh(model)
return model

View file

@ -8,13 +8,20 @@ from app.component.encrypt import password_verify
from app.component.stack_auth import StackAuth
from app.exception.exception import UserException
from app.model.user.user import LoginByPasswordIn, LoginResponse, Status, User, RegisterIn
from pydantic import BaseModel
from loguru import logger
from app.component.environment import env
from datetime import datetime
import jwt
router = APIRouter(tags=["Login/Registration"])
class RefreshTokenRequest(BaseModel):
refresh_token: str
@router.post("/login", name="login by email or password")
async def by_password(data: LoginByPasswordIn, session: Session = Depends(session)) -> LoginResponse:
"""
@ -23,7 +30,11 @@ async def by_password(data: LoginByPasswordIn, session: Session = Depends(sessio
user = User.by(User.email == data.email, s=session).one_or_none()
if not user or not password_verify(data.password, user.password):
raise UserException(code.password, _("Account or password error"))
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
return LoginResponse(
access_token=Auth.create_access_token(user.id),
refresh_token=Auth.create_refresh_token(user.id),
email=user.email
)
@router.post("/login-by_stack", name="login by stack")
@ -57,7 +68,11 @@ async def by_stack_auth(
s.add(user)
s.commit()
session.refresh(user)
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
return LoginResponse(
access_token=Auth.create_access_token(user.id),
refresh_token=Auth.create_refresh_token(user.id),
email=user.email
)
except Exception as e:
s.rollback()
logger.error(f"Failed to register: {e}")
@ -65,7 +80,11 @@ async def by_stack_auth(
else:
if user.status == Status.Block:
raise UserException(code.error, _("Your account has been blocked."))
return LoginResponse(token=Auth.create_access_token(user.id), email=user.email)
return LoginResponse(
access_token=Auth.create_access_token(user.id),
refresh_token=Auth.create_refresh_token(user.id),
email=user.email
)
@router.post("/register", name="register by email/password")
@ -88,3 +107,40 @@ async def register(data: RegisterIn, session: Session = Depends(session)):
logger.error(f"Failed to register: {e}")
raise UserException(code.error, _("Failed to register"))
return {"status": "success"}
@router.post("/refresh", name="refresh access token")
async def refresh_token(data: RefreshTokenRequest, session: Session = Depends(session)) -> LoginResponse:
"""
Refresh the access token using a valid refresh token.
"""
try:
# Decode the refresh token
payload = jwt.decode(data.refresh_token, Auth.SECRET_KEY, algorithms=["HS256"])
# Verify it's a refresh token
if payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid token type")
# Check if expired
if payload["exp"] < int(datetime.now().timestamp()):
raise HTTPException(status_code=401, detail="Refresh token expired")
# Get the user
user_id = payload["id"]
user = session.get(User, user_id)
if not user:
raise HTTPException(status_code=401, detail="User not found")
# Check if user is blocked
if user.status == Status.Block:
raise HTTPException(status_code=401, detail="User account is blocked")
# Generate new tokens
return LoginResponse(
access_token=Auth.create_access_token(user.id),
refresh_token=Auth.create_refresh_token(user.id),
email=user.email
)
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid refresh token")

View file

@ -50,7 +50,7 @@ def get_privacy(session: Session = Depends(session), auth: Auth = Depends(auth_m
if not model:
return UserPrivacySettings.default_settings()
return model.pricacy_setting
return model.privacy_setting
@router.put("/user/privacy", name="update user privacy")
@ -61,13 +61,13 @@ def put_privacy(data: UserPrivacySettings, session: Session = Depends(session),
default_settings = UserPrivacySettings.default_settings()
if model:
model.pricacy_setting = {**model.pricacy_setting, **data.model_dump()}
model.privacy_setting = {**model.privacy_setting, **data.model_dump()}
model.save(session)
else:
model = UserPrivacy(user_id=user_id, pricacy_setting={**default_settings, **data.model_dump()})
model = UserPrivacy(user_id=user_id, privacy_setting={**default_settings, **data.model_dump()})
model.save(session)
return model.pricacy_setting
return model.privacy_setting
@router.get("/user/current_credits", name="get user current credits")

View file

@ -9,7 +9,7 @@ from sqlalchemy import text
from app.model.abstract.model import AbstractModel, DefaultTimes
class VaildStatus(IntEnum):
class ValidStatus(IntEnum):
not_valid = 1
is_valid = 2
@ -23,9 +23,9 @@ class Provider(AbstractModel, DefaultTimes, table=True):
endpoint_url: str = ""
encrypted_config: dict | None = Field(default=None, sa_column=Column(JSON))
prefer: bool = Field(default=False, sa_column=Column(Boolean, server_default=text("false")))
is_vaild: VaildStatus = Field(
default=VaildStatus.not_valid,
sa_column=Column(ChoiceType(VaildStatus, SmallInteger()), server_default=text("1")),
is_valid: ValidStatus = Field(
default=ValidStatus.not_valid,
sa_column=Column(ChoiceType(ValidStatus, SmallInteger()), server_default=text("1")),
)
@ -35,7 +35,7 @@ class ProviderIn(BaseModel):
api_key: str
endpoint_url: str
encrypted_config: dict | None = None
is_vaild: VaildStatus = VaildStatus.not_valid
is_valid: ValidStatus = ValidStatus.not_valid
prefer: bool = False

View file

@ -10,7 +10,7 @@ from app.model.abstract.model import AbstractModel, DefaultTimes
class UserPrivacy(AbstractModel, DefaultTimes, table=True):
id: int = Field(default=None, primary_key=True)
user_id: int = Field(unique=True, foreign_key="user.id")
pricacy_setting: dict = Field(default="{}", sa_column=Column(JSON))
privacy_setting: dict = Field(default="{}", sa_column=Column(JSON))
class UserPrivacySettings(BaseModel):

View file

@ -43,8 +43,14 @@ class LoginByPasswordIn(BaseModel):
class LoginResponse(BaseModel):
token: str
access_token: str
refresh_token: str
token_type: str = "Bearer"
email: EmailStr
# Backward compatibility
@property
def token(self) -> str:
return self.access_token
class UserIn(BaseModel):

View file

@ -7,9 +7,9 @@ services:
container_name: eigent_postgres
restart: unless-stopped
environment:
POSTGRES_DB: eigent
POSTGRES_USER: postgres
POSTGRES_PASSWORD: 123456
POSTGRES_DB: ${POSTGRES_DB:-eigent}
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_INITDB_ARGS: "--encoding=UTF-8 --lc-collate=C --lc-ctype=C"
ports:
- "5432:5432"
@ -30,13 +30,13 @@ services:
context: .
dockerfile: Dockerfile
args:
database_url: postgresql://postgres:123456@postgres:5432/eigent
database_url: ${DATABASE_URL:-postgresql://postgres:postgres@postgres:5432/eigent}
container_name: eigent_api
restart: unless-stopped
ports:
- "3001:5678"
environment:
- DATABASE_URL=postgresql://postgres:123456@postgres:5432/eigent
- DATABASE_URL=${DATABASE_URL:-postgresql://postgres:postgres@postgres:5432/eigent}
- ENVIRONMENT=production
- DEBUG=false
# volumes:

View file

@ -3,6 +3,17 @@ from app.component.environment import auto_include_routers, env
from loguru import logger
import os
from fastapi.staticfiles import StaticFiles
from fastapi import status
from fastapi.responses import JSONResponse
# Health check endpoint
@api.get("/health", tags=["Health"])
async def health_check():
"""Health check endpoint for monitoring."""
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"status": "healthy", "service": "eigent-api"}
)
prefix = env("url_prefix", "")
auto_include_routers(api, prefix, "app/controller")