import asyncio import json from typing import AsyncGenerator, List, Optional from fastapi import APIRouter, HTTPException, Path from fastapi.responses import StreamingResponse from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig from loguru import logger from pydantic import BaseModel, Field from open_notebook.database.repository import ensure_record_id, repo_query from open_notebook.domain.notebook import ChatSession, Source from open_notebook.exceptions import ( NotFoundError, ) from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph router = APIRouter() # Request/Response models class CreateSourceChatSessionRequest(BaseModel): source_id: str = Field(..., description="Source ID to create chat session for") title: Optional[str] = Field(None, description="Optional session title") model_override: Optional[str] = Field( None, description="Optional model override for this session" ) class UpdateSourceChatSessionRequest(BaseModel): title: Optional[str] = Field(None, description="New session title") model_override: Optional[str] = Field( None, description="Model override for this session" ) class ChatMessage(BaseModel): id: str = Field(..., description="Message ID") type: str = Field(..., description="Message type (human|ai)") content: str = Field(..., description="Message content") timestamp: Optional[str] = Field(None, description="Message timestamp") class ContextIndicator(BaseModel): sources: List[str] = Field( default_factory=list, description="Source IDs used in context" ) insights: List[str] = Field( default_factory=list, description="Insight IDs used in context" ) notes: List[str] = Field( default_factory=list, description="Note IDs used in context" ) class SourceChatSessionResponse(BaseModel): id: str = Field(..., description="Session ID") title: str = Field(..., description="Session title") source_id: str = Field(..., description="Source ID") model_override: Optional[str] = Field( None, description="Model override for this session" ) created: str = Field(..., description="Creation timestamp") updated: str = Field(..., description="Last update timestamp") message_count: Optional[int] = Field( None, description="Number of messages in session" ) class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse): messages: List[ChatMessage] = Field( default_factory=list, description="Session messages" ) context_indicators: Optional[ContextIndicator] = Field( None, description="Context indicators from last response" ) class SendMessageRequest(BaseModel): message: str = Field(..., description="User message content") model_override: Optional[str] = Field( None, description="Optional model override for this message" ) class SuccessResponse(BaseModel): success: bool = Field(True, description="Operation success status") message: str = Field(..., description="Success message") @router.post( "/sources/{source_id}/chat/sessions", response_model=SourceChatSessionResponse ) async def create_source_chat_session( request: CreateSourceChatSessionRequest, source_id: str = Path(..., description="Source ID"), ): """Create a new chat session for a source.""" try: # Verify source exists full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Create new session with model_override support session = ChatSession( title=request.title or f"Source Chat {asyncio.get_event_loop().time():.0f}", model_override=request.model_override, ) await session.save() # Relate session to source using "refers_to" relation await session.relate("refers_to", full_source_id) return SourceChatSessionResponse( id=session.id or "", title=session.title or "Untitled Session", source_id=source_id, model_override=session.model_override, created=str(session.created), updated=str(session.updated), message_count=0, ) except NotFoundError: raise HTTPException(status_code=404, detail="Source not found") except Exception as e: logger.error(f"Error creating source chat session: {str(e)}") raise HTTPException( status_code=500, detail=f"Error creating source chat session: {str(e)}" ) @router.get( "/sources/{source_id}/chat/sessions", response_model=List[SourceChatSessionResponse] ) async def get_source_chat_sessions(source_id: str = Path(..., description="Source ID")): """Get all chat sessions for a source.""" try: # Verify source exists full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get sessions that refer to this source - first get relations, then sessions relations = await repo_query( "SELECT in FROM refers_to WHERE out = $source_id", {"source_id": ensure_record_id(full_source_id)}, ) sessions = [] for relation in relations: session_id = relation.get("in") if session_id: session_result = await repo_query(f"SELECT * FROM {session_id}") if session_result and len(session_result) > 0: session_data = session_result[0] sessions.append( SourceChatSessionResponse( id=session_data.get("id") or "", title=session_data.get("title") or "Untitled Session", source_id=source_id, model_override=session_data.get("model_override"), created=str(session_data.get("created")), updated=str(session_data.get("updated")), message_count=0, # TODO: Add message count if needed ) ) # Sort sessions by created date (newest first) sessions.sort(key=lambda x: x.created, reverse=True) return sessions except NotFoundError: raise HTTPException(status_code=404, detail="Source not found") except Exception as e: logger.error(f"Error fetching source chat sessions: {str(e)}") raise HTTPException( status_code=500, detail=f"Error fetching source chat sessions: {str(e)}" ) @router.get( "/sources/{source_id}/chat/sessions/{session_id}", response_model=SourceChatSessionWithMessagesResponse, ) async def get_source_chat_session( source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID"), ): """Get a specific source chat session with its messages.""" try: # Verify source exists full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get session full_session_id = ( session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" ) session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", { "session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id), }, ) if not relation_query: raise HTTPException( status_code=404, detail="Session not found for this source" ) # Get session state from LangGraph to retrieve messages thread_state = source_chat_graph.get_state( config=RunnableConfig(configurable={"thread_id": session_id}) ) # Extract messages from state messages: list[ChatMessage] = [] context_indicators = None if thread_state and thread_state.values: # Extract messages if "messages" in thread_state.values: for msg in thread_state.values["messages"]: messages.append( ChatMessage( id=getattr(msg, "id", f"msg_{len(messages)}"), type=msg.type if hasattr(msg, "type") else "unknown", content=msg.content if hasattr(msg, "content") else str(msg), timestamp=None, # LangChain messages don't have timestamps by default ) ) # Extract context indicators from the last state if "context_indicators" in thread_state.values: context_data = thread_state.values["context_indicators"] context_indicators = ContextIndicator( sources=context_data.get("sources", []), insights=context_data.get("insights", []), notes=context_data.get("notes", []), ) return SourceChatSessionWithMessagesResponse( id=session.id or "", title=session.title or "Untitled Session", source_id=source_id, model_override=getattr(session, "model_override", None), created=str(session.created), updated=str(session.updated), message_count=len(messages), messages=messages, context_indicators=context_indicators, ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") except Exception as e: logger.error(f"Error fetching source chat session: {str(e)}") raise HTTPException( status_code=500, detail=f"Error fetching source chat session: {str(e)}" ) @router.put( "/sources/{source_id}/chat/sessions/{session_id}", response_model=SourceChatSessionResponse, ) async def update_source_chat_session( request: UpdateSourceChatSessionRequest, source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID"), ): """Update source chat session title and/or model override.""" try: # Verify source exists full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get session full_session_id = ( session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" ) session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", { "session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id), }, ) if not relation_query: raise HTTPException( status_code=404, detail="Session not found for this source" ) # Update session fields if request.title is not None: session.title = request.title if request.model_override is not None: session.model_override = request.model_override await session.save() return SourceChatSessionResponse( id=session.id or "", title=session.title or "Untitled Session", source_id=source_id, model_override=getattr(session, "model_override", None), created=str(session.created), updated=str(session.updated), message_count=0, ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") except Exception as e: logger.error(f"Error updating source chat session: {str(e)}") raise HTTPException( status_code=500, detail=f"Error updating source chat session: {str(e)}" ) @router.delete( "/sources/{source_id}/chat/sessions/{session_id}", response_model=SuccessResponse ) async def delete_source_chat_session( source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID"), ): """Delete a source chat session.""" try: # Verify source exists full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Get session full_session_id = ( session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" ) session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", { "session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id), }, ) if not relation_query: raise HTTPException( status_code=404, detail="Session not found for this source" ) await session.delete() return SuccessResponse( success=True, message="Source chat session deleted successfully" ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") except Exception as e: logger.error(f"Error deleting source chat session: {str(e)}") raise HTTPException( status_code=500, detail=f"Error deleting source chat session: {str(e)}" ) async def stream_source_chat_response( session_id: str, source_id: str, message: str, model_override: Optional[str] = None ) -> AsyncGenerator[str, None]: """Stream the source chat response as Server-Sent Events.""" try: # Get current state current_state = source_chat_graph.get_state( config=RunnableConfig(configurable={"thread_id": session_id}) ) # Prepare state for execution state_values = current_state.values if current_state else {} state_values["messages"] = state_values.get("messages", []) state_values["source_id"] = source_id state_values["model_override"] = model_override # Add user message to state user_message = HumanMessage(content=message) state_values["messages"].append(user_message) # Send user message event user_event = {"type": "user_message", "content": message, "timestamp": None} yield f"data: {json.dumps(user_event)}\n\n" # Execute source chat graph synchronously (like notebook chat does) result = source_chat_graph.invoke( input=state_values, # type: ignore[arg-type] config=RunnableConfig( configurable={"thread_id": session_id, "model_id": model_override} ), ) # Stream the complete AI response if "messages" in result: for msg in result["messages"]: if hasattr(msg, "type") and msg.type == "ai": ai_event = { "type": "ai_message", "content": msg.content if hasattr(msg, "content") else str(msg), "timestamp": None, } yield f"data: {json.dumps(ai_event)}\n\n" # Stream context indicators if "context_indicators" in result: context_event = { "type": "context_indicators", "data": result["context_indicators"], } yield f"data: {json.dumps(context_event)}\n\n" # Send completion signal completion_event = {"type": "complete"} yield f"data: {json.dumps(completion_event)}\n\n" except Exception as e: logger.error(f"Error in source chat streaming: {str(e)}") error_event = {"type": "error", "message": str(e)} yield f"data: {json.dumps(error_event)}\n\n" @router.post("/sources/{source_id}/chat/sessions/{session_id}/messages") async def send_message_to_source_chat( request: SendMessageRequest, source_id: str = Path(..., description="Source ID"), session_id: str = Path(..., description="Session ID"), ): """Send a message to source chat session with SSE streaming response.""" try: # Verify source exists full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) source = await Source.get(full_source_id) if not source: raise HTTPException(status_code=404, detail="Source not found") # Verify session exists and is related to source full_session_id = ( session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}" ) session = await ChatSession.get(full_session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") # Verify session is related to this source relation_query = await repo_query( "SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id", { "session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id), }, ) if not relation_query: raise HTTPException( status_code=404, detail="Session not found for this source" ) if not request.message: raise HTTPException(status_code=400, detail="Message content is required") # Determine model override (request override takes precedence over session override) model_override = request.model_override or getattr( session, "model_override", None ) # Update session timestamp await session.save() # Return streaming response return StreamingResponse( stream_source_chat_response( session_id=session_id, source_id=full_source_id, message=request.message, model_override=model_override, ), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/plain; charset=utf-8", }, ) except HTTPException: raise except Exception as e: logger.error(f"Error sending message to source chat: {str(e)}") raise HTTPException(status_code=500, detail=f"Error sending message: {str(e)}")