mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-02 21:30:38 +00:00
Replace generic "An unexpected error occurred" messages with descriptive, user-friendly error messages when LLM operations fail. Errors like invalid API keys, wrong model names, and rate limits now surface clearly in the UI. Adds error classification utility, global FastAPI exception handlers, and frontend getApiErrorMessage() helper. Bumps version to 1.7.2.
217 lines
8.4 KiB
Python
217 lines
8.4 KiB
Python
import json
|
|
from typing import AsyncGenerator
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from loguru import logger
|
|
|
|
from api.models import AskRequest, AskResponse, SearchRequest, SearchResponse
|
|
from open_notebook.ai.models import Model, model_manager
|
|
from open_notebook.domain.notebook import text_search, vector_search
|
|
from open_notebook.exceptions import DatabaseOperationError, InvalidInputError
|
|
from open_notebook.graphs.ask import graph as ask_graph
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/search", response_model=SearchResponse)
|
|
async def search_knowledge_base(search_request: SearchRequest):
|
|
"""Search the knowledge base using text or vector search."""
|
|
try:
|
|
if search_request.type == "vector":
|
|
# Check if embedding model is available for vector search
|
|
if not await model_manager.get_embedding_model():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Vector search requires an embedding model. Please configure one in the Models section.",
|
|
)
|
|
|
|
results = await vector_search(
|
|
keyword=search_request.query,
|
|
results=search_request.limit,
|
|
source=search_request.search_sources,
|
|
note=search_request.search_notes,
|
|
minimum_score=search_request.minimum_score,
|
|
)
|
|
else:
|
|
# Text search
|
|
results = await text_search(
|
|
keyword=search_request.query,
|
|
results=search_request.limit,
|
|
source=search_request.search_sources,
|
|
note=search_request.search_notes,
|
|
)
|
|
|
|
return SearchResponse(
|
|
results=results or [],
|
|
total_count=len(results) if results else 0,
|
|
search_type=search_request.type,
|
|
)
|
|
|
|
except InvalidInputError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except DatabaseOperationError as e:
|
|
logger.error(f"Database error during search: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during search: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
|
|
|
|
|
async def stream_ask_response(
|
|
question: str, strategy_model: Model, answer_model: Model, final_answer_model: Model
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Stream the ask response as Server-Sent Events."""
|
|
try:
|
|
final_answer = None
|
|
|
|
async for chunk in ask_graph.astream(
|
|
input=dict(question=question), # type: ignore[arg-type]
|
|
config=dict(
|
|
configurable=dict(
|
|
strategy_model=strategy_model.id,
|
|
answer_model=answer_model.id,
|
|
final_answer_model=final_answer_model.id,
|
|
)
|
|
),
|
|
stream_mode="updates",
|
|
):
|
|
if "agent" in chunk:
|
|
strategy_data = {
|
|
"type": "strategy",
|
|
"reasoning": chunk["agent"]["strategy"].reasoning,
|
|
"searches": [
|
|
{"term": search.term, "instructions": search.instructions}
|
|
for search in chunk["agent"]["strategy"].searches
|
|
],
|
|
}
|
|
yield f"data: {json.dumps(strategy_data)}\n\n"
|
|
|
|
elif "provide_answer" in chunk:
|
|
for answer in chunk["provide_answer"]["answers"]:
|
|
answer_data = {"type": "answer", "content": answer}
|
|
yield f"data: {json.dumps(answer_data)}\n\n"
|
|
|
|
elif "write_final_answer" in chunk:
|
|
final_answer = chunk["write_final_answer"]["final_answer"]
|
|
final_data = {"type": "final_answer", "content": final_answer}
|
|
yield f"data: {json.dumps(final_data)}\n\n"
|
|
|
|
# Send completion signal
|
|
completion_data = {"type": "complete", "final_answer": final_answer}
|
|
yield f"data: {json.dumps(completion_data)}\n\n"
|
|
|
|
except Exception as e:
|
|
from open_notebook.utils.error_classifier import classify_error
|
|
|
|
_, user_message = classify_error(e)
|
|
logger.error(f"Error in ask streaming: {str(e)}")
|
|
error_data = {"type": "error", "message": user_message}
|
|
yield f"data: {json.dumps(error_data)}\n\n"
|
|
|
|
|
|
@router.post("/search/ask")
|
|
async def ask_knowledge_base(ask_request: AskRequest):
|
|
"""Ask the knowledge base a question using AI models."""
|
|
try:
|
|
# Validate models exist
|
|
strategy_model = await Model.get(ask_request.strategy_model)
|
|
answer_model = await Model.get(ask_request.answer_model)
|
|
final_answer_model = await Model.get(ask_request.final_answer_model)
|
|
|
|
if not strategy_model:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Strategy model {ask_request.strategy_model} not found",
|
|
)
|
|
if not answer_model:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Answer model {ask_request.answer_model} not found",
|
|
)
|
|
if not final_answer_model:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Final answer model {ask_request.final_answer_model} not found",
|
|
)
|
|
|
|
# Check if embedding model is available
|
|
if not await model_manager.get_embedding_model():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Ask feature requires an embedding model. Please configure one in the Models section.",
|
|
)
|
|
|
|
# For streaming response
|
|
return StreamingResponse(
|
|
stream_ask_response(
|
|
ask_request.question, strategy_model, answer_model, final_answer_model
|
|
),
|
|
media_type="text/plain",
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in ask endpoint: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}")
|
|
|
|
|
|
@router.post("/search/ask/simple", response_model=AskResponse)
|
|
async def ask_knowledge_base_simple(ask_request: AskRequest):
|
|
"""Ask the knowledge base a question and return a simple response (non-streaming)."""
|
|
try:
|
|
# Validate models exist
|
|
strategy_model = await Model.get(ask_request.strategy_model)
|
|
answer_model = await Model.get(ask_request.answer_model)
|
|
final_answer_model = await Model.get(ask_request.final_answer_model)
|
|
|
|
if not strategy_model:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Strategy model {ask_request.strategy_model} not found",
|
|
)
|
|
if not answer_model:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Answer model {ask_request.answer_model} not found",
|
|
)
|
|
if not final_answer_model:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Final answer model {ask_request.final_answer_model} not found",
|
|
)
|
|
|
|
# Check if embedding model is available
|
|
if not await model_manager.get_embedding_model():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Ask feature requires an embedding model. Please configure one in the Models section.",
|
|
)
|
|
|
|
# Run the ask graph and get final result
|
|
final_answer = None
|
|
async for chunk in ask_graph.astream(
|
|
input=dict(question=ask_request.question), # type: ignore[arg-type]
|
|
config=dict(
|
|
configurable=dict(
|
|
strategy_model=strategy_model.id,
|
|
answer_model=answer_model.id,
|
|
final_answer_model=final_answer_model.id,
|
|
)
|
|
),
|
|
stream_mode="updates",
|
|
):
|
|
if "write_final_answer" in chunk:
|
|
final_answer = chunk["write_final_answer"]["final_answer"]
|
|
|
|
if not final_answer:
|
|
raise HTTPException(status_code=500, detail="No answer generated")
|
|
|
|
return AskResponse(answer=final_answer, question=ask_request.question)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in ask simple endpoint: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Ask operation failed: {str(e)}")
|