feat: message counting for chat sessions (#430)

This commit is contained in:
Fauzira Alpiandi 2026-01-30 09:00:22 +07:00 committed by GitHub
parent 5621066123
commit 9adf70d18d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 69 additions and 33 deletions

View file

@ -13,6 +13,7 @@ from open_notebook.exceptions import (
NotFoundError,
)
from open_notebook.graphs.chat import graph as chat_graph
from open_notebook.utils.graph_utils import get_session_message_count
router = APIRouter()
@ -102,20 +103,28 @@ async def get_sessions(notebook_id: str = Query(..., description="Notebook ID"))
raise HTTPException(status_code=404, detail="Notebook not found")
# Get sessions for this notebook
sessions = await notebook.get_chat_sessions()
sessions_list = await notebook.get_chat_sessions()
return [
ChatSessionResponse(
id=session.id or "",
title=session.title or "Untitled Session",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0, # TODO: Add message count if needed
model_override=getattr(session, "model_override", None),
results = []
for session in sessions_list:
session_id = str(session.id)
# Get message count from LangGraph state
msg_count = await get_session_message_count(chat_graph, session_id)
results.append(
ChatSessionResponse(
id=session.id or "",
title=session.title or "Untitled Session",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=msg_count,
model_override=getattr(session, "model_override", None),
)
)
for session in sessions
]
return results
except NotFoundError:
raise HTTPException(status_code=404, detail="Notebook not found")
except Exception as e:
@ -181,8 +190,8 @@ async def get_session(session_id: str):
raise HTTPException(status_code=404, detail="Session not found")
# Get session state from LangGraph to retrieve messages
thread_state = chat_graph.get_state(
config=RunnableConfig(configurable={"thread_id": session_id})
thread_state = await chat_graph.aget_state(
config=RunnableConfig(configurable={"thread_id": full_session_id})
)
# Extract messages from state
@ -273,13 +282,16 @@ async def update_session(session_id: str, request: UpdateSessionRequest):
)
notebook_id = notebook_query[0]["out"] if notebook_query else None
# Get message count from LangGraph state
msg_count = await get_session_message_count(chat_graph, full_session_id)
return ChatSessionResponse(
id=session.id or "",
title=session.title or "",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0,
message_count=msg_count,
model_override=session.model_override,
)
except NotFoundError:
@ -336,8 +348,8 @@ async def execute_chat(request: ExecuteChatRequest):
)
# Get current state
current_state = chat_graph.get_state(
config=RunnableConfig(configurable={"thread_id": request.session_id})
current_state = await chat_graph.aget_state(
config=RunnableConfig(configurable={"thread_id": full_session_id})
)
# Prepare state for execution
@ -357,7 +369,7 @@ async def execute_chat(request: ExecuteChatRequest):
input=state_values, # type: ignore[arg-type]
config=RunnableConfig(
configurable={
"thread_id": request.session_id,
"thread_id": full_session_id,
"model_id": model_override,
}
),

View file

@ -15,6 +15,7 @@ from open_notebook.exceptions import (
NotFoundError,
)
from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph
from open_notebook.utils.graph_utils import get_session_message_count
router = APIRouter()
@ -27,14 +28,12 @@ class CreateSourceChatSessionRequest(BaseModel):
None, description="Optional model override for this session"
)
class UpdateSourceChatSessionRequest(BaseModel):
title: Optional[str] = Field(None, description="New session title")
model_override: Optional[str] = Field(
None, description="Model override for this session"
)
class ChatMessage(BaseModel):
id: str = Field(..., description="Message ID")
type: str = Field(..., description="Message type (human|ai)")
@ -53,7 +52,6 @@ class ContextIndicator(BaseModel):
default_factory=list, description="Note IDs used in context"
)
class SourceChatSessionResponse(BaseModel):
id: str = Field(..., description="Session ID")
title: str = Field(..., description="Session title")
@ -67,7 +65,6 @@ class SourceChatSessionResponse(BaseModel):
None, description="Number of messages in session"
)
class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse):
messages: List[ChatMessage] = Field(
default_factory=list, description="Session messages"
@ -76,14 +73,12 @@ class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse):
None, description="Context indicators from last response"
)
class SendMessageRequest(BaseModel):
message: str = Field(..., description="User message content")
model_override: Optional[str] = Field(
None, description="Optional model override for this message"
)
class SuccessResponse(BaseModel):
success: bool = Field(True, description="Operation success status")
message: str = Field(..., description="Success message")
@ -156,11 +151,19 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc
sessions = []
for relation in relations:
session_id = relation.get("in")
if session_id:
session_result = await repo_query(f"SELECT * FROM {session_id}")
session_id_raw = relation.get("in")
if session_id_raw:
session_id = str(session_id_raw)
session_result = await repo_query(f"SELECT * FROM {session_id_raw}")
if session_result and len(session_result) > 0:
session_data = session_result[0]
# Get message count from LangGraph state
msg_count = await get_session_message_count(
source_chat_graph, session_id
)
sessions.append(
SourceChatSessionResponse(
id=session_data.get("id") or "",
@ -169,7 +172,7 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc
model_override=session_data.get("model_override"),
created=str(session_data.get("created")),
updated=str(session_data.get("updated")),
message_count=0, # TODO: Add message count if needed
message_count=msg_count,
)
)
@ -228,8 +231,8 @@ async def get_source_chat_session(
)
# Get session state from LangGraph to retrieve messages
thread_state = source_chat_graph.get_state(
config=RunnableConfig(configurable={"thread_id": session_id})
thread_state = await source_chat_graph.aget_state(
config=RunnableConfig(configurable={"thread_id": full_session_id})
)
# Extract messages from state
@ -331,6 +334,9 @@ async def update_source_chat_session(
await session.save()
# Get message count from LangGraph state
msg_count = await get_session_message_count(source_chat_graph, full_session_id)
return SourceChatSessionResponse(
id=session.id or "",
title=session.title or "Untitled Session",
@ -338,7 +344,7 @@ async def update_source_chat_session(
model_override=getattr(session, "model_override", None),
created=str(session.created),
updated=str(session.updated),
message_count=0,
message_count=msg_count,
)
except NotFoundError:
raise HTTPException(status_code=404, detail="Source or session not found")
@ -410,7 +416,7 @@ async def stream_source_chat_response(
"""Stream the source chat response as Server-Sent Events."""
try:
# Get current state
current_state = source_chat_graph.get_state(
current_state = await source_chat_graph.aget_state(
config=RunnableConfig(configurable={"thread_id": session_id})
)
@ -519,7 +525,7 @@ async def send_message_to_source_chat(
# Return streaming response
return StreamingResponse(
stream_source_chat_response(
session_id=session_id,
session_id=full_session_id,
source_id=full_source_id,
message=request.message,
model_override=model_override,