mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
feat: Removed GPT-Researcher in favour of own SurfSense LangGraph Agent
This commit is contained in:
parent
94c94e6898
commit
130f43a0fa
14 changed files with 439 additions and 918 deletions
11
README.md
11
README.md
|
@ -120,7 +120,6 @@ This is the core of SurfSense. Before we begin let's look at `.env` variables' t
|
|||
| RERANKERS_MODEL_NAME| Name of the reranker model for search result reranking. Eg. `ms-marco-MiniLM-L-12-v2`|
|
||||
| RERANKERS_MODEL_TYPE| Type of reranker model being used. Eg. `flashrank`|
|
||||
| FAST_LLM| LiteLLM routed Smaller, faster LLM for quick responses. Eg. `litellm:openai/gpt-4o`|
|
||||
| SMART_LLM| LiteLLM routed Balanced LLM for general use. Eg. `litellm:openai/gpt-4o`|
|
||||
| STRATEGIC_LLM| LiteLLM routed Advanced LLM for complex reasoning tasks. Eg. `litellm:openai/gpt-4o`|
|
||||
| LONG_CONTEXT_LLM| LiteLLM routed LLM capable of handling longer context windows. Eg. `litellm:gemini/gemini-2.0-flash`|
|
||||
| UNSTRUCTURED_API_KEY| API key for Unstructured.io service for document parsing|
|
||||
|
@ -221,15 +220,15 @@ After filling in your SurfSense API key you should be able to use extension now.
|
|||
- **Alembic**: A database migrations tool for SQLAlchemy.
|
||||
|
||||
- **FastAPI Users**: Authentication and user management with JWT and OAuth support
|
||||
|
||||
- **LangChain**: Framework for developing AI-powered applications
|
||||
|
||||
- **GPT Integration**: Integration with LLM models through LiteLLM
|
||||
- **LangGraph**: Framework for developing AI-agents.
|
||||
|
||||
- **LangChain**: Framework for developing AI-powered applications.
|
||||
|
||||
- **LLM Integration**: Integration with LLM models through LiteLLM
|
||||
|
||||
- **Rerankers**: Advanced result ranking for improved search relevance
|
||||
|
||||
- **GPT-Researcher**: Advanced research capabilities
|
||||
|
||||
- **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF)
|
||||
|
||||
- **Vector Embeddings**: Document and text embeddings for semantic search
|
||||
|
|
|
@ -10,9 +10,8 @@ RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
|
|||
RERANKERS_MODEL_TYPE="flashrank"
|
||||
|
||||
FAST_LLM="litellm:openai/gpt-4o-mini"
|
||||
SMART_LLM="litellm:openai/gpt-4o-mini"
|
||||
STRATEGIC_LLM="litellm:openai/gpt-4o-mini"
|
||||
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21"
|
||||
STRATEGIC_LLM="litellm:openai/gpt-4o"
|
||||
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash"
|
||||
|
||||
OPENAI_API_KEY="sk-proj-iA"
|
||||
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
|
||||
|
|
|
@ -1,17 +1,23 @@
|
|||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict, List
|
||||
from app.config import config as app_config
|
||||
from .prompts import get_answer_outline_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import asyncio
|
||||
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import async_session_maker
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_answer_outline_system_prompt
|
||||
from .state import State
|
||||
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
section_id: int = Field(..., description="The zero-based index of the section")
|
||||
|
@ -22,7 +28,7 @@ class AnswerOutline(BaseModel):
|
|||
"""The complete answer outline with all sections."""
|
||||
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
|
||||
|
||||
async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a structured answer outline based on the user query.
|
||||
|
||||
|
@ -33,12 +39,18 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
|
|||
Returns:
|
||||
Dict containing the answer outline in the "answer_outline" key for state update.
|
||||
"""
|
||||
streaming_service = state.streaming_service
|
||||
|
||||
streaming_service.only_update_terminal("Generating answer outline...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
# Get configuration from runnable config
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
user_query = configuration.user_query
|
||||
num_sections = configuration.num_sections
|
||||
|
||||
streaming_service.only_update_terminal(f"Planning research approach for query: {user_query[:100]}...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.strategic_llm_instance
|
||||
|
||||
|
@ -66,6 +78,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
|
|||
Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation.
|
||||
"""
|
||||
|
||||
streaming_service.only_update_terminal("Designing structured outline with AI...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create messages for the LLM
|
||||
messages = [
|
||||
SystemMessage(content=get_answer_outline_system_prompt()),
|
||||
|
@ -73,6 +88,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
|
|||
]
|
||||
|
||||
# Call the LLM directly without using structured output
|
||||
streaming_service.only_update_terminal("Processing answer structure...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
# Parse the JSON response manually
|
||||
|
@ -92,16 +110,27 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
|
|||
# Convert to Pydantic model
|
||||
answer_outline = AnswerOutline(**parsed_data)
|
||||
|
||||
total_questions = sum(len(section.questions) for section in answer_outline.answer_outline)
|
||||
streaming_service.only_update_terminal(f"Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections")
|
||||
|
||||
# Return state update
|
||||
return {"answer_outline": answer_outline}
|
||||
else:
|
||||
# If JSON structure not found, raise a clear error
|
||||
raise ValueError(f"Could not find valid JSON in LLM response. Raw response: {content}")
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
streaming_service.only_update_terminal(error_message, "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
# Log the error and re-raise it
|
||||
error_message = f"Error parsing LLM response: {str(e)}"
|
||||
streaming_service.only_update_terminal(error_message, "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
print(f"Error parsing LLM response: {str(e)}")
|
||||
print(f"Raw response: {response.content}")
|
||||
raise
|
||||
|
@ -112,18 +141,21 @@ async def fetch_relevant_documents(
|
|||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
connectors_to_search: List[str],
|
||||
top_k: int = 5
|
||||
writer: StreamWriter = None,
|
||||
state: State = None,
|
||||
top_k: int = 20
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch relevant documents for research questions using the provided connectors.
|
||||
|
||||
Args:
|
||||
section_title: The title of the section being researched
|
||||
research_questions: List of research questions to find documents for
|
||||
user_id: The user ID
|
||||
search_space_id: The search space ID
|
||||
db_session: The database session
|
||||
connectors_to_search: List of connectors to search
|
||||
writer: StreamWriter for sending progress updates
|
||||
state: The current state containing the streaming service
|
||||
top_k: Number of top results to retrieve per connector per question
|
||||
|
||||
Returns:
|
||||
|
@ -131,83 +163,237 @@ async def fetch_relevant_documents(
|
|||
"""
|
||||
# Initialize services
|
||||
connector_service = ConnectorService(db_session)
|
||||
|
||||
all_raw_documents = [] # Store all raw documents before reranking
|
||||
|
||||
for user_query in research_questions:
|
||||
# Only use streaming if both writer and state are provided
|
||||
streaming_service = state.streaming_service if state is not None else None
|
||||
|
||||
# Stream initial status update
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Starting research on {len(research_questions)} questions using {len(connectors_to_search)} connectors...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
all_raw_documents = [] # Store all raw documents
|
||||
all_sources = [] # Store all sources
|
||||
|
||||
for i, user_query in enumerate(research_questions):
|
||||
# Stream question being researched
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Researching question {i+1}/{len(research_questions)}: {user_query[:100]}...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Use original research question as the query
|
||||
reformulated_query = user_query
|
||||
|
||||
# Process each selected connector
|
||||
for connector in connectors_to_search:
|
||||
# Stream connector being searched
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Searching {connector} for relevant information...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
try:
|
||||
if connector == "YOUTUBE_VIDEO":
|
||||
_, youtube_chunks = await connector_service.search_youtube(
|
||||
source_object, youtube_chunks = await connector_service.search_youtube(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(youtube_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(youtube_chunks)} YouTube chunks relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "EXTENSION":
|
||||
_, extension_chunks = await connector_service.search_extension(
|
||||
source_object, extension_chunks = await connector_service.search_extension(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(extension_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(extension_chunks)} extension chunks relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "CRAWLED_URL":
|
||||
_, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
source_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(crawled_urls_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(crawled_urls_chunks)} crawled URL chunks relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "FILE":
|
||||
_, files_chunks = await connector_service.search_files(
|
||||
source_object, files_chunks = await connector_service.search_files(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(files_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(files_chunks)} file chunks relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "TAVILY_API":
|
||||
_, tavily_chunks = await connector_service.search_tavily(
|
||||
source_object, tavily_chunks = await connector_service.search_tavily(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(tavily_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(tavily_chunks)} web search results relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "SLACK_CONNECTOR":
|
||||
_, slack_chunks = await connector_service.search_slack(
|
||||
source_object, slack_chunks = await connector_service.search_slack(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(slack_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(slack_chunks)} Slack messages relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "NOTION_CONNECTOR":
|
||||
_, notion_chunks = await connector_service.search_notion(
|
||||
source_object, notion_chunks = await connector_service.search_notion(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(notion_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(notion_chunks)} Notion pages/blocks relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "GITHUB_CONNECTOR":
|
||||
source_object, github_chunks = await connector_service.search_github(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(github_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(github_chunks)} GitHub files/issues relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
elif connector == "LINEAR_CONNECTOR":
|
||||
source_object, linear_chunks = await connector_service.search_linear(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# Add to sources and raw documents
|
||||
if source_object:
|
||||
all_sources.append(source_object)
|
||||
all_raw_documents.extend(linear_chunks)
|
||||
|
||||
# Stream found document count
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(linear_chunks)} Linear issues relevant to the query")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
except Exception as e:
|
||||
print(f"Error searching connector {connector}: {str(e)}")
|
||||
error_message = f"Error searching connector {connector}: {str(e)}"
|
||||
print(error_message)
|
||||
|
||||
# Stream error message
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(error_message, "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Continue with other connectors on error
|
||||
continue
|
||||
|
||||
# Deduplicate documents based on chunk_id or content
|
||||
# Deduplicate source objects by ID before streaming
|
||||
deduplicated_sources = []
|
||||
seen_source_keys = set()
|
||||
|
||||
for source_obj in all_sources:
|
||||
# Use combination of source ID and type as a unique identifier
|
||||
# This ensures we don't accidentally deduplicate sources from different connectors
|
||||
source_id = source_obj.get('id')
|
||||
source_type = source_obj.get('type')
|
||||
|
||||
if source_id and source_type:
|
||||
source_key = f"{source_type}_{source_id}"
|
||||
if source_key not in seen_source_keys:
|
||||
seen_source_keys.add(source_key)
|
||||
deduplicated_sources.append(source_obj)
|
||||
else:
|
||||
# If there's no ID or type, just add it to be safe
|
||||
deduplicated_sources.append(source_obj)
|
||||
|
||||
# Stream info about deduplicated sources
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Collected {len(deduplicated_sources)} unique sources across all connectors")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# After all sources are collected and deduplicated, stream them
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_sources(deduplicated_sources)
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Deduplicate raw documents based on chunk_id or content
|
||||
seen_chunk_ids = set()
|
||||
seen_content_hashes = set()
|
||||
deduplicated_docs = []
|
||||
|
@ -227,11 +413,15 @@ async def fetch_relevant_documents(
|
|||
seen_content_hashes.add(content_hash)
|
||||
deduplicated_docs.append(doc)
|
||||
|
||||
# Stream info about deduplicated documents
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"Found {len(deduplicated_docs)} unique document chunks after deduplication")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Return deduplicated documents
|
||||
return deduplicated_docs
|
||||
|
||||
|
||||
|
||||
async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||
"""
|
||||
Process all sections in parallel and combine the results.
|
||||
|
||||
|
@ -245,89 +435,97 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
# Get configuration and answer outline from state
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
answer_outline = state.answer_outline
|
||||
streaming_service = state.streaming_service
|
||||
|
||||
streaming_service.only_update_terminal(f"Starting to process research sections...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
print(f"Processing sections from outline: {answer_outline is not None}")
|
||||
|
||||
if not answer_outline:
|
||||
streaming_service.only_update_terminal("Error: No answer outline was provided. Cannot generate report.", "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
return {
|
||||
"final_written_report": "No answer outline was provided. Cannot generate final report."
|
||||
}
|
||||
|
||||
# Create session maker from the engine or directly use the session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Use the engine if available, otherwise create a new session for each task
|
||||
if state.engine:
|
||||
session_maker = sessionmaker(
|
||||
state.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
else:
|
||||
# Fallback to using the same session (less optimal but will work)
|
||||
print("Warning: No engine available. Using same session for all tasks.")
|
||||
# Create a mock session maker that returns the same session
|
||||
async def mock_session_maker():
|
||||
class ContextManager:
|
||||
async def __aenter__(self):
|
||||
return state.db_session
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
return ContextManager()
|
||||
session_maker = mock_session_maker
|
||||
|
||||
# Collect all questions from all sections
|
||||
all_questions = []
|
||||
for section in answer_outline.answer_outline:
|
||||
all_questions.extend(section.questions)
|
||||
|
||||
print(f"Collected {len(all_questions)} questions from all sections")
|
||||
streaming_service.only_update_terminal(f"Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Fetch relevant documents once for all questions
|
||||
streaming_service.only_update_terminal("Searching for relevant information across all connectors...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
relevant_documents = []
|
||||
async with session_maker() as db_session:
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
relevant_documents = await fetch_relevant_documents(
|
||||
research_questions=all_questions,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
db_session=db_session,
|
||||
connectors_to_search=configuration.connectors_to_search
|
||||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error fetching relevant documents: {str(e)}")
|
||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
||||
print(error_message)
|
||||
streaming_service.only_update_terminal(error_message, "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
# Log the error and continue with an empty list of documents
|
||||
# This allows the process to continue, but the report might lack information
|
||||
relevant_documents = []
|
||||
# Consider adding more robust error handling or reporting if needed
|
||||
|
||||
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
|
||||
streaming_service.only_update_terminal(f"Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create tasks to process each section in parallel with the same document set
|
||||
section_tasks = []
|
||||
streaming_service.only_update_terminal("Creating processing tasks for each section...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
for section in answer_outline.answer_outline:
|
||||
section_tasks.append(
|
||||
process_section_with_documents(
|
||||
section_title=section.section_title,
|
||||
section_questions=section.questions,
|
||||
user_query=configuration.user_query,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
session_maker=session_maker,
|
||||
relevant_documents=relevant_documents
|
||||
relevant_documents=relevant_documents,
|
||||
state=state,
|
||||
writer=writer
|
||||
)
|
||||
)
|
||||
|
||||
# Run all section processing tasks in parallel
|
||||
print(f"Running {len(section_tasks)} section processing tasks in parallel")
|
||||
streaming_service.only_update_terminal(f"Processing {len(section_tasks)} sections simultaneously...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
section_results = await asyncio.gather(*section_tasks, return_exceptions=True)
|
||||
|
||||
# Handle any exceptions in the results
|
||||
streaming_service.only_update_terminal("Combining section results into final report...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
processed_results = []
|
||||
for i, result in enumerate(section_results):
|
||||
if isinstance(result, Exception):
|
||||
section_title = answer_outline.answer_outline[i].section_title
|
||||
error_message = f"Error processing section '{section_title}': {str(result)}"
|
||||
print(error_message)
|
||||
streaming_service.only_update_terminal(error_message, "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
processed_results.append(error_message)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
@ -337,12 +535,33 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
|
|||
for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
|
||||
# Skip adding the section header since the content already contains the title
|
||||
final_report.append(content)
|
||||
final_report.append("\n") # Add spacing between sections
|
||||
final_report.append("\n")
|
||||
|
||||
|
||||
# Join all sections with newlines
|
||||
final_written_report = "\n".join(final_report)
|
||||
print(f"Generated final report with {len(final_report)} parts")
|
||||
|
||||
streaming_service.only_update_terminal("Final research report generated successfully!")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
if hasattr(state, 'streaming_service') and state.streaming_service:
|
||||
# Convert the final report to the expected format for UI:
|
||||
# A list of strings where empty strings represent line breaks
|
||||
formatted_report = []
|
||||
for section in final_report:
|
||||
if section == "\n":
|
||||
# Add an empty string for line breaks
|
||||
formatted_report.append("")
|
||||
else:
|
||||
# Split any multiline content by newlines and add each line
|
||||
section_lines = section.split("\n")
|
||||
formatted_report.extend(section_lines)
|
||||
|
||||
state.streaming_service.only_update_answer(formatted_report)
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
|
||||
return {
|
||||
"final_written_report": final_written_report
|
||||
}
|
||||
|
@ -352,8 +571,10 @@ async def process_section_with_documents(
|
|||
section_questions: List[str],
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
session_maker,
|
||||
relevant_documents: List[Dict[str, Any]]
|
||||
relevant_documents: List[Dict[str, Any]],
|
||||
user_query: str,
|
||||
state: State = None,
|
||||
writer: StreamWriter = None
|
||||
) -> str:
|
||||
"""
|
||||
Process a single section using pre-fetched documents.
|
||||
|
@ -363,31 +584,42 @@ async def process_section_with_documents(
|
|||
section_questions: List of research questions for this section
|
||||
user_id: The user ID
|
||||
search_space_id: The search space ID
|
||||
session_maker: Factory for creating new database sessions
|
||||
relevant_documents: Pre-fetched documents to use for this section
|
||||
state: The current state
|
||||
writer: StreamWriter for sending progress updates
|
||||
|
||||
Returns:
|
||||
The written section content
|
||||
"""
|
||||
try:
|
||||
# Use the provided documents
|
||||
documents_to_use = relevant_documents
|
||||
|
||||
# Send status update via streaming if available
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"Writing section: {section_title} with {len(section_questions)} research questions")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
# Fallback if no documents found
|
||||
if not documents_to_use:
|
||||
print(f"No relevant documents found for section: {section_title}")
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"Warning: No relevant documents found for section: {section_title}", "warning")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
documents_to_use = [
|
||||
{"content": f"No specific information was found for: {question}"}
|
||||
for question in section_questions
|
||||
]
|
||||
|
||||
# Create a new database session for this section
|
||||
async with session_maker() as db_session:
|
||||
# Use the provided documents
|
||||
documents_to_use = relevant_documents
|
||||
|
||||
# Fallback if no documents found
|
||||
if not documents_to_use:
|
||||
print(f"No relevant documents found for section: {section_title}")
|
||||
documents_to_use = [
|
||||
{"content": f"No specific information was found for: {question}"}
|
||||
for question in section_questions
|
||||
]
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
# Call the sub_section_writer graph with the appropriate config
|
||||
config = {
|
||||
"configurable": {
|
||||
"sub_section_title": section_title,
|
||||
"sub_section_questions": section_questions,
|
||||
"user_query": user_query,
|
||||
"relevant_documents": documents_to_use,
|
||||
"user_id": user_id,
|
||||
"search_space_id": search_space_id
|
||||
|
@ -395,16 +627,32 @@ async def process_section_with_documents(
|
|||
}
|
||||
|
||||
# Create the initial state with db_session
|
||||
state = {"db_session": db_session}
|
||||
sub_state = {"db_session": db_session}
|
||||
|
||||
# Invoke the sub-section writer graph
|
||||
print(f"Invoking sub_section_writer for: {section_title}")
|
||||
result = await sub_section_writer_graph.ainvoke(state, config)
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"Analyzing information and drafting content for section: {section_title}")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
result = await sub_section_writer_graph.ainvoke(sub_state, config)
|
||||
|
||||
# Return the final answer from the sub_section_writer
|
||||
final_answer = result.get("final_answer", "No content was generated for this section.")
|
||||
|
||||
# Send section content update via streaming if available
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"Completed writing section: {section_title}")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
return final_answer
|
||||
except Exception as e:
|
||||
print(f"Error processing section '{section_title}': {str(e)}")
|
||||
|
||||
# Send error update via streaming if available
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"Error processing section '{section_title}': {str(e)}", "error")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
||||
|
||||
|
|
|
@ -3,10 +3,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Any, Dict, Annotated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.utils.streaming_service import StreamingService
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
|
@ -18,7 +17,9 @@ class State:
|
|||
"""
|
||||
# Runtime context (not part of actual graph state)
|
||||
db_session: AsyncSession
|
||||
engine: Optional[AsyncEngine] = None
|
||||
|
||||
# Streaming service
|
||||
streaming_service: StreamingService
|
||||
|
||||
# Intermediate state - populated during workflow
|
||||
# Using field to explicitly mark as part of state
|
||||
|
|
|
@ -15,6 +15,7 @@ class Configuration:
|
|||
# Input parameters provided at invocation
|
||||
sub_section_title: str
|
||||
sub_section_questions: List[str]
|
||||
user_query: str
|
||||
relevant_documents: List[Any] # Documents provided directly to the agent
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
|
|
|
@ -119,6 +119,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
# Create the query that uses the section title and questions
|
||||
section_title = configuration.sub_section_title
|
||||
sub_section_questions = configuration.sub_section_questions
|
||||
user_query = configuration.user_query # Get the original user query
|
||||
documents_text = "\n".join(formatted_documents)
|
||||
|
||||
# Format the questions as bullet points for clarity
|
||||
|
@ -126,17 +127,16 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
# Construct a clear, structured query for the LLM
|
||||
human_message_content = f"""
|
||||
Please write a comprehensive answer for the title:
|
||||
Now user's query is:
|
||||
<user_query>
|
||||
{user_query}
|
||||
</user_query>
|
||||
|
||||
<title>
|
||||
The sub-section title is:
|
||||
<sub_section_title>
|
||||
{section_title}
|
||||
</title>
|
||||
</sub_section_title>
|
||||
|
||||
Focus on answering these specific questions related to the title:
|
||||
<questions>
|
||||
{questions_text}
|
||||
</questions>
|
||||
|
||||
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
|
||||
<documents>
|
||||
{documents_text}
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the Researcher LangGraph agent.
|
||||
|
||||
This script demonstrates how to invoke the researcher agent with a sample query.
|
||||
Run this script directly from VSCode using the "Run Python File" button or
|
||||
right-click and select "Run Python File in Terminal".
|
||||
|
||||
Before running:
|
||||
1. Make sure your Python environment has all required dependencies
|
||||
2. Create a .env file with any required API keys
|
||||
3. Ensure database connection is properly configured
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to Python path so that 'app' can be found as a module
|
||||
# Get the absolute path to the surfsense_backend directory which contains the app module
|
||||
project_root = str(Path(__file__).resolve().parents[3]) # Go up 3 levels from the script: app/agents/researcher -> app/agents -> app -> surfsense_backend
|
||||
print(f"Adding to Python path: {project_root}")
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Now import the modules after fixing the path
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# These imports should now work with the correct path
|
||||
from app.agents.researcher.graph import graph
|
||||
from app.agents.researcher.state import State
|
||||
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Database connection string - use a test database or mock
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
# Create async engine and session
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def run_test():
|
||||
"""Run a test of the researcher agent."""
|
||||
print("Starting researcher agent test...")
|
||||
|
||||
# Create a database session
|
||||
async with async_session_maker() as db_session:
|
||||
# Sample configuration
|
||||
config = {
|
||||
"configurable": {
|
||||
"user_query": "What are the best clash royale decks recommended by Surgical Goblin?",
|
||||
"num_sections": 1,
|
||||
"connectors_to_search": [
|
||||
"YOUTUBE_VIDEO",
|
||||
],
|
||||
"user_id": "d6ac2187-7407-4664-8734-af09926d161e",
|
||||
"search_space_id": 2
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# Initialize state with database session and engine
|
||||
initial_state = State(db_session=db_session, engine=engine)
|
||||
|
||||
# Run the graph directly
|
||||
print("\nRunning the complete researcher workflow...")
|
||||
result = await graph.ainvoke(initial_state, config)
|
||||
|
||||
# Extract the answer outline for display
|
||||
if "answer_outline" in result and result["answer_outline"]:
|
||||
print(f"\nGenerated answer outline with {len(result['answer_outline'].answer_outline)} sections")
|
||||
|
||||
# Print the outline
|
||||
print("\nGenerated Answer Outline:")
|
||||
for section in result["answer_outline"].answer_outline:
|
||||
print(f"\nSection {section.section_id}: {section.section_title}")
|
||||
print("Research Questions:")
|
||||
for q in section.questions:
|
||||
print(f" - {q}")
|
||||
|
||||
# Check if we got a final report
|
||||
if "final_written_report" in result and result["final_written_report"]:
|
||||
final_report = result["final_written_report"]
|
||||
print("\nFinal Research Report generated successfully!")
|
||||
print(f"Report length: {len(final_report)} characters")
|
||||
|
||||
# Display the final report
|
||||
print("\n==== FINAL RESEARCH REPORT ====\n")
|
||||
print(final_report)
|
||||
else:
|
||||
print("\nNo final report was generated.")
|
||||
print(f"Available result keys: {list(result.keys())}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error running researcher agent: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the test script."""
|
||||
try:
|
||||
result = await run_test()
|
||||
print("\nTest completed successfully.")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"\nTest failed with error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the async test
|
||||
result = asyncio.run(main())
|
||||
|
||||
# Keep terminal open if run directly in VSCode
|
||||
if 'VSCODE_PID' in os.environ:
|
||||
input("\nPress Enter to close this window...")
|
|
@ -1,14 +1,15 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Chat
|
||||
from app.schemas import ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
|
||||
from app.db import Chat, SearchSpace, User, get_async_session
|
||||
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
|
||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
@ -1,20 +1,15 @@
|
|||
import json
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, AsyncGenerator, Dict, Any
|
||||
import asyncio
|
||||
import re
|
||||
from typing import AsyncGenerator, List, Union
|
||||
from uuid import UUID
|
||||
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from app.utils.research_service import ResearchService
|
||||
from app.agents.researcher.graph import graph as researcher_graph
|
||||
from app.agents.researcher.state import State
|
||||
from app.utils.streaming_service import StreamingService
|
||||
from app.utils.reranker_service import RerankerService
|
||||
from app.utils.query_service import QueryService
|
||||
from app.config import config
|
||||
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: str,
|
||||
user_id: Union[str, UUID],
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
|
@ -25,7 +20,7 @@ async def stream_connector_search_results(
|
|||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
user_id: The user's ID (can be UUID object or string)
|
||||
search_space_id: The search space ID
|
||||
session: The database session
|
||||
research_mode: The research mode
|
||||
|
@ -34,418 +29,45 @@ async def stream_connector_search_results(
|
|||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
# Initialize services
|
||||
connector_service = ConnectorService(session)
|
||||
streaming_service = StreamingService()
|
||||
|
||||
# Reformulate the user query using the strategic LLM
|
||||
yield streaming_service.add_terminal_message("Reformulating your query for better results...", "info")
|
||||
reformulated_query = await QueryService.reformulate_query(user_query)
|
||||
yield streaming_service.add_terminal_message(f"Searching for: {reformulated_query}", "success")
|
||||
|
||||
reranker_service = RerankerService.get_reranker_instance(config)
|
||||
|
||||
all_raw_documents = [] # Store all raw documents before reranking
|
||||
all_sources = []
|
||||
TOP_K = 20
|
||||
|
||||
if research_mode == "GENERAL":
|
||||
TOP_K = 20
|
||||
NUM_SECTIONS = 1
|
||||
elif research_mode == "DEEP":
|
||||
TOP_K = 40
|
||||
NUM_SECTIONS = 3
|
||||
elif research_mode == "DEEPER":
|
||||
TOP_K = 60
|
||||
NUM_SECTIONS = 6
|
||||
|
||||
|
||||
# Process each selected connector
|
||||
for connector in selected_connectors:
|
||||
if connector == "YOUTUBE_VIDEO":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for youtube videos...")
|
||||
|
||||
# Search for YouTube videos using reformulated query
|
||||
result_object, youtube_chunks = await connector_service.search_youtube(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant YouTube videos",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(youtube_chunks)
|
||||
|
||||
|
||||
# Extension Docs
|
||||
if connector == "EXTENSION":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for extension...")
|
||||
|
||||
# Search for crawled URLs using reformulated query
|
||||
result_object, extension_chunks = await connector_service.search_extension(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant extension documents",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(extension_chunks)
|
||||
|
||||
|
||||
# Crawled URLs
|
||||
if connector == "CRAWLED_URL":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
||||
|
||||
# Search for crawled URLs using reformulated query
|
||||
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant crawled URLs",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(crawled_urls_chunks)
|
||||
|
||||
|
||||
# Files
|
||||
if connector == "FILE":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for files...")
|
||||
|
||||
# Search for files using reformulated query
|
||||
result_object, files_chunks = await connector_service.search_files(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant files",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(files_chunks)
|
||||
|
||||
# Tavily Connector
|
||||
if connector == "TAVILY_API":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
||||
|
||||
# Search using Tavily API with reformulated query
|
||||
result_object, tavily_chunks = await connector_service.search_tavily(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Tavily",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(tavily_chunks)
|
||||
|
||||
# Slack Connector
|
||||
if connector == "SLACK_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
||||
|
||||
# Search using Slack API with reformulated query
|
||||
result_object, slack_chunks = await connector_service.search_slack(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Slack",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(slack_chunks)
|
||||
|
||||
|
||||
# Notion Connector
|
||||
if connector == "NOTION_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
||||
|
||||
# Search using Notion API with reformulated query
|
||||
result_object, notion_chunks = await connector_service.search_notion(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Notion",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(notion_chunks)
|
||||
|
||||
|
||||
# Github Connector
|
||||
if connector == "GITHUB_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for GitHub connector...")
|
||||
print("Starting to search for GitHub connector...")
|
||||
# Search using Github API with reformulated query
|
||||
result_object, github_chunks = await connector_service.search_github(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Github",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(github_chunks)
|
||||
|
||||
# Linear Connector
|
||||
if connector == "LINEAR_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for Linear issues...")
|
||||
|
||||
# Search using Linear API with reformulated query
|
||||
result_object, linear_chunks = await connector_service.search_linear(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Linear",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(linear_chunks)
|
||||
|
||||
|
||||
# Convert UUID to string if needed
|
||||
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
|
||||
|
||||
|
||||
# If we have documents to research
|
||||
if all_raw_documents:
|
||||
# Rerank all documents if reranker is available
|
||||
if reranker_service:
|
||||
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
|
||||
|
||||
# Convert documents to format expected by reranker
|
||||
reranker_input_docs = [
|
||||
{
|
||||
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
||||
"content": doc.get("content", ""),
|
||||
"score": doc.get("score", 0.0),
|
||||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(all_raw_documents)
|
||||
]
|
||||
|
||||
# Rerank documents using the reformulated query
|
||||
reranked_docs = reranker_service.rerank_documents(reformulated_query, reranker_input_docs)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
|
||||
|
||||
# Convert back to langchain documents format
|
||||
from langchain.schema import Document as LangchainDocument
|
||||
all_langchain_documents_to_research = [
|
||||
LangchainDocument(
|
||||
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
|
||||
metadata={
|
||||
# **doc.get("document", {}).get("metadata", {}),
|
||||
# "score": doc.get("score", 0.0),
|
||||
# "rank": doc.get("rank", 0),
|
||||
# "document_id": doc.get("document", {}).get("id", ""),
|
||||
# "document_title": doc.get("document", {}).get("title", ""),
|
||||
# "document_type": doc.get("document", {}).get("document_type", ""),
|
||||
# # Explicitly set source_id for citation purposes
|
||||
"source_id": str(doc.get("document", {}).get("id", ""))
|
||||
}
|
||||
) for doc in reranked_docs
|
||||
]
|
||||
|
||||
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
|
||||
else:
|
||||
# Use raw documents if no reranker is available
|
||||
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
|
||||
|
||||
# Send terminal message about starting research
|
||||
yield streaming_service.add_terminal_message("Starting to research...", "info")
|
||||
|
||||
# Create a buffer to collect report content
|
||||
report_buffer = []
|
||||
|
||||
|
||||
# Use the streaming research method
|
||||
yield streaming_service.add_terminal_message("Generating report...", "info")
|
||||
|
||||
# Create a wrapper to handle the streaming
|
||||
class StreamHandler:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def handle_progress(self, data):
|
||||
result = None
|
||||
if data.get("type") == "logs":
|
||||
# Handle log messages
|
||||
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
|
||||
elif data.get("type") == "report":
|
||||
# Handle report content
|
||||
content = data.get("output", "")
|
||||
|
||||
# Fix incorrect citation formats using regex
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
content = re.sub(pattern, r'[\1]', content)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
|
||||
|
||||
report_buffer.append(content)
|
||||
# Update the answer with the accumulated content
|
||||
result = streaming_service.update_answer(report_buffer)
|
||||
|
||||
if result:
|
||||
await self.queue.put(result)
|
||||
return result
|
||||
|
||||
async def get_next(self):
|
||||
try:
|
||||
return await self.queue.get()
|
||||
except Exception as e:
|
||||
print(f"Error getting next item from queue: {e}")
|
||||
return None
|
||||
|
||||
def task_done(self):
|
||||
self.queue.task_done()
|
||||
|
||||
# Create the stream handler
|
||||
stream_handler = StreamHandler()
|
||||
|
||||
# Start the research process in a separate task
|
||||
research_task = asyncio.create_task(
|
||||
ResearchService.stream_research(
|
||||
user_query=reformulated_query,
|
||||
documents=all_langchain_documents_to_research,
|
||||
on_progress=stream_handler.handle_progress,
|
||||
research_mode=research_mode
|
||||
)
|
||||
)
|
||||
|
||||
# Stream results as they become available
|
||||
while not research_task.done() or not stream_handler.queue.empty():
|
||||
try:
|
||||
# Get the next result with a timeout
|
||||
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
|
||||
stream_handler.task_done()
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
# No result available yet, check if the research task is done
|
||||
if research_task.done():
|
||||
# If the queue is empty and the task is done, we're finished
|
||||
if stream_handler.queue.empty():
|
||||
break
|
||||
|
||||
# Get the final report
|
||||
try:
|
||||
final_report = await research_task
|
||||
|
||||
# Send terminal message about research completion
|
||||
yield streaming_service.add_terminal_message("Research completed", "success")
|
||||
|
||||
# Update the answer with the final report
|
||||
final_report_lines = final_report.split('\n')
|
||||
yield streaming_service.update_answer(final_report_lines)
|
||||
except Exception as e:
|
||||
# Handle any exceptions
|
||||
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
|
||||
|
||||
# Send completion message
|
||||
yield streaming_service.format_completion()
|
||||
# Sample configuration
|
||||
config = {
|
||||
"configurable": {
|
||||
"user_query": user_query,
|
||||
"num_sections": NUM_SECTIONS,
|
||||
"connectors_to_search": selected_connectors,
|
||||
"user_id": user_id_str,
|
||||
"search_space_id": search_space_id
|
||||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
db_session=session,
|
||||
streaming_service=streaming_service
|
||||
)
|
||||
|
||||
# Run the graph directly
|
||||
print("\nRunning the complete researcher workflow...")
|
||||
|
||||
# Use streaming with config parameter
|
||||
async for chunk in researcher_graph.astream(
|
||||
initial_state,
|
||||
config=config,
|
||||
stream_mode="custom",
|
||||
):
|
||||
# If the chunk contains a 'yeild_value' key, print its value
|
||||
# Note: there's a typo in 'yeild_value' in the code, but we need to match it
|
||||
if isinstance(chunk, dict) and 'yeild_value' in chunk:
|
||||
yield chunk['yeild_value']
|
||||
|
||||
yield streaming_service.format_completion()
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Dict, Any
|
||||
from langchain.schema import LLMResult, HumanMessage, SystemMessage
|
||||
"""
|
||||
NOTE: This is not used anymore. Might be removed in the future.
|
||||
"""
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from app.config import config
|
||||
|
||||
class QueryService:
|
||||
|
|
|
@ -1,211 +0,0 @@
|
|||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
|
||||
from langchain.schema import Document
|
||||
from gpt_researcher.agent import GPTResearcher
|
||||
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class ResearchService:
|
||||
@staticmethod
|
||||
async def create_custom_prompt(user_query: str) -> str:
|
||||
citation_prompt = f"""
|
||||
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
|
||||
|
||||
<instructions>
|
||||
1. Carefully analyze all provided documents in the <document> section's.
|
||||
2. Extract relevant information that addresses the user's query.
|
||||
3. Synthesize a comprehensive, well-structured answer using information from these documents.
|
||||
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
|
||||
5. Make sure ALL factual statements from the documents have proper citations.
|
||||
6. If multiple documents support the same point, include all relevant citations [X], [Y].
|
||||
7. Present information in a logical, coherent flow.
|
||||
8. Use your own words to connect ideas, but cite ALL information from the documents.
|
||||
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
|
||||
10. Do not make up or include information not found in the provided documents.
|
||||
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
|
||||
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
|
||||
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
|
||||
14. CRITICAL: Do not return citations as clickable links.
|
||||
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
|
||||
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
|
||||
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in clear, professional language suitable for academic or technical audiences
|
||||
- Organize your response with appropriate paragraphs, headings, and structure
|
||||
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [X], [Y], [Z]
|
||||
- No need to return references section. Just citation numbers in answer.
|
||||
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
|
||||
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
|
||||
</format>
|
||||
|
||||
<input_example>
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>1</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>13</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>21</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
</input_example>
|
||||
|
||||
<output_example>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
|
||||
</output_example>
|
||||
|
||||
<incorrect_citation_formats>
|
||||
DO NOT use any of these incorrect citation formats:
|
||||
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([1])
|
||||
- Using hyperlinked text: [link to source 1](https://example.com)
|
||||
- Using footnote style: ... reef system¹
|
||||
- Making up citation numbers when source_id is unknown
|
||||
|
||||
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
|
||||
</incorrect_citation_formats>
|
||||
|
||||
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
|
||||
|
||||
Now, please research the following query:
|
||||
|
||||
<user_query_to_research>
|
||||
{user_query}
|
||||
</user_query_to_research>
|
||||
"""
|
||||
|
||||
return citation_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def stream_research(
|
||||
user_query: str,
|
||||
documents: List[Document] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
research_mode: str = "GENERAL"
|
||||
) -> str:
|
||||
"""
|
||||
Stream the research process using GPTResearcher
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
documents: List of Document objects to use for research
|
||||
on_progress: Optional callback for progress updates
|
||||
research_mode: Research mode to use
|
||||
|
||||
Returns:
|
||||
str: The final research report
|
||||
"""
|
||||
# Create a custom websocket-like object to capture streaming output
|
||||
class StreamingWebsocket:
|
||||
async def send_json(self, data):
|
||||
if on_progress:
|
||||
try:
|
||||
# Filter out excessive logging of the prompt
|
||||
if data.get("type") == "logs":
|
||||
output = data.get("output", "")
|
||||
# Check if this is a verbose prompt log
|
||||
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
|
||||
# Replace with a shorter message
|
||||
data["output"] = f"Processing research for query: {user_query}"
|
||||
|
||||
result = await on_progress(data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error in on_progress callback: {e}")
|
||||
return None
|
||||
|
||||
streaming_websocket = StreamingWebsocket()
|
||||
|
||||
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
|
||||
|
||||
if(research_mode == "GENERAL"):
|
||||
research_report_type = ReportType.CustomReport.value
|
||||
elif(research_mode == "DEEP"):
|
||||
research_report_type = ReportType.ResearchReport.value
|
||||
elif(research_mode == "DEEPER"):
|
||||
research_report_type = ReportType.DetailedReport.value
|
||||
# elif(research_mode == "DEEPEST"):
|
||||
# research_report_type = ReportType.DeepResearch.value
|
||||
|
||||
# Initialize GPTResearcher with the streaming websocket
|
||||
researcher = GPTResearcher(
|
||||
query=custom_prompt_for_ieee_citations,
|
||||
report_type=research_report_type,
|
||||
report_format="IEEE",
|
||||
report_source=ReportSource.LangChainDocuments.value,
|
||||
tone=Tone.Formal,
|
||||
documents=documents,
|
||||
verbose=True,
|
||||
websocket=streaming_websocket
|
||||
)
|
||||
|
||||
# Conduct research
|
||||
await researcher.conduct_research()
|
||||
|
||||
# Generate report with streaming
|
||||
report = await researcher.write_report()
|
||||
|
||||
# Fix citation format
|
||||
report = ResearchService.fix_citation_format(report)
|
||||
|
||||
return report
|
||||
|
||||
@staticmethod
|
||||
def fix_citation_format(text: str) -> str:
|
||||
"""
|
||||
Fix any incorrectly formatted citations in the text.
|
||||
|
||||
Args:
|
||||
text: The text to fix
|
||||
|
||||
Returns:
|
||||
str: The text with fixed citations
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
text = re.sub(pattern, r'[\1]', text)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
|
||||
|
||||
return text
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Generator
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self):
|
||||
|
@ -18,55 +19,7 @@ class StreamingService:
|
|||
"content": []
|
||||
}
|
||||
]
|
||||
|
||||
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
|
||||
"""
|
||||
Add a terminal message to the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
message_type: The message type (info, success, error)
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[0]["content"].append({
|
||||
"id": self.terminal_idx,
|
||||
"text": text,
|
||||
"type": message_type
|
||||
})
|
||||
self.terminal_idx += 1
|
||||
return self._format_annotations()
|
||||
|
||||
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Update the sources in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
sources: List of source objects
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[1]["content"] = sources
|
||||
return self._format_annotations()
|
||||
|
||||
def update_answer(self, answer_content: List[str]) -> str:
|
||||
"""
|
||||
Update the answer in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
answer_content: The answer content as a list of strings
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[2] = {
|
||||
"type": "ANSWER",
|
||||
"content": answer_content
|
||||
}
|
||||
return self._format_annotations()
|
||||
|
||||
# It is used to send annotations to the frontend
|
||||
def _format_annotations(self) -> str:
|
||||
"""
|
||||
Format the annotations as a string
|
||||
|
@ -76,6 +29,7 @@ class StreamingService:
|
|||
"""
|
||||
return f'8:{json.dumps(self.message_annotations)}\n'
|
||||
|
||||
# It is used to end Streaming
|
||||
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
|
||||
"""
|
||||
Format a completion message
|
||||
|
@ -96,4 +50,23 @@ class StreamingService:
|
|||
"totalTokens": total_tokens
|
||||
}
|
||||
}
|
||||
return f'd:{json.dumps(completion_data)}\n'
|
||||
return f'd:{json.dumps(completion_data)}\n'
|
||||
|
||||
def only_update_terminal(self, text: str, message_type: str = "info") -> str:
|
||||
self.message_annotations[0]["content"].append({
|
||||
"id": self.terminal_idx,
|
||||
"text": text,
|
||||
"type": message_type
|
||||
})
|
||||
self.terminal_idx += 1
|
||||
return self.message_annotations
|
||||
|
||||
def only_update_sources(self, sources: List[Dict[str, Any]]) -> str:
|
||||
self.message_annotations[1]["content"] = sources
|
||||
return self.message_annotations
|
||||
|
||||
def only_update_answer(self, answer: List[str]) -> str:
|
||||
self.message_annotations[2]["content"] = answer
|
||||
return self.message_annotations
|
||||
|
||||
|
|
@ -11,7 +11,6 @@ dependencies = [
|
|||
"fastapi>=0.115.8",
|
||||
"fastapi-users[oauth,sqlalchemy]>=14.0.1",
|
||||
"firecrawl-py>=1.12.0",
|
||||
"gpt-researcher>=0.12.12",
|
||||
"github3.py==4.0.1",
|
||||
"langchain-community>=0.3.17",
|
||||
"langchain-unstructured>=0.1.6",
|
||||
|
|
|
@ -117,13 +117,21 @@ function processCitationsInReactChildren(children: React.ReactNode, getCitationS
|
|||
|
||||
// Process citation references in text content
|
||||
function processCitationsInText(text: string, getCitationSource: (id: number) => Source | null): React.ReactNode[] {
|
||||
// Use improved regex to catch citation numbers more reliably
|
||||
// This will match patterns like [1], [42], etc. including when they appear at the end of a line or sentence
|
||||
const citationRegex = /\[(\d+)\]/g;
|
||||
const parts: React.ReactNode[] = [];
|
||||
let lastIndex = 0;
|
||||
let match;
|
||||
let position = 0;
|
||||
|
||||
// Debug log for troubleshooting
|
||||
console.log("Processing citations in text:", text);
|
||||
|
||||
while ((match = citationRegex.exec(text)) !== null) {
|
||||
// Log each match for debugging
|
||||
console.log("Citation match found:", match[0], "at index", match.index);
|
||||
|
||||
// Add text before the citation
|
||||
if (match.index > lastIndex) {
|
||||
parts.push(text.substring(lastIndex, match.index));
|
||||
|
@ -131,13 +139,18 @@ function processCitationsInText(text: string, getCitationSource: (id: number) =>
|
|||
|
||||
// Add the citation component
|
||||
const citationId = parseInt(match[1], 10);
|
||||
const source = getCitationSource(citationId);
|
||||
|
||||
// Log the citation details
|
||||
console.log("Citation ID:", citationId, "Source:", source ? "found" : "not found");
|
||||
|
||||
parts.push(
|
||||
<Citation
|
||||
key={`citation-${citationId}-${position}`}
|
||||
citationId={citationId}
|
||||
citationText={match[0]}
|
||||
position={position}
|
||||
source={getCitationSource(citationId)}
|
||||
source={source}
|
||||
/>
|
||||
);
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue