feat: Removed GPT-Researcher in favour of own SurfSense LangGraph Agent

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-04-20 19:19:35 -07:00
parent 94c94e6898
commit 130f43a0fa
14 changed files with 439 additions and 918 deletions

View file

@ -120,7 +120,6 @@ This is the core of SurfSense. Before we begin let's look at `.env` variables' t
| RERANKERS_MODEL_NAME| Name of the reranker model for search result reranking. Eg. `ms-marco-MiniLM-L-12-v2`|
| RERANKERS_MODEL_TYPE| Type of reranker model being used. Eg. `flashrank`|
| FAST_LLM| LiteLLM routed Smaller, faster LLM for quick responses. Eg. `litellm:openai/gpt-4o`|
| SMART_LLM| LiteLLM routed Balanced LLM for general use. Eg. `litellm:openai/gpt-4o`|
| STRATEGIC_LLM| LiteLLM routed Advanced LLM for complex reasoning tasks. Eg. `litellm:openai/gpt-4o`|
| LONG_CONTEXT_LLM| LiteLLM routed LLM capable of handling longer context windows. Eg. `litellm:gemini/gemini-2.0-flash`|
| UNSTRUCTURED_API_KEY| API key for Unstructured.io service for document parsing|
@ -221,15 +220,15 @@ After filling in your SurfSense API key you should be able to use extension now.
- **Alembic**: A database migrations tool for SQLAlchemy.
- **FastAPI Users**: Authentication and user management with JWT and OAuth support
- **LangChain**: Framework for developing AI-powered applications
- **GPT Integration**: Integration with LLM models through LiteLLM
- **LangGraph**: Framework for developing AI-agents.
- **LangChain**: Framework for developing AI-powered applications.
- **LLM Integration**: Integration with LLM models through LiteLLM
- **Rerankers**: Advanced result ranking for improved search relevance
- **GPT-Researcher**: Advanced research capabilities
- **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF)
- **Vector Embeddings**: Document and text embeddings for semantic search

View file

@ -10,9 +10,8 @@ RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
RERANKERS_MODEL_TYPE="flashrank"
FAST_LLM="litellm:openai/gpt-4o-mini"
SMART_LLM="litellm:openai/gpt-4o-mini"
STRATEGIC_LLM="litellm:openai/gpt-4o-mini"
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21"
STRATEGIC_LLM="litellm:openai/gpt-4o"
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash"
OPENAI_API_KEY="sk-proj-iA"
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"

View file

@ -1,17 +1,23 @@
from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict, List
from app.config import config as app_config
from .prompts import get_answer_outline_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
import json
import asyncio
from .sub_section_writer.graph import graph as sub_section_writer_graph
import json
from typing import Any, Dict, List
from app.config import config as app_config
from app.db import async_session_maker
from app.utils.connector_service import ConnectorService
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from .configuration import Configuration
from .prompts import get_answer_outline_system_prompt
from .state import State
from .sub_section_writer.graph import graph as sub_section_writer_graph
from langgraph.types import StreamWriter
class Section(BaseModel):
"""A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section")
@ -22,7 +28,7 @@ class AnswerOutline(BaseModel):
"""The complete answer outline with all sections."""
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
"""
Create a structured answer outline based on the user query.
@ -33,12 +39,18 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
Returns:
Dict containing the answer outline in the "answer_outline" key for state update.
"""
streaming_service = state.streaming_service
streaming_service.only_update_terminal("Generating answer outline...")
writer({"yeild_value": streaming_service._format_annotations()})
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_query = configuration.user_query
num_sections = configuration.num_sections
streaming_service.only_update_terminal(f"Planning research approach for query: {user_query[:100]}...")
writer({"yeild_value": streaming_service._format_annotations()})
# Initialize LLM
llm = app_config.strategic_llm_instance
@ -66,6 +78,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation.
"""
streaming_service.only_update_terminal("Designing structured outline with AI...")
writer({"yeild_value": streaming_service._format_annotations()})
# Create messages for the LLM
messages = [
SystemMessage(content=get_answer_outline_system_prompt()),
@ -73,6 +88,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
]
# Call the LLM directly without using structured output
streaming_service.only_update_terminal("Processing answer structure...")
writer({"yeild_value": streaming_service._format_annotations()})
response = await llm.ainvoke(messages)
# Parse the JSON response manually
@ -92,16 +110,27 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
# Convert to Pydantic model
answer_outline = AnswerOutline(**parsed_data)
total_questions = sum(len(section.questions) for section in answer_outline.answer_outline)
streaming_service.only_update_terminal(f"Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions")
writer({"yeild_value": streaming_service._format_annotations()})
print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections")
# Return state update
return {"answer_outline": answer_outline}
else:
# If JSON structure not found, raise a clear error
raise ValueError(f"Could not find valid JSON in LLM response. Raw response: {content}")
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
streaming_service.only_update_terminal(error_message, "error")
writer({"yeild_value": streaming_service._format_annotations()})
raise ValueError(error_message)
except (json.JSONDecodeError, ValueError) as e:
# Log the error and re-raise it
error_message = f"Error parsing LLM response: {str(e)}"
streaming_service.only_update_terminal(error_message, "error")
writer({"yeild_value": streaming_service._format_annotations()})
print(f"Error parsing LLM response: {str(e)}")
print(f"Raw response: {response.content}")
raise
@ -112,18 +141,21 @@ async def fetch_relevant_documents(
search_space_id: int,
db_session: AsyncSession,
connectors_to_search: List[str],
top_k: int = 5
writer: StreamWriter = None,
state: State = None,
top_k: int = 20
) -> List[Dict[str, Any]]:
"""
Fetch relevant documents for research questions using the provided connectors.
Args:
section_title: The title of the section being researched
research_questions: List of research questions to find documents for
user_id: The user ID
search_space_id: The search space ID
db_session: The database session
connectors_to_search: List of connectors to search
writer: StreamWriter for sending progress updates
state: The current state containing the streaming service
top_k: Number of top results to retrieve per connector per question
Returns:
@ -131,83 +163,237 @@ async def fetch_relevant_documents(
"""
# Initialize services
connector_service = ConnectorService(db_session)
all_raw_documents = [] # Store all raw documents before reranking
for user_query in research_questions:
# Only use streaming if both writer and state are provided
streaming_service = state.streaming_service if state is not None else None
# Stream initial status update
if streaming_service and writer:
streaming_service.only_update_terminal(f"Starting research on {len(research_questions)} questions using {len(connectors_to_search)} connectors...")
writer({"yeild_value": streaming_service._format_annotations()})
all_raw_documents = [] # Store all raw documents
all_sources = [] # Store all sources
for i, user_query in enumerate(research_questions):
# Stream question being researched
if streaming_service and writer:
streaming_service.only_update_terminal(f"Researching question {i+1}/{len(research_questions)}: {user_query[:100]}...")
writer({"yeild_value": streaming_service._format_annotations()})
# Use original research question as the query
reformulated_query = user_query
# Process each selected connector
for connector in connectors_to_search:
# Stream connector being searched
if streaming_service and writer:
streaming_service.only_update_terminal(f"Searching {connector} for relevant information...")
writer({"yeild_value": streaming_service._format_annotations()})
try:
if connector == "YOUTUBE_VIDEO":
_, youtube_chunks = await connector_service.search_youtube(
source_object, youtube_chunks = await connector_service.search_youtube(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(youtube_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(youtube_chunks)} YouTube chunks relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "EXTENSION":
_, extension_chunks = await connector_service.search_extension(
source_object, extension_chunks = await connector_service.search_extension(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(extension_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(extension_chunks)} extension chunks relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "CRAWLED_URL":
_, crawled_urls_chunks = await connector_service.search_crawled_urls(
source_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(crawled_urls_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(crawled_urls_chunks)} crawled URL chunks relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "FILE":
_, files_chunks = await connector_service.search_files(
source_object, files_chunks = await connector_service.search_files(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(files_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(files_chunks)} file chunks relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "TAVILY_API":
_, tavily_chunks = await connector_service.search_tavily(
source_object, tavily_chunks = await connector_service.search_tavily(
user_query=reformulated_query,
user_id=user_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(tavily_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(tavily_chunks)} web search results relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "SLACK_CONNECTOR":
_, slack_chunks = await connector_service.search_slack(
source_object, slack_chunks = await connector_service.search_slack(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(slack_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(slack_chunks)} Slack messages relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "NOTION_CONNECTOR":
_, notion_chunks = await connector_service.search_notion(
source_object, notion_chunks = await connector_service.search_notion(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(notion_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(notion_chunks)} Notion pages/blocks relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "GITHUB_CONNECTOR":
source_object, github_chunks = await connector_service.search_github(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(github_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(github_chunks)} GitHub files/issues relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "LINEAR_CONNECTOR":
source_object, linear_chunks = await connector_service.search_linear(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
# Add to sources and raw documents
if source_object:
all_sources.append(source_object)
all_raw_documents.extend(linear_chunks)
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(linear_chunks)} Linear issues relevant to the query")
writer({"yeild_value": streaming_service._format_annotations()})
except Exception as e:
print(f"Error searching connector {connector}: {str(e)}")
error_message = f"Error searching connector {connector}: {str(e)}"
print(error_message)
# Stream error message
if streaming_service and writer:
streaming_service.only_update_terminal(error_message, "error")
writer({"yeild_value": streaming_service._format_annotations()})
# Continue with other connectors on error
continue
# Deduplicate documents based on chunk_id or content
# Deduplicate source objects by ID before streaming
deduplicated_sources = []
seen_source_keys = set()
for source_obj in all_sources:
# Use combination of source ID and type as a unique identifier
# This ensures we don't accidentally deduplicate sources from different connectors
source_id = source_obj.get('id')
source_type = source_obj.get('type')
if source_id and source_type:
source_key = f"{source_type}_{source_id}"
if source_key not in seen_source_keys:
seen_source_keys.add(source_key)
deduplicated_sources.append(source_obj)
else:
# If there's no ID or type, just add it to be safe
deduplicated_sources.append(source_obj)
# Stream info about deduplicated sources
if streaming_service and writer:
streaming_service.only_update_terminal(f"Collected {len(deduplicated_sources)} unique sources across all connectors")
writer({"yeild_value": streaming_service._format_annotations()})
# After all sources are collected and deduplicated, stream them
if streaming_service and writer:
streaming_service.only_update_sources(deduplicated_sources)
writer({"yeild_value": streaming_service._format_annotations()})
# Deduplicate raw documents based on chunk_id or content
seen_chunk_ids = set()
seen_content_hashes = set()
deduplicated_docs = []
@ -227,11 +413,15 @@ async def fetch_relevant_documents(
seen_content_hashes.add(content_hash)
deduplicated_docs.append(doc)
# Stream info about deduplicated documents
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(deduplicated_docs)} unique document chunks after deduplication")
writer({"yeild_value": streaming_service._format_annotations()})
# Return deduplicated documents
return deduplicated_docs
async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
"""
Process all sections in parallel and combine the results.
@ -245,89 +435,97 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
# Get configuration and answer outline from state
configuration = Configuration.from_runnable_config(config)
answer_outline = state.answer_outline
streaming_service = state.streaming_service
streaming_service.only_update_terminal(f"Starting to process research sections...")
writer({"yeild_value": streaming_service._format_annotations()})
print(f"Processing sections from outline: {answer_outline is not None}")
if not answer_outline:
streaming_service.only_update_terminal("Error: No answer outline was provided. Cannot generate report.", "error")
writer({"yeild_value": streaming_service._format_annotations()})
return {
"final_written_report": "No answer outline was provided. Cannot generate final report."
}
# Create session maker from the engine or directly use the session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
# Use the engine if available, otherwise create a new session for each task
if state.engine:
session_maker = sessionmaker(
state.engine, class_=AsyncSession, expire_on_commit=False
)
else:
# Fallback to using the same session (less optimal but will work)
print("Warning: No engine available. Using same session for all tasks.")
# Create a mock session maker that returns the same session
async def mock_session_maker():
class ContextManager:
async def __aenter__(self):
return state.db_session
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
return ContextManager()
session_maker = mock_session_maker
# Collect all questions from all sections
all_questions = []
for section in answer_outline.answer_outline:
all_questions.extend(section.questions)
print(f"Collected {len(all_questions)} questions from all sections")
streaming_service.only_update_terminal(f"Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections")
writer({"yeild_value": streaming_service._format_annotations()})
# Fetch relevant documents once for all questions
streaming_service.only_update_terminal("Searching for relevant information across all connectors...")
writer({"yeild_value": streaming_service._format_annotations()})
relevant_documents = []
async with session_maker() as db_session:
async with async_session_maker() as db_session:
try:
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=db_session,
connectors_to_search=configuration.connectors_to_search
connectors_to_search=configuration.connectors_to_search,
writer=writer,
state=state
)
except Exception as e:
print(f"Error fetching relevant documents: {str(e)}")
error_message = f"Error fetching relevant documents: {str(e)}"
print(error_message)
streaming_service.only_update_terminal(error_message, "error")
writer({"yeild_value": streaming_service._format_annotations()})
# Log the error and continue with an empty list of documents
# This allows the process to continue, but the report might lack information
relevant_documents = []
# Consider adding more robust error handling or reporting if needed
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
streaming_service.only_update_terminal(f"Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
writer({"yeild_value": streaming_service._format_annotations()})
# Create tasks to process each section in parallel with the same document set
section_tasks = []
streaming_service.only_update_terminal("Creating processing tasks for each section...")
writer({"yeild_value": streaming_service._format_annotations()})
for section in answer_outline.answer_outline:
section_tasks.append(
process_section_with_documents(
section_title=section.section_title,
section_questions=section.questions,
user_query=configuration.user_query,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
session_maker=session_maker,
relevant_documents=relevant_documents
relevant_documents=relevant_documents,
state=state,
writer=writer
)
)
# Run all section processing tasks in parallel
print(f"Running {len(section_tasks)} section processing tasks in parallel")
streaming_service.only_update_terminal(f"Processing {len(section_tasks)} sections simultaneously...")
writer({"yeild_value": streaming_service._format_annotations()})
section_results = await asyncio.gather(*section_tasks, return_exceptions=True)
# Handle any exceptions in the results
streaming_service.only_update_terminal("Combining section results into final report...")
writer({"yeild_value": streaming_service._format_annotations()})
processed_results = []
for i, result in enumerate(section_results):
if isinstance(result, Exception):
section_title = answer_outline.answer_outline[i].section_title
error_message = f"Error processing section '{section_title}': {str(result)}"
print(error_message)
streaming_service.only_update_terminal(error_message, "error")
writer({"yeild_value": streaming_service._format_annotations()})
processed_results.append(error_message)
else:
processed_results.append(result)
@ -337,12 +535,33 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
# Skip adding the section header since the content already contains the title
final_report.append(content)
final_report.append("\n") # Add spacing between sections
final_report.append("\n")
# Join all sections with newlines
final_written_report = "\n".join(final_report)
print(f"Generated final report with {len(final_report)} parts")
streaming_service.only_update_terminal("Final research report generated successfully!")
writer({"yeild_value": streaming_service._format_annotations()})
if hasattr(state, 'streaming_service') and state.streaming_service:
# Convert the final report to the expected format for UI:
# A list of strings where empty strings represent line breaks
formatted_report = []
for section in final_report:
if section == "\n":
# Add an empty string for line breaks
formatted_report.append("")
else:
# Split any multiline content by newlines and add each line
section_lines = section.split("\n")
formatted_report.extend(section_lines)
state.streaming_service.only_update_answer(formatted_report)
writer({"yeild_value": state.streaming_service._format_annotations()})
return {
"final_written_report": final_written_report
}
@ -352,8 +571,10 @@ async def process_section_with_documents(
section_questions: List[str],
user_id: str,
search_space_id: int,
session_maker,
relevant_documents: List[Dict[str, Any]]
relevant_documents: List[Dict[str, Any]],
user_query: str,
state: State = None,
writer: StreamWriter = None
) -> str:
"""
Process a single section using pre-fetched documents.
@ -363,31 +584,42 @@ async def process_section_with_documents(
section_questions: List of research questions for this section
user_id: The user ID
search_space_id: The search space ID
session_maker: Factory for creating new database sessions
relevant_documents: Pre-fetched documents to use for this section
state: The current state
writer: StreamWriter for sending progress updates
Returns:
The written section content
"""
try:
# Use the provided documents
documents_to_use = relevant_documents
# Send status update via streaming if available
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Writing section: {section_title} with {len(section_questions)} research questions")
writer({"yeild_value": state.streaming_service._format_annotations()})
# Fallback if no documents found
if not documents_to_use:
print(f"No relevant documents found for section: {section_title}")
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Warning: No relevant documents found for section: {section_title}", "warning")
writer({"yeild_value": state.streaming_service._format_annotations()})
documents_to_use = [
{"content": f"No specific information was found for: {question}"}
for question in section_questions
]
# Create a new database session for this section
async with session_maker() as db_session:
# Use the provided documents
documents_to_use = relevant_documents
# Fallback if no documents found
if not documents_to_use:
print(f"No relevant documents found for section: {section_title}")
documents_to_use = [
{"content": f"No specific information was found for: {question}"}
for question in section_questions
]
async with async_session_maker() as db_session:
# Call the sub_section_writer graph with the appropriate config
config = {
"configurable": {
"sub_section_title": section_title,
"sub_section_questions": section_questions,
"user_query": user_query,
"relevant_documents": documents_to_use,
"user_id": user_id,
"search_space_id": search_space_id
@ -395,16 +627,32 @@ async def process_section_with_documents(
}
# Create the initial state with db_session
state = {"db_session": db_session}
sub_state = {"db_session": db_session}
# Invoke the sub-section writer graph
print(f"Invoking sub_section_writer for: {section_title}")
result = await sub_section_writer_graph.ainvoke(state, config)
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Analyzing information and drafting content for section: {section_title}")
writer({"yeild_value": state.streaming_service._format_annotations()})
result = await sub_section_writer_graph.ainvoke(sub_state, config)
# Return the final answer from the sub_section_writer
final_answer = result.get("final_answer", "No content was generated for this section.")
# Send section content update via streaming if available
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Completed writing section: {section_title}")
writer({"yeild_value": state.streaming_service._format_annotations()})
return final_answer
except Exception as e:
print(f"Error processing section '{section_title}': {str(e)}")
# Send error update via streaming if available
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Error processing section '{section_title}': {str(e)}", "error")
writer({"yeild_value": state.streaming_service._format_annotations()})
return f"Error processing section: {section_title}. Details: {str(e)}"

View file

@ -3,10 +3,9 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Any, Dict, Annotated
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
from langchain_core.messages import BaseMessage, HumanMessage
from pydantic import BaseModel
from typing import Optional, Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.utils.streaming_service import StreamingService
@dataclass
class State:
@ -18,7 +17,9 @@ class State:
"""
# Runtime context (not part of actual graph state)
db_session: AsyncSession
engine: Optional[AsyncEngine] = None
# Streaming service
streaming_service: StreamingService
# Intermediate state - populated during workflow
# Using field to explicitly mark as part of state

View file

@ -15,6 +15,7 @@ class Configuration:
# Input parameters provided at invocation
sub_section_title: str
sub_section_questions: List[str]
user_query: str
relevant_documents: List[Any] # Documents provided directly to the agent
user_id: str
search_space_id: int

View file

@ -119,6 +119,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Create the query that uses the section title and questions
section_title = configuration.sub_section_title
sub_section_questions = configuration.sub_section_questions
user_query = configuration.user_query # Get the original user query
documents_text = "\n".join(formatted_documents)
# Format the questions as bullet points for clarity
@ -126,17 +127,16 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Construct a clear, structured query for the LLM
human_message_content = f"""
Please write a comprehensive answer for the title:
Now user's query is:
<user_query>
{user_query}
</user_query>
<title>
The sub-section title is:
<sub_section_title>
{section_title}
</title>
</sub_section_title>
Focus on answering these specific questions related to the title:
<questions>
{questions_text}
</questions>
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
<documents>
{documents_text}

View file

@ -1,126 +0,0 @@
#!/usr/bin/env python3
"""
Test script for the Researcher LangGraph agent.
This script demonstrates how to invoke the researcher agent with a sample query.
Run this script directly from VSCode using the "Run Python File" button or
right-click and select "Run Python File in Terminal".
Before running:
1. Make sure your Python environment has all required dependencies
2. Create a .env file with any required API keys
3. Ensure database connection is properly configured
"""
import asyncio
import os
import sys
from pathlib import Path
# Add project root to Python path so that 'app' can be found as a module
# Get the absolute path to the surfsense_backend directory which contains the app module
project_root = str(Path(__file__).resolve().parents[3]) # Go up 3 levels from the script: app/agents/researcher -> app/agents -> app -> surfsense_backend
print(f"Adding to Python path: {project_root}")
sys.path.insert(0, project_root)
# Now import the modules after fixing the path
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from dotenv import load_dotenv
# These imports should now work with the correct path
from app.agents.researcher.graph import graph
from app.agents.researcher.state import State
# Load environment variables
load_dotenv()
# Database connection string - use a test database or mock
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
# Create async engine and session
engine = create_async_engine(DATABASE_URL)
async_session_maker = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def run_test():
"""Run a test of the researcher agent."""
print("Starting researcher agent test...")
# Create a database session
async with async_session_maker() as db_session:
# Sample configuration
config = {
"configurable": {
"user_query": "What are the best clash royale decks recommended by Surgical Goblin?",
"num_sections": 1,
"connectors_to_search": [
"YOUTUBE_VIDEO",
],
"user_id": "d6ac2187-7407-4664-8734-af09926d161e",
"search_space_id": 2
}
}
try:
# Initialize state with database session and engine
initial_state = State(db_session=db_session, engine=engine)
# Run the graph directly
print("\nRunning the complete researcher workflow...")
result = await graph.ainvoke(initial_state, config)
# Extract the answer outline for display
if "answer_outline" in result and result["answer_outline"]:
print(f"\nGenerated answer outline with {len(result['answer_outline'].answer_outline)} sections")
# Print the outline
print("\nGenerated Answer Outline:")
for section in result["answer_outline"].answer_outline:
print(f"\nSection {section.section_id}: {section.section_title}")
print("Research Questions:")
for q in section.questions:
print(f" - {q}")
# Check if we got a final report
if "final_written_report" in result and result["final_written_report"]:
final_report = result["final_written_report"]
print("\nFinal Research Report generated successfully!")
print(f"Report length: {len(final_report)} characters")
# Display the final report
print("\n==== FINAL RESEARCH REPORT ====\n")
print(final_report)
else:
print("\nNo final report was generated.")
print(f"Available result keys: {list(result.keys())}")
return result
except Exception as e:
print(f"Error running researcher agent: {str(e)}")
import traceback
traceback.print_exc()
raise
async def main():
"""Main entry point for the test script."""
try:
result = await run_test()
print("\nTest completed successfully.")
return result
except Exception as e:
print(f"\nTest failed with error: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
# Run the async test
result = asyncio.run(main())
# Keep terminal open if run directly in VSCode
if 'VSCODE_PID' in os.environ:
input("\nPress Enter to close this window...")

View file

@ -1,14 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError, OperationalError
from typing import List
from app.db import get_async_session, User, SearchSpace, Chat
from app.schemas import ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
from app.db import Chat, SearchSpace, User, get_async_session
from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate
from app.tasks.stream_connector_search_results import stream_connector_search_results
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
router = APIRouter()

View file

@ -1,20 +1,15 @@
import json
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, AsyncGenerator, Dict, Any
import asyncio
import re
from typing import AsyncGenerator, List, Union
from uuid import UUID
from app.utils.connector_service import ConnectorService
from app.utils.research_service import ResearchService
from app.agents.researcher.graph import graph as researcher_graph
from app.agents.researcher.state import State
from app.utils.streaming_service import StreamingService
from app.utils.reranker_service import RerankerService
from app.utils.query_service import QueryService
from app.config import config
from app.utils.document_converters import convert_chunks_to_langchain_documents
from sqlalchemy.ext.asyncio import AsyncSession
async def stream_connector_search_results(
user_query: str,
user_id: str,
user_id: Union[str, UUID],
search_space_id: int,
session: AsyncSession,
research_mode: str,
@ -25,7 +20,7 @@ async def stream_connector_search_results(
Args:
user_query: The user's query
user_id: The user's ID
user_id: The user's ID (can be UUID object or string)
search_space_id: The search space ID
session: The database session
research_mode: The research mode
@ -34,418 +29,45 @@ async def stream_connector_search_results(
Yields:
str: Formatted response strings
"""
# Initialize services
connector_service = ConnectorService(session)
streaming_service = StreamingService()
# Reformulate the user query using the strategic LLM
yield streaming_service.add_terminal_message("Reformulating your query for better results...", "info")
reformulated_query = await QueryService.reformulate_query(user_query)
yield streaming_service.add_terminal_message(f"Searching for: {reformulated_query}", "success")
reranker_service = RerankerService.get_reranker_instance(config)
all_raw_documents = [] # Store all raw documents before reranking
all_sources = []
TOP_K = 20
if research_mode == "GENERAL":
TOP_K = 20
NUM_SECTIONS = 1
elif research_mode == "DEEP":
TOP_K = 40
NUM_SECTIONS = 3
elif research_mode == "DEEPER":
TOP_K = 60
NUM_SECTIONS = 6
# Process each selected connector
for connector in selected_connectors:
if connector == "YOUTUBE_VIDEO":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for youtube videos...")
# Search for YouTube videos using reformulated query
result_object, youtube_chunks = await connector_service.search_youtube(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant YouTube videos",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(youtube_chunks)
# Extension Docs
if connector == "EXTENSION":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for extension...")
# Search for crawled URLs using reformulated query
result_object, extension_chunks = await connector_service.search_extension(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant extension documents",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(extension_chunks)
# Crawled URLs
if connector == "CRAWLED_URL":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
# Search for crawled URLs using reformulated query
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant crawled URLs",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(crawled_urls_chunks)
# Files
if connector == "FILE":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for files...")
# Search for files using reformulated query
result_object, files_chunks = await connector_service.search_files(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant files",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(files_chunks)
# Tavily Connector
if connector == "TAVILY_API":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
# Search using Tavily API with reformulated query
result_object, tavily_chunks = await connector_service.search_tavily(
user_query=reformulated_query,
user_id=user_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Tavily",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(tavily_chunks)
# Slack Connector
if connector == "SLACK_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
# Search using Slack API with reformulated query
result_object, slack_chunks = await connector_service.search_slack(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Slack",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(slack_chunks)
# Notion Connector
if connector == "NOTION_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
# Search using Notion API with reformulated query
result_object, notion_chunks = await connector_service.search_notion(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Notion",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(notion_chunks)
# Github Connector
if connector == "GITHUB_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for GitHub connector...")
print("Starting to search for GitHub connector...")
# Search using Github API with reformulated query
result_object, github_chunks = await connector_service.search_github(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Github",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(github_chunks)
# Linear Connector
if connector == "LINEAR_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for Linear issues...")
# Search using Linear API with reformulated query
result_object, linear_chunks = await connector_service.search_linear(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Linear",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(linear_chunks)
# Convert UUID to string if needed
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
# If we have documents to research
if all_raw_documents:
# Rerank all documents if reranker is available
if reranker_service:
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
# Convert documents to format expected by reranker
reranker_input_docs = [
{
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
"content": doc.get("content", ""),
"score": doc.get("score", 0.0),
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(all_raw_documents)
]
# Rerank documents using the reformulated query
reranked_docs = reranker_service.rerank_documents(reformulated_query, reranker_input_docs)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
# Convert back to langchain documents format
from langchain.schema import Document as LangchainDocument
all_langchain_documents_to_research = [
LangchainDocument(
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
metadata={
# **doc.get("document", {}).get("metadata", {}),
# "score": doc.get("score", 0.0),
# "rank": doc.get("rank", 0),
# "document_id": doc.get("document", {}).get("id", ""),
# "document_title": doc.get("document", {}).get("title", ""),
# "document_type": doc.get("document", {}).get("document_type", ""),
# # Explicitly set source_id for citation purposes
"source_id": str(doc.get("document", {}).get("id", ""))
}
) for doc in reranked_docs
]
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
else:
# Use raw documents if no reranker is available
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
# Send terminal message about starting research
yield streaming_service.add_terminal_message("Starting to research...", "info")
# Create a buffer to collect report content
report_buffer = []
# Use the streaming research method
yield streaming_service.add_terminal_message("Generating report...", "info")
# Create a wrapper to handle the streaming
class StreamHandler:
def __init__(self):
self.queue = asyncio.Queue()
async def handle_progress(self, data):
result = None
if data.get("type") == "logs":
# Handle log messages
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
elif data.get("type") == "report":
# Handle report content
content = data.get("output", "")
# Fix incorrect citation formats using regex
# More specific pattern to match only numeric citations in markdown-style links
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
# Replace with just [X] where X is the number
content = re.sub(pattern, r'[\1]', content)
# Also match other incorrect formats like ([1]) and convert to [1]
# Only match if the content inside brackets is a number
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
report_buffer.append(content)
# Update the answer with the accumulated content
result = streaming_service.update_answer(report_buffer)
if result:
await self.queue.put(result)
return result
async def get_next(self):
try:
return await self.queue.get()
except Exception as e:
print(f"Error getting next item from queue: {e}")
return None
def task_done(self):
self.queue.task_done()
# Create the stream handler
stream_handler = StreamHandler()
# Start the research process in a separate task
research_task = asyncio.create_task(
ResearchService.stream_research(
user_query=reformulated_query,
documents=all_langchain_documents_to_research,
on_progress=stream_handler.handle_progress,
research_mode=research_mode
)
)
# Stream results as they become available
while not research_task.done() or not stream_handler.queue.empty():
try:
# Get the next result with a timeout
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
stream_handler.task_done()
yield result
except asyncio.TimeoutError:
# No result available yet, check if the research task is done
if research_task.done():
# If the queue is empty and the task is done, we're finished
if stream_handler.queue.empty():
break
# Get the final report
try:
final_report = await research_task
# Send terminal message about research completion
yield streaming_service.add_terminal_message("Research completed", "success")
# Update the answer with the final report
final_report_lines = final_report.split('\n')
yield streaming_service.update_answer(final_report_lines)
except Exception as e:
# Handle any exceptions
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
# Send completion message
yield streaming_service.format_completion()
# Sample configuration
config = {
"configurable": {
"user_query": user_query,
"num_sections": NUM_SECTIONS,
"connectors_to_search": selected_connectors,
"user_id": user_id_str,
"search_space_id": search_space_id
}
}
# Initialize state with database session and streaming service
initial_state = State(
db_session=session,
streaming_service=streaming_service
)
# Run the graph directly
print("\nRunning the complete researcher workflow...")
# Use streaming with config parameter
async for chunk in researcher_graph.astream(
initial_state,
config=config,
stream_mode="custom",
):
# If the chunk contains a 'yeild_value' key, print its value
# Note: there's a typo in 'yeild_value' in the code, but we need to match it
if isinstance(chunk, dict) and 'yeild_value' in chunk:
yield chunk['yeild_value']
yield streaming_service.format_completion()

View file

@ -1,5 +1,7 @@
from typing import Dict, Any
from langchain.schema import LLMResult, HumanMessage, SystemMessage
"""
NOTE: This is not used anymore. Might be removed in the future.
"""
from langchain.schema import HumanMessage, SystemMessage
from app.config import config
class QueryService:

View file

@ -1,211 +0,0 @@
import asyncio
import re
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
from langchain.schema import Document
from gpt_researcher.agent import GPTResearcher
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
from dotenv import load_dotenv
load_dotenv()
class ResearchService:
@staticmethod
async def create_custom_prompt(user_query: str) -> str:
citation_prompt = f"""
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
<instructions>
1. Carefully analyze all provided documents in the <document> section's.
2. Extract relevant information that addresses the user's query.
3. Synthesize a comprehensive, well-structured answer using information from these documents.
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
5. Make sure ALL factual statements from the documents have proper citations.
6. If multiple documents support the same point, include all relevant citations [X], [Y].
7. Present information in a logical, coherent flow.
8. Use your own words to connect ideas, but cite ALL information from the documents.
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
10. Do not make up or include information not found in the provided documents.
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
14. CRITICAL: Do not return citations as clickable links.
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
</instructions>
<format>
- Write in clear, professional language suitable for academic or technical audiences
- Organize your response with appropriate paragraphs, headings, and structure
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
- Citations should appear at the end of the sentence containing the information they support
- Multiple citations should be separated by commas: [X], [Y], [Z]
- No need to return references section. Just citation numbers in answer.
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
</format>
<input_example>
<document>
<metadata>
<source_id>1</source_id>
</metadata>
<content>
<text>
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
</text>
</content>
</document>
<document>
<metadata>
<source_id>13</source_id>
</metadata>
<content>
<text>
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
</text>
</content>
</document>
<document>
<metadata>
<source_id>21</source_id>
</metadata>
<content>
<text>
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
</text>
</content>
</document>
</input_example>
<output_example>
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
</output_example>
<incorrect_citation_formats>
DO NOT use any of these incorrect citation formats:
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
- Using parentheses around brackets: ([1])
- Using hyperlinked text: [link to source 1](https://example.com)
- Using footnote style: ... reef system¹
- Making up citation numbers when source_id is unknown
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
</incorrect_citation_formats>
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
Now, please research the following query:
<user_query_to_research>
{user_query}
</user_query_to_research>
"""
return citation_prompt
@staticmethod
async def stream_research(
user_query: str,
documents: List[Document] = None,
on_progress: Optional[Callable] = None,
research_mode: str = "GENERAL"
) -> str:
"""
Stream the research process using GPTResearcher
Args:
user_query: The user's query
documents: List of Document objects to use for research
on_progress: Optional callback for progress updates
research_mode: Research mode to use
Returns:
str: The final research report
"""
# Create a custom websocket-like object to capture streaming output
class StreamingWebsocket:
async def send_json(self, data):
if on_progress:
try:
# Filter out excessive logging of the prompt
if data.get("type") == "logs":
output = data.get("output", "")
# Check if this is a verbose prompt log
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
# Replace with a shorter message
data["output"] = f"Processing research for query: {user_query}"
result = await on_progress(data)
return result
except Exception as e:
print(f"Error in on_progress callback: {e}")
return None
streaming_websocket = StreamingWebsocket()
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
if(research_mode == "GENERAL"):
research_report_type = ReportType.CustomReport.value
elif(research_mode == "DEEP"):
research_report_type = ReportType.ResearchReport.value
elif(research_mode == "DEEPER"):
research_report_type = ReportType.DetailedReport.value
# elif(research_mode == "DEEPEST"):
# research_report_type = ReportType.DeepResearch.value
# Initialize GPTResearcher with the streaming websocket
researcher = GPTResearcher(
query=custom_prompt_for_ieee_citations,
report_type=research_report_type,
report_format="IEEE",
report_source=ReportSource.LangChainDocuments.value,
tone=Tone.Formal,
documents=documents,
verbose=True,
websocket=streaming_websocket
)
# Conduct research
await researcher.conduct_research()
# Generate report with streaming
report = await researcher.write_report()
# Fix citation format
report = ResearchService.fix_citation_format(report)
return report
@staticmethod
def fix_citation_format(text: str) -> str:
"""
Fix any incorrectly formatted citations in the text.
Args:
text: The text to fix
Returns:
str: The text with fixed citations
"""
if not text:
return text
# More specific pattern to match only numeric citations in markdown-style links
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
# Replace with just [X] where X is the number
text = re.sub(pattern, r'[\1]', text)
# Also match other incorrect formats like ([1]) and convert to [1]
# Only match if the content inside brackets is a number
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
return text

View file

@ -1,5 +1,6 @@
import json
from typing import List, Dict, Any, Generator
from typing import Any, Dict, List
class StreamingService:
def __init__(self):
@ -18,55 +19,7 @@ class StreamingService:
"content": []
}
]
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
"""
Add a terminal message to the annotations and return the formatted response
Args:
text: The message text
message_type: The message type (info, success, error)
Returns:
str: The formatted response string
"""
self.message_annotations[0]["content"].append({
"id": self.terminal_idx,
"text": text,
"type": message_type
})
self.terminal_idx += 1
return self._format_annotations()
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
"""
Update the sources in the annotations and return the formatted response
Args:
sources: List of source objects
Returns:
str: The formatted response string
"""
self.message_annotations[1]["content"] = sources
return self._format_annotations()
def update_answer(self, answer_content: List[str]) -> str:
"""
Update the answer in the annotations and return the formatted response
Args:
answer_content: The answer content as a list of strings
Returns:
str: The formatted response string
"""
self.message_annotations[2] = {
"type": "ANSWER",
"content": answer_content
}
return self._format_annotations()
# It is used to send annotations to the frontend
def _format_annotations(self) -> str:
"""
Format the annotations as a string
@ -76,6 +29,7 @@ class StreamingService:
"""
return f'8:{json.dumps(self.message_annotations)}\n'
# It is used to end Streaming
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
"""
Format a completion message
@ -96,4 +50,23 @@ class StreamingService:
"totalTokens": total_tokens
}
}
return f'd:{json.dumps(completion_data)}\n'
return f'd:{json.dumps(completion_data)}\n'
def only_update_terminal(self, text: str, message_type: str = "info") -> str:
self.message_annotations[0]["content"].append({
"id": self.terminal_idx,
"text": text,
"type": message_type
})
self.terminal_idx += 1
return self.message_annotations
def only_update_sources(self, sources: List[Dict[str, Any]]) -> str:
self.message_annotations[1]["content"] = sources
return self.message_annotations
def only_update_answer(self, answer: List[str]) -> str:
self.message_annotations[2]["content"] = answer
return self.message_annotations

View file

@ -11,7 +11,6 @@ dependencies = [
"fastapi>=0.115.8",
"fastapi-users[oauth,sqlalchemy]>=14.0.1",
"firecrawl-py>=1.12.0",
"gpt-researcher>=0.12.12",
"github3.py==4.0.1",
"langchain-community>=0.3.17",
"langchain-unstructured>=0.1.6",

View file

@ -117,13 +117,21 @@ function processCitationsInReactChildren(children: React.ReactNode, getCitationS
// Process citation references in text content
function processCitationsInText(text: string, getCitationSource: (id: number) => Source | null): React.ReactNode[] {
// Use improved regex to catch citation numbers more reliably
// This will match patterns like [1], [42], etc. including when they appear at the end of a line or sentence
const citationRegex = /\[(\d+)\]/g;
const parts: React.ReactNode[] = [];
let lastIndex = 0;
let match;
let position = 0;
// Debug log for troubleshooting
console.log("Processing citations in text:", text);
while ((match = citationRegex.exec(text)) !== null) {
// Log each match for debugging
console.log("Citation match found:", match[0], "at index", match.index);
// Add text before the citation
if (match.index > lastIndex) {
parts.push(text.substring(lastIndex, match.index));
@ -131,13 +139,18 @@ function processCitationsInText(text: string, getCitationSource: (id: number) =>
// Add the citation component
const citationId = parseInt(match[1], 10);
const source = getCitationSource(citationId);
// Log the citation details
console.log("Citation ID:", citationId, "Source:", source ? "found" : "not found");
parts.push(
<Citation
key={`citation-${citationId}-${position}`}
citationId={citationId}
citationText={match[0]}
position={position}
source={getCitationSource(citationId)}
source={source}
/>
);