feat: message counting for chat sessions (#430)

This commit is contained in:
Fauzira Alpiandi 2026-01-30 09:00:22 +07:00 committed by GitHub
parent 5621066123
commit 9adf70d18d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 69 additions and 33 deletions

View file

@ -13,6 +13,7 @@ from open_notebook.exceptions import (
NotFoundError,
)
from open_notebook.graphs.chat import graph as chat_graph
from open_notebook.utils.graph_utils import get_session_message_count
router = APIRouter()
@ -102,20 +103,28 @@ async def get_sessions(notebook_id: str = Query(..., description="Notebook ID"))
raise HTTPException(status_code=404, detail="Notebook not found")
# Get sessions for this notebook
sessions = await notebook.get_chat_sessions()
sessions_list = await notebook.get_chat_sessions()
return [
ChatSessionResponse(
id=session.id or "",
title=session.title or "Untitled Session",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0, # TODO: Add message count if needed
model_override=getattr(session, "model_override", None),
results = []
for session in sessions_list:
session_id = str(session.id)
# Get message count from LangGraph state
msg_count = await get_session_message_count(chat_graph, session_id)
results.append(
ChatSessionResponse(
id=session.id or "",
title=session.title or "Untitled Session",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=msg_count,
model_override=getattr(session, "model_override", None),
)
)
for session in sessions
]
return results
except NotFoundError:
raise HTTPException(status_code=404, detail="Notebook not found")
except Exception as e:
@ -181,8 +190,8 @@ async def get_session(session_id: str):
raise HTTPException(status_code=404, detail="Session not found")
# Get session state from LangGraph to retrieve messages
thread_state = chat_graph.get_state(
config=RunnableConfig(configurable={"thread_id": session_id})
thread_state = await chat_graph.aget_state(
config=RunnableConfig(configurable={"thread_id": full_session_id})
)
# Extract messages from state
@ -273,13 +282,16 @@ async def update_session(session_id: str, request: UpdateSessionRequest):
)
notebook_id = notebook_query[0]["out"] if notebook_query else None
# Get message count from LangGraph state
msg_count = await get_session_message_count(chat_graph, full_session_id)
return ChatSessionResponse(
id=session.id or "",
title=session.title or "",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0,
message_count=msg_count,
model_override=session.model_override,
)
except NotFoundError:
@ -336,8 +348,8 @@ async def execute_chat(request: ExecuteChatRequest):
)
# Get current state
current_state = chat_graph.get_state(
config=RunnableConfig(configurable={"thread_id": request.session_id})
current_state = await chat_graph.aget_state(
config=RunnableConfig(configurable={"thread_id": full_session_id})
)
# Prepare state for execution
@ -357,7 +369,7 @@ async def execute_chat(request: ExecuteChatRequest):
input=state_values, # type: ignore[arg-type]
config=RunnableConfig(
configurable={
"thread_id": request.session_id,
"thread_id": full_session_id,
"model_id": model_override,
}
),