open-notebook/api/routers/search.py
Luis Novo b7e656a319
Version 1 (#160)
New front-end
Launch Chat API
Manage Sources
Enable re-embedding of all contents
Sources can be added without a notebook now
Improved settings
Enable model selector on all chats
Background processing for better experience
Dark mode
Improved Notes

Improved Docs: 
- Remove all Streamlit references from documentation
- Update deployment guides with React frontend setup
- Fix Docker environment variables format (SURREAL_URL, SURREAL_PASSWORD)
- Update docker image tag from :latest to :v1-latest
- Change navigation references (Settings → Models to just Models)
- Update development setup to include frontend npm commands
- Add MIGRATION.md guide for users upgrading from Streamlit
- Update quick-start guide with correct environment variables
- Add port 5055 documentation for API access
- Update project structure to reflect frontend/ directory
- Remove outdated source-chat documentation files
2025-10-18 12:46:22 -03:00

214 lines
8.3 KiB
Python

import json
from typing import AsyncGenerator
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from api.models import AskRequest, AskResponse, SearchRequest, SearchResponse
from open_notebook.domain.models import Model, model_manager
from open_notebook.domain.notebook import text_search, vector_search
from open_notebook.exceptions import DatabaseOperationError, InvalidInputError
from open_notebook.graphs.ask import graph as ask_graph
router = APIRouter()
@router.post("/search", response_model=SearchResponse)
async def search_knowledge_base(search_request: SearchRequest):
"""Search the knowledge base using text or vector search."""
try:
if search_request.type == "vector":
# Check if embedding model is available for vector search
if not await model_manager.get_embedding_model():
raise HTTPException(
status_code=400,
detail="Vector search requires an embedding model. Please configure one in the Models section.",
)
results = await vector_search(
keyword=search_request.query,
results=search_request.limit,
source=search_request.search_sources,
note=search_request.search_notes,
minimum_score=search_request.minimum_score,
)
else:
# Text search
results = await text_search(
keyword=search_request.query,
results=search_request.limit,
source=search_request.search_sources,
note=search_request.search_notes,
)
return SearchResponse(
results=results or [],
total_count=len(results) if results else 0,
search_type=search_request.type,
)
except InvalidInputError as e:
raise HTTPException(status_code=400, detail=str(e))
except DatabaseOperationError as e:
logger.error(f"Database error during search: {str(e)}")
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error during search: {str(e)}")
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
async def stream_ask_response(
question: str, strategy_model: Model, answer_model: Model, final_answer_model: Model
) -> AsyncGenerator[str, None]:
"""Stream the ask response as Server-Sent Events."""
try:
final_answer = None
async for chunk in ask_graph.astream(
input=dict(question=question), # type: ignore[arg-type]
config=dict(
configurable=dict(
strategy_model=strategy_model.id,
answer_model=answer_model.id,
final_answer_model=final_answer_model.id,
)
),
stream_mode="updates",
):
if "agent" in chunk:
strategy_data = {
"type": "strategy",
"reasoning": chunk["agent"]["strategy"].reasoning,
"searches": [
{"term": search.term, "instructions": search.instructions}
for search in chunk["agent"]["strategy"].searches
],
}
yield f"data: {json.dumps(strategy_data)}\n\n"
elif "provide_answer" in chunk:
for answer in chunk["provide_answer"]["answers"]:
answer_data = {"type": "answer", "content": answer}
yield f"data: {json.dumps(answer_data)}\n\n"
elif "write_final_answer" in chunk:
final_answer = chunk["write_final_answer"]["final_answer"]
final_data = {"type": "final_answer", "content": final_answer}
yield f"data: {json.dumps(final_data)}\n\n"
# Send completion signal
completion_data = {"type": "complete", "final_answer": final_answer}
yield f"data: {json.dumps(completion_data)}\n\n"
except Exception as e:
logger.error(f"Error in ask streaming: {str(e)}")
error_data = {"type": "error", "message": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
@router.post("/search/ask")
async def ask_knowledge_base(ask_request: AskRequest):
"""Ask the knowledge base a question using AI models."""
try:
# Validate models exist
strategy_model = await Model.get(ask_request.strategy_model)
answer_model = await Model.get(ask_request.answer_model)
final_answer_model = await Model.get(ask_request.final_answer_model)
if not strategy_model:
raise HTTPException(
status_code=400,
detail=f"Strategy model {ask_request.strategy_model} not found",
)
if not answer_model:
raise HTTPException(
status_code=400,
detail=f"Answer model {ask_request.answer_model} not found",
)
if not final_answer_model:
raise HTTPException(
status_code=400,
detail=f"Final answer model {ask_request.final_answer_model} not found",
)
# Check if embedding model is available
if not await model_manager.get_embedding_model():
raise HTTPException(
status_code=400,
detail="Ask feature requires an embedding model. Please configure one in the Models section.",
)
# For streaming response
return StreamingResponse(
stream_ask_response(
ask_request.question, strategy_model, answer_model, final_answer_model
),
media_type="text/plain",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in ask endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}")
@router.post("/search/ask/simple", response_model=AskResponse)
async def ask_knowledge_base_simple(ask_request: AskRequest):
"""Ask the knowledge base a question and return a simple response (non-streaming)."""
try:
# Validate models exist
strategy_model = await Model.get(ask_request.strategy_model)
answer_model = await Model.get(ask_request.answer_model)
final_answer_model = await Model.get(ask_request.final_answer_model)
if not strategy_model:
raise HTTPException(
status_code=400,
detail=f"Strategy model {ask_request.strategy_model} not found",
)
if not answer_model:
raise HTTPException(
status_code=400,
detail=f"Answer model {ask_request.answer_model} not found",
)
if not final_answer_model:
raise HTTPException(
status_code=400,
detail=f"Final answer model {ask_request.final_answer_model} not found",
)
# Check if embedding model is available
if not await model_manager.get_embedding_model():
raise HTTPException(
status_code=400,
detail="Ask feature requires an embedding model. Please configure one in the Models section.",
)
# Run the ask graph and get final result
final_answer = None
async for chunk in ask_graph.astream(
input=dict(question=ask_request.question), # type: ignore[arg-type]
config=dict(
configurable=dict(
strategy_model=strategy_model.id,
answer_model=answer_model.id,
final_answer_model=final_answer_model.id,
)
),
stream_mode="updates",
):
if "write_final_answer" in chunk:
final_answer = chunk["write_final_answer"]["final_answer"]
if not final_answer:
raise HTTPException(status_code=500, detail="No answer generated")
return AskResponse(answer=final_answer, question=ask_request.question)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in ask simple endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}")