mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-03 05:40:36 +00:00
feat: message counting for chat sessions (#430)
This commit is contained in:
parent
5621066123
commit
9adf70d18d
3 changed files with 69 additions and 33 deletions
|
|
@ -13,6 +13,7 @@ from open_notebook.exceptions import (
|
|||
NotFoundError,
|
||||
)
|
||||
from open_notebook.graphs.chat import graph as chat_graph
|
||||
from open_notebook.utils.graph_utils import get_session_message_count
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -102,20 +103,28 @@ async def get_sessions(notebook_id: str = Query(..., description="Notebook ID"))
|
|||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
# Get sessions for this notebook
|
||||
sessions = await notebook.get_chat_sessions()
|
||||
sessions_list = await notebook.get_chat_sessions()
|
||||
|
||||
return [
|
||||
ChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
notebook_id=notebook_id,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0, # TODO: Add message count if needed
|
||||
model_override=getattr(session, "model_override", None),
|
||||
results = []
|
||||
for session in sessions_list:
|
||||
session_id = str(session.id)
|
||||
|
||||
# Get message count from LangGraph state
|
||||
msg_count = await get_session_message_count(chat_graph, session_id)
|
||||
|
||||
results.append(
|
||||
ChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
notebook_id=notebook_id,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=msg_count,
|
||||
model_override=getattr(session, "model_override", None),
|
||||
)
|
||||
)
|
||||
for session in sessions
|
||||
]
|
||||
|
||||
return results
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
except Exception as e:
|
||||
|
|
@ -181,8 +190,8 @@ async def get_session(session_id: str):
|
|||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Get session state from LangGraph to retrieve messages
|
||||
thread_state = chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": session_id})
|
||||
thread_state = await chat_graph.aget_state(
|
||||
config=RunnableConfig(configurable={"thread_id": full_session_id})
|
||||
)
|
||||
|
||||
# Extract messages from state
|
||||
|
|
@ -273,13 +282,16 @@ async def update_session(session_id: str, request: UpdateSessionRequest):
|
|||
)
|
||||
notebook_id = notebook_query[0]["out"] if notebook_query else None
|
||||
|
||||
# Get message count from LangGraph state
|
||||
msg_count = await get_session_message_count(chat_graph, full_session_id)
|
||||
|
||||
return ChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "",
|
||||
notebook_id=notebook_id,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0,
|
||||
message_count=msg_count,
|
||||
model_override=session.model_override,
|
||||
)
|
||||
except NotFoundError:
|
||||
|
|
@ -336,8 +348,8 @@ async def execute_chat(request: ExecuteChatRequest):
|
|||
)
|
||||
|
||||
# Get current state
|
||||
current_state = chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": request.session_id})
|
||||
current_state = await chat_graph.aget_state(
|
||||
config=RunnableConfig(configurable={"thread_id": full_session_id})
|
||||
)
|
||||
|
||||
# Prepare state for execution
|
||||
|
|
@ -357,7 +369,7 @@ async def execute_chat(request: ExecuteChatRequest):
|
|||
input=state_values, # type: ignore[arg-type]
|
||||
config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": request.session_id,
|
||||
"thread_id": full_session_id,
|
||||
"model_id": model_override,
|
||||
}
|
||||
),
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from open_notebook.exceptions import (
|
|||
NotFoundError,
|
||||
)
|
||||
from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph
|
||||
from open_notebook.utils.graph_utils import get_session_message_count
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -27,14 +28,12 @@ class CreateSourceChatSessionRequest(BaseModel):
|
|||
None, description="Optional model override for this session"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSourceChatSessionRequest(BaseModel):
|
||||
title: Optional[str] = Field(None, description="New session title")
|
||||
model_override: Optional[str] = Field(
|
||||
None, description="Model override for this session"
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
id: str = Field(..., description="Message ID")
|
||||
type: str = Field(..., description="Message type (human|ai)")
|
||||
|
|
@ -53,7 +52,6 @@ class ContextIndicator(BaseModel):
|
|||
default_factory=list, description="Note IDs used in context"
|
||||
)
|
||||
|
||||
|
||||
class SourceChatSessionResponse(BaseModel):
|
||||
id: str = Field(..., description="Session ID")
|
||||
title: str = Field(..., description="Session title")
|
||||
|
|
@ -67,7 +65,6 @@ class SourceChatSessionResponse(BaseModel):
|
|||
None, description="Number of messages in session"
|
||||
)
|
||||
|
||||
|
||||
class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse):
|
||||
messages: List[ChatMessage] = Field(
|
||||
default_factory=list, description="Session messages"
|
||||
|
|
@ -76,14 +73,12 @@ class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse):
|
|||
None, description="Context indicators from last response"
|
||||
)
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
message: str = Field(..., description="User message content")
|
||||
model_override: Optional[str] = Field(
|
||||
None, description="Optional model override for this message"
|
||||
)
|
||||
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
success: bool = Field(True, description="Operation success status")
|
||||
message: str = Field(..., description="Success message")
|
||||
|
|
@ -156,11 +151,19 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc
|
|||
|
||||
sessions = []
|
||||
for relation in relations:
|
||||
session_id = relation.get("in")
|
||||
if session_id:
|
||||
session_result = await repo_query(f"SELECT * FROM {session_id}")
|
||||
session_id_raw = relation.get("in")
|
||||
if session_id_raw:
|
||||
session_id = str(session_id_raw)
|
||||
|
||||
session_result = await repo_query(f"SELECT * FROM {session_id_raw}")
|
||||
if session_result and len(session_result) > 0:
|
||||
session_data = session_result[0]
|
||||
|
||||
# Get message count from LangGraph state
|
||||
msg_count = await get_session_message_count(
|
||||
source_chat_graph, session_id
|
||||
)
|
||||
|
||||
sessions.append(
|
||||
SourceChatSessionResponse(
|
||||
id=session_data.get("id") or "",
|
||||
|
|
@ -169,7 +172,7 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc
|
|||
model_override=session_data.get("model_override"),
|
||||
created=str(session_data.get("created")),
|
||||
updated=str(session_data.get("updated")),
|
||||
message_count=0, # TODO: Add message count if needed
|
||||
message_count=msg_count,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -228,8 +231,8 @@ async def get_source_chat_session(
|
|||
)
|
||||
|
||||
# Get session state from LangGraph to retrieve messages
|
||||
thread_state = source_chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": session_id})
|
||||
thread_state = await source_chat_graph.aget_state(
|
||||
config=RunnableConfig(configurable={"thread_id": full_session_id})
|
||||
)
|
||||
|
||||
# Extract messages from state
|
||||
|
|
@ -331,6 +334,9 @@ async def update_source_chat_session(
|
|||
|
||||
await session.save()
|
||||
|
||||
# Get message count from LangGraph state
|
||||
msg_count = await get_session_message_count(source_chat_graph, full_session_id)
|
||||
|
||||
return SourceChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
|
|
@ -338,7 +344,7 @@ async def update_source_chat_session(
|
|||
model_override=getattr(session, "model_override", None),
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0,
|
||||
message_count=msg_count,
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Source or session not found")
|
||||
|
|
@ -410,7 +416,7 @@ async def stream_source_chat_response(
|
|||
"""Stream the source chat response as Server-Sent Events."""
|
||||
try:
|
||||
# Get current state
|
||||
current_state = source_chat_graph.get_state(
|
||||
current_state = await source_chat_graph.aget_state(
|
||||
config=RunnableConfig(configurable={"thread_id": session_id})
|
||||
)
|
||||
|
||||
|
|
@ -519,7 +525,7 @@ async def send_message_to_source_chat(
|
|||
# Return streaming response
|
||||
return StreamingResponse(
|
||||
stream_source_chat_response(
|
||||
session_id=session_id,
|
||||
session_id=full_session_id,
|
||||
source_id=full_source_id,
|
||||
message=request.message,
|
||||
model_override=model_override,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue