eigent/server/app/controller/chat/step_controller.py

207 lines
8.1 KiB
Python

# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
import asyncio
import json
import logging
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from fastapi_babel import _
from sqlalchemy.sql.expression import case
from sqlmodel import Session, asc, select
from app.component.auth import Auth, auth_must
from app.component.database import session
from app.model.chat.chat_step import ChatStep, ChatStepIn, ChatStepOut
logger = logging.getLogger("server_chat_step")
router = APIRouter(prefix="/chat", tags=["Chat Step Management"])
@router.get("/steps", name="list chat steps", response_model=list[ChatStepOut])
async def list_chat_steps(
task_id: str, step: str | None = None, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""List chat steps for a task with optional step type filtering."""
user_id = auth.user.id
query = select(ChatStep)
if task_id is not None:
query = query.where(ChatStep.task_id == task_id)
if step is not None:
query = query.where(ChatStep.step == step)
chat_steps = session.exec(query).all()
logger.debug(
"Chat steps listed", extra={"user_id": user_id, "task_id": task_id, "step_type": step, "count": len(chat_steps)}
)
return chat_steps
@router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE")
async def share_playback(
task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Playback chat steps via SSE stream."""
user_id = auth.user.id
if delay_time > 5:
logger.debug(
"Delay time capped", extra={"user_id": user_id, "task_id": task_id, "requested": delay_time, "capped": 5}
)
delay_time = 5
async def event_generator():
try:
stmt = (
select(ChatStep)
.where(ChatStep.task_id == task_id)
.order_by(
asc(case((ChatStep.timestamp.is_(None), 1), else_=0)), asc(ChatStep.timestamp), asc(ChatStep.id)
)
)
steps = session.exec(stmt).all()
if not steps:
logger.warning("No steps found for playback", extra={"user_id": user_id, "task_id": task_id})
yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n"
return
logger.info(
"Chat step playback started",
extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps), "delay_time": delay_time},
)
for step in steps:
step_data = {
"id": step.id,
"task_id": step.task_id,
"step": step.step,
"data": step.data,
"created_at": step.created_at.isoformat() if step.created_at else None,
}
yield f"data: {json.dumps(step_data)}\n\n"
if delay_time > 0:
await asyncio.sleep(delay_time)
logger.info(
"Chat step playback completed", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps)}
)
except Exception as e:
logger.error(
"Chat step playback error",
extra={"user_id": user_id, "task_id": task_id, "error": str(e)},
exc_info=True,
)
yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut)
async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Get specific chat step."""
user_id = auth.user.id
chat_step = session.get(ChatStep, step_id)
if not chat_step:
logger.warning("Chat step not found", extra={"user_id": user_id, "step_id": step_id})
raise HTTPException(status_code=404, detail=_("Chat step not found"))
logger.debug("Chat step retrieved", extra={"user_id": user_id, "step_id": step_id, "task_id": chat_step.task_id})
return chat_step
@router.post("/steps", name="create chat step")
async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)):
"""Create new chat step. TODO: Implement request source validation."""
try:
chat_step = ChatStep(task_id=step.task_id, step=step.step, data=step.data, timestamp=step.timestamp)
session.add(chat_step)
session.commit()
session.refresh(chat_step)
logger.info(
"Chat step created", extra={"step_id": chat_step.id, "task_id": step.task_id, "step_type": step.step}
)
return {"code": 200, "msg": "success"}
except Exception as e:
session.rollback()
logger.error(
"Chat step creation failed",
extra={"task_id": step.task_id, "step_type": step.step, "error": str(e)},
exc_info=True,
)
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut)
async def update_chat_step(
step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must)
):
"""Update chat step."""
user_id = auth.user.id
db_chat_step = session.get(ChatStep, step_id)
if not db_chat_step:
logger.warning("Chat step not found for update", extra={"user_id": user_id, "step_id": step_id})
raise HTTPException(status_code=404, detail=_("Chat step not found"))
try:
update_data = chat_step_update.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_chat_step, key, value)
session.add(db_chat_step)
session.commit()
session.refresh(db_chat_step)
logger.info(
"Chat step updated",
extra={
"user_id": user_id,
"step_id": step_id,
"task_id": db_chat_step.task_id,
"fields_updated": list(update_data.keys()),
},
)
return db_chat_step
except Exception as e:
session.rollback()
logger.error(
"Chat step update failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True
)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/steps/{step_id}", name="delete chat step")
async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)):
"""Delete chat step."""
user_id = auth.user.id
db_chat_step = session.get(ChatStep, step_id)
if not db_chat_step:
logger.warning("Chat step not found for deletion", extra={"user_id": user_id, "step_id": step_id})
raise HTTPException(status_code=404, detail=_("Chat step not found"))
try:
session.delete(db_chat_step)
session.commit()
logger.info(
"Chat step deleted", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id}
)
return Response(status_code=204)
except Exception as e:
session.rollback()
logger.error(
"Chat step deletion failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True
)
raise HTTPException(status_code=500, detail="Internal server error")