mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-05 23:37:58 +00:00
Version 1 (#160)
New front-end Launch Chat API Manage Sources Enable re-embedding of all contents Sources can be added without a notebook now Improved settings Enable model selector on all chats Background processing for better experience Dark mode Improved Notes Improved Docs: - Remove all Streamlit references from documentation - Update deployment guides with React frontend setup - Fix Docker environment variables format (SURREAL_URL, SURREAL_PASSWORD) - Update docker image tag from :latest to :v1-latest - Change navigation references (Settings → Models to just Models) - Update development setup to include frontend npm commands - Add MIGRATION.md guide for users upgrading from Streamlit - Update quick-start guide with correct environment variables - Add port 5055 documentation for API access - Update project structure to reflect frontend/ directory - Remove outdated source-chat documentation files
This commit is contained in:
parent
124d7d110c
commit
b7e656a319
319 changed files with 46747 additions and 7408 deletions
446
api/routers/source_chat.py
Normal file
446
api/routers/source_chat.py
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from open_notebook.database.repository import ensure_record_id, repo_query
|
||||
from open_notebook.domain.notebook import ChatSession, Source
|
||||
from open_notebook.exceptions import (
|
||||
NotFoundError,
|
||||
)
|
||||
from open_notebook.graphs.source_chat import source_chat_graph as source_chat_graph
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Request/Response models
|
||||
class CreateSourceChatSessionRequest(BaseModel):
|
||||
source_id: str = Field(..., description="Source ID to create chat session for")
|
||||
title: Optional[str] = Field(None, description="Optional session title")
|
||||
model_override: Optional[str] = Field(None, description="Optional model override for this session")
|
||||
|
||||
class UpdateSourceChatSessionRequest(BaseModel):
|
||||
title: Optional[str] = Field(None, description="New session title")
|
||||
model_override: Optional[str] = Field(None, description="Model override for this session")
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
id: str = Field(..., description="Message ID")
|
||||
type: str = Field(..., description="Message type (human|ai)")
|
||||
content: str = Field(..., description="Message content")
|
||||
timestamp: Optional[str] = Field(None, description="Message timestamp")
|
||||
|
||||
class ContextIndicator(BaseModel):
|
||||
sources: List[str] = Field(default_factory=list, description="Source IDs used in context")
|
||||
insights: List[str] = Field(default_factory=list, description="Insight IDs used in context")
|
||||
notes: List[str] = Field(default_factory=list, description="Note IDs used in context")
|
||||
|
||||
class SourceChatSessionResponse(BaseModel):
|
||||
id: str = Field(..., description="Session ID")
|
||||
title: str = Field(..., description="Session title")
|
||||
source_id: str = Field(..., description="Source ID")
|
||||
model_override: Optional[str] = Field(None, description="Model override for this session")
|
||||
created: str = Field(..., description="Creation timestamp")
|
||||
updated: str = Field(..., description="Last update timestamp")
|
||||
message_count: Optional[int] = Field(None, description="Number of messages in session")
|
||||
|
||||
class SourceChatSessionWithMessagesResponse(SourceChatSessionResponse):
|
||||
messages: List[ChatMessage] = Field(default_factory=list, description="Session messages")
|
||||
context_indicators: Optional[ContextIndicator] = Field(None, description="Context indicators from last response")
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
message: str = Field(..., description="User message content")
|
||||
model_override: Optional[str] = Field(None, description="Optional model override for this message")
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
success: bool = Field(True, description="Operation success status")
|
||||
message: str = Field(..., description="Success message")
|
||||
|
||||
|
||||
@router.post("/sources/{source_id}/chat/sessions", response_model=SourceChatSessionResponse)
|
||||
async def create_source_chat_session(
|
||||
request: CreateSourceChatSessionRequest,
|
||||
source_id: str = Path(..., description="Source ID")
|
||||
):
|
||||
"""Create a new chat session for a source."""
|
||||
try:
|
||||
# Verify source exists
|
||||
full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}"
|
||||
source = await Source.get(full_source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
# Create new session with model_override support
|
||||
session = ChatSession(
|
||||
title=request.title or f"Source Chat {asyncio.get_event_loop().time():.0f}",
|
||||
model_override=request.model_override
|
||||
)
|
||||
await session.save()
|
||||
|
||||
# Relate session to source using "refers_to" relation
|
||||
await session.relate("refers_to", full_source_id)
|
||||
|
||||
return SourceChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
source_id=source_id,
|
||||
model_override=session.model_override,
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating source chat session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error creating source chat session: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/sources/{source_id}/chat/sessions", response_model=List[SourceChatSessionResponse])
|
||||
async def get_source_chat_sessions(
|
||||
source_id: str = Path(..., description="Source ID")
|
||||
):
|
||||
"""Get all chat sessions for a source."""
|
||||
try:
|
||||
# Verify source exists
|
||||
full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}"
|
||||
source = await Source.get(full_source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
# Get sessions that refer to this source - first get relations, then sessions
|
||||
relations = await repo_query(
|
||||
"SELECT in FROM refers_to WHERE out = $source_id",
|
||||
{"source_id": ensure_record_id(full_source_id)}
|
||||
)
|
||||
|
||||
sessions = []
|
||||
for relation in relations:
|
||||
session_id = relation.get("in")
|
||||
if session_id:
|
||||
session_result = await repo_query(f"SELECT * FROM {session_id}")
|
||||
if session_result and len(session_result) > 0:
|
||||
session_data = session_result[0]
|
||||
sessions.append(SourceChatSessionResponse(
|
||||
id=session_data.get("id") or "",
|
||||
title=session_data.get("title") or "Untitled Session",
|
||||
source_id=source_id,
|
||||
model_override=session_data.get("model_override"),
|
||||
created=str(session_data.get("created")),
|
||||
updated=str(session_data.get("updated")),
|
||||
message_count=0 # TODO: Add message count if needed
|
||||
))
|
||||
|
||||
# Sort sessions by created date (newest first)
|
||||
sessions.sort(key=lambda x: x.created, reverse=True)
|
||||
return sessions
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching source chat sessions: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching source chat sessions: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/sources/{source_id}/chat/sessions/{session_id}", response_model=SourceChatSessionWithMessagesResponse)
|
||||
async def get_source_chat_session(
|
||||
source_id: str = Path(..., description="Source ID"),
|
||||
session_id: str = Path(..., description="Session ID")
|
||||
):
|
||||
"""Get a specific source chat session with its messages."""
|
||||
try:
|
||||
# Verify source exists
|
||||
full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}"
|
||||
source = await Source.get(full_source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
# Get session
|
||||
full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}"
|
||||
session = await ChatSession.get(full_session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Verify session is related to this source
|
||||
relation_query = await repo_query(
|
||||
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
||||
{"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)}
|
||||
)
|
||||
|
||||
if not relation_query:
|
||||
raise HTTPException(status_code=404, detail="Session not found for this source")
|
||||
|
||||
# Get session state from LangGraph to retrieve messages
|
||||
thread_state = source_chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": session_id})
|
||||
)
|
||||
|
||||
# Extract messages from state
|
||||
messages: list[ChatMessage] = []
|
||||
context_indicators = None
|
||||
|
||||
if thread_state and thread_state.values:
|
||||
# Extract messages
|
||||
if "messages" in thread_state.values:
|
||||
for msg in thread_state.values["messages"]:
|
||||
messages.append(ChatMessage(
|
||||
id=getattr(msg, 'id', f"msg_{len(messages)}"),
|
||||
type=msg.type if hasattr(msg, 'type') else 'unknown',
|
||||
content=msg.content if hasattr(msg, 'content') else str(msg),
|
||||
timestamp=None # LangChain messages don't have timestamps by default
|
||||
))
|
||||
|
||||
# Extract context indicators from the last state
|
||||
if "context_indicators" in thread_state.values:
|
||||
context_data = thread_state.values["context_indicators"]
|
||||
context_indicators = ContextIndicator(
|
||||
sources=context_data.get("sources", []),
|
||||
insights=context_data.get("insights", []),
|
||||
notes=context_data.get("notes", [])
|
||||
)
|
||||
|
||||
return SourceChatSessionWithMessagesResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
source_id=source_id,
|
||||
model_override=getattr(session, 'model_override', None),
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=len(messages),
|
||||
messages=messages,
|
||||
context_indicators=context_indicators
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Source or session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching source chat session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching source chat session: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/sources/{source_id}/chat/sessions/{session_id}", response_model=SourceChatSessionResponse)
|
||||
async def update_source_chat_session(
|
||||
request: UpdateSourceChatSessionRequest,
|
||||
source_id: str = Path(..., description="Source ID"),
|
||||
session_id: str = Path(..., description="Session ID")
|
||||
):
|
||||
"""Update source chat session title and/or model override."""
|
||||
try:
|
||||
# Verify source exists
|
||||
full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}"
|
||||
source = await Source.get(full_source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
# Get session
|
||||
full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}"
|
||||
session = await ChatSession.get(full_session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Verify session is related to this source
|
||||
relation_query = await repo_query(
|
||||
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
||||
{"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)}
|
||||
)
|
||||
|
||||
if not relation_query:
|
||||
raise HTTPException(status_code=404, detail="Session not found for this source")
|
||||
|
||||
# Update session fields
|
||||
if request.title is not None:
|
||||
session.title = request.title
|
||||
if request.model_override is not None:
|
||||
session.model_override = request.model_override
|
||||
|
||||
await session.save()
|
||||
|
||||
return SourceChatSessionResponse(
|
||||
id=session.id or "",
|
||||
title=session.title or "Untitled Session",
|
||||
source_id=source_id,
|
||||
model_override=getattr(session, 'model_override', None),
|
||||
created=str(session.created),
|
||||
updated=str(session.updated),
|
||||
message_count=0
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Source or session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating source chat session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error updating source chat session: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/sources/{source_id}/chat/sessions/{session_id}", response_model=SuccessResponse)
|
||||
async def delete_source_chat_session(
|
||||
source_id: str = Path(..., description="Source ID"),
|
||||
session_id: str = Path(..., description="Session ID")
|
||||
):
|
||||
"""Delete a source chat session."""
|
||||
try:
|
||||
# Verify source exists
|
||||
full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}"
|
||||
source = await Source.get(full_source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
# Get session
|
||||
full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}"
|
||||
session = await ChatSession.get(full_session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Verify session is related to this source
|
||||
relation_query = await repo_query(
|
||||
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
||||
{"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)}
|
||||
)
|
||||
|
||||
if not relation_query:
|
||||
raise HTTPException(status_code=404, detail="Session not found for this source")
|
||||
|
||||
await session.delete()
|
||||
|
||||
return SuccessResponse(
|
||||
success=True,
|
||||
message="Source chat session deleted successfully"
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Source or session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting source chat session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting source chat session: {str(e)}")
|
||||
|
||||
|
||||
async def stream_source_chat_response(
|
||||
session_id: str,
|
||||
source_id: str,
|
||||
message: str,
|
||||
model_override: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream the source chat response as Server-Sent Events."""
|
||||
try:
|
||||
# Get current state
|
||||
current_state = source_chat_graph.get_state(
|
||||
config=RunnableConfig(configurable={"thread_id": session_id})
|
||||
)
|
||||
|
||||
# Prepare state for execution
|
||||
state_values = current_state.values if current_state else {}
|
||||
state_values["messages"] = state_values.get("messages", [])
|
||||
state_values["source_id"] = source_id
|
||||
state_values["model_override"] = model_override
|
||||
|
||||
# Add user message to state
|
||||
user_message = HumanMessage(content=message)
|
||||
state_values["messages"].append(user_message)
|
||||
|
||||
# Send user message event
|
||||
user_event = {
|
||||
"type": "user_message",
|
||||
"content": message,
|
||||
"timestamp": None
|
||||
}
|
||||
yield f"data: {json.dumps(user_event)}\n\n"
|
||||
|
||||
# Execute source chat graph synchronously (like notebook chat does)
|
||||
result = source_chat_graph.invoke(
|
||||
input=state_values, # type: ignore[arg-type]
|
||||
config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": session_id,
|
||||
"model_id": model_override
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Stream the complete AI response
|
||||
if "messages" in result:
|
||||
for msg in result["messages"]:
|
||||
if hasattr(msg, 'type') and msg.type == 'ai':
|
||||
ai_event = {
|
||||
"type": "ai_message",
|
||||
"content": msg.content if hasattr(msg, 'content') else str(msg),
|
||||
"timestamp": None
|
||||
}
|
||||
yield f"data: {json.dumps(ai_event)}\n\n"
|
||||
|
||||
# Stream context indicators
|
||||
if "context_indicators" in result:
|
||||
context_event = {
|
||||
"type": "context_indicators",
|
||||
"data": result["context_indicators"]
|
||||
}
|
||||
yield f"data: {json.dumps(context_event)}\n\n"
|
||||
|
||||
# Send completion signal
|
||||
completion_event = {"type": "complete"}
|
||||
yield f"data: {json.dumps(completion_event)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in source chat streaming: {str(e)}")
|
||||
error_event = {"type": "error", "message": str(e)}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
|
||||
@router.post("/sources/{source_id}/chat/sessions/{session_id}/messages")
|
||||
async def send_message_to_source_chat(
|
||||
request: SendMessageRequest,
|
||||
source_id: str = Path(..., description="Source ID"),
|
||||
session_id: str = Path(..., description="Session ID")
|
||||
):
|
||||
"""Send a message to source chat session with SSE streaming response."""
|
||||
try:
|
||||
# Verify source exists
|
||||
full_source_id = source_id if source_id.startswith("source:") else f"source:{source_id}"
|
||||
source = await Source.get(full_source_id)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
# Verify session exists and is related to source
|
||||
full_session_id = session_id if session_id.startswith("chat_session:") else f"chat_session:{session_id}"
|
||||
session = await ChatSession.get(full_session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Verify session is related to this source
|
||||
relation_query = await repo_query(
|
||||
"SELECT * FROM refers_to WHERE in = $session_id AND out = $source_id",
|
||||
{"session_id": ensure_record_id(full_session_id), "source_id": ensure_record_id(full_source_id)}
|
||||
)
|
||||
|
||||
if not relation_query:
|
||||
raise HTTPException(status_code=404, detail="Session not found for this source")
|
||||
|
||||
if not request.message:
|
||||
raise HTTPException(status_code=400, detail="Message content is required")
|
||||
|
||||
# Determine model override (request override takes precedence over session override)
|
||||
model_override = request.model_override or getattr(session, 'model_override', None)
|
||||
|
||||
# Update session timestamp
|
||||
await session.save()
|
||||
|
||||
# Return streaming response
|
||||
return StreamingResponse(
|
||||
stream_source_chat_response(
|
||||
session_id=session_id,
|
||||
source_id=full_source_id,
|
||||
message=request.message,
|
||||
model_override=model_override
|
||||
),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/plain; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message to source chat: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error sending message: {str(e)}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue