diff --git a/surfsense_backend/app/agents/__init__.py b/surfsense_backend/app/agents/__init__.py
index e69de29..944afeb 100644
--- a/surfsense_backend/app/agents/__init__.py
+++ b/surfsense_backend/app/agents/__init__.py
@@ -0,0 +1 @@
+"""This is upcoming research agent. Work in progress."""
\ No newline at end of file
diff --git a/surfsense_backend/app/agents/researcher/configuration.py b/surfsense_backend/app/agents/researcher/configuration.py
new file mode 100644
index 0000000..8ba3849
--- /dev/null
+++ b/surfsense_backend/app/agents/researcher/configuration.py
@@ -0,0 +1,30 @@
+"""Define the configurable parameters for the agent."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, fields
+from typing import Optional, List, Any
+
+from langchain_core.runnables import RunnableConfig
+
+
+@dataclass(kw_only=True)
+class Configuration:
+ """The configuration for the agent."""
+
+ # Input parameters provided at invocation
+ user_query: str
+ num_sections: int
+ connectors_to_search: List[str]
+ user_id: str
+ search_space_id: int
+
+
+ @classmethod
+ def from_runnable_config(
+ cls, config: Optional[RunnableConfig] = None
+ ) -> Configuration:
+ """Create a Configuration instance from a RunnableConfig object."""
+ configurable = (config.get("configurable") or {}) if config else {}
+ _fields = {f.name for f in fields(cls) if f.init}
+ return cls(**{k: v for k, v in configurable.items() if k in _fields})
diff --git a/surfsense_backend/app/agents/researcher/graph.py b/surfsense_backend/app/agents/researcher/graph.py
new file mode 100644
index 0000000..31835da
--- /dev/null
+++ b/surfsense_backend/app/agents/researcher/graph.py
@@ -0,0 +1,43 @@
+from langgraph.graph import StateGraph
+from .state import State
+from .nodes import write_answer_outline, process_sections
+from .configuration import Configuration
+from typing import TypedDict, List, Dict, Any, Optional
+
+# Define what keys are in our state dict
+class GraphState(TypedDict):
+ # Intermediate data produced during workflow
+ answer_outline: Optional[Any]
+ # Final output
+ final_written_report: Optional[str]
+
+def build_graph():
+ """
+ Build and return the LangGraph workflow.
+
+ This function constructs the researcher agent graph with proper state management
+ and node connections following LangGraph best practices.
+
+ Returns:
+ A compiled LangGraph workflow
+ """
+ # Define a new graph with state class
+ workflow = StateGraph(State, config_schema=Configuration)
+
+ # Add nodes to the graph
+ workflow.add_node("write_answer_outline", write_answer_outline)
+ workflow.add_node("process_sections", process_sections)
+
+ # Define the edges - create a linear flow
+ workflow.add_edge("__start__", "write_answer_outline")
+ workflow.add_edge("write_answer_outline", "process_sections")
+ workflow.add_edge("process_sections", "__end__")
+
+ # Compile the workflow into an executable graph
+ graph = workflow.compile()
+ graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
+
+ return graph
+
+# Compile the graph once when the module is loaded
+graph = build_graph()
diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py
new file mode 100644
index 0000000..c88ad33
--- /dev/null
+++ b/surfsense_backend/app/agents/researcher/nodes.py
@@ -0,0 +1,410 @@
+from .configuration import Configuration
+from langchain_core.runnables import RunnableConfig
+from .state import State
+from typing import Any, Dict, List
+from app.config import config as app_config
+from .prompts import get_answer_outline_system_prompt
+from langchain_core.messages import HumanMessage, SystemMessage
+from pydantic import BaseModel, Field
+import json
+import asyncio
+from .sub_section_writer.graph import graph as sub_section_writer_graph
+from app.utils.connector_service import ConnectorService
+from sqlalchemy.ext.asyncio import AsyncSession
+
+class Section(BaseModel):
+ """A section in the answer outline."""
+ section_id: int = Field(..., description="The zero-based index of the section")
+ section_title: str = Field(..., description="The title of the section")
+ questions: List[str] = Field(..., description="Questions to research for this section")
+
+class AnswerOutline(BaseModel):
+ """The complete answer outline with all sections."""
+ answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
+
+async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str, Any]:
+ """
+ Create a structured answer outline based on the user query.
+
+ This node takes the user query and number of sections from the configuration and uses
+ an LLM to generate a comprehensive outline with logical sections and research questions
+ for each section.
+
+ Returns:
+ Dict containing the answer outline in the "answer_outline" key for state update.
+ """
+
+ # Get configuration from runnable config
+ configuration = Configuration.from_runnable_config(config)
+ user_query = configuration.user_query
+ num_sections = configuration.num_sections
+
+ # Initialize LLM
+ llm = app_config.strategic_llm_instance
+
+ # Create the human message content
+ human_message_content = f"""
+ Now Please create an answer outline for the following query:
+
+ User Query: {user_query}
+ Number of Sections: {num_sections}
+
+ Remember to format your response as valid JSON exactly matching this structure:
+ {{
+ "answer_outline": [
+ {{
+ "section_id": 0,
+ "section_title": "Section Title",
+ "questions": [
+ "Question 1 to research for this section",
+ "Question 2 to research for this section"
+ ]
+ }}
+ ]
+ }}
+
+ Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation.
+ """
+
+ # Create messages for the LLM
+ messages = [
+ SystemMessage(content=get_answer_outline_system_prompt()),
+ HumanMessage(content=human_message_content)
+ ]
+
+ # Call the LLM directly without using structured output
+ response = await llm.ainvoke(messages)
+
+ # Parse the JSON response manually
+ try:
+ # Extract JSON content from the response
+ content = response.content
+
+ # Find the JSON in the content (handle case where LLM might add additional text)
+ json_start = content.find('{')
+ json_end = content.rfind('}') + 1
+ if json_start >= 0 and json_end > json_start:
+ json_str = content[json_start:json_end]
+
+ # Parse the JSON string
+ parsed_data = json.loads(json_str)
+
+ # Convert to Pydantic model
+ answer_outline = AnswerOutline(**parsed_data)
+
+ print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections")
+
+ # Return state update
+ return {"answer_outline": answer_outline}
+ else:
+ # If JSON structure not found, raise a clear error
+ raise ValueError(f"Could not find valid JSON in LLM response. Raw response: {content}")
+
+ except (json.JSONDecodeError, ValueError) as e:
+ # Log the error and re-raise it
+ print(f"Error parsing LLM response: {str(e)}")
+ print(f"Raw response: {response.content}")
+ raise
+
+async def fetch_relevant_documents(
+ research_questions: List[str],
+ user_id: str,
+ search_space_id: int,
+ db_session: AsyncSession,
+ connectors_to_search: List[str],
+ top_k: int = 5
+) -> List[Dict[str, Any]]:
+ """
+ Fetch relevant documents for research questions using the provided connectors.
+
+ Args:
+ section_title: The title of the section being researched
+ research_questions: List of research questions to find documents for
+ user_id: The user ID
+ search_space_id: The search space ID
+ db_session: The database session
+ connectors_to_search: List of connectors to search
+ top_k: Number of top results to retrieve per connector per question
+
+ Returns:
+ List of relevant documents
+ """
+ # Initialize services
+ connector_service = ConnectorService(db_session)
+
+ all_raw_documents = [] # Store all raw documents before reranking
+
+ for user_query in research_questions:
+ # Use original research question as the query
+ reformulated_query = user_query
+
+ # Process each selected connector
+ for connector in connectors_to_search:
+ try:
+ if connector == "YOUTUBE_VIDEO":
+ _, youtube_chunks = await connector_service.search_youtube(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(youtube_chunks)
+
+ elif connector == "EXTENSION":
+ _, extension_chunks = await connector_service.search_extension(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(extension_chunks)
+
+ elif connector == "CRAWLED_URL":
+ _, crawled_urls_chunks = await connector_service.search_crawled_urls(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(crawled_urls_chunks)
+
+ elif connector == "FILE":
+ _, files_chunks = await connector_service.search_files(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(files_chunks)
+
+ elif connector == "TAVILY_API":
+ _, tavily_chunks = await connector_service.search_tavily(
+ user_query=reformulated_query,
+ user_id=user_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(tavily_chunks)
+
+ elif connector == "SLACK_CONNECTOR":
+ _, slack_chunks = await connector_service.search_slack(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(slack_chunks)
+
+ elif connector == "NOTION_CONNECTOR":
+ _, notion_chunks = await connector_service.search_notion(
+ user_query=reformulated_query,
+ user_id=user_id,
+ search_space_id=search_space_id,
+ top_k=top_k
+ )
+ all_raw_documents.extend(notion_chunks)
+ except Exception as e:
+ print(f"Error searching connector {connector}: {str(e)}")
+ # Continue with other connectors on error
+ continue
+
+ # Deduplicate documents based on chunk_id or content
+ seen_chunk_ids = set()
+ seen_content_hashes = set()
+ deduplicated_docs = []
+
+ for doc in all_raw_documents:
+ chunk_id = doc.get("chunk_id")
+ content = doc.get("content", "")
+ content_hash = hash(content)
+
+ # Skip if we've seen this chunk_id or content before
+ if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes:
+ continue
+
+ # Add to our tracking sets and keep this document
+ if chunk_id:
+ seen_chunk_ids.add(chunk_id)
+ seen_content_hashes.add(content_hash)
+ deduplicated_docs.append(doc)
+
+ return deduplicated_docs
+
+
+
+async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
+ """
+ Process all sections in parallel and combine the results.
+
+ This node takes the answer outline from the previous step, fetches relevant documents
+ for all questions across all sections once, and then processes each section in parallel
+ using the sub_section_writer graph with the shared document pool.
+
+ Returns:
+ Dict containing the final written report in the "final_written_report" key.
+ """
+ # Get configuration and answer outline from state
+ configuration = Configuration.from_runnable_config(config)
+ answer_outline = state.answer_outline
+
+ print(f"Processing sections from outline: {answer_outline is not None}")
+
+ if not answer_outline:
+ return {
+ "final_written_report": "No answer outline was provided. Cannot generate final report."
+ }
+
+ # Create session maker from the engine or directly use the session
+ from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.orm import sessionmaker
+
+ # Use the engine if available, otherwise create a new session for each task
+ if state.engine:
+ session_maker = sessionmaker(
+ state.engine, class_=AsyncSession, expire_on_commit=False
+ )
+ else:
+ # Fallback to using the same session (less optimal but will work)
+ print("Warning: No engine available. Using same session for all tasks.")
+ # Create a mock session maker that returns the same session
+ async def mock_session_maker():
+ class ContextManager:
+ async def __aenter__(self):
+ return state.db_session
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ pass
+ return ContextManager()
+ session_maker = mock_session_maker
+
+ # Collect all questions from all sections
+ all_questions = []
+ for section in answer_outline.answer_outline:
+ all_questions.extend(section.questions)
+
+ print(f"Collected {len(all_questions)} questions from all sections")
+
+ # Fetch relevant documents once for all questions
+ relevant_documents = []
+ async with session_maker() as db_session:
+
+ try:
+ relevant_documents = await fetch_relevant_documents(
+ research_questions=all_questions,
+ user_id=configuration.user_id,
+ search_space_id=configuration.search_space_id,
+ db_session=db_session,
+ connectors_to_search=configuration.connectors_to_search
+ )
+ except Exception as e:
+ print(f"Error fetching relevant documents: {str(e)}")
+ # Log the error and continue with an empty list of documents
+ # This allows the process to continue, but the report might lack information
+ relevant_documents = []
+ # Consider adding more robust error handling or reporting if needed
+
+ print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
+
+ # Create tasks to process each section in parallel with the same document set
+ section_tasks = []
+ for section in answer_outline.answer_outline:
+ section_tasks.append(
+ process_section_with_documents(
+ section_title=section.section_title,
+ section_questions=section.questions,
+ user_id=configuration.user_id,
+ search_space_id=configuration.search_space_id,
+ session_maker=session_maker,
+ relevant_documents=relevant_documents
+ )
+ )
+
+ # Run all section processing tasks in parallel
+ print(f"Running {len(section_tasks)} section processing tasks in parallel")
+ section_results = await asyncio.gather(*section_tasks, return_exceptions=True)
+
+ # Handle any exceptions in the results
+ processed_results = []
+ for i, result in enumerate(section_results):
+ if isinstance(result, Exception):
+ section_title = answer_outline.answer_outline[i].section_title
+ error_message = f"Error processing section '{section_title}': {str(result)}"
+ print(error_message)
+ processed_results.append(error_message)
+ else:
+ processed_results.append(result)
+
+ # Combine the results into a final report with section titles
+ final_report = []
+ for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
+ # Skip adding the section header since the content already contains the title
+ final_report.append(content)
+ final_report.append("\n") # Add spacing between sections
+
+ # Join all sections with newlines
+ final_written_report = "\n".join(final_report)
+ print(f"Generated final report with {len(final_report)} parts")
+
+ return {
+ "final_written_report": final_written_report
+ }
+
+async def process_section_with_documents(
+ section_title: str,
+ section_questions: List[str],
+ user_id: str,
+ search_space_id: int,
+ session_maker,
+ relevant_documents: List[Dict[str, Any]]
+) -> str:
+ """
+ Process a single section using pre-fetched documents.
+
+ Args:
+ section_title: The title of the section
+ section_questions: List of research questions for this section
+ user_id: The user ID
+ search_space_id: The search space ID
+ session_maker: Factory for creating new database sessions
+ relevant_documents: Pre-fetched documents to use for this section
+
+ Returns:
+ The written section content
+ """
+ try:
+ # Create a new database session for this section
+ async with session_maker() as db_session:
+ # Use the provided documents
+ documents_to_use = relevant_documents
+
+ # Fallback if no documents found
+ if not documents_to_use:
+ print(f"No relevant documents found for section: {section_title}")
+ documents_to_use = [
+ {"content": f"No specific information was found for: {question}"}
+ for question in section_questions
+ ]
+
+ # Call the sub_section_writer graph with the appropriate config
+ config = {
+ "configurable": {
+ "sub_section_title": section_title,
+ "sub_section_questions": section_questions,
+ "relevant_documents": documents_to_use,
+ "user_id": user_id,
+ "search_space_id": search_space_id
+ }
+ }
+
+ # Create the initial state with db_session
+ state = {"db_session": db_session}
+
+ # Invoke the sub-section writer graph
+ print(f"Invoking sub_section_writer for: {section_title}")
+ result = await sub_section_writer_graph.ainvoke(state, config)
+
+ # Return the final answer from the sub_section_writer
+ final_answer = result.get("final_answer", "No content was generated for this section.")
+ return final_answer
+ except Exception as e:
+ print(f"Error processing section '{section_title}': {str(e)}")
+ return f"Error processing section: {section_title}. Details: {str(e)}"
+
diff --git a/surfsense_backend/app/agents/researcher/prompts.py b/surfsense_backend/app/agents/researcher/prompts.py
new file mode 100644
index 0000000..3a2a3f7
--- /dev/null
+++ b/surfsense_backend/app/agents/researcher/prompts.py
@@ -0,0 +1,92 @@
+import datetime
+
+
+def get_answer_outline_system_prompt():
+ return f"""
+Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
+
+You are an expert research assistant specializing in structuring information. Your task is to create a detailed and logical research outline based on the user's query. This outline will serve as the blueprint for generating a comprehensive research report.
+
+
+- user_query (string): The main question or topic the user wants researched. This guides the entire outline creation process.
+- num_sections (integer): The target number of distinct sections the final research report should have. This helps control the granularity and structure of the outline.
+
+
+
+A JSON object with the following structure:
+{{
+ "answer_outline": [
+ {{
+ "section_id": 0,
+ "section_title": "Section Title",
+ "questions": [
+ "Question 1 to research for this section",
+ "Question 2 to research for this section"
+ ]
+ }}
+ ]
+}}
+
+
+
+1. **Deconstruct the `user_query`:** Identify the key concepts, entities, and the core information requested by the user.
+2. **Determine Section Themes:** Based on the analysis and the requested `num_sections`, divide the topic into distinct, logical themes or sub-topics. Each theme will become a section. Ensure these themes collectively address the `user_query` comprehensively.
+3. **Develop Sections:** For *each* of the `num_sections`:
+ * **Assign `section_id`:** Start with 0 and increment sequentially for each section.
+ * **Craft `section_title`:** Write a concise, descriptive title that clearly defines the scope and focus of the section's theme.
+ * **Formulate Research `questions`:** Generate 2 to 5 specific, targeted research questions for this section. These questions must:
+ * Directly relate to the `section_title` and explore its key aspects.
+ * Be answerable through focused research (e.g., searching documents, databases, or knowledge bases).
+ * Be distinct from each other and from questions in other sections. Avoid redundancy.
+ * Collectively guide the gathering of information needed to fully address the section's theme.
+4. **Ensure Logical Flow:** Arrange the sections in a coherent and intuitive sequence. Consider structures like:
+ * General background -> Specific details -> Analysis/Comparison -> Applications/Implications
+ * Problem definition -> Proposed solutions -> Evaluation -> Conclusion
+ * Chronological progression
+5. **Verify Completeness and Cohesion:** Review the entire outline (`section_titles` and `questions`) to confirm that:
+ * All sections together provide a complete and well-structured answer to the original `user_query`.
+ * There are no significant overlaps or gaps in coverage between sections.
+6. **Adhere Strictly to Output Format:** Ensure the final output is a valid JSON object matching the specified structure exactly, including correct field names (`answer_outline`, `section_id`, `section_title`, `questions`) and data types.
+
+
+
+User Query: "What are the health benefits of meditation?"
+Number of Sections: 3
+
+{{
+ "answer_outline": [
+ {{
+ "section_id": 0,
+ "section_title": "Physical Health Benefits of Meditation",
+ "questions": [
+ "What physiological changes occur in the body during meditation?",
+ "How does regular meditation affect blood pressure and heart health?",
+ "What impact does meditation have on inflammation and immune function?",
+ "Can meditation help with pain management, and if so, how?"
+ ]
+ }},
+ {{
+ "section_id": 1,
+ "section_title": "Mental Health Benefits of Meditation",
+ "questions": [
+ "How does meditation affect stress and anxiety levels?",
+ "What changes in brain structure or function have been observed in meditation practitioners?",
+ "Can meditation help with depression and mood disorders?",
+ "What is the relationship between meditation and cognitive function?"
+ ]
+ }},
+ {{
+ "section_id": 2,
+ "section_title": "Best Meditation Practices for Maximum Benefits",
+ "questions": [
+ "What are the most effective meditation techniques for beginners?",
+ "How long and how frequently should one meditate to see benefits?",
+ "Are there specific meditation approaches best suited for particular health goals?",
+ "What common obstacles prevent people from experiencing meditation benefits?"
+ ]
+ }}
+ ]
+}}
+
+
+"""
\ No newline at end of file
diff --git a/surfsense_backend/app/agents/researcher/state.py b/surfsense_backend/app/agents/researcher/state.py
new file mode 100644
index 0000000..483e96a
--- /dev/null
+++ b/surfsense_backend/app/agents/researcher/state.py
@@ -0,0 +1,30 @@
+"""Define the state structures for the agent."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import List, Optional, Any, Dict, Annotated
+from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
+from langchain_core.messages import BaseMessage, HumanMessage
+from pydantic import BaseModel
+
+@dataclass
+class State:
+ """Defines the dynamic state for the agent during execution.
+
+ This state tracks the database session and the outputs generated by the agent's nodes.
+ See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
+ for more information.
+ """
+ # Runtime context (not part of actual graph state)
+ db_session: AsyncSession
+ engine: Optional[AsyncEngine] = None
+
+ # Intermediate state - populated during workflow
+ # Using field to explicitly mark as part of state
+ answer_outline: Optional[Any] = field(default=None)
+
+ # OUTPUT: Populated by agent nodes
+ # Using field to explicitly mark as part of state
+ final_written_report: Optional[str] = field(default=None)
+
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py
index b34090e..fbde94d 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py
@@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass, fields
-from typing import Optional, List
+from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig
@@ -14,11 +14,10 @@ class Configuration:
# Input parameters provided at invocation
sub_section_title: str
- sub_questions: List[str]
- connectors_to_search: List[str]
+ sub_section_questions: List[str]
+ relevant_documents: List[Any] # Documents provided directly to the agent
user_id: str
search_space_id: int
- top_k: int = 20 # Default top_k value
@classmethod
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py b/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py
index e250cde..5a5a5ba 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/graph.py
@@ -1,20 +1,18 @@
from langgraph.graph import StateGraph
from .state import State
-from .nodes import fetch_relevant_documents, write_sub_section
+from .nodes import write_sub_section, rerank_documents
from .configuration import Configuration
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)
# Add the nodes to the graph
-workflow.add_node("fetch_relevant_documents", fetch_relevant_documents)
+workflow.add_node("rerank_documents", rerank_documents)
workflow.add_node("write_sub_section", write_sub_section)
-# Entry point
-workflow.add_edge("__start__", "fetch_relevant_documents")
-# Connect fetch_relevant_documents to write_sub_section
-workflow.add_edge("fetch_relevant_documents", "write_sub_section")
-# Exit point
+# Connect the nodes
+workflow.add_edge("__start__", "rerank_documents")
+workflow.add_edge("rerank_documents", "write_sub_section")
workflow.add_edge("write_sub_section", "__end__")
# Compile the workflow into an executable graph
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py
index 52fa877..af807e3 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py
@@ -2,171 +2,79 @@ from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict
-from app.utils.connector_service import ConnectorService
-from app.utils.reranker_service import RerankerService
from app.config import config as app_config
-from .prompts import citation_system_prompt
+from .prompts import get_citation_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
-async def fetch_relevant_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
+async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""
- Fetch relevant documents for the sub-section using specified connectors.
+ Rerank the documents based on relevance to the sub-section title.
- This node retrieves documents from various data sources based on the sub-questions
- derived from the sub-section title. It searches across all selected connectors
- (YouTube, Extension, Crawled URLs, Files, Tavily API, Slack, Notion) and reranks
- the results to provide the most relevant information for the agent workflow.
+ This node takes the relevant documents provided in the configuration,
+ reranks them using the reranker service based on the sub-section title,
+ and updates the state with the reranked documents.
Returns:
- Dict containing the reranked documents in the "relevant_documents_fetched" key.
+ Dict containing the reranked documents.
"""
- # Get configuration
+ # Get configuration and relevant documents
configuration = Configuration.from_runnable_config(config)
-
- # Extract state parameters
- db_session = state.db_session
-
- # Extract config parameters
- user_id = configuration.user_id
- search_space_id = configuration.search_space_id
- TOP_K = configuration.top_k
-
- # Initialize services
- connector_service = ConnectorService(db_session)
- reranker_service = RerankerService.get_reranker_instance(app_config)
+ documents = configuration.relevant_documents
+ sub_section_questions = configuration.sub_section_questions
- all_raw_documents = [] # Store all raw documents before reranking
+ # If no documents were provided, return empty list
+ if not documents or len(documents) == 0:
+ return {
+ "reranked_documents": []
+ }
- for user_query in configuration.sub_questions:
- # Reformulate query (optional, consider if needed for each sub-question)
- # reformulated_query = await QueryService.reformulate_query(user_query)
- reformulated_query = user_query # Using original sub-question for now
-
- # Process each selected connector
- for connector in configuration.connectors_to_search:
- if connector == "YOUTUBE_VIDEO":
- _, youtube_chunks = await connector_service.search_youtube(
- user_query=reformulated_query,
- user_id=user_id,
- search_space_id=search_space_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(youtube_chunks)
-
- elif connector == "EXTENSION":
- _, extension_chunks = await connector_service.search_extension(
- user_query=reformulated_query,
- user_id=user_id,
- search_space_id=search_space_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(extension_chunks)
-
- elif connector == "CRAWLED_URL":
- _, crawled_urls_chunks = await connector_service.search_crawled_urls(
- user_query=reformulated_query,
- user_id=user_id,
- search_space_id=search_space_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(crawled_urls_chunks)
-
- elif connector == "FILE":
- _, files_chunks = await connector_service.search_files(
- user_query=reformulated_query,
- user_id=user_id,
- search_space_id=search_space_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(files_chunks)
-
- elif connector == "TAVILY_API":
- _, tavily_chunks = await connector_service.search_tavily(
- user_query=reformulated_query,
- user_id=user_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(tavily_chunks)
-
- elif connector == "SLACK_CONNECTOR":
- _, slack_chunks = await connector_service.search_slack(
- user_query=reformulated_query,
- user_id=user_id,
- search_space_id=search_space_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(slack_chunks)
-
- elif connector == "NOTION_CONNECTOR":
- _, notion_chunks = await connector_service.search_notion(
- user_query=reformulated_query,
- user_id=user_id,
- search_space_id=search_space_id,
- top_k=TOP_K
- )
- all_raw_documents.extend(notion_chunks)
+ # Get reranker service from app config
+ reranker_service = getattr(app_config, "reranker_service", None)
- # If we have documents and a reranker is available, rerank them
- # Deduplicate documents based on chunk_id or content to avoid processing duplicates
- seen_chunk_ids = set()
- seen_content_hashes = set()
- deduplicated_docs = []
+ # Use documents as is if no reranker service is available
+ reranked_docs = documents
- for doc in all_raw_documents:
- chunk_id = doc.get("chunk_id")
- content = doc.get("content", "")
- content_hash = hash(content)
-
- # Skip if we've seen this chunk_id or content before
- if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes:
- continue
+ if reranker_service:
+ try:
+ # Use the sub-section questions for reranking context
+ rerank_query = "\n".join(sub_section_questions)
- # Add to our tracking sets and keep this document
- if chunk_id:
- seen_chunk_ids.add(chunk_id)
- seen_content_hashes.add(content_hash)
- deduplicated_docs.append(doc)
+ # Convert documents to format expected by reranker if needed
+ reranker_input_docs = [
+ {
+ "chunk_id": doc.get("chunk_id", f"chunk_{i}"),
+ "content": doc.get("content", ""),
+ "score": doc.get("score", 0.0),
+ "document": {
+ "id": doc.get("document", {}).get("id", ""),
+ "title": doc.get("document", {}).get("title", ""),
+ "document_type": doc.get("document", {}).get("document_type", ""),
+ "metadata": doc.get("document", {}).get("metadata", {})
+ }
+ } for i, doc in enumerate(documents)
+ ]
+
+ # Rerank documents using the section title
+ reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
+
+ # Sort by score in descending order
+ reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
+
+ print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}")
+ except Exception as e:
+ print(f"Error during reranking: {str(e)}")
+ # Use original docs if reranking fails
- # Use deduplicated documents for reranking
- reranked_docs = deduplicated_docs
- if deduplicated_docs and reranker_service:
- # Use the main sub_section_title for reranking context
- rerank_query = configuration.sub_section_title
-
- # Convert documents to format expected by reranker
- reranker_input_docs = [
- {
- "chunk_id": doc.get("chunk_id", f"chunk_{i}"),
- "content": doc.get("content", ""),
- "score": doc.get("score", 0.0),
- "document": {
- "id": doc.get("document", {}).get("id", ""),
- "title": doc.get("document", {}).get("title", ""),
- "document_type": doc.get("document", {}).get("document_type", ""),
- "metadata": doc.get("document", {}).get("metadata", {})
- }
- } for i, doc in enumerate(deduplicated_docs)
- ]
-
- # Rerank documents using the main title query
- reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
-
- # Sort by score in descending order
- reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
-
- # Update state with fetched documents
return {
- "relevant_documents_fetched": reranked_docs
+ "reranked_documents": reranked_docs
}
-
-
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""
- Write the sub-section using the fetched documents.
+ Write the sub-section using the provided documents.
- This node takes the relevant documents fetched in the previous node and uses
- an LLM to generate a comprehensive answer to the sub-section questions with
+ This node takes the relevant documents provided in the configuration and uses
+ an LLM to generate a comprehensive answer to the sub-section title with
proper citations. The citations follow IEEE format using source IDs from the
documents.
@@ -174,17 +82,17 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
Dict containing the final answer in the "final_answer" key.
"""
- # Get configuration and relevant documents
+ # Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
- documents = state.relevant_documents_fetched
+ documents = configuration.relevant_documents
# Initialize LLM
llm = app_config.fast_llm_instance
- # If no documents were found, return a message indicating this
+ # If no documents were provided, return a message indicating this
if not documents or len(documents) == 0:
return {
- "final_answer": "No relevant documents were found to answer this question. Please try refining your search or providing more specific questions."
+ "final_answer": "No relevant documents were provided to answer this question. Please provide documents or try a different approach."
}
# Prepare documents for citation formatting
@@ -208,18 +116,25 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
"""
formatted_documents.append(formatted_doc)
- # Create the query that combines the section title and questions
- # section_title = configuration.sub_section_title
- questions = "\n".join([f"- {q}" for q in configuration.sub_questions])
+ # Create the query that uses the section title and questions
+ section_title = configuration.sub_section_title
+ sub_section_questions = configuration.sub_section_questions
documents_text = "\n".join(formatted_documents)
+ # Format the questions as bullet points for clarity
+ questions_text = "\n".join([f"- {question}" for question in sub_section_questions])
+
# Construct a clear, structured query for the LLM
human_message_content = f"""
Please write a comprehensive answer for the title:
- Address the following questions:
+
+ {section_title}
+
+
+ Focus on answering these specific questions related to the title:
- {questions}
+ {questions_text}
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
@@ -230,7 +145,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Create messages for the LLM
messages = [
- SystemMessage(content=citation_system_prompt),
+ SystemMessage(content=get_citation_system_prompt()),
HumanMessage(content=human_message_content)
]
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py b/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py
index cc3ad61..18a91eb 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/prompts.py
@@ -1,4 +1,9 @@
-citation_system_prompt = f"""
+import datetime
+
+
+def get_citation_system_prompt():
+ return f"""
+Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/state.py b/surfsense_backend/app/agents/researcher/sub_section_writer/state.py
index fb5b08e..b33abe6 100644
--- a/surfsense_backend/app/agents/researcher/sub_section_writer/state.py
+++ b/surfsense_backend/app/agents/researcher/sub_section_writer/state.py
@@ -18,6 +18,6 @@ class State:
db_session: AsyncSession
# OUTPUT: Populated by agent nodes
- relevant_documents_fetched: Optional[List[Any]] = None
+ reranked_documents: Optional[List[Any]] = None
final_answer: Optional[str] = None
diff --git a/surfsense_backend/app/agents/researcher/test_researcher.py b/surfsense_backend/app/agents/researcher/test_researcher.py
new file mode 100644
index 0000000..15c993e
--- /dev/null
+++ b/surfsense_backend/app/agents/researcher/test_researcher.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python3
+"""
+Test script for the Researcher LangGraph agent.
+
+This script demonstrates how to invoke the researcher agent with a sample query.
+Run this script directly from VSCode using the "Run Python File" button or
+right-click and select "Run Python File in Terminal".
+
+Before running:
+1. Make sure your Python environment has all required dependencies
+2. Create a .env file with any required API keys
+3. Ensure database connection is properly configured
+"""
+
+import asyncio
+import os
+import sys
+from pathlib import Path
+
+# Add project root to Python path so that 'app' can be found as a module
+# Get the absolute path to the surfsense_backend directory which contains the app module
+project_root = str(Path(__file__).resolve().parents[3]) # Go up 3 levels from the script: app/agents/researcher -> app/agents -> app -> surfsense_backend
+print(f"Adding to Python path: {project_root}")
+sys.path.insert(0, project_root)
+
+# Now import the modules after fixing the path
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
+from sqlalchemy.orm import sessionmaker
+from dotenv import load_dotenv
+
+# These imports should now work with the correct path
+from app.agents.researcher.graph import graph
+from app.agents.researcher.state import State
+
+
+# Load environment variables
+load_dotenv()
+
+# Database connection string - use a test database or mock
+DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
+
+# Create async engine and session
+engine = create_async_engine(DATABASE_URL)
+async_session_maker = sessionmaker(
+ engine, class_=AsyncSession, expire_on_commit=False
+)
+
+async def run_test():
+ """Run a test of the researcher agent."""
+ print("Starting researcher agent test...")
+
+ # Create a database session
+ async with async_session_maker() as db_session:
+ # Sample configuration
+ config = {
+ "configurable": {
+ "user_query": "What are the best clash royale decks recommended by Surgical Goblin?",
+ "num_sections": 1,
+ "connectors_to_search": [
+ "YOUTUBE_VIDEO",
+ ],
+ "user_id": "d6ac2187-7407-4664-8734-af09926d161e",
+ "search_space_id": 2
+ }
+ }
+
+ try:
+ # Initialize state with database session and engine
+ initial_state = State(db_session=db_session, engine=engine)
+
+ # Run the graph directly
+ print("\nRunning the complete researcher workflow...")
+ result = await graph.ainvoke(initial_state, config)
+
+ # Extract the answer outline for display
+ if "answer_outline" in result and result["answer_outline"]:
+ print(f"\nGenerated answer outline with {len(result['answer_outline'].answer_outline)} sections")
+
+ # Print the outline
+ print("\nGenerated Answer Outline:")
+ for section in result["answer_outline"].answer_outline:
+ print(f"\nSection {section.section_id}: {section.section_title}")
+ print("Research Questions:")
+ for q in section.questions:
+ print(f" - {q}")
+
+ # Check if we got a final report
+ if "final_written_report" in result and result["final_written_report"]:
+ final_report = result["final_written_report"]
+ print("\nFinal Research Report generated successfully!")
+ print(f"Report length: {len(final_report)} characters")
+
+ # Display the final report
+ print("\n==== FINAL RESEARCH REPORT ====\n")
+ print(final_report)
+ else:
+ print("\nNo final report was generated.")
+ print(f"Available result keys: {list(result.keys())}")
+
+ return result
+
+ except Exception as e:
+ print(f"Error running researcher agent: {str(e)}")
+ import traceback
+ traceback.print_exc()
+ raise
+
+async def main():
+ """Main entry point for the test script."""
+ try:
+ result = await run_test()
+ print("\nTest completed successfully.")
+ return result
+ except Exception as e:
+ print(f"\nTest failed with error: {e}")
+ import traceback
+ traceback.print_exc()
+ return None
+
+if __name__ == "__main__":
+ # Run the async test
+ result = asyncio.run(main())
+
+ # Keep terminal open if run directly in VSCode
+ if 'VSCODE_PID' in os.environ:
+ input("\nPress Enter to close this window...")
\ No newline at end of file
diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py
index f4226ed..82517a8 100644
--- a/surfsense_backend/app/config/__init__.py
+++ b/surfsense_backend/app/config/__init__.py
@@ -42,10 +42,8 @@ class Config:
# GPT Researcher
FAST_LLM = os.getenv("FAST_LLM")
- SMART_LLM = os.getenv("SMART_LLM")
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
fast_llm_instance = ChatLiteLLM(model=extract_model_name(FAST_LLM))
- smart_llm_instance = ChatLiteLLM(model=extract_model_name(SMART_LLM))
strategic_llm_instance = ChatLiteLLM(model=extract_model_name(STRATEGIC_LLM))