mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-02 21:30:38 +00:00
feat: message counting for chat sessions (#430)
This commit is contained in:
parent
5621066123
commit
9adf70d18d
3 changed files with 69 additions and 33 deletions
|
|
@ -13,6 +13,7 @@ from open_notebook.exceptions import (
|
|||
NotFoundError,
|
||||
)
|
||||
from open_notebook.graphs.chat import graph as chat_graph
|
||||
from open_notebook.utils.graph_utils import get_session_message_count
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -102,20 +103,28 @@ async def get_sessions(notebook_id: str = Query(..., description="Notebook ID"))
|
|||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
# Get sessions for this notebook
|
||||
sessions = await notebook.get_chat_sessions()
|
||||
sessions_list = await notebook.get_chat_sessions()
|
||||
|
||||
return [
|
||||
ChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
notebook_id=notebook_id,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0, # TODO: Add message count if needed
|
||||
model_override=getattr(session, "model_override", None),
|
||||
results = []
|
||||
for session in sessions_list:
|
||||
session_id = str(session.id)
|
||||
|
||||
# Get message count from LangGraph state
|
||||
msg_count = await get_session_message_count(chat_graph, session_id)
|
||||
|
||||
results.append(
|
||||
ChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
notebook_id=notebook_id,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=msg_count,
|
||||
model_override=getattr(session, "model_override", None),
|
||||
)
|
||||
)
|
||||
for session in sessions
|
||||
]
|
||||
|
||||
return results
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
except Exception as e:
|
||||
|
|
@ -181,8 +190,8 @@ async def get_session(session_id: str):
|
|||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Get session state from LangGraph to retrieve messages
|
||||
thread_state = chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": session_id})
|
||||
thread_state = await chat_graph.aget_state(
|
||||
config=RunnableConfig(configurable={"thread_id": full_session_id})
|
||||
)
|
||||
|
||||
# Extract messages from state
|
||||
|
|
@ -273,13 +282,16 @@ async def update_session(session_id: str, request: UpdateSessionRequest):
|
|||
)
|
||||
notebook_id = notebook_query[0]["out"] if notebook_query else None
|
||||
|
||||
# Get message count from LangGraph state
|
||||
msg_count = await get_session_message_count(chat_graph, full_session_id)
|
||||
|
||||
return ChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "",
|
||||
notebook_id=notebook_id,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0,
|
||||
message_count=msg_count,
|
||||
model_override=session.model_override,
|
||||
)
|
||||
except NotFoundError:
|
||||
|
|
@ -336,8 +348,8 @@ async def execute_chat(request: ExecuteChatRequest):
|
|||
)
|
||||
|
||||
# Get current state
|
||||
current_state = chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": request.session_id})
|
||||
current_state = await chat_graph.aget_state(
|
||||
config=RunnableConfig(configurable={"thread_id": full_session_id})
|
||||
)
|
||||
|
||||
# Prepare state for execution
|
||||
|
|
@ -357,7 +369,7 @@ async def execute_chat(request: ExecuteChatRequest):
|
|||
input=state_values, # type: ignore[arg-type]
|
||||
config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": request.session_id,
|
||||
"thread_id": full_session_id,
|
||||
"model_id": model_override,
|
||||
}
|
||||
),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue