feat: Initial version of SurfSense own LangGraph Agent.

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-04-19 23:25:06 -07:00
parent 6e1a254fcd
commit 34300ead02
13 changed files with 884 additions and 167 deletions

View file

@ -0,0 +1 @@
"""This is upcoming research agent. Work in progress."""

View file

@ -0,0 +1,30 @@
"""Define the configurable parameters for the agent."""
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig
@dataclass(kw_only=True)
class Configuration:
"""The configuration for the agent."""
# Input parameters provided at invocation
user_query: str
num_sections: int
connectors_to_search: List[str]
user_id: str
search_space_id: int
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields})

View file

@ -0,0 +1,43 @@
from langgraph.graph import StateGraph
from .state import State
from .nodes import write_answer_outline, process_sections
from .configuration import Configuration
from typing import TypedDict, List, Dict, Any, Optional
# Define what keys are in our state dict
class GraphState(TypedDict):
# Intermediate data produced during workflow
answer_outline: Optional[Any]
# Final output
final_written_report: Optional[str]
def build_graph():
"""
Build and return the LangGraph workflow.
This function constructs the researcher agent graph with proper state management
and node connections following LangGraph best practices.
Returns:
A compiled LangGraph workflow
"""
# Define a new graph with state class
workflow = StateGraph(State, config_schema=Configuration)
# Add nodes to the graph
workflow.add_node("write_answer_outline", write_answer_outline)
workflow.add_node("process_sections", process_sections)
# Define the edges - create a linear flow
workflow.add_edge("__start__", "write_answer_outline")
workflow.add_edge("write_answer_outline", "process_sections")
workflow.add_edge("process_sections", "__end__")
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
return graph
# Compile the graph once when the module is loaded
graph = build_graph()

View file

@ -0,0 +1,476 @@
from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict, List
from app.config import config as app_config
from .prompts import answer_outline_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
import json
import asyncio
from .sub_section_writer.graph import graph as sub_section_writer_graph
from app.utils.connector_service import ConnectorService
from app.utils.reranker_service import RerankerService
from sqlalchemy.ext.asyncio import AsyncSession
import copy
class Section(BaseModel):
"""A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section")
section_title: str = Field(..., description="The title of the section")
questions: List[str] = Field(..., description="Questions to research for this section")
class AnswerOutline(BaseModel):
"""The complete answer outline with all sections."""
answer_outline: List[Section] = Field(..., description="List of sections in the answer outline")
async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""
Create a structured answer outline based on the user query.
This node takes the user query and number of sections from the configuration and uses
an LLM to generate a comprehensive outline with logical sections and research questions
for each section.
Returns:
Dict containing the answer outline in the "answer_outline" key for state update.
"""
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_query = configuration.user_query
num_sections = configuration.num_sections
# Initialize LLM
llm = app_config.strategic_llm_instance
# Create the human message content
human_message_content = f"""
Now Please create an answer outline for the following query:
User Query: {user_query}
Number of Sections: {num_sections}
Remember to format your response as valid JSON exactly matching this structure:
{{
"answer_outline": [
{{
"section_id": 0,
"section_title": "Section Title",
"questions": [
"Question 1 to research for this section",
"Question 2 to research for this section"
]
}}
]
}}
Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation.
"""
# Create messages for the LLM
messages = [
SystemMessage(content=answer_outline_system_prompt),
HumanMessage(content=human_message_content)
]
# Call the LLM directly without using structured output
response = await llm.ainvoke(messages)
# Parse the JSON response manually
try:
# Extract JSON content from the response
content = response.content
# Find the JSON in the content (handle case where LLM might add additional text)
json_start = content.find('{')
json_end = content.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = content[json_start:json_end]
# Parse the JSON string
parsed_data = json.loads(json_str)
# Convert to Pydantic model
answer_outline = AnswerOutline(**parsed_data)
print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections")
# Return state update
return {"answer_outline": answer_outline}
else:
# If JSON structure not found, raise a clear error
raise ValueError(f"Could not find valid JSON in LLM response. Raw response: {content}")
except (json.JSONDecodeError, ValueError) as e:
# Log the error and re-raise it
print(f"Error parsing LLM response: {str(e)}")
print(f"Raw response: {response.content}")
raise
async def fetch_relevant_documents(
research_questions: List[str],
user_id: str,
search_space_id: int,
db_session: AsyncSession,
connectors_to_search: List[str],
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Fetch relevant documents for research questions using the provided connectors.
Args:
section_title: The title of the section being researched
research_questions: List of research questions to find documents for
user_id: The user ID
search_space_id: The search space ID
db_session: The database session
connectors_to_search: List of connectors to search
top_k: Number of top results to retrieve per connector per question
Returns:
List of relevant documents
"""
# Initialize services
connector_service = ConnectorService(db_session)
reranker_service = RerankerService.get_reranker_instance(app_config)
all_raw_documents = [] # Store all raw documents before reranking
for user_query in research_questions:
# Use original research question as the query
reformulated_query = user_query
# Process each selected connector
for connector in connectors_to_search:
try:
if connector == "YOUTUBE_VIDEO":
_, youtube_chunks = await connector_service.search_youtube(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
all_raw_documents.extend(youtube_chunks)
elif connector == "EXTENSION":
_, extension_chunks = await connector_service.search_extension(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
all_raw_documents.extend(extension_chunks)
elif connector == "CRAWLED_URL":
_, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
all_raw_documents.extend(crawled_urls_chunks)
elif connector == "FILE":
_, files_chunks = await connector_service.search_files(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
all_raw_documents.extend(files_chunks)
elif connector == "TAVILY_API":
_, tavily_chunks = await connector_service.search_tavily(
user_query=reformulated_query,
user_id=user_id,
top_k=top_k
)
all_raw_documents.extend(tavily_chunks)
elif connector == "SLACK_CONNECTOR":
_, slack_chunks = await connector_service.search_slack(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
all_raw_documents.extend(slack_chunks)
elif connector == "NOTION_CONNECTOR":
_, notion_chunks = await connector_service.search_notion(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=top_k
)
all_raw_documents.extend(notion_chunks)
except Exception as e:
print(f"Error searching connector {connector}: {str(e)}")
# Continue with other connectors on error
continue
# Deduplicate documents based on chunk_id or content
seen_chunk_ids = set()
seen_content_hashes = set()
deduplicated_docs = []
for doc in all_raw_documents:
chunk_id = doc.get("chunk_id")
content = doc.get("content", "")
content_hash = hash(content)
# Skip if we've seen this chunk_id or content before
if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes:
continue
# Add to our tracking sets and keep this document
if chunk_id:
seen_chunk_ids.add(chunk_id)
seen_content_hashes.add(content_hash)
deduplicated_docs.append(doc)
return deduplicated_docs
async def process_section(
section_title: str,
user_id: str,
search_space_id: int,
session_maker,
research_questions: List[str],
connectors_to_search: List[str]
) -> str:
"""
Process a single section by sending it to the sub_section_writer graph.
Args:
section_title: The title of the section
user_id: The user ID
search_space_id: The search space ID
session_maker: Factory for creating new database sessions
research_questions: List of research questions for this section
connectors_to_search: List of connectors to search
Returns:
The written section content
"""
try:
# Create a new database session for this section
async with session_maker() as db_session:
# Fetch relevant documents using all research questions for this section
relevant_documents = await fetch_relevant_documents(
section_title=section_title,
research_questions=research_questions,
user_id=user_id,
search_space_id=search_space_id,
db_session=db_session,
connectors_to_search=connectors_to_search
)
# Fallback if no documents found
if not relevant_documents:
print(f"No relevant documents found for section: {section_title}")
relevant_documents = [
{
"content": f"No specific information was found for: {question}"
for question in research_questions
}
]
# Call the sub_section_writer graph with the appropriate config
config = {
"configurable": {
"sub_section_title": section_title,
"relevant_documents": relevant_documents,
"user_id": user_id,
"search_space_id": search_space_id
}
}
# Create the initial state with db_session
state = {"db_session": db_session}
# Invoke the sub-section writer graph
print(f"Invoking sub_section_writer for: {section_title}")
result = await sub_section_writer_graph.ainvoke(state, config)
# Return the final answer from the sub_section_writer
final_answer = result.get("final_answer", "No content was generated for this section.")
return final_answer
except Exception as e:
print(f"Error processing section '{section_title}': {str(e)}")
return f"Error processing section: {section_title}. Details: {str(e)}"
async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""
Process all sections in parallel and combine the results.
This node takes the answer outline from the previous step, fetches relevant documents
for all questions across all sections once, and then processes each section in parallel
using the sub_section_writer graph with the shared document pool.
Returns:
Dict containing the final written report in the "final_written_report" key.
"""
# Get configuration and answer outline from state
configuration = Configuration.from_runnable_config(config)
answer_outline = state.answer_outline
print(f"Processing sections from outline: {answer_outline is not None}")
if not answer_outline:
return {
"final_written_report": "No answer outline was provided. Cannot generate final report."
}
# Create session maker from the engine or directly use the session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
# Use the engine if available, otherwise create a new session for each task
if state.engine:
session_maker = sessionmaker(
state.engine, class_=AsyncSession, expire_on_commit=False
)
else:
# Fallback to using the same session (less optimal but will work)
print("Warning: No engine available. Using same session for all tasks.")
# Create a mock session maker that returns the same session
async def mock_session_maker():
class ContextManager:
async def __aenter__(self):
return state.db_session
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
return ContextManager()
session_maker = mock_session_maker
# Collect all questions from all sections
all_questions = []
for section in answer_outline.answer_outline:
all_questions.extend(section.questions)
print(f"Collected {len(all_questions)} questions from all sections")
# Fetch relevant documents once for all questions
relevant_documents = []
async with session_maker() as db_session:
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=db_session,
connectors_to_search=configuration.connectors_to_search
)
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
# Create tasks to process each section in parallel with the same document set
section_tasks = []
for section in answer_outline.answer_outline:
section_tasks.append(
process_section_with_documents(
section_title=section.section_title,
section_questions=section.questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
session_maker=session_maker,
relevant_documents=relevant_documents
)
)
# Run all section processing tasks in parallel
print(f"Running {len(section_tasks)} section processing tasks in parallel")
section_results = await asyncio.gather(*section_tasks, return_exceptions=True)
# Handle any exceptions in the results
processed_results = []
for i, result in enumerate(section_results):
if isinstance(result, Exception):
section_title = answer_outline.answer_outline[i].section_title
error_message = f"Error processing section '{section_title}': {str(result)}"
print(error_message)
processed_results.append(error_message)
else:
processed_results.append(result)
# Combine the results into a final report with section titles
final_report = []
for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
section_header = f"## {section.section_title}"
final_report.append(section_header)
final_report.append(content)
final_report.append("\n") # Add spacing between sections
# Join all sections with newlines
final_written_report = "\n".join(final_report)
print(f"Generated final report with {len(final_report)} parts")
return {
"final_written_report": final_written_report
}
async def process_section_with_documents(
section_title: str,
section_questions: List[str],
user_id: str,
search_space_id: int,
session_maker,
relevant_documents: List[Dict[str, Any]]
) -> str:
"""
Process a single section using pre-fetched documents.
Args:
section_title: The title of the section
section_questions: List of research questions for this section
user_id: The user ID
search_space_id: The search space ID
session_maker: Factory for creating new database sessions
relevant_documents: Pre-fetched documents to use for this section
Returns:
The written section content
"""
try:
# Create a new database session for this section
async with session_maker() as db_session:
# Use the provided documents
documents_to_use = relevant_documents
# Fallback if no documents found
if not documents_to_use:
print(f"No relevant documents found for section: {section_title}")
documents_to_use = [
{
"content": f"No specific information was found for: {question}"
for question in section_questions
}
]
# Call the sub_section_writer graph with the appropriate config
config = {
"configurable": {
"sub_section_title": section_title,
"sub_section_questions": section_questions,
"relevant_documents": documents_to_use,
"user_id": user_id,
"search_space_id": search_space_id
}
}
# Create the initial state with db_session
state = {"db_session": db_session}
# Invoke the sub-section writer graph
print(f"Invoking sub_section_writer for: {section_title}")
result = await sub_section_writer_graph.ainvoke(state, config)
# Return the final answer from the sub_section_writer
final_answer = result.get("final_answer", "No content was generated for this section.")
return final_answer
except Exception as e:
print(f"Error processing section '{section_title}': {str(e)}")
return f"Error processing section: {section_title}. Details: {str(e)}"

View file

@ -0,0 +1,91 @@
import datetime
answer_outline_system_prompt = f"""
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
<answer_outline_system>
You are an expert research assistant specializing in structuring information. Your task is to create a detailed and logical research outline based on the user's query. This outline will serve as the blueprint for generating a comprehensive research report.
<input>
- user_query (string): The main question or topic the user wants researched. This guides the entire outline creation process.
- num_sections (integer): The target number of distinct sections the final research report should have. This helps control the granularity and structure of the outline.
</input>
<output_format>
A JSON object with the following structure:
{{
"answer_outline": [
{{
"section_id": 0,
"section_title": "Section Title",
"questions": [
"Question 1 to research for this section",
"Question 2 to research for this section"
]
}}
]
}}
</output_format>
<instructions>
1. **Deconstruct the `user_query`:** Identify the key concepts, entities, and the core information requested by the user.
2. **Determine Section Themes:** Based on the analysis and the requested `num_sections`, divide the topic into distinct, logical themes or sub-topics. Each theme will become a section. Ensure these themes collectively address the `user_query` comprehensively.
3. **Develop Sections:** For *each* of the `num_sections`:
* **Assign `section_id`:** Start with 0 and increment sequentially for each section.
* **Craft `section_title`:** Write a concise, descriptive title that clearly defines the scope and focus of the section's theme.
* **Formulate Research `questions`:** Generate 2 to 5 specific, targeted research questions for this section. These questions must:
* Directly relate to the `section_title` and explore its key aspects.
* Be answerable through focused research (e.g., searching documents, databases, or knowledge bases).
* Be distinct from each other and from questions in other sections. Avoid redundancy.
* Collectively guide the gathering of information needed to fully address the section's theme.
4. **Ensure Logical Flow:** Arrange the sections in a coherent and intuitive sequence. Consider structures like:
* General background -> Specific details -> Analysis/Comparison -> Applications/Implications
* Problem definition -> Proposed solutions -> Evaluation -> Conclusion
* Chronological progression
5. **Verify Completeness and Cohesion:** Review the entire outline (`section_titles` and `questions`) to confirm that:
* All sections together provide a complete and well-structured answer to the original `user_query`.
* There are no significant overlaps or gaps in coverage between sections.
6. **Adhere Strictly to Output Format:** Ensure the final output is a valid JSON object matching the specified structure exactly, including correct field names (`answer_outline`, `section_id`, `section_title`, `questions`) and data types.
</instructions>
<examples>
User Query: "What are the health benefits of meditation?"
Number of Sections: 3
{{
"answer_outline": [
{{
"section_id": 0,
"section_title": "Physical Health Benefits of Meditation",
"questions": [
"What physiological changes occur in the body during meditation?",
"How does regular meditation affect blood pressure and heart health?",
"What impact does meditation have on inflammation and immune function?",
"Can meditation help with pain management, and if so, how?"
]
}},
{{
"section_id": 1,
"section_title": "Mental Health Benefits of Meditation",
"questions": [
"How does meditation affect stress and anxiety levels?",
"What changes in brain structure or function have been observed in meditation practitioners?",
"Can meditation help with depression and mood disorders?",
"What is the relationship between meditation and cognitive function?"
]
}},
{{
"section_id": 2,
"section_title": "Best Meditation Practices for Maximum Benefits",
"questions": [
"What are the most effective meditation techniques for beginners?",
"How long and how frequently should one meditate to see benefits?",
"Are there specific meditation approaches best suited for particular health goals?",
"What common obstacles prevent people from experiencing meditation benefits?"
]
}}
]
}}
</examples>
</answer_outline_system>
"""

View file

@ -0,0 +1,30 @@
"""Define the state structures for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Any, Dict, Annotated
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
from langchain_core.messages import BaseMessage, HumanMessage
from pydantic import BaseModel
@dataclass
class State:
"""Defines the dynamic state for the agent during execution.
This state tracks the database session and the outputs generated by the agent's nodes.
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context (not part of actual graph state)
db_session: AsyncSession
engine: Optional[AsyncEngine] = None
# Intermediate state - populated during workflow
# Using field to explicitly mark as part of state
answer_outline: Optional[Any] = field(default=None)
# OUTPUT: Populated by agent nodes
# Using field to explicitly mark as part of state
final_written_report: Optional[str] = field(default=None)

View file

@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Optional, List
from typing import Optional, List, Any
from langchain_core.runnables import RunnableConfig
@ -14,11 +14,10 @@ class Configuration:
# Input parameters provided at invocation
sub_section_title: str
sub_questions: List[str]
connectors_to_search: List[str]
sub_section_questions: List[str]
relevant_documents: List[Any] # Documents provided directly to the agent
user_id: str
search_space_id: int
top_k: int = 20 # Default top_k value
@classmethod

View file

@ -1,20 +1,18 @@
from langgraph.graph import StateGraph
from .state import State
from .nodes import fetch_relevant_documents, write_sub_section
from .nodes import write_sub_section, rerank_documents
from .configuration import Configuration
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)
# Add the nodes to the graph
workflow.add_node("fetch_relevant_documents", fetch_relevant_documents)
workflow.add_node("rerank_documents", rerank_documents)
workflow.add_node("write_sub_section", write_sub_section)
# Entry point
workflow.add_edge("__start__", "fetch_relevant_documents")
# Connect fetch_relevant_documents to write_sub_section
workflow.add_edge("fetch_relevant_documents", "write_sub_section")
# Exit point
# Connect the nodes
workflow.add_edge("__start__", "rerank_documents")
workflow.add_edge("rerank_documents", "write_sub_section")
workflow.add_edge("write_sub_section", "__end__")
# Compile the workflow into an executable graph

View file

@ -1,172 +1,80 @@
from .configuration import Configuration
from langchain_core.runnables import RunnableConfig
from .state import State
from typing import Any, Dict
from app.utils.connector_service import ConnectorService
from app.utils.reranker_service import RerankerService
from typing import Any, Dict, List
from app.config import config as app_config
from .prompts import citation_system_prompt
from langchain_core.messages import HumanMessage, SystemMessage
async def fetch_relevant_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""
Fetch relevant documents for the sub-section using specified connectors.
Rerank the documents based on relevance to the sub-section title.
This node retrieves documents from various data sources based on the sub-questions
derived from the sub-section title. It searches across all selected connectors
(YouTube, Extension, Crawled URLs, Files, Tavily API, Slack, Notion) and reranks
the results to provide the most relevant information for the agent workflow.
This node takes the relevant documents provided in the configuration,
reranks them using the reranker service based on the sub-section title,
and updates the state with the reranked documents.
Returns:
Dict containing the reranked documents in the "relevant_documents_fetched" key.
Dict containing the reranked documents.
"""
# Get configuration
# Get configuration and relevant documents
configuration = Configuration.from_runnable_config(config)
# Extract state parameters
db_session = state.db_session
# Extract config parameters
user_id = configuration.user_id
search_space_id = configuration.search_space_id
TOP_K = configuration.top_k
# Initialize services
connector_service = ConnectorService(db_session)
reranker_service = RerankerService.get_reranker_instance(app_config)
documents = configuration.relevant_documents
sub_section_questions = configuration.sub_section_questions
all_raw_documents = [] # Store all raw documents before reranking
# If no documents were provided, return empty list
if not documents or len(documents) == 0:
return {
"reranked_documents": []
}
for user_query in configuration.sub_questions:
# Reformulate query (optional, consider if needed for each sub-question)
# reformulated_query = await QueryService.reformulate_query(user_query)
reformulated_query = user_query # Using original sub-question for now
# Process each selected connector
for connector in configuration.connectors_to_search:
if connector == "YOUTUBE_VIDEO":
_, youtube_chunks = await connector_service.search_youtube(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
all_raw_documents.extend(youtube_chunks)
elif connector == "EXTENSION":
_, extension_chunks = await connector_service.search_extension(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
all_raw_documents.extend(extension_chunks)
elif connector == "CRAWLED_URL":
_, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
all_raw_documents.extend(crawled_urls_chunks)
elif connector == "FILE":
_, files_chunks = await connector_service.search_files(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
all_raw_documents.extend(files_chunks)
elif connector == "TAVILY_API":
_, tavily_chunks = await connector_service.search_tavily(
user_query=reformulated_query,
user_id=user_id,
top_k=TOP_K
)
all_raw_documents.extend(tavily_chunks)
elif connector == "SLACK_CONNECTOR":
_, slack_chunks = await connector_service.search_slack(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
all_raw_documents.extend(slack_chunks)
elif connector == "NOTION_CONNECTOR":
_, notion_chunks = await connector_service.search_notion(
user_query=reformulated_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
all_raw_documents.extend(notion_chunks)
# Get reranker service from app config
reranker_service = getattr(app_config, "reranker_service", None)
# If we have documents and a reranker is available, rerank them
# Deduplicate documents based on chunk_id or content to avoid processing duplicates
seen_chunk_ids = set()
seen_content_hashes = set()
deduplicated_docs = []
# Use documents as is if no reranker service is available
reranked_docs = documents
for doc in all_raw_documents:
chunk_id = doc.get("chunk_id")
content = doc.get("content", "")
content_hash = hash(content)
# Skip if we've seen this chunk_id or content before
if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes:
continue
if reranker_service:
try:
# Use the sub-section questions for reranking context
rerank_query = "\n".join(sub_section_questions)
# Add to our tracking sets and keep this document
if chunk_id:
seen_chunk_ids.add(chunk_id)
seen_content_hashes.add(content_hash)
deduplicated_docs.append(doc)
# Convert documents to format expected by reranker if needed
reranker_input_docs = [
{
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
"content": doc.get("content", ""),
"score": doc.get("score", 0.0),
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(documents)
]
# Rerank documents using the section title
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print(f"Reranked {len(reranked_docs)} documents for section: {configuration.sub_section_title}")
except Exception as e:
print(f"Error during reranking: {str(e)}")
# Use original docs if reranking fails
# Use deduplicated documents for reranking
reranked_docs = deduplicated_docs
if deduplicated_docs and reranker_service:
# Use the main sub_section_title for reranking context
rerank_query = configuration.sub_section_title
# Convert documents to format expected by reranker
reranker_input_docs = [
{
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
"content": doc.get("content", ""),
"score": doc.get("score", 0.0),
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(deduplicated_docs)
]
# Rerank documents using the main title query
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
# Update state with fetched documents
return {
"relevant_documents_fetched": reranked_docs
"reranked_documents": reranked_docs
}
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""
Write the sub-section using the fetched documents.
Write the sub-section using the provided documents.
This node takes the relevant documents fetched in the previous node and uses
an LLM to generate a comprehensive answer to the sub-section questions with
This node takes the relevant documents provided in the configuration and uses
an LLM to generate a comprehensive answer to the sub-section title with
proper citations. The citations follow IEEE format using source IDs from the
documents.
@ -174,17 +82,17 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
Dict containing the final answer in the "final_answer" key.
"""
# Get configuration and relevant documents
# Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
documents = state.relevant_documents_fetched
documents = configuration.relevant_documents
# Initialize LLM
llm = app_config.fast_llm_instance
# If no documents were found, return a message indicating this
# If no documents were provided, return a message indicating this
if not documents or len(documents) == 0:
return {
"final_answer": "No relevant documents were found to answer this question. Please try refining your search or providing more specific questions."
"final_answer": "No relevant documents were provided to answer this question. Please provide documents or try a different approach."
}
# Prepare documents for citation formatting
@ -208,18 +116,25 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
"""
formatted_documents.append(formatted_doc)
# Create the query that combines the section title and questions
# section_title = configuration.sub_section_title
questions = "\n".join([f"- {q}" for q in configuration.sub_questions])
# Create the query that uses the section title and questions
section_title = configuration.sub_section_title
sub_section_questions = configuration.sub_section_questions
documents_text = "\n".join(formatted_documents)
# Format the questions as bullet points for clarity
questions_text = "\n".join([f"- {question}" for question in sub_section_questions])
# Construct a clear, structured query for the LLM
human_message_content = f"""
Please write a comprehensive answer for the title:
Address the following questions:
<title>
{section_title}
</title>
Focus on answering these specific questions related to the title:
<questions>
{questions}
{questions_text}
</questions>
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.

View file

@ -1,4 +1,8 @@
import datetime
citation_system_prompt = f"""
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
<instructions>

View file

@ -18,6 +18,6 @@ class State:
db_session: AsyncSession
# OUTPUT: Populated by agent nodes
relevant_documents_fetched: Optional[List[Any]] = None
reranked_documents: Optional[List[Any]] = None
final_answer: Optional[str] = None

View file

@ -0,0 +1,132 @@
#!/usr/bin/env python3
"""
Test script for the Researcher LangGraph agent.
This script demonstrates how to invoke the researcher agent with a sample query.
Run this script directly from VSCode using the "Run Python File" button or
right-click and select "Run Python File in Terminal".
Before running:
1. Make sure your Python environment has all required dependencies
2. Create a .env file with any required API keys
3. Ensure database connection is properly configured
"""
import asyncio
import os
import sys
from pathlib import Path
# Add project root to Python path so that 'app' can be found as a module
# Get the absolute path to the surfsense_backend directory which contains the app module
project_root = str(Path(__file__).resolve().parents[3]) # Go up 3 levels from the script: app/agents/researcher -> app/agents -> app -> surfsense_backend
print(f"Adding to Python path: {project_root}")
sys.path.insert(0, project_root)
# Now import the modules after fixing the path
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from dotenv import load_dotenv
# These imports should now work with the correct path
from app.agents.researcher.graph import graph
from app.agents.researcher.state import State
from app.agents.researcher.nodes import write_answer_outline, process_sections
# Load environment variables
load_dotenv()
# Database connection string - use a test database or mock
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
# Create async engine and session
engine = create_async_engine(DATABASE_URL)
async_session_maker = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def run_test():
"""Run a test of the researcher agent."""
print("Starting researcher agent test...")
# Create a database session
async with async_session_maker() as db_session:
# Sample configuration
config = {
"configurable": {
"user_query": "What are the best clash royale decks recommended by Surgical Goblin?",
"num_sections": 1,
"connectors_to_search": [
"YOUTUBE_VIDEO",
],
"user_id": "d6ac2187-7407-4664-8734-af09926d161e",
"search_space_id": 2
}
}
try:
# Initialize state with database session and engine
initial_state = State(db_session=db_session, engine=engine)
# Instead of using the graph directly, let's run the nodes manually
# to track the state transitions explicitly
print("\nSTEP 1: Running write_answer_outline node...")
outline_result = await write_answer_outline(initial_state, config)
# Update the state with the outline
if "answer_outline" in outline_result:
initial_state.answer_outline = outline_result["answer_outline"]
print(f"Generated answer outline with {len(initial_state.answer_outline.answer_outline)} sections")
# Print the outline
print("\nGenerated Answer Outline:")
for section in initial_state.answer_outline.answer_outline:
print(f"\nSection {section.section_id}: {section.section_title}")
print("Research Questions:")
for q in section.questions:
print(f" - {q}")
# Run the second node with the updated state
print("\nSTEP 2: Running process_sections node...")
sections_result = await process_sections(initial_state, config)
# Check if we got a final report
if "final_written_report" in sections_result:
final_report = sections_result["final_written_report"]
print("\nFinal Research Report generated successfully!")
print(f"Report length: {len(final_report)} characters")
# Display the final report
print("\n==== FINAL RESEARCH REPORT ====\n")
print(final_report)
else:
print("\nNo final report was generated.")
print(f"Result keys: {list(sections_result.keys())}")
return sections_result
except Exception as e:
print(f"Error running researcher agent: {str(e)}")
import traceback
traceback.print_exc()
raise
async def main():
"""Main entry point for the test script."""
try:
result = await run_test()
print("\nTest completed successfully.")
return result
except Exception as e:
print(f"\nTest failed with error: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
# Run the async test
result = asyncio.run(main())
# Keep terminal open if run directly in VSCode
if 'VSCODE_PID' in os.environ:
input("\nPress Enter to close this window...")

View file

@ -42,10 +42,8 @@ class Config:
# GPT Researcher
FAST_LLM = os.getenv("FAST_LLM")
SMART_LLM = os.getenv("SMART_LLM")
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
fast_llm_instance = ChatLiteLLM(model=extract_model_name(FAST_LLM))
smart_llm_instance = ChatLiteLLM(model=extract_model_name(SMART_LLM))
strategic_llm_instance = ChatLiteLLM(model=extract_model_name(STRATEGIC_LLM))