diff --git a/api/routers/chat.py b/api/routers/chat.py index c67041d..f4523be 100644 --- a/api/routers/chat.py +++ b/api/routers/chat.py @@ -13,6 +13,7 @@ from open_notebook.exceptions import ( NotFoundError, ) from open_notebook.graphs.chat import graph as chat_graph +from open_notebook.utils.graph_utils import get_session_message_count router = APIRouter() @@ -102,20 +103,28 @@ async def get_sessions(notebook_id: str = Query(..., description="Notebook ID")) raise HTTPException(status_code=404, detail="Notebook not found") # Get sessions for this notebook - sessions = await notebook.get_chat_sessions() + sessions_list = await notebook.get_chat_sessions() - return [ - ChatSessionResponse( - id=session.id or "", - title=session.title or "Untitled Session", - notebook_id=notebook_id, - created=str(session.created), - updated=str(session.updated), - message_count=0, # TODO: Add message count if needed - model_override=getattr(session, "model_override", None), + results = [] + for session in sessions_list: + session_id = str(session.id) + + # Get message count from LangGraph state + msg_count = await get_session_message_count(chat_graph, session_id) + + results.append( + ChatSessionResponse( + id=session.id or "", + title=session.title or "Untitled Session", + notebook_id=notebook_id, + created=str(session.created), + updated=str(session.updated), + message_count=msg_count, + model_override=getattr(session, "model_override", None), + ) ) - for session in sessions - ] + + return results except NotFoundError: raise HTTPException(status_code=404, detail="Notebook not found") except Exception as e: @@ -181,8 +190,8 @@ async def get_session(session_id: str): raise HTTPException(status_code=404, detail="Session not found") # Get session state from LangGraph to retrieve messages - thread_state = chat_graph.get_state( - config=RunnableConfig(configurable={"thread_id": session_id}) + thread_state = await chat_graph.aget_state( + config=RunnableConfig(configurable={"thread_id": full_session_id}) ) # Extract messages from state @@ -273,13 +282,16 @@ async def update_session(session_id: str, request: UpdateSessionRequest): ) notebook_id = notebook_query[0]["out"] if notebook_query else None + # Get message count from LangGraph state + msg_count = await get_session_message_count(chat_graph, full_session_id) + return ChatSessionResponse( id=session.id or "", title=session.title or "", notebook_id=notebook_id, created=str(session.created), updated=str(session.updated), - message_count=0, + message_count=msg_count, model_override=session.model_override, ) except NotFoundError: @@ -336,8 +348,8 @@ async def execute_chat(request: ExecuteChatRequest): ) # Get current state - current_state = chat_graph.get_state( - config=RunnableConfig(configurable={"thread_id": request.session_id}) + current_state = await chat_graph.aget_state( + config=RunnableConfig(configurable={"thread_id": full_session_id}) ) # Prepare state for execution @@ -357,7 +369,7 @@ async def execute_chat(request: ExecuteChatRequest): input=state_values, # type: ignore[arg-type] config=RunnableConfig( configurable={ - "thread_id": request.session_id, + "thread_id": full_session_id, "model_id": model_override, } ), diff --git a/api/routers/source_chat.py b/api/routers/source_chat.py index ddda4e1..779657b 100644 --- a/api/routers/source_chat.py +++ b/api/routers/source_chat.py @@ -15,6 +15,7 @@ from open_notebook.exceptions import ( NotFoundError, ) from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph +from open_notebook.utils.graph_utils import get_session_message_count router = APIRouter() @@ -27,14 +28,12 @@ class CreateSourceChatSessionRequest(BaseModel): None, description="Optional model override for this session" ) - class UpdateSourceChatSessionRequest(BaseModel): title: Optional[str] = Field(None, description="New session title") model_override: Optional[str] = Field( None, description="Model override for this session" ) - class ChatMessage(BaseModel): id: str = Field(..., description="Message ID") type: str = Field(..., description="Message type (human|ai)") @@ -53,7 +52,6 @@ class ContextIndicator(BaseModel): default_factory=list, description="Note IDs used in context" ) - class SourceChatSessionResponse(BaseModel): id: str = Field(..., description="Session ID") title: str = Field(..., description="Session title") @@ -67,7 +65,6 @@ class SourceChatSessionResponse(BaseModel): None, description="Number of messages in session" ) - class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse): messages: List[ChatMessage] = Field( default_factory=list, description="Session messages" @@ -76,14 +73,12 @@ class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse): None, description="Context indicators from last response" ) - class SendMessageRequest(BaseModel): message: str = Field(..., description="User message content") model_override: Optional[str] = Field( None, description="Optional model override for this message" ) - class SuccessResponse(BaseModel): success: bool = Field(True, description="Operation success status") message: str = Field(..., description="Success message") @@ -156,11 +151,19 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc sessions = [] for relation in relations: - session_id = relation.get("in") - if session_id: - session_result = await repo_query(f"SELECT * FROM {session_id}") + session_id_raw = relation.get("in") + if session_id_raw: + session_id = str(session_id_raw) + + session_result = await repo_query(f"SELECT * FROM {session_id_raw}") if session_result and len(session_result) > 0: session_data = session_result[0] + + # Get message count from LangGraph state + msg_count = await get_session_message_count( + source_chat_graph, session_id + ) + sessions.append( SourceChatSessionResponse( id=session_data.get("id") or "", @@ -169,7 +172,7 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc model_override=session_data.get("model_override"), created=str(session_data.get("created")), updated=str(session_data.get("updated")), - message_count=0, # TODO: Add message count if needed + message_count=msg_count, ) ) @@ -228,8 +231,8 @@ async def get_source_chat_session( ) # Get session state from LangGraph to retrieve messages - thread_state = source_chat_graph.get_state( - config=RunnableConfig(configurable={"thread_id": session_id}) + thread_state = await source_chat_graph.aget_state( + config=RunnableConfig(configurable={"thread_id": full_session_id}) ) # Extract messages from state @@ -331,6 +334,9 @@ async def update_source_chat_session( await session.save() + # Get message count from LangGraph state + msg_count = await get_session_message_count(source_chat_graph, full_session_id) + return SourceChatSessionResponse( id=session.id or "", title=session.title or "Untitled Session", @@ -338,7 +344,7 @@ async def update_source_chat_session( model_override=getattr(session, "model_override", None), created=str(session.created), updated=str(session.updated), - message_count=0, + message_count=msg_count, ) except NotFoundError: raise HTTPException(status_code=404, detail="Source or session not found") @@ -410,7 +416,7 @@ async def stream_source_chat_response( """Stream the source chat response as Server-Sent Events.""" try: # Get current state - current_state = source_chat_graph.get_state( + current_state = await source_chat_graph.aget_state( config=RunnableConfig(configurable={"thread_id": session_id}) ) @@ -519,7 +525,7 @@ async def send_message_to_source_chat( # Return streaming response return StreamingResponse( stream_source_chat_response( - session_id=session_id, + session_id=full_session_id, source_id=full_source_id, message=request.message, model_override=model_override, diff --git a/open_notebook/utils/graph_utils.py b/open_notebook/utils/graph_utils.py new file mode 100644 index 0000000..02aae2e --- /dev/null +++ b/open_notebook/utils/graph_utils.py @@ -0,0 +1,18 @@ +from langchain_core.runnables import RunnableConfig +from loguru import logger + +async def get_session_message_count(graph, session_id: str) -> int: + """Get message count from LangGraph state, returns 0 on error.""" + try: + thread_state = await graph.aget_state( # async version + config=RunnableConfig(configurable={"thread_id": session_id}) + ) + if ( + thread_state + and thread_state.values + and "messages" in thread_state.values + ): + return len(thread_state.values["messages"]) + except Exception as e: + logger.warning(f"Could not fetch message count for session {session_id}: {e}") + return 0