mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-03 13:50:31 +00:00
545 lines
20 KiB
Python
545 lines
20 KiB
Python
import asyncio
|
|
import json
|
|
from typing import AsyncGenerator, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, Path
|
|
from fastapi.responses import StreamingResponse
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from loguru import logger
|
|
from pydantic import BaseModel, Field
|
|
|
|
from open_notebook.database.repository import ensure_record_id, repo_query
|
|
from open_notebook.domain.notebook import ChatSession, Source
|
|
from open_notebook.exceptions import (
|
|
NotFoundError,
|
|
)
|
|
from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph
|
|
from open_notebook.utils.graph_utils import get_session_message_count
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# Request/Response models
|
|
class CreateSourceChatSessionRequest(BaseModel):
|
|
source_id: str = Field(..., description="Source ID to create chat session for")
|
|
title: Optional[str] = Field(None, description="Optional session title")
|
|
model_override: Optional[str] = Field(
|
|
None, description="Optional model override for this session"
|
|
)
|
|
|
|
class UpdateSourceChatSessionRequest(BaseModel):
|
|
title: Optional[str] = Field(None, description="New session title")
|
|
model_override: Optional[str] = Field(
|
|
None, description="Model override for this session"
|
|
)
|
|
|
|
class ChatMessage(BaseModel):
|
|
id: str = Field(..., description="Message ID")
|
|
type: str = Field(..., description="Message type (human|ai)")
|
|
content: str = Field(..., description="Message content")
|
|
timestamp: Optional[str] = Field(None, description="Message timestamp")
|
|
|
|
|
|
class ContextIndicator(BaseModel):
|
|
sources: List[str] = Field(
|
|
default_factory=list, description="Source IDs used in context"
|
|
)
|
|
insights: List[str] = Field(
|
|
default_factory=list, description="Insight IDs used in context"
|
|
)
|
|
notes: List[str] = Field(
|
|
default_factory=list, description="Note IDs used in context"
|
|
)
|
|
|
|
class SourceChatSessionResponse(BaseModel):
|
|
id: str = Field(..., description="Session ID")
|
|
title: str = Field(..., description="Session title")
|
|
source_id: str = Field(..., description="Source ID")
|
|
model_override: Optional[str] = Field(
|
|
None, description="Model override for this session"
|
|
)
|
|
created: str = Field(..., description="Creation timestamp")
|
|
updated: str = Field(..., description="Last update timestamp")
|
|
message_count: Optional[int] = Field(
|
|
None, description="Number of messages in session"
|
|
)
|
|
|
|
class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse):
|
|
messages: List[ChatMessage] = Field(
|
|
default_factory=list, description="Session messages"
|
|
)
|
|
context_indicators: Optional[ContextIndicator] = Field(
|
|
None, description="Context indicators from last response"
|
|
)
|
|
|
|
class SendMessageRequest(BaseModel):
|
|
message: str = Field(..., description="User message content")
|
|
model_override: Optional[str] = Field(
|
|
None, description="Optional model override for this message"
|
|
)
|
|
|
|
class SuccessResponse(BaseModel):
|
|
success: bool = Field(True, description="Operation success status")
|
|
message: str = Field(..., description="Success message")
|
|
|
|
|
|
@router.post(
|
|
"/sources/{source_id}/chat/sessions", response_model=SourceChatSessionResponse
|
|
)
|
|
async def create_source_chat_session(
|
|
request: CreateSourceChatSessionRequest,
|
|
source_id: str = Path(..., description="Source ID"),
|
|
):
|
|
"""Create a new chat session for a source."""
|
|
try:
|
|
# Verify source exists
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:") else f"source:{source_id}"
|
|
)
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
|
|
# Create new session with model_override support
|
|
session = ChatSession(
|
|
title=request.title or f"Source Chat {asyncio.get_event_loop().time():.0f}",
|
|
model_override=request.model_override,
|
|
)
|
|
await session.save()
|
|
|
|
# Relate session to source using "refers_to" relation
|
|
await session.relate("refers_to", full_source_id)
|
|
|
|
return SourceChatSessionResponse(
|
|
id=session.id or "",
|
|
title=session.title or "Untitled Session",
|
|
source_id=source_id,
|
|
model_override=session.model_override,
|
|
created=str(session.created),
|
|
updated=str(session.updated),
|
|
message_count=0,
|
|
)
|
|
except NotFoundError:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
except Exception as e:
|
|
logger.error(f"Error creating source chat session: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error creating source chat session: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/sources/{source_id}/chat/sessions", response_model=List[SourceChatSessionResponse]
|
|
)
|
|
async def get_source_chat_sessions(source_id: str = Path(..., description="Source ID")):
|
|
"""Get all chat sessions for a source."""
|
|
try:
|
|
# Verify source exists
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:") else f"source:{source_id}"
|
|
)
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
|
|
# Get sessions that refer to this source - first get relations, then sessions
|
|
relations = await repo_query(
|
|
"SELECT in FROM refers_to WHERE out = $source_id",
|
|
{"source_id": ensure_record_id(full_source_id)},
|
|
)
|
|
|
|
sessions = []
|
|
for relation in relations:
|
|
session_id_raw = relation.get("in")
|
|
if session_id_raw:
|
|
session_id = str(session_id_raw)
|
|
|
|
session_result = await repo_query(f"SELECT * FROM {session_id_raw}")
|
|
if session_result and len(session_result) > 0:
|
|
session_data = session_result[0]
|
|
|
|
# Get message count from LangGraph state
|
|
msg_count = await get_session_message_count(
|
|
source_chat_graph, session_id
|
|
)
|
|
|
|
sessions.append(
|
|
SourceChatSessionResponse(
|
|
id=session_data.get("id") or "",
|
|
title=session_data.get("title") or "Untitled Session",
|
|
source_id=source_id,
|
|
model_override=session_data.get("model_override"),
|
|
created=str(session_data.get("created")),
|
|
updated=str(session_data.get("updated")),
|
|
message_count=msg_count,
|
|
)
|
|
)
|
|
|
|
# Sort sessions by created date (newest first)
|
|
sessions.sort(key=lambda x: x.created, reverse=True)
|
|
return sessions
|
|
except NotFoundError:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source chat sessions: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error fetching source chat sessions: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/sources/{source_id}/chat/sessions/{session_id}",
|
|
response_model=SourceChatSessionWithMessagesResponse,
|
|
)
|
|
async def get_source_chat_session(
|
|
source_id: str = Path(..., description="Source ID"),
|
|
session_id: str = Path(..., description="Session ID"),
|
|
):
|
|
"""Get a specific source chat session with its messages."""
|
|
try:
|
|
# Verify source exists
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:") else f"source:{source_id}"
|
|
)
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
|
|
# Get session
|
|
full_session_id = (
|
|
session_id
|
|
if session_id.startswith("chat_session:")
|
|
else f"chat_session:{session_id}"
|
|
)
|
|
session = await ChatSession.get(full_session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Verify session is related to this source
|
|
relation_query = await repo_query(
|
|
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
|
{
|
|
"session_id": ensure_record_id(full_session_id),
|
|
"source_id": ensure_record_id(full_source_id),
|
|
},
|
|
)
|
|
|
|
if not relation_query:
|
|
raise HTTPException(
|
|
status_code=404, detail="Session not found for this source"
|
|
)
|
|
|
|
# Get session state from LangGraph to retrieve messages
|
|
thread_state = await source_chat_graph.aget_state(
|
|
config=RunnableConfig(configurable={"thread_id": full_session_id})
|
|
)
|
|
|
|
# Extract messages from state
|
|
messages: list[ChatMessage] = []
|
|
context_indicators = None
|
|
|
|
if thread_state and thread_state.values:
|
|
# Extract messages
|
|
if "messages" in thread_state.values:
|
|
for msg in thread_state.values["messages"]:
|
|
messages.append(
|
|
ChatMessage(
|
|
id=getattr(msg, "id", f"msg_{len(messages)}"),
|
|
type=msg.type if hasattr(msg, "type") else "unknown",
|
|
content=msg.content
|
|
if hasattr(msg, "content")
|
|
else str(msg),
|
|
timestamp=None, # LangChain messages don't have timestamps by default
|
|
)
|
|
)
|
|
|
|
# Extract context indicators from the last state
|
|
if "context_indicators" in thread_state.values:
|
|
context_data = thread_state.values["context_indicators"]
|
|
context_indicators = ContextIndicator(
|
|
sources=context_data.get("sources", []),
|
|
insights=context_data.get("insights", []),
|
|
notes=context_data.get("notes", []),
|
|
)
|
|
|
|
return SourceChatSessionWithMessagesResponse(
|
|
id=session.id or "",
|
|
title=session.title or "Untitled Session",
|
|
source_id=source_id,
|
|
model_override=getattr(session, "model_override", None),
|
|
created=str(session.created),
|
|
updated=str(session.updated),
|
|
message_count=len(messages),
|
|
messages=messages,
|
|
context_indicators=context_indicators,
|
|
)
|
|
except NotFoundError:
|
|
raise HTTPException(status_code=404, detail="Source or session not found")
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source chat session: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error fetching source chat session: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.put(
|
|
"/sources/{source_id}/chat/sessions/{session_id}",
|
|
response_model=SourceChatSessionResponse,
|
|
)
|
|
async def update_source_chat_session(
|
|
request: UpdateSourceChatSessionRequest,
|
|
source_id: str = Path(..., description="Source ID"),
|
|
session_id: str = Path(..., description="Session ID"),
|
|
):
|
|
"""Update source chat session title and/or model override."""
|
|
try:
|
|
# Verify source exists
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:") else f"source:{source_id}"
|
|
)
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
|
|
# Get session
|
|
full_session_id = (
|
|
session_id
|
|
if session_id.startswith("chat_session:")
|
|
else f"chat_session:{session_id}"
|
|
)
|
|
session = await ChatSession.get(full_session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Verify session is related to this source
|
|
relation_query = await repo_query(
|
|
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
|
{
|
|
"session_id": ensure_record_id(full_session_id),
|
|
"source_id": ensure_record_id(full_source_id),
|
|
},
|
|
)
|
|
|
|
if not relation_query:
|
|
raise HTTPException(
|
|
status_code=404, detail="Session not found for this source"
|
|
)
|
|
|
|
# Update session fields
|
|
if request.title is not None:
|
|
session.title = request.title
|
|
if request.model_override is not None:
|
|
session.model_override = request.model_override
|
|
|
|
await session.save()
|
|
|
|
# Get message count from LangGraph state
|
|
msg_count = await get_session_message_count(source_chat_graph, full_session_id)
|
|
|
|
return SourceChatSessionResponse(
|
|
id=session.id or "",
|
|
title=session.title or "Untitled Session",
|
|
source_id=source_id,
|
|
model_override=getattr(session, "model_override", None),
|
|
created=str(session.created),
|
|
updated=str(session.updated),
|
|
message_count=msg_count,
|
|
)
|
|
except NotFoundError:
|
|
raise HTTPException(status_code=404, detail="Source or session not found")
|
|
except Exception as e:
|
|
logger.error(f"Error updating source chat session: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error updating source chat session: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.delete(
|
|
"/sources/{source_id}/chat/sessions/{session_id}", response_model=SuccessResponse
|
|
)
|
|
async def delete_source_chat_session(
|
|
source_id: str = Path(..., description="Source ID"),
|
|
session_id: str = Path(..., description="Session ID"),
|
|
):
|
|
"""Delete a source chat session."""
|
|
try:
|
|
# Verify source exists
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:") else f"source:{source_id}"
|
|
)
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
|
|
# Get session
|
|
full_session_id = (
|
|
session_id
|
|
if session_id.startswith("chat_session:")
|
|
else f"chat_session:{session_id}"
|
|
)
|
|
session = await ChatSession.get(full_session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Verify session is related to this source
|
|
relation_query = await repo_query(
|
|
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
|
{
|
|
"session_id": ensure_record_id(full_session_id),
|
|
"source_id": ensure_record_id(full_source_id),
|
|
},
|
|
)
|
|
|
|
if not relation_query:
|
|
raise HTTPException(
|
|
status_code=404, detail="Session not found for this source"
|
|
)
|
|
|
|
await session.delete()
|
|
|
|
return SuccessResponse(
|
|
success=True, message="Source chat session deleted successfully"
|
|
)
|
|
except NotFoundError:
|
|
raise HTTPException(status_code=404, detail="Source or session not found")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting source chat session: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error deleting source chat session: {str(e)}"
|
|
)
|
|
|
|
|
|
async def stream_source_chat_response(
|
|
session_id: str, source_id: str, message: str, model_override: Optional[str] = None
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Stream the source chat response as Server-Sent Events."""
|
|
try:
|
|
# Get current state
|
|
current_state = await source_chat_graph.aget_state(
|
|
config=RunnableConfig(configurable={"thread_id": session_id})
|
|
)
|
|
|
|
# Prepare state for execution
|
|
state_values = current_state.values if current_state else {}
|
|
state_values["messages"] = state_values.get("messages", [])
|
|
state_values["source_id"] = source_id
|
|
state_values["model_override"] = model_override
|
|
|
|
# Add user message to state
|
|
user_message = HumanMessage(content=message)
|
|
state_values["messages"].append(user_message)
|
|
|
|
# Send user message event
|
|
user_event = {"type": "user_message", "content": message, "timestamp": None}
|
|
yield f"data: {json.dumps(user_event)}\n\n"
|
|
|
|
# Execute source chat graph synchronously (like notebook chat does)
|
|
result = source_chat_graph.invoke(
|
|
input=state_values, # type: ignore[arg-type]
|
|
config=RunnableConfig(
|
|
configurable={"thread_id": session_id, "model_id": model_override}
|
|
),
|
|
)
|
|
|
|
# Stream the complete AI response
|
|
if "messages" in result:
|
|
for msg in result["messages"]:
|
|
if hasattr(msg, "type") and msg.type == "ai":
|
|
ai_event = {
|
|
"type": "ai_message",
|
|
"content": msg.content if hasattr(msg, "content") else str(msg),
|
|
"timestamp": None,
|
|
}
|
|
yield f"data: {json.dumps(ai_event)}\n\n"
|
|
|
|
# Stream context indicators
|
|
if "context_indicators" in result:
|
|
context_event = {
|
|
"type": "context_indicators",
|
|
"data": result["context_indicators"],
|
|
}
|
|
yield f"data: {json.dumps(context_event)}\n\n"
|
|
|
|
# Send completion signal
|
|
completion_event = {"type": "complete"}
|
|
yield f"data: {json.dumps(completion_event)}\n\n"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in source chat streaming: {str(e)}")
|
|
error_event = {"type": "error", "message": str(e)}
|
|
yield f"data: {json.dumps(error_event)}\n\n"
|
|
|
|
|
|
@router.post("/sources/{source_id}/chat/sessions/{session_id}/messages")
|
|
async def send_message_to_source_chat(
|
|
request: SendMessageRequest,
|
|
source_id: str = Path(..., description="Source ID"),
|
|
session_id: str = Path(..., description="Session ID"),
|
|
):
|
|
"""Send a message to source chat session with SSE streaming response."""
|
|
try:
|
|
# Verify source exists
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:") else f"source:{source_id}"
|
|
)
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
raise HTTPException(status_code=404, detail="Source not found")
|
|
|
|
# Verify session exists and is related to source
|
|
full_session_id = (
|
|
session_id
|
|
if session_id.startswith("chat_session:")
|
|
else f"chat_session:{session_id}"
|
|
)
|
|
session = await ChatSession.get(full_session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Verify session is related to this source
|
|
relation_query = await repo_query(
|
|
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
|
{
|
|
"session_id": ensure_record_id(full_session_id),
|
|
"source_id": ensure_record_id(full_source_id),
|
|
},
|
|
)
|
|
|
|
if not relation_query:
|
|
raise HTTPException(
|
|
status_code=404, detail="Session not found for this source"
|
|
)
|
|
|
|
if not request.message:
|
|
raise HTTPException(status_code=400, detail="Message content is required")
|
|
|
|
# Determine model override (request override takes precedence over session override)
|
|
model_override = request.model_override or getattr(
|
|
session, "model_override", None
|
|
)
|
|
|
|
# Update session timestamp
|
|
await session.save()
|
|
|
|
# Return streaming response
|
|
return StreamingResponse(
|
|
stream_source_chat_response(
|
|
session_id=full_session_id,
|
|
source_id=full_source_id,
|
|
message=request.message,
|
|
model_override=model_override,
|
|
),
|
|
media_type="text/plain",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": "text/plain; charset=utf-8",
|
|
},
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error sending message to source chat: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Error sending message: {str(e)}")
|