# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import asyncio import json import logging from fastapi import APIRouter, Depends, HTTPException, Response from fastapi.responses import StreamingResponse from fastapi_babel import _ from sqlalchemy.sql.expression import case from sqlmodel import Session, asc, select from app.component.auth import Auth, auth_must from app.component.database import session from app.model.chat.chat_step import ChatStep, ChatStepIn, ChatStepOut logger = logging.getLogger("server_chat_step") router = APIRouter(prefix="/chat", tags=["Chat Step Management"]) @router.get("/steps", name="list chat steps", response_model=list[ChatStepOut]) async def list_chat_steps( task_id: str, step: str | None = None, session: Session = Depends(session), auth: Auth = Depends(auth_must) ): """List chat steps for a task with optional step type filtering.""" user_id = auth.user.id query = select(ChatStep) if task_id is not None: query = query.where(ChatStep.task_id == task_id) if step is not None: query = query.where(ChatStep.step == step) chat_steps = session.exec(query).all() logger.debug( "Chat steps listed", extra={"user_id": user_id, "task_id": task_id, "step_type": step, "count": len(chat_steps)} ) return chat_steps @router.get("/steps/playback/{task_id}", name="Playback Chat Step via SSE") async def share_playback( task_id: str, delay_time: float = 0, session: Session = Depends(session), auth: Auth = Depends(auth_must) ): """Playback chat steps via SSE stream.""" user_id = auth.user.id if delay_time > 5: logger.debug( "Delay time capped", extra={"user_id": user_id, "task_id": task_id, "requested": delay_time, "capped": 5} ) delay_time = 5 async def event_generator(): try: stmt = ( select(ChatStep) .where(ChatStep.task_id == task_id) .order_by( asc(case((ChatStep.timestamp.is_(None), 1), else_=0)), asc(ChatStep.timestamp), asc(ChatStep.id) ) ) steps = session.exec(stmt).all() if not steps: logger.warning("No steps found for playback", extra={"user_id": user_id, "task_id": task_id}) yield f"data: {json.dumps({'error': 'No steps found for this task.'})}\n\n" return logger.info( "Chat step playback started", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps), "delay_time": delay_time}, ) for step in steps: step_data = { "id": step.id, "task_id": step.task_id, "step": step.step, "data": step.data, "created_at": step.created_at.isoformat() if step.created_at else None, } yield f"data: {json.dumps(step_data)}\n\n" if delay_time > 0: await asyncio.sleep(delay_time) logger.info( "Chat step playback completed", extra={"user_id": user_id, "task_id": task_id, "step_count": len(steps)} ) except Exception as e: logger.error( "Chat step playback error", extra={"user_id": user_id, "task_id": task_id, "error": str(e)}, exc_info=True, ) yield f"data: {json.dumps({'error': 'Playback error occurred.'})}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") @router.get("/steps/{step_id}", name="get chat step", response_model=ChatStepOut) async def get_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): """Get specific chat step.""" user_id = auth.user.id chat_step = session.get(ChatStep, step_id) if not chat_step: logger.warning("Chat step not found", extra={"user_id": user_id, "step_id": step_id}) raise HTTPException(status_code=404, detail=_("Chat step not found")) logger.debug("Chat step retrieved", extra={"user_id": user_id, "step_id": step_id, "task_id": chat_step.task_id}) return chat_step @router.post("/steps", name="create chat step") async def create_chat_step(step: ChatStepIn, session: Session = Depends(session)): """Create new chat step. TODO: Implement request source validation.""" try: chat_step = ChatStep(task_id=step.task_id, step=step.step, data=step.data, timestamp=step.timestamp) session.add(chat_step) session.commit() session.refresh(chat_step) logger.info( "Chat step created", extra={"step_id": chat_step.id, "task_id": step.task_id, "step_type": step.step} ) return {"code": 200, "msg": "success"} except Exception as e: session.rollback() logger.error( "Chat step creation failed", extra={"task_id": step.task_id, "step_type": step.step, "error": str(e)}, exc_info=True, ) raise HTTPException(status_code=500, detail="Internal server error") @router.put("/steps/{step_id}", name="update chat step", response_model=ChatStepOut) async def update_chat_step( step_id: int, chat_step_update: ChatStep, session: Session = Depends(session), auth: Auth = Depends(auth_must) ): """Update chat step.""" user_id = auth.user.id db_chat_step = session.get(ChatStep, step_id) if not db_chat_step: logger.warning("Chat step not found for update", extra={"user_id": user_id, "step_id": step_id}) raise HTTPException(status_code=404, detail=_("Chat step not found")) try: update_data = chat_step_update.dict(exclude_unset=True) for key, value in update_data.items(): setattr(db_chat_step, key, value) session.add(db_chat_step) session.commit() session.refresh(db_chat_step) logger.info( "Chat step updated", extra={ "user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id, "fields_updated": list(update_data.keys()), }, ) return db_chat_step except Exception as e: session.rollback() logger.error( "Chat step update failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True ) raise HTTPException(status_code=500, detail="Internal server error") @router.delete("/steps/{step_id}", name="delete chat step") async def delete_chat_step(step_id: int, session: Session = Depends(session), auth: Auth = Depends(auth_must)): """Delete chat step.""" user_id = auth.user.id db_chat_step = session.get(ChatStep, step_id) if not db_chat_step: logger.warning("Chat step not found for deletion", extra={"user_id": user_id, "step_id": step_id}) raise HTTPException(status_code=404, detail=_("Chat step not found")) try: session.delete(db_chat_step) session.commit() logger.info( "Chat step deleted", extra={"user_id": user_id, "step_id": step_id, "task_id": db_chat_step.task_id} ) return Response(status_code=204) except Exception as e: session.rollback() logger.error( "Chat step deletion failed", extra={"user_id": user_id, "step_id": step_id, "error": str(e)}, exc_info=True ) raise HTTPException(status_code=500, detail="Internal server error")