diff --git a/README.md b/README.md
index 4b9ea5a..aa337c7 100644
--- a/README.md
+++ b/README.md
@@ -120,7 +120,6 @@ This is the core of SurfSense. Before we begin let's look at `.env` variables' t
| RERANKERS_MODEL_NAME| Name of the reranker model for search result reranking. Eg. `ms-marco-MiniLM-L-12-v2`|
| RERANKERS_MODEL_TYPE| Type of reranker model being used. Eg. `flashrank`|
| FAST_LLM| LiteLLM routed Smaller, faster LLM for quick responses. Eg. `litellm:openai/gpt-4o`|
-| SMART_LLM| LiteLLM routed Balanced LLM for general use. Eg. `litellm:openai/gpt-4o`|
| STRATEGIC_LLM| LiteLLM routed Advanced LLM for complex reasoning tasks. Eg. `litellm:openai/gpt-4o`|
| LONG_CONTEXT_LLM| LiteLLM routed LLM capable of handling longer context windows. Eg. `litellm:gemini/gemini-2.0-flash`|
| UNSTRUCTURED_API_KEY| API key for Unstructured.io service for document parsing|
@@ -221,15 +220,15 @@ After filling in your SurfSense API key you should be able to use extension now.
- **Alembic**: A database migrations tool for SQLAlchemy.
- **FastAPI Users**: Authentication and user management with JWT and OAuth support
-
-- **LangChain**: Framework for developing AI-powered applications
-- **GPT Integration**: Integration with LLM models through LiteLLM
+- **LangGraph**: Framework for developing AI-agents.
+
+- **LangChain**: Framework for developing AI-powered applications.
+
+- **LLM Integration**: Integration with LLM models through LiteLLM
- **Rerankers**: Advanced result ranking for improved search relevance
-- **GPT-Researcher**: Advanced research capabilities
-
- **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF)
- **Vector Embeddings**: Document and text embeddings for semantic search
diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example
index 380296e..bc3f5ce 100644
--- a/surfsense_backend/.env.example
+++ b/surfsense_backend/.env.example
@@ -10,9 +10,8 @@ RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
RERANKERS_MODEL_TYPE="flashrank"
FAST_LLM="litellm:openai/gpt-4o-mini"
-SMART_LLM="litellm:openai/gpt-4o-mini"
-STRATEGIC_LLM="litellm:openai/gpt-4o-mini"
-LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21"
+STRATEGIC_LLM="litellm:openai/gpt-4o"
+LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash"
OPENAI_API_KEY="sk-proj-iA"
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py
index c88ad33..15935f2 100644
--- a/surfsense_backend/app/agents/researcher/nodes.py
+++ b/surfsense_backend/app/agents/researcher/nodes.py
@@ -1,17 +1,23 @@
-from .configuration import Configuration
-from langchain_core.runnables import RunnableConfig
-from .state import State
-from typing import Any, Dict, List
-from app.config import config as app_config
-from .prompts import get_answer_outline_system_prompt
-from langchain_core.messages import HumanMessage, SystemMessage
-from pydantic import BaseModel, Field
-import json
import asyncio
-from .sub_section_writer.graph import graph as sub_section_writer_graph
+import json
+from typing import Any, Dict, List
+
+from app.config import config as app_config
+from app.db import async_session_maker
from app.utils.connector_service import ConnectorService
+from langchain_core.messages import HumanMessage, SystemMessage
+from langchain_core.runnables import RunnableConfig
+from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
+from .configuration import Configuration
+from .prompts import get_answer_outline_system_prompt
+from .state import State
+from .sub_section_writer.graph import graph as sub_section_writer_graph
+
+from langgraph.types import StreamWriter
+
+
class Section(BaseModel):
"""A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section")
@@ -22,7 +28,7 @@ class AnswerOutline(BaseModel):
"""The complete answer outline with all sections."""
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
-async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str, Any]:
+async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
"""
Create a structured answer outline based on the user query.
@@ -33,12 +39,18 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
Returns:
Dict containing the answer outline in the "answer_outline" key for state update.
"""
+ streaming_service = state.streaming_service
+ streaming_service.only_update_terminal("Generating answer outline...")
+ writer({"yeild_value": streaming_service._format_annotations()})
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_query = configuration.user_query
num_sections = configuration.num_sections
+ streaming_service.only_update_terminal(f"Planning research approach for query: {user_query[:100]}...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
# Initialize LLM
llm = app_config.strategic_llm_instance
@@ -66,6 +78,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation.
"""
+ streaming_service.only_update_terminal("Designing structured outline with AI...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
# Create messages for the LLM
messages = [
SystemMessage(content=get_answer_outline_system_prompt()),
@@ -73,6 +88,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
]
# Call the LLM directly without using structured output
+ streaming_service.only_update_terminal("Processing answer structure...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
response = await llm.ainvoke(messages)
# Parse the JSON response manually
@@ -92,16 +110,27 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str
# Convert to Pydantic model
answer_outline = AnswerOutline(**parsed_data)
+ total_questions = sum(len(section.questions) for section in answer_outline.answer_outline)
+ streaming_service.only_update_terminal(f"Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections")
# Return state update
return {"answer_outline": answer_outline}
else:
# If JSON structure not found, raise a clear error
- raise ValueError(f"Could not find valid JSON in LLM response. Raw response: {content}")
+ error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
+ streaming_service.only_update_terminal(error_message, "error")
+ writer({"yeild_value": streaming_service._format_annotations()})
+ raise ValueError(error_message)
except (json.JSONDecodeError, ValueError) as e:
# Log the error and re-raise it
+ error_message = f"Error parsing LLM response: {str(e)}"
+ streaming_service.only_update_terminal(error_message, "error")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
print(f"Error parsing LLM response: {str(e)}")
print(f"Raw response: {response.content}")
raise
@@ -112,18 +141,21 @@ async def fetch_relevant_documents(
search_space_id: int,
db_session: AsyncSession,
connectors_to_search: List[str],
- top_k: int = 5
+ writer: StreamWriter = None,
+ state: State = None,
+ top_k: int = 20
) -> List[Dict[str, Any]]:
"""
Fetch relevant documents for research questions using the provided connectors.
Args:
- section_title: The title of the section being researched
research_questions: List of research questions to find documents for
user_id: The user ID
search_space_id: The search space ID
db_session: The database session
connectors_to_search: List of connectors to search
+ writer: StreamWriter for sending progress updates
+ state: The current state containing the streaming service
top_k: Number of top results to retrieve per connector per question
Returns:
@@ -131,83 +163,237 @@ async def fetch_relevant_documents(
"""
# Initialize services
connector_service = ConnectorService(db_session)
-
- all_raw_documents = [] # Store all raw documents before reranking
- for user_query in research_questions:
+ # Only use streaming if both writer and state are provided
+ streaming_service = state.streaming_service if state is not None else None
+
+ # Stream initial status update
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Starting research on {len(research_questions)} questions using {len(connectors_to_search)} connectors...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ all_raw_documents = [] # Store all raw documents
+ all_sources = [] # Store all sources
+
+ for i, user_query in enumerate(research_questions):
+ # Stream question being researched
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Researching question {i+1}/{len(research_questions)}: {user_query[:100]}...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
# Use original research question as the query
reformulated_query = user_query
# Process each selected connector
for connector in connectors_to_search:
+ # Stream connector being searched
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Searching {connector} for relevant information...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
try:
if connector == "YOUTUBE_VIDEO":
- _, youtube_chunks = await connector_service.search_youtube(
+ source_object, youtube_chunks = await connector_service.search_youtube(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(youtube_chunks)
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(youtube_chunks)} YouTube chunks relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
elif connector == "EXTENSION":
- _, extension_chunks = await connector_service.search_extension(
+ source_object, extension_chunks = await connector_service.search_extension(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(extension_chunks)
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(extension_chunks)} extension chunks relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
elif connector == "CRAWLED_URL":
- _, crawled_urls_chunks = await connector_service.search_crawled_urls(
+ source_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(crawled_urls_chunks)
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(crawled_urls_chunks)} crawled URL chunks relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
elif connector == "FILE":
- _, files_chunks = await connector_service.search_files(
+ source_object, files_chunks = await connector_service.search_files(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(files_chunks)
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(files_chunks)} file chunks relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
elif connector == "TAVILY_API":
- _, tavily_chunks = await connector_service.search_tavily(
+ source_object, tavily_chunks = await connector_service.search_tavily(
user_query=reformulated_query,
user_id=user_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(tavily_chunks)
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(tavily_chunks)} web search results relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
elif connector == "SLACK_CONNECTOR":
- _, slack_chunks = await connector_service.search_slack(
+ source_object, slack_chunks = await connector_service.search_slack(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(slack_chunks)
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(slack_chunks)} Slack messages relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
elif connector == "NOTION_CONNECTOR":
- _, notion_chunks = await connector_service.search_notion(
+ source_object, notion_chunks = await connector_service.search_notion(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
all_raw_documents.extend(notion_chunks)
+
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(notion_chunks)} Notion pages/blocks relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ elif connector == "GITHUB_CONNECTOR":
+ source_object, github_chunks = await connector_service.search_github(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
+ all_raw_documents.extend(github_chunks)
+
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(github_chunks)} GitHub files/issues relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ elif connector == "LINEAR_CONNECTOR":
+ source_object, linear_chunks = await connector_service.search_linear(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+
+ # Add to sources and raw documents
+ if source_object:
+ all_sources.append(source_object)
+ all_raw_documents.extend(linear_chunks)
+
+ # Stream found document count
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(linear_chunks)} Linear issues relevant to the query")
+ writer({"yeild_value": streaming_service._format_annotations()})
except Exception as e:
- print(f"Error searching connector {connector}: {str(e)}")
+ error_message = f"Error searching connector {connector}: {str(e)}"
+ print(error_message)
+
+ # Stream error message
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(error_message, "error")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
# Continue with other connectors on error
continue
- # Deduplicate documents based on chunk_id or content
+ # Deduplicate source objects by ID before streaming
+ deduplicated_sources = []
+ seen_source_keys = set()
+
+ for source_obj in all_sources:
+ # Use combination of source ID and type as a unique identifier
+ # This ensures we don't accidentally deduplicate sources from different connectors
+ source_id = source_obj.get('id')
+ source_type = source_obj.get('type')
+
+ if source_id and source_type:
+ source_key = f"{source_type}_{source_id}"
+ if source_key not in seen_source_keys:
+ seen_source_keys.add(source_key)
+ deduplicated_sources.append(source_obj)
+ else:
+ # If there's no ID or type, just add it to be safe
+ deduplicated_sources.append(source_obj)
+
+ # Stream info about deduplicated sources
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Collected {len(deduplicated_sources)} unique sources across all connectors")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ # After all sources are collected and deduplicated, stream them
+ if streaming_service and writer:
+ streaming_service.only_update_sources(deduplicated_sources)
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ # Deduplicate raw documents based on chunk_id or content
seen_chunk_ids = set()
seen_content_hashes = set()
deduplicated_docs = []
@@ -227,11 +413,15 @@ async def fetch_relevant_documents(
seen_content_hashes.add(content_hash)
deduplicated_docs.append(doc)
+ # Stream info about deduplicated documents
+ if streaming_service and writer:
+ streaming_service.only_update_terminal(f"Found {len(deduplicated_docs)} unique document chunks after deduplication")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ # Return deduplicated documents
return deduplicated_docs
-
-
-async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
+async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
"""
Process all sections in parallel and combine the results.
@@ -245,89 +435,97 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
# Get configuration and answer outline from state
configuration = Configuration.from_runnable_config(config)
answer_outline = state.answer_outline
+ streaming_service = state.streaming_service
+
+ streaming_service.only_update_terminal(f"Starting to process research sections...")
+ writer({"yeild_value": streaming_service._format_annotations()})
print(f"Processing sections from outline: {answer_outline is not None}")
if not answer_outline:
+ streaming_service.only_update_terminal("Error: No answer outline was provided. Cannot generate report.", "error")
+ writer({"yeild_value": streaming_service._format_annotations()})
return {
"final_written_report": "No answer outline was provided. Cannot generate final report."
}
- # Create session maker from the engine or directly use the session
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import sessionmaker
-
- # Use the engine if available, otherwise create a new session for each task
- if state.engine:
- session_maker = sessionmaker(
- state.engine, class_=AsyncSession, expire_on_commit=False
- )
- else:
- # Fallback to using the same session (less optimal but will work)
- print("Warning: No engine available. Using same session for all tasks.")
- # Create a mock session maker that returns the same session
- async def mock_session_maker():
- class ContextManager:
- async def __aenter__(self):
- return state.db_session
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- pass
- return ContextManager()
- session_maker = mock_session_maker
-
# Collect all questions from all sections
all_questions = []
for section in answer_outline.answer_outline:
all_questions.extend(section.questions)
print(f"Collected {len(all_questions)} questions from all sections")
+ streaming_service.only_update_terminal(f"Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections")
+ writer({"yeild_value": streaming_service._format_annotations()})
# Fetch relevant documents once for all questions
+ streaming_service.only_update_terminal("Searching for relevant information across all connectors...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
relevant_documents = []
- async with session_maker() as db_session:
-
+ async with async_session_maker() as db_session:
try:
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=db_session,
- connectors_to_search=configuration.connectors_to_search
+ connectors_to_search=configuration.connectors_to_search,
+ writer=writer,
+ state=state
)
except Exception as e:
- print(f"Error fetching relevant documents: {str(e)}")
+ error_message = f"Error fetching relevant documents: {str(e)}"
+ print(error_message)
+ streaming_service.only_update_terminal(error_message, "error")
+ writer({"yeild_value": streaming_service._format_annotations()})
# Log the error and continue with an empty list of documents
# This allows the process to continue, but the report might lack information
relevant_documents = []
# Consider adding more robust error handling or reporting if needed
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
+ streaming_service.only_update_terminal(f"Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
+ writer({"yeild_value": streaming_service._format_annotations()})
# Create tasks to process each section in parallel with the same document set
section_tasks = []
+ streaming_service.only_update_terminal("Creating processing tasks for each section...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
for section in answer_outline.answer_outline:
section_tasks.append(
process_section_with_documents(
section_title=section.section_title,
section_questions=section.questions,
+ user_query=configuration.user_query,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
- session_maker=session_maker,
- relevant_documents=relevant_documents
+ relevant_documents=relevant_documents,
+ state=state,
+ writer=writer
)
)
# Run all section processing tasks in parallel
print(f"Running {len(section_tasks)} section processing tasks in parallel")
+ streaming_service.only_update_terminal(f"Processing {len(section_tasks)} sections simultaneously...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
section_results = await asyncio.gather(*section_tasks, return_exceptions=True)
# Handle any exceptions in the results
+ streaming_service.only_update_terminal("Combining section results into final report...")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
processed_results = []
for i, result in enumerate(section_results):
if isinstance(result, Exception):
section_title = answer_outline.answer_outline[i].section_title
error_message = f"Error processing section '{section_title}': {str(result)}"
print(error_message)
+ streaming_service.only_update_terminal(error_message, "error")
+ writer({"yeild_value": streaming_service._format_annotations()})
processed_results.append(error_message)
else:
processed_results.append(result)
@@ -337,12 +535,33 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
# Skip adding the section header since the content already contains the title
final_report.append(content)
- final_report.append("\n") # Add spacing between sections
+ final_report.append("\n")
+
# Join all sections with newlines
final_written_report = "\n".join(final_report)
print(f"Generated final report with {len(final_report)} parts")
+ streaming_service.only_update_terminal("Final research report generated successfully!")
+ writer({"yeild_value": streaming_service._format_annotations()})
+
+ if hasattr(state, 'streaming_service') and state.streaming_service:
+ # Convert the final report to the expected format for UI:
+ # A list of strings where empty strings represent line breaks
+ formatted_report = []
+ for section in final_report:
+ if section == "\n":
+ # Add an empty string for line breaks
+ formatted_report.append("")
+ else:
+ # Split any multiline content by newlines and add each line
+ section_lines = section.split("\n")
+ formatted_report.extend(section_lines)
+
+ state.streaming_service.only_update_answer(formatted_report)
+ writer({"yeild_value": state.streaming_service._format_annotations()})
+
+
return {
"final_written_report": final_written_report
}
@@ -352,8 +571,10 @@ async def process_section_with_documents(
section_questions: List[str],
user_id: str,
search_space_id: int,
- session_maker,
- relevant_documents: List[Dict[str, Any]]
+ relevant_documents: List[Dict[str, Any]],
+ user_query: str,
+ state: State = None,
+ writer: StreamWriter = None
) -> str:
"""
Process a single section using pre-fetched documents.
@@ -363,31 +584,42 @@ async def process_section_with_documents(
section_questions: List of research questions for this section
user_id: The user ID
search_space_id: The search space ID
- session_maker: Factory for creating new database sessions
relevant_documents: Pre-fetched documents to use for this section
+ state: The current state
+ writer: StreamWriter for sending progress updates
Returns:
The written section content
"""
try:
+ # Use the provided documents
+ documents_to_use = relevant_documents
+
+ # Send status update via streaming if available
+ if state and state.streaming_service and writer:
+ state.streaming_service.only_update_terminal(f"Writing section: {section_title} with {len(section_questions)} research questions")
+ writer({"yeild_value": state.streaming_service._format_annotations()})
+
+ # Fallback if no documents found
+ if not documents_to_use:
+ print(f"No relevant documents found for section: {section_title}")
+ if state and state.streaming_service and writer:
+ state.streaming_service.only_update_terminal(f"Warning: No relevant documents found for section: {section_title}", "warning")
+ writer({"yeild_value": state.streaming_service._format_annotations()})
+
+ documents_to_use = [
+ {"content": f"No specific information was found for: {question}"}
+ for question in section_questions
+ ]
+
# Create a new database session for this section
- async with session_maker() as db_session:
- # Use the provided documents
- documents_to_use = relevant_documents
-
- # Fallback if no documents found
- if not documents_to_use:
- print(f"No relevant documents found for section: {section_title}")
- documents_to_use = [
- {"content": f"No specific information was found for: {question}"}
- for question in section_questions
- ]
-
+ async with async_session_maker() as db_session:
# Call the sub_section_writer graph with the appropriate config
config = {
"configurable": {
"sub_section_title": section_title,
"sub_section_questions": section_questions,
+ "user_query": user_query,
"relevant_documents": documents_to_use,
"user_id": user_id,
"search_space_id": search_space_id
@@ -395,16 +627,32 @@ async def process_section_with_documents(
}
# Create the initial state with db_session
- state = {"db_session": db_session}
+ sub_state = {"db_session": db_session}
# Invoke the sub-section writer graph
print(f"Invoking sub_section_writer for: {section_title}")
- result = await sub_section_writer_graph.ainvoke(state, config)
+ if state and state.streaming_service and writer:
+ state.streaming_service.only_update_terminal(f"Analyzing information and drafting content for section: {section_title}")
+ writer({"yeild_value": state.streaming_service._format_annotations()})
+
+ result = await sub_section_writer_graph.ainvoke(sub_state, config)
# Return the final answer from the sub_section_writer
final_answer = result.get("final_answer", "No content was generated for this section.")
+
+ # Send section content update via streaming if available
+ if state and state.streaming_service and writer:
+ state.streaming_service.only_update_terminal(f"Completed writing section: {section_title}")
+ writer({"yeild_value": state.streaming_service._format_annotations()})
+
return final_answer
except Exception as e:
print(f"Error processing section '{section_title}': {str(e)}")
+
+ # Send error update via streaming if available
+ if state and state.streaming_service and writer:
+ state.streaming_service.only_update_terminal(f"Error processing section '{section_title}': {str(e)}", "error")
+ writer({"yeild_value": state.streaming_service._format_annotations()})
+
return f"Error processing section: {section_title}. Details: {str(e)}"
diff --git a/surfsense_backend/app/agents/researcher/state.py b/surfsense_backend/app/agents/researcher/state.py
index 483e96a..dd36163 100644
--- a/surfsense_backend/app/agents/researcher/state.py
+++ b/surfsense_backend/app/agents/researcher/state.py
@@ -3,10 +3,9 @@
from __future__ import annotations
from dataclasses import dataclass, field
-from typing import List, Optional, Any, Dict, Annotated
-from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
-from langchain_core.messages import BaseMessage, HumanMessage
-from pydantic import BaseModel
+from typing import Optional, Any
+from sqlalchemy.ext.asyncio import AsyncSession
+from app.utils.streaming_service import StreamingService
@dataclass
class State:
@@ -18,7 +17,9 @@ class State:
"""
# Runtime context (not part of actual graph state)
db_session: AsyncSession
- engine: Optional[AsyncEngine] = None
+
+ # Streaming service
+ streaming_service: StreamingService
# Intermediate state - populated during workflow
# Using field to explicitly mark as part of state
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py
index fbde94d..9e1ca32 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py
@@ -15,6 +15,7 @@ class Configuration:
# Input parameters provided at invocation
sub_section_title: str
sub_section_questions: List[str]
+ user_query: str
relevant_documents: List[Any] # Documents provided directly to the agent
user_id: str
search_space_id: int
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py
index af807e3..5b11141 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py
@@ -119,6 +119,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Create the query that uses the section title and questions
section_title = configuration.sub_section_title
sub_section_questions = configuration.sub_section_questions
+ user_query = configuration.user_query # Get the original user query
documents_text = "\n".join(formatted_documents)
# Format the questions as bullet points for clarity
@@ -126,17 +127,16 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Construct a clear, structured query for the LLM
human_message_content = f"""
- Please write a comprehensive answer for the title:
+ Now user's query is:
+