mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-05 03:59:06 +00:00
refactor: remove process_section function and streamline test workflow
This commit is contained in:
parent
154c5748fd
commit
7be68ebf41
2 changed files with 14 additions and 85 deletions
|
@ -232,71 +232,7 @@ async def fetch_relevant_documents(
|
||||||
|
|
||||||
return deduplicated_docs
|
return deduplicated_docs
|
||||||
|
|
||||||
async def process_section(
|
|
||||||
section_title: str,
|
|
||||||
user_id: str,
|
|
||||||
search_space_id: int,
|
|
||||||
session_maker,
|
|
||||||
research_questions: List[str],
|
|
||||||
connectors_to_search: List[str]
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Process a single section by sending it to the sub_section_writer graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
section_title: The title of the section
|
|
||||||
user_id: The user ID
|
|
||||||
search_space_id: The search space ID
|
|
||||||
session_maker: Factory for creating new database sessions
|
|
||||||
research_questions: List of research questions for this section
|
|
||||||
connectors_to_search: List of connectors to search
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The written section content
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Create a new database session for this section
|
|
||||||
async with session_maker() as db_session:
|
|
||||||
# Fetch relevant documents using all research questions for this section
|
|
||||||
relevant_documents = await fetch_relevant_documents(
|
|
||||||
research_questions=research_questions,
|
|
||||||
user_id=user_id,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
db_session=db_session,
|
|
||||||
connectors_to_search=connectors_to_search
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback if no documents found
|
|
||||||
if not relevant_documents:
|
|
||||||
print(f"No relevant documents found for section: {section_title}")
|
|
||||||
relevant_documents = [
|
|
||||||
{"content": f"No specific information was found for: {question}"}
|
|
||||||
for question in research_questions
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call the sub_section_writer graph with the appropriate config
|
|
||||||
config = {
|
|
||||||
"configurable": {
|
|
||||||
"sub_section_title": section_title,
|
|
||||||
"relevant_documents": relevant_documents,
|
|
||||||
"user_id": user_id,
|
|
||||||
"search_space_id": search_space_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create the initial state with db_session
|
|
||||||
state = {"db_session": db_session}
|
|
||||||
|
|
||||||
# Invoke the sub-section writer graph
|
|
||||||
print(f"Invoking sub_section_writer for: {section_title}")
|
|
||||||
result = await sub_section_writer_graph.ainvoke(state, config)
|
|
||||||
|
|
||||||
# Return the final answer from the sub_section_writer
|
|
||||||
final_answer = result.get("final_answer", "No content was generated for this section.")
|
|
||||||
return final_answer
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing section '{section_title}': {str(e)}")
|
|
||||||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
|
||||||
|
|
||||||
async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -395,8 +331,7 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An
|
||||||
# Combine the results into a final report with section titles
|
# Combine the results into a final report with section titles
|
||||||
final_report = []
|
final_report = []
|
||||||
for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
|
for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)):
|
||||||
section_header = f"## {section.section_title}"
|
# Skip adding the section header since the content already contains the title
|
||||||
final_report.append(section_header)
|
|
||||||
final_report.append(content)
|
final_report.append(content)
|
||||||
final_report.append("\n") # Add spacing between sections
|
final_report.append("\n") # Add spacing between sections
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ from dotenv import load_dotenv
|
||||||
# These imports should now work with the correct path
|
# These imports should now work with the correct path
|
||||||
from app.agents.researcher.graph import graph
|
from app.agents.researcher.graph import graph
|
||||||
from app.agents.researcher.state import State
|
from app.agents.researcher.state import State
|
||||||
from app.agents.researcher.nodes import write_answer_outline, process_sections
|
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -68,31 +68,25 @@ async def run_test():
|
||||||
# Initialize state with database session and engine
|
# Initialize state with database session and engine
|
||||||
initial_state = State(db_session=db_session, engine=engine)
|
initial_state = State(db_session=db_session, engine=engine)
|
||||||
|
|
||||||
# Instead of using the graph directly, let's run the nodes manually
|
# Run the graph directly
|
||||||
# to track the state transitions explicitly
|
print("\nRunning the complete researcher workflow...")
|
||||||
print("\nSTEP 1: Running write_answer_outline node...")
|
result = await graph.ainvoke(initial_state, config)
|
||||||
outline_result = await write_answer_outline(initial_state, config)
|
|
||||||
|
|
||||||
# Update the state with the outline
|
# Extract the answer outline for display
|
||||||
if "answer_outline" in outline_result:
|
if "answer_outline" in result and result["answer_outline"]:
|
||||||
initial_state.answer_outline = outline_result["answer_outline"]
|
print(f"\nGenerated answer outline with {len(result['answer_outline'].answer_outline)} sections")
|
||||||
print(f"Generated answer outline with {len(initial_state.answer_outline.answer_outline)} sections")
|
|
||||||
|
|
||||||
# Print the outline
|
# Print the outline
|
||||||
print("\nGenerated Answer Outline:")
|
print("\nGenerated Answer Outline:")
|
||||||
for section in initial_state.answer_outline.answer_outline:
|
for section in result["answer_outline"].answer_outline:
|
||||||
print(f"\nSection {section.section_id}: {section.section_title}")
|
print(f"\nSection {section.section_id}: {section.section_title}")
|
||||||
print("Research Questions:")
|
print("Research Questions:")
|
||||||
for q in section.questions:
|
for q in section.questions:
|
||||||
print(f" - {q}")
|
print(f" - {q}")
|
||||||
|
|
||||||
# Run the second node with the updated state
|
|
||||||
print("\nSTEP 2: Running process_sections node...")
|
|
||||||
sections_result = await process_sections(initial_state, config)
|
|
||||||
|
|
||||||
# Check if we got a final report
|
# Check if we got a final report
|
||||||
if "final_written_report" in sections_result:
|
if "final_written_report" in result and result["final_written_report"]:
|
||||||
final_report = sections_result["final_written_report"]
|
final_report = result["final_written_report"]
|
||||||
print("\nFinal Research Report generated successfully!")
|
print("\nFinal Research Report generated successfully!")
|
||||||
print(f"Report length: {len(final_report)} characters")
|
print(f"Report length: {len(final_report)} characters")
|
||||||
|
|
||||||
|
@ -101,9 +95,9 @@ async def run_test():
|
||||||
print(final_report)
|
print(final_report)
|
||||||
else:
|
else:
|
||||||
print("\nNo final report was generated.")
|
print("\nNo final report was generated.")
|
||||||
print(f"Result keys: {list(sections_result.keys())}")
|
print(f"Available result keys: {list(result.keys())}")
|
||||||
|
|
||||||
return sections_result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error running researcher agent: {str(e)}")
|
print(f"Error running researcher agent: {str(e)}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue