diff --git a/README.md b/README.md index 4b9ea5a..aa337c7 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,6 @@ This is the core of SurfSense. Before we begin let's look at `.env` variables' t | RERANKERS_MODEL_NAME| Name of the reranker model for search result reranking. Eg. `ms-marco-MiniLM-L-12-v2`| | RERANKERS_MODEL_TYPE| Type of reranker model being used. Eg. `flashrank`| | FAST_LLM| LiteLLM routed Smaller, faster LLM for quick responses. Eg. `litellm:openai/gpt-4o`| -| SMART_LLM| LiteLLM routed Balanced LLM for general use. Eg. `litellm:openai/gpt-4o`| | STRATEGIC_LLM| LiteLLM routed Advanced LLM for complex reasoning tasks. Eg. `litellm:openai/gpt-4o`| | LONG_CONTEXT_LLM| LiteLLM routed LLM capable of handling longer context windows. Eg. `litellm:gemini/gemini-2.0-flash`| | UNSTRUCTURED_API_KEY| API key for Unstructured.io service for document parsing| @@ -221,15 +220,15 @@ After filling in your SurfSense API key you should be able to use extension now. - **Alembic**: A database migrations tool for SQLAlchemy. - **FastAPI Users**: Authentication and user management with JWT and OAuth support - -- **LangChain**: Framework for developing AI-powered applications -- **GPT Integration**: Integration with LLM models through LiteLLM +- **LangGraph**: Framework for developing AI-agents. + +- **LangChain**: Framework for developing AI-powered applications. + +- **LLM Integration**: Integration with LLM models through LiteLLM - **Rerankers**: Advanced result ranking for improved search relevance -- **GPT-Researcher**: Advanced research capabilities - - **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF) - **Vector Embeddings**: Document and text embeddings for semantic search diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 380296e..bc3f5ce 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -10,9 +10,8 @@ RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2" RERANKERS_MODEL_TYPE="flashrank" FAST_LLM="litellm:openai/gpt-4o-mini" -SMART_LLM="litellm:openai/gpt-4o-mini" -STRATEGIC_LLM="litellm:openai/gpt-4o-mini" -LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21" +STRATEGIC_LLM="litellm:openai/gpt-4o" +LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash" OPENAI_API_KEY="sk-proj-iA" GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124" diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index c88ad33..15935f2 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -1,17 +1,23 @@ -from .configuration import Configuration -from langchain_core.runnables import RunnableConfig -from .state import State -from typing import Any, Dict, List -from app.config import config as app_config -from .prompts import get_answer_outline_system_prompt -from langchain_core.messages import HumanMessage, SystemMessage -from pydantic import BaseModel, Field -import json import asyncio -from .sub_section_writer.graph import graph as sub_section_writer_graph +import json +from typing import Any, Dict, List + +from app.config import config as app_config +from app.db import async_session_maker from app.utils.connector_service import ConnectorService +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig +from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession +from .configuration import Configuration +from .prompts import get_answer_outline_system_prompt +from .state import State +from .sub_section_writer.graph import graph as sub_section_writer_graph + +from langgraph.types import StreamWriter + + class Section(BaseModel): """A section in the answer outline.""" section_id: int = Field(..., description="The zero-based index of the section") @@ -22,7 +28,7 @@ class AnswerOutline(BaseModel): """The complete answer outline with all sections.""" answer_outline: List[Section] = Field(..., description="List of sections in the answer outline") -async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str, Any]: +async def write_answer_outline(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: """ Create a structured answer outline based on the user query. @@ -33,12 +39,18 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str Returns: Dict containing the answer outline in the "answer_outline" key for state update. """ + streaming_service = state.streaming_service + streaming_service.only_update_terminal("Generating answer outline...") + writer({"yeild_value": streaming_service._format_annotations()}) # Get configuration from runnable config configuration = Configuration.from_runnable_config(config) user_query = configuration.user_query num_sections = configuration.num_sections + streaming_service.only_update_terminal(f"Planning research approach for query: {user_query[:100]}...") + writer({"yeild_value": streaming_service._format_annotations()}) + # Initialize LLM llm = app_config.strategic_llm_instance @@ -66,6 +78,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation. """ + streaming_service.only_update_terminal("Designing structured outline with AI...") + writer({"yeild_value": streaming_service._format_annotations()}) + # Create messages for the LLM messages = [ SystemMessage(content=get_answer_outline_system_prompt()), @@ -73,6 +88,9 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str ] # Call the LLM directly without using structured output + streaming_service.only_update_terminal("Processing answer structure...") + writer({"yeild_value": streaming_service._format_annotations()}) + response = await llm.ainvoke(messages) # Parse the JSON response manually @@ -92,16 +110,27 @@ async def write_answer_outline(state: State, config: RunnableConfig) -> Dict[str # Convert to Pydantic model answer_outline = AnswerOutline(**parsed_data) + total_questions = sum(len(section.questions) for section in answer_outline.answer_outline) + streaming_service.only_update_terminal(f"Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions") + writer({"yeild_value": streaming_service._format_annotations()}) + print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections") # Return state update return {"answer_outline": answer_outline} else: # If JSON structure not found, raise a clear error - raise ValueError(f"Could not find valid JSON in LLM response. Raw response: {content}") + error_message = f"Could not find valid JSON in LLM response. Raw response: {content}" + streaming_service.only_update_terminal(error_message, "error") + writer({"yeild_value": streaming_service._format_annotations()}) + raise ValueError(error_message) except (json.JSONDecodeError, ValueError) as e: # Log the error and re-raise it + error_message = f"Error parsing LLM response: {str(e)}" + streaming_service.only_update_terminal(error_message, "error") + writer({"yeild_value": streaming_service._format_annotations()}) + print(f"Error parsing LLM response: {str(e)}") print(f"Raw response: {response.content}") raise @@ -112,18 +141,21 @@ async def fetch_relevant_documents( search_space_id: int, db_session: AsyncSession, connectors_to_search: List[str], - top_k: int = 5 + writer: StreamWriter = None, + state: State = None, + top_k: int = 20 ) -> List[Dict[str, Any]]: """ Fetch relevant documents for research questions using the provided connectors. Args: - section_title: The title of the section being researched research_questions: List of research questions to find documents for user_id: The user ID search_space_id: The search space ID db_session: The database session connectors_to_search: List of connectors to search + writer: StreamWriter for sending progress updates + state: The current state containing the streaming service top_k: Number of top results to retrieve per connector per question Returns: @@ -131,83 +163,237 @@ async def fetch_relevant_documents( """ # Initialize services connector_service = ConnectorService(db_session) - - all_raw_documents = [] # Store all raw documents before reranking - for user_query in research_questions: + # Only use streaming if both writer and state are provided + streaming_service = state.streaming_service if state is not None else None + + # Stream initial status update + if streaming_service and writer: + streaming_service.only_update_terminal(f"Starting research on {len(research_questions)} questions using {len(connectors_to_search)} connectors...") + writer({"yeild_value": streaming_service._format_annotations()}) + + all_raw_documents = [] # Store all raw documents + all_sources = [] # Store all sources + + for i, user_query in enumerate(research_questions): + # Stream question being researched + if streaming_service and writer: + streaming_service.only_update_terminal(f"Researching question {i+1}/{len(research_questions)}: {user_query[:100]}...") + writer({"yeild_value": streaming_service._format_annotations()}) + # Use original research question as the query reformulated_query = user_query # Process each selected connector for connector in connectors_to_search: + # Stream connector being searched + if streaming_service and writer: + streaming_service.only_update_terminal(f"Searching {connector} for relevant information...") + writer({"yeild_value": streaming_service._format_annotations()}) + try: if connector == "YOUTUBE_VIDEO": - _, youtube_chunks = await connector_service.search_youtube( + source_object, youtube_chunks = await connector_service.search_youtube( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(youtube_chunks) + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(youtube_chunks)} YouTube chunks relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + elif connector == "EXTENSION": - _, extension_chunks = await connector_service.search_extension( + source_object, extension_chunks = await connector_service.search_extension( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(extension_chunks) + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(extension_chunks)} extension chunks relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + elif connector == "CRAWLED_URL": - _, crawled_urls_chunks = await connector_service.search_crawled_urls( + source_object, crawled_urls_chunks = await connector_service.search_crawled_urls( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(crawled_urls_chunks) + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(crawled_urls_chunks)} crawled URL chunks relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + elif connector == "FILE": - _, files_chunks = await connector_service.search_files( + source_object, files_chunks = await connector_service.search_files( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(files_chunks) + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(files_chunks)} file chunks relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + elif connector == "TAVILY_API": - _, tavily_chunks = await connector_service.search_tavily( + source_object, tavily_chunks = await connector_service.search_tavily( user_query=reformulated_query, user_id=user_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(tavily_chunks) + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(tavily_chunks)} web search results relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + elif connector == "SLACK_CONNECTOR": - _, slack_chunks = await connector_service.search_slack( + source_object, slack_chunks = await connector_service.search_slack( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(slack_chunks) + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(slack_chunks)} Slack messages relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + elif connector == "NOTION_CONNECTOR": - _, notion_chunks = await connector_service.search_notion( + source_object, notion_chunks = await connector_service.search_notion( user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=top_k ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) all_raw_documents.extend(notion_chunks) + + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(notion_chunks)} Notion pages/blocks relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + + elif connector == "GITHUB_CONNECTOR": + source_object, github_chunks = await connector_service.search_github( + user_query=reformulated_query, + user_id=user_id, + search_space_id=search_space_id, + top_k=top_k + ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) + all_raw_documents.extend(github_chunks) + + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(github_chunks)} GitHub files/issues relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) + + elif connector == "LINEAR_CONNECTOR": + source_object, linear_chunks = await connector_service.search_linear( + user_query=reformulated_query, + user_id=user_id, + search_space_id=search_space_id, + top_k=top_k + ) + + # Add to sources and raw documents + if source_object: + all_sources.append(source_object) + all_raw_documents.extend(linear_chunks) + + # Stream found document count + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(linear_chunks)} Linear issues relevant to the query") + writer({"yeild_value": streaming_service._format_annotations()}) except Exception as e: - print(f"Error searching connector {connector}: {str(e)}") + error_message = f"Error searching connector {connector}: {str(e)}" + print(error_message) + + # Stream error message + if streaming_service and writer: + streaming_service.only_update_terminal(error_message, "error") + writer({"yeild_value": streaming_service._format_annotations()}) + # Continue with other connectors on error continue - # Deduplicate documents based on chunk_id or content + # Deduplicate source objects by ID before streaming + deduplicated_sources = [] + seen_source_keys = set() + + for source_obj in all_sources: + # Use combination of source ID and type as a unique identifier + # This ensures we don't accidentally deduplicate sources from different connectors + source_id = source_obj.get('id') + source_type = source_obj.get('type') + + if source_id and source_type: + source_key = f"{source_type}_{source_id}" + if source_key not in seen_source_keys: + seen_source_keys.add(source_key) + deduplicated_sources.append(source_obj) + else: + # If there's no ID or type, just add it to be safe + deduplicated_sources.append(source_obj) + + # Stream info about deduplicated sources + if streaming_service and writer: + streaming_service.only_update_terminal(f"Collected {len(deduplicated_sources)} unique sources across all connectors") + writer({"yeild_value": streaming_service._format_annotations()}) + + # After all sources are collected and deduplicated, stream them + if streaming_service and writer: + streaming_service.only_update_sources(deduplicated_sources) + writer({"yeild_value": streaming_service._format_annotations()}) + + # Deduplicate raw documents based on chunk_id or content seen_chunk_ids = set() seen_content_hashes = set() deduplicated_docs = [] @@ -227,11 +413,15 @@ async def fetch_relevant_documents( seen_content_hashes.add(content_hash) deduplicated_docs.append(doc) + # Stream info about deduplicated documents + if streaming_service and writer: + streaming_service.only_update_terminal(f"Found {len(deduplicated_docs)} unique document chunks after deduplication") + writer({"yeild_value": streaming_service._format_annotations()}) + + # Return deduplicated documents return deduplicated_docs - - -async def process_sections(state: State, config: RunnableConfig) -> Dict[str, Any]: +async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: """ Process all sections in parallel and combine the results. @@ -245,89 +435,97 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An # Get configuration and answer outline from state configuration = Configuration.from_runnable_config(config) answer_outline = state.answer_outline + streaming_service = state.streaming_service + + streaming_service.only_update_terminal(f"Starting to process research sections...") + writer({"yeild_value": streaming_service._format_annotations()}) print(f"Processing sections from outline: {answer_outline is not None}") if not answer_outline: + streaming_service.only_update_terminal("Error: No answer outline was provided. Cannot generate report.", "error") + writer({"yeild_value": streaming_service._format_annotations()}) return { "final_written_report": "No answer outline was provided. Cannot generate final report." } - # Create session maker from the engine or directly use the session - from sqlalchemy.ext.asyncio import AsyncSession - from sqlalchemy.orm import sessionmaker - - # Use the engine if available, otherwise create a new session for each task - if state.engine: - session_maker = sessionmaker( - state.engine, class_=AsyncSession, expire_on_commit=False - ) - else: - # Fallback to using the same session (less optimal but will work) - print("Warning: No engine available. Using same session for all tasks.") - # Create a mock session maker that returns the same session - async def mock_session_maker(): - class ContextManager: - async def __aenter__(self): - return state.db_session - async def __aexit__(self, exc_type, exc_val, exc_tb): - pass - return ContextManager() - session_maker = mock_session_maker - # Collect all questions from all sections all_questions = [] for section in answer_outline.answer_outline: all_questions.extend(section.questions) print(f"Collected {len(all_questions)} questions from all sections") + streaming_service.only_update_terminal(f"Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections") + writer({"yeild_value": streaming_service._format_annotations()}) # Fetch relevant documents once for all questions + streaming_service.only_update_terminal("Searching for relevant information across all connectors...") + writer({"yeild_value": streaming_service._format_annotations()}) + relevant_documents = [] - async with session_maker() as db_session: - + async with async_session_maker() as db_session: try: relevant_documents = await fetch_relevant_documents( research_questions=all_questions, user_id=configuration.user_id, search_space_id=configuration.search_space_id, db_session=db_session, - connectors_to_search=configuration.connectors_to_search + connectors_to_search=configuration.connectors_to_search, + writer=writer, + state=state ) except Exception as e: - print(f"Error fetching relevant documents: {str(e)}") + error_message = f"Error fetching relevant documents: {str(e)}" + print(error_message) + streaming_service.only_update_terminal(error_message, "error") + writer({"yeild_value": streaming_service._format_annotations()}) # Log the error and continue with an empty list of documents # This allows the process to continue, but the report might lack information relevant_documents = [] # Consider adding more robust error handling or reporting if needed print(f"Fetched {len(relevant_documents)} relevant documents for all sections") + streaming_service.only_update_terminal(f"Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks") + writer({"yeild_value": streaming_service._format_annotations()}) # Create tasks to process each section in parallel with the same document set section_tasks = [] + streaming_service.only_update_terminal("Creating processing tasks for each section...") + writer({"yeild_value": streaming_service._format_annotations()}) + for section in answer_outline.answer_outline: section_tasks.append( process_section_with_documents( section_title=section.section_title, section_questions=section.questions, + user_query=configuration.user_query, user_id=configuration.user_id, search_space_id=configuration.search_space_id, - session_maker=session_maker, - relevant_documents=relevant_documents + relevant_documents=relevant_documents, + state=state, + writer=writer ) ) # Run all section processing tasks in parallel print(f"Running {len(section_tasks)} section processing tasks in parallel") + streaming_service.only_update_terminal(f"Processing {len(section_tasks)} sections simultaneously...") + writer({"yeild_value": streaming_service._format_annotations()}) + section_results = await asyncio.gather(*section_tasks, return_exceptions=True) # Handle any exceptions in the results + streaming_service.only_update_terminal("Combining section results into final report...") + writer({"yeild_value": streaming_service._format_annotations()}) + processed_results = [] for i, result in enumerate(section_results): if isinstance(result, Exception): section_title = answer_outline.answer_outline[i].section_title error_message = f"Error processing section '{section_title}': {str(result)}" print(error_message) + streaming_service.only_update_terminal(error_message, "error") + writer({"yeild_value": streaming_service._format_annotations()}) processed_results.append(error_message) else: processed_results.append(result) @@ -337,12 +535,33 @@ async def process_sections(state: State, config: RunnableConfig) -> Dict[str, An for i, (section, content) in enumerate(zip(answer_outline.answer_outline, processed_results)): # Skip adding the section header since the content already contains the title final_report.append(content) - final_report.append("\n") # Add spacing between sections + final_report.append("\n") + # Join all sections with newlines final_written_report = "\n".join(final_report) print(f"Generated final report with {len(final_report)} parts") + streaming_service.only_update_terminal("Final research report generated successfully!") + writer({"yeild_value": streaming_service._format_annotations()}) + + if hasattr(state, 'streaming_service') and state.streaming_service: + # Convert the final report to the expected format for UI: + # A list of strings where empty strings represent line breaks + formatted_report = [] + for section in final_report: + if section == "\n": + # Add an empty string for line breaks + formatted_report.append("") + else: + # Split any multiline content by newlines and add each line + section_lines = section.split("\n") + formatted_report.extend(section_lines) + + state.streaming_service.only_update_answer(formatted_report) + writer({"yeild_value": state.streaming_service._format_annotations()}) + + return { "final_written_report": final_written_report } @@ -352,8 +571,10 @@ async def process_section_with_documents( section_questions: List[str], user_id: str, search_space_id: int, - session_maker, - relevant_documents: List[Dict[str, Any]] + relevant_documents: List[Dict[str, Any]], + user_query: str, + state: State = None, + writer: StreamWriter = None ) -> str: """ Process a single section using pre-fetched documents. @@ -363,31 +584,42 @@ async def process_section_with_documents( section_questions: List of research questions for this section user_id: The user ID search_space_id: The search space ID - session_maker: Factory for creating new database sessions relevant_documents: Pre-fetched documents to use for this section + state: The current state + writer: StreamWriter for sending progress updates Returns: The written section content """ try: + # Use the provided documents + documents_to_use = relevant_documents + + # Send status update via streaming if available + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"Writing section: {section_title} with {len(section_questions)} research questions") + writer({"yeild_value": state.streaming_service._format_annotations()}) + + # Fallback if no documents found + if not documents_to_use: + print(f"No relevant documents found for section: {section_title}") + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"Warning: No relevant documents found for section: {section_title}", "warning") + writer({"yeild_value": state.streaming_service._format_annotations()}) + + documents_to_use = [ + {"content": f"No specific information was found for: {question}"} + for question in section_questions + ] + # Create a new database session for this section - async with session_maker() as db_session: - # Use the provided documents - documents_to_use = relevant_documents - - # Fallback if no documents found - if not documents_to_use: - print(f"No relevant documents found for section: {section_title}") - documents_to_use = [ - {"content": f"No specific information was found for: {question}"} - for question in section_questions - ] - + async with async_session_maker() as db_session: # Call the sub_section_writer graph with the appropriate config config = { "configurable": { "sub_section_title": section_title, "sub_section_questions": section_questions, + "user_query": user_query, "relevant_documents": documents_to_use, "user_id": user_id, "search_space_id": search_space_id @@ -395,16 +627,32 @@ async def process_section_with_documents( } # Create the initial state with db_session - state = {"db_session": db_session} + sub_state = {"db_session": db_session} # Invoke the sub-section writer graph print(f"Invoking sub_section_writer for: {section_title}") - result = await sub_section_writer_graph.ainvoke(state, config) + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"Analyzing information and drafting content for section: {section_title}") + writer({"yeild_value": state.streaming_service._format_annotations()}) + + result = await sub_section_writer_graph.ainvoke(sub_state, config) # Return the final answer from the sub_section_writer final_answer = result.get("final_answer", "No content was generated for this section.") + + # Send section content update via streaming if available + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"Completed writing section: {section_title}") + writer({"yeild_value": state.streaming_service._format_annotations()}) + return final_answer except Exception as e: print(f"Error processing section '{section_title}': {str(e)}") + + # Send error update via streaming if available + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"Error processing section '{section_title}': {str(e)}", "error") + writer({"yeild_value": state.streaming_service._format_annotations()}) + return f"Error processing section: {section_title}. Details: {str(e)}" diff --git a/surfsense_backend/app/agents/researcher/state.py b/surfsense_backend/app/agents/researcher/state.py index 483e96a..dd36163 100644 --- a/surfsense_backend/app/agents/researcher/state.py +++ b/surfsense_backend/app/agents/researcher/state.py @@ -3,10 +3,9 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional, Any, Dict, Annotated -from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine -from langchain_core.messages import BaseMessage, HumanMessage -from pydantic import BaseModel +from typing import Optional, Any +from sqlalchemy.ext.asyncio import AsyncSession +from app.utils.streaming_service import StreamingService @dataclass class State: @@ -18,7 +17,9 @@ class State: """ # Runtime context (not part of actual graph state) db_session: AsyncSession - engine: Optional[AsyncEngine] = None + + # Streaming service + streaming_service: StreamingService # Intermediate state - populated during workflow # Using field to explicitly mark as part of state diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py index fbde94d..9e1ca32 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/configuration.py @@ -15,6 +15,7 @@ class Configuration: # Input parameters provided at invocation sub_section_title: str sub_section_questions: List[str] + user_query: str relevant_documents: List[Any] # Documents provided directly to the agent user_id: str search_space_id: int diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py index af807e3..5b11141 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py @@ -119,6 +119,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A # Create the query that uses the section title and questions section_title = configuration.sub_section_title sub_section_questions = configuration.sub_section_questions + user_query = configuration.user_query # Get the original user query documents_text = "\n".join(formatted_documents) # Format the questions as bullet points for clarity @@ -126,17 +127,16 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A # Construct a clear, structured query for the LLM human_message_content = f""" - Please write a comprehensive answer for the title: + Now user's query is: + + {user_query} + - + The sub-section title is: + <sub_section_title> {section_title} - + - Focus on answering these specific questions related to the title: - - {questions_text} - - Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id. {documents_text} diff --git a/surfsense_backend/app/agents/researcher/test_researcher.py b/surfsense_backend/app/agents/researcher/test_researcher.py deleted file mode 100644 index 15c993e..0000000 --- a/surfsense_backend/app/agents/researcher/test_researcher.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the Researcher LangGraph agent. - -This script demonstrates how to invoke the researcher agent with a sample query. -Run this script directly from VSCode using the "Run Python File" button or -right-click and select "Run Python File in Terminal". - -Before running: -1. Make sure your Python environment has all required dependencies -2. Create a .env file with any required API keys -3. Ensure database connection is properly configured -""" - -import asyncio -import os -import sys -from pathlib import Path - -# Add project root to Python path so that 'app' can be found as a module -# Get the absolute path to the surfsense_backend directory which contains the app module -project_root = str(Path(__file__).resolve().parents[3]) # Go up 3 levels from the script: app/agents/researcher -> app/agents -> app -> surfsense_backend -print(f"Adding to Python path: {project_root}") -sys.path.insert(0, project_root) - -# Now import the modules after fixing the path -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker -from dotenv import load_dotenv - -# These imports should now work with the correct path -from app.agents.researcher.graph import graph -from app.agents.researcher.state import State - - -# Load environment variables -load_dotenv() - -# Database connection string - use a test database or mock -DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense" - -# Create async engine and session -engine = create_async_engine(DATABASE_URL) -async_session_maker = sessionmaker( - engine, class_=AsyncSession, expire_on_commit=False -) - -async def run_test(): - """Run a test of the researcher agent.""" - print("Starting researcher agent test...") - - # Create a database session - async with async_session_maker() as db_session: - # Sample configuration - config = { - "configurable": { - "user_query": "What are the best clash royale decks recommended by Surgical Goblin?", - "num_sections": 1, - "connectors_to_search": [ - "YOUTUBE_VIDEO", - ], - "user_id": "d6ac2187-7407-4664-8734-af09926d161e", - "search_space_id": 2 - } - } - - try: - # Initialize state with database session and engine - initial_state = State(db_session=db_session, engine=engine) - - # Run the graph directly - print("\nRunning the complete researcher workflow...") - result = await graph.ainvoke(initial_state, config) - - # Extract the answer outline for display - if "answer_outline" in result and result["answer_outline"]: - print(f"\nGenerated answer outline with {len(result['answer_outline'].answer_outline)} sections") - - # Print the outline - print("\nGenerated Answer Outline:") - for section in result["answer_outline"].answer_outline: - print(f"\nSection {section.section_id}: {section.section_title}") - print("Research Questions:") - for q in section.questions: - print(f" - {q}") - - # Check if we got a final report - if "final_written_report" in result and result["final_written_report"]: - final_report = result["final_written_report"] - print("\nFinal Research Report generated successfully!") - print(f"Report length: {len(final_report)} characters") - - # Display the final report - print("\n==== FINAL RESEARCH REPORT ====\n") - print(final_report) - else: - print("\nNo final report was generated.") - print(f"Available result keys: {list(result.keys())}") - - return result - - except Exception as e: - print(f"Error running researcher agent: {str(e)}") - import traceback - traceback.print_exc() - raise - -async def main(): - """Main entry point for the test script.""" - try: - result = await run_test() - print("\nTest completed successfully.") - return result - except Exception as e: - print(f"\nTest failed with error: {e}") - import traceback - traceback.print_exc() - return None - -if __name__ == "__main__": - # Run the async test - result = asyncio.run(main()) - - # Keep terminal open if run directly in VSCode - if 'VSCODE_PID' in os.environ: - input("\nPress Enter to close this window...") \ No newline at end of file diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py index e10aa50..74ea97b 100644 --- a/surfsense_backend/app/routes/chats_routes.py +++ b/surfsense_backend/app/routes/chats_routes.py @@ -1,14 +1,15 @@ -from fastapi import APIRouter, Depends, HTTPException, Query -from fastapi.responses import StreamingResponse -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.exc import IntegrityError, OperationalError from typing import List -from app.db import get_async_session, User, SearchSpace, Chat -from app.schemas import ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest + +from app.db import Chat, SearchSpace, User, get_async_session +from app.schemas import AISDKChatRequest, ChatCreate, ChatRead, ChatUpdate from app.tasks.stream_connector_search_results import stream_connector_search_results from app.users import current_active_user from app.utils.check_ownership import check_ownership +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select router = APIRouter() diff --git a/surfsense_backend/app/tasks/stream_connector_search_results.py b/surfsense_backend/app/tasks/stream_connector_search_results.py index a1dc0a3..c7eb076 100644 --- a/surfsense_backend/app/tasks/stream_connector_search_results.py +++ b/surfsense_backend/app/tasks/stream_connector_search_results.py @@ -1,20 +1,15 @@ -import json -from sqlalchemy.ext.asyncio import AsyncSession -from typing import List, AsyncGenerator, Dict, Any -import asyncio -import re +from typing import AsyncGenerator, List, Union +from uuid import UUID -from app.utils.connector_service import ConnectorService -from app.utils.research_service import ResearchService +from app.agents.researcher.graph import graph as researcher_graph +from app.agents.researcher.state import State from app.utils.streaming_service import StreamingService -from app.utils.reranker_service import RerankerService -from app.utils.query_service import QueryService -from app.config import config -from app.utils.document_converters import convert_chunks_to_langchain_documents +from sqlalchemy.ext.asyncio import AsyncSession + async def stream_connector_search_results( user_query: str, - user_id: str, + user_id: Union[str, UUID], search_space_id: int, session: AsyncSession, research_mode: str, @@ -25,7 +20,7 @@ async def stream_connector_search_results( Args: user_query: The user's query - user_id: The user's ID + user_id: The user's ID (can be UUID object or string) search_space_id: The search space ID session: The database session research_mode: The research mode @@ -34,418 +29,45 @@ async def stream_connector_search_results( Yields: str: Formatted response strings """ - # Initialize services - connector_service = ConnectorService(session) streaming_service = StreamingService() - - # Reformulate the user query using the strategic LLM - yield streaming_service.add_terminal_message("Reformulating your query for better results...", "info") - reformulated_query = await QueryService.reformulate_query(user_query) - yield streaming_service.add_terminal_message(f"Searching for: {reformulated_query}", "success") - - reranker_service = RerankerService.get_reranker_instance(config) - - all_raw_documents = [] # Store all raw documents before reranking - all_sources = [] - TOP_K = 20 - if research_mode == "GENERAL": - TOP_K = 20 + NUM_SECTIONS = 1 elif research_mode == "DEEP": - TOP_K = 40 + NUM_SECTIONS = 3 elif research_mode == "DEEPER": - TOP_K = 60 + NUM_SECTIONS = 6 - - # Process each selected connector - for connector in selected_connectors: - if connector == "YOUTUBE_VIDEO": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for youtube videos...") - - # Search for YouTube videos using reformulated query - result_object, youtube_chunks = await connector_service.search_youtube( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant YouTube videos", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(youtube_chunks) - - - # Extension Docs - if connector == "EXTENSION": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for extension...") - - # Search for crawled URLs using reformulated query - result_object, extension_chunks = await connector_service.search_extension( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant extension documents", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(extension_chunks) - - - # Crawled URLs - if connector == "CRAWLED_URL": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for crawled URLs...") - - # Search for crawled URLs using reformulated query - result_object, crawled_urls_chunks = await connector_service.search_crawled_urls( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant crawled URLs", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(crawled_urls_chunks) - - - # Files - if connector == "FILE": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for files...") - - # Search for files using reformulated query - result_object, files_chunks = await connector_service.search_files( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant files", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(files_chunks) - - # Tavily Connector - if connector == "TAVILY_API": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search with Tavily API...") - - # Search using Tavily API with reformulated query - result_object, tavily_chunks = await connector_service.search_tavily( - user_query=reformulated_query, - user_id=user_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant results from Tavily", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(tavily_chunks) - - # Slack Connector - if connector == "SLACK_CONNECTOR": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for slack connector...") - - # Search using Slack API with reformulated query - result_object, slack_chunks = await connector_service.search_slack( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant results from Slack", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(slack_chunks) - - - # Notion Connector - if connector == "NOTION_CONNECTOR": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for notion connector...") - - # Search using Notion API with reformulated query - result_object, notion_chunks = await connector_service.search_notion( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant results from Notion", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(notion_chunks) - - - # Github Connector - if connector == "GITHUB_CONNECTOR": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for GitHub connector...") - print("Starting to search for GitHub connector...") - # Search using Github API with reformulated query - result_object, github_chunks = await connector_service.search_github( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant results from Github", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(github_chunks) - - # Linear Connector - if connector == "LINEAR_CONNECTOR": - # Send terminal message about starting search - yield streaming_service.add_terminal_message("Starting to search for Linear issues...") - - # Search using Linear API with reformulated query - result_object, linear_chunks = await connector_service.search_linear( - user_query=reformulated_query, - user_id=user_id, - search_space_id=search_space_id, - top_k=TOP_K - ) - - # Send terminal message about search results - yield streaming_service.add_terminal_message( - f"Found {len(result_object['sources'])} relevant results from Linear", - "success" - ) - - # Update sources - all_sources.append(result_object) - yield streaming_service.update_sources(all_sources) - - # Add documents to collection - all_raw_documents.extend(linear_chunks) - - + # Convert UUID to string if needed + user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id - - # If we have documents to research - if all_raw_documents: - # Rerank all documents if reranker is available - if reranker_service: - yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info") - - # Convert documents to format expected by reranker - reranker_input_docs = [ - { - "chunk_id": doc.get("chunk_id", f"chunk_{i}"), - "content": doc.get("content", ""), - "score": doc.get("score", 0.0), - "document": { - "id": doc.get("document", {}).get("id", ""), - "title": doc.get("document", {}).get("title", ""), - "document_type": doc.get("document", {}).get("document_type", ""), - "metadata": doc.get("document", {}).get("metadata", {}) - } - } for i, doc in enumerate(all_raw_documents) - ] - - # Rerank documents using the reformulated query - reranked_docs = reranker_service.rerank_documents(reformulated_query, reranker_input_docs) - - # Sort by score in descending order - reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) - - - - # Convert back to langchain documents format - from langchain.schema import Document as LangchainDocument - all_langchain_documents_to_research = [ - LangchainDocument( - page_content= f"""{doc.get("document", {}).get("id", "")}{doc.get("content", "")}""", - metadata={ - # **doc.get("document", {}).get("metadata", {}), - # "score": doc.get("score", 0.0), - # "rank": doc.get("rank", 0), - # "document_id": doc.get("document", {}).get("id", ""), - # "document_title": doc.get("document", {}).get("title", ""), - # "document_type": doc.get("document", {}).get("document_type", ""), - # # Explicitly set source_id for citation purposes - "source_id": str(doc.get("document", {}).get("id", "")) - } - ) for doc in reranked_docs - ] - - yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success") - else: - # Use raw documents if no reranker is available - all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents) - - # Send terminal message about starting research - yield streaming_service.add_terminal_message("Starting to research...", "info") - - # Create a buffer to collect report content - report_buffer = [] - - - # Use the streaming research method - yield streaming_service.add_terminal_message("Generating report...", "info") - - # Create a wrapper to handle the streaming - class StreamHandler: - def __init__(self): - self.queue = asyncio.Queue() - - async def handle_progress(self, data): - result = None - if data.get("type") == "logs": - # Handle log messages - result = streaming_service.add_terminal_message(data.get("output", ""), "info") - elif data.get("type") == "report": - # Handle report content - content = data.get("output", "") - - # Fix incorrect citation formats using regex - - # More specific pattern to match only numeric citations in markdown-style links - # This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...)) - pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)' - - # Replace with just [X] where X is the number - content = re.sub(pattern, r'[\1]', content) - - # Also match other incorrect formats like ([1]) and convert to [1] - # Only match if the content inside brackets is a number - content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content) - - report_buffer.append(content) - # Update the answer with the accumulated content - result = streaming_service.update_answer(report_buffer) - - if result: - await self.queue.put(result) - return result - - async def get_next(self): - try: - return await self.queue.get() - except Exception as e: - print(f"Error getting next item from queue: {e}") - return None - - def task_done(self): - self.queue.task_done() - - # Create the stream handler - stream_handler = StreamHandler() - - # Start the research process in a separate task - research_task = asyncio.create_task( - ResearchService.stream_research( - user_query=reformulated_query, - documents=all_langchain_documents_to_research, - on_progress=stream_handler.handle_progress, - research_mode=research_mode - ) - ) - - # Stream results as they become available - while not research_task.done() or not stream_handler.queue.empty(): - try: - # Get the next result with a timeout - result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1) - stream_handler.task_done() - yield result - except asyncio.TimeoutError: - # No result available yet, check if the research task is done - if research_task.done(): - # If the queue is empty and the task is done, we're finished - if stream_handler.queue.empty(): - break - - # Get the final report - try: - final_report = await research_task - - # Send terminal message about research completion - yield streaming_service.add_terminal_message("Research completed", "success") - - # Update the answer with the final report - final_report_lines = final_report.split('\n') - yield streaming_service.update_answer(final_report_lines) - except Exception as e: - # Handle any exceptions - yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error") - - # Send completion message - yield streaming_service.format_completion() + # Sample configuration + config = { + "configurable": { + "user_query": user_query, + "num_sections": NUM_SECTIONS, + "connectors_to_search": selected_connectors, + "user_id": user_id_str, + "search_space_id": search_space_id + } + } + # Initialize state with database session and streaming service + initial_state = State( + db_session=session, + streaming_service=streaming_service + ) + + # Run the graph directly + print("\nRunning the complete researcher workflow...") + + # Use streaming with config parameter + async for chunk in researcher_graph.astream( + initial_state, + config=config, + stream_mode="custom", + ): + # If the chunk contains a 'yeild_value' key, print its value + # Note: there's a typo in 'yeild_value' in the code, but we need to match it + if isinstance(chunk, dict) and 'yeild_value' in chunk: + yield chunk['yeild_value'] + + yield streaming_service.format_completion() \ No newline at end of file diff --git a/surfsense_backend/app/utils/query_service.py b/surfsense_backend/app/utils/query_service.py index b5df744..760f0c8 100644 --- a/surfsense_backend/app/utils/query_service.py +++ b/surfsense_backend/app/utils/query_service.py @@ -1,5 +1,7 @@ -from typing import Dict, Any -from langchain.schema import LLMResult, HumanMessage, SystemMessage +""" +NOTE: This is not used anymore. Might be removed in the future. +""" +from langchain.schema import HumanMessage, SystemMessage from app.config import config class QueryService: diff --git a/surfsense_backend/app/utils/research_service.py b/surfsense_backend/app/utils/research_service.py deleted file mode 100644 index c0034fd..0000000 --- a/surfsense_backend/app/utils/research_service.py +++ /dev/null @@ -1,211 +0,0 @@ -import asyncio -import re -from typing import List, Dict, Any, AsyncGenerator, Callable, Optional -from langchain.schema import Document -from gpt_researcher.agent import GPTResearcher -from gpt_researcher.utils.enum import ReportType, Tone, ReportSource -from dotenv import load_dotenv - -load_dotenv() - -class ResearchService: - @staticmethod - async def create_custom_prompt(user_query: str) -> str: - citation_prompt = f""" - You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format. - - - 1. Carefully analyze all provided documents in the section's. - 2. Extract relevant information that addresses the user's query. - 3. Synthesize a comprehensive, well-structured answer using information from these documents. - 4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata. - 5. Make sure ALL factual statements from the documents have proper citations. - 6. If multiple documents support the same point, include all relevant citations [X], [Y]. - 7. Present information in a logical, coherent flow. - 8. Use your own words to connect ideas, but cite ALL information from the documents. - 9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations. - 10. Do not make up or include information not found in the provided documents. - 11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers. - 12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value. - 13. CRITICAL: Never renumber or reorder citations - always use the original source_id values. - 14. CRITICAL: Do not return citations as clickable links. - 15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only. - 16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting. - 17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata. - 18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up. - - - - - Write in clear, professional language suitable for academic or technical audiences - - Organize your response with appropriate paragraphs, headings, and structure - - Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata - - Citations should appear at the end of the sentence containing the information they support - - Multiple citations should be separated by commas: [X], [Y], [Z] - - No need to return references section. Just citation numbers in answer. - - NEVER create your own citation numbering system - use the exact source_id values from the documents. - - NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only. - - NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess. - - - - - - 1 - - - - The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands. - - - - - - - 13 - - - - Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020. - - - - - - - 21 - - - - The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral. - - - - - - - The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13]. - - - - DO NOT use any of these incorrect citation formats: - - Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense)) - - Using parentheses around brackets: ([1]) - - Using hyperlinked text: [link to source 1](https://example.com) - - Using footnote style: ... reef system¹ - - Making up citation numbers when source_id is unknown - - ONLY use plain square brackets [1] or multiple citations [1], [2], [3] - - - Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences. - - Now, please research the following query: - - - {user_query} - - """ - - return citation_prompt - - - @staticmethod - async def stream_research( - user_query: str, - documents: List[Document] = None, - on_progress: Optional[Callable] = None, - research_mode: str = "GENERAL" - ) -> str: - """ - Stream the research process using GPTResearcher - - Args: - user_query: The user's query - documents: List of Document objects to use for research - on_progress: Optional callback for progress updates - research_mode: Research mode to use - - Returns: - str: The final research report - """ - # Create a custom websocket-like object to capture streaming output - class StreamingWebsocket: - async def send_json(self, data): - if on_progress: - try: - # Filter out excessive logging of the prompt - if data.get("type") == "logs": - output = data.get("output", "") - # Check if this is a verbose prompt log - if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500: - # Replace with a shorter message - data["output"] = f"Processing research for query: {user_query}" - - result = await on_progress(data) - return result - except Exception as e: - print(f"Error in on_progress callback: {e}") - return None - - streaming_websocket = StreamingWebsocket() - - custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query) - - if(research_mode == "GENERAL"): - research_report_type = ReportType.CustomReport.value - elif(research_mode == "DEEP"): - research_report_type = ReportType.ResearchReport.value - elif(research_mode == "DEEPER"): - research_report_type = ReportType.DetailedReport.value - # elif(research_mode == "DEEPEST"): - # research_report_type = ReportType.DeepResearch.value - - # Initialize GPTResearcher with the streaming websocket - researcher = GPTResearcher( - query=custom_prompt_for_ieee_citations, - report_type=research_report_type, - report_format="IEEE", - report_source=ReportSource.LangChainDocuments.value, - tone=Tone.Formal, - documents=documents, - verbose=True, - websocket=streaming_websocket - ) - - # Conduct research - await researcher.conduct_research() - - # Generate report with streaming - report = await researcher.write_report() - - # Fix citation format - report = ResearchService.fix_citation_format(report) - - return report - - @staticmethod - def fix_citation_format(text: str) -> str: - """ - Fix any incorrectly formatted citations in the text. - - Args: - text: The text to fix - - Returns: - str: The text with fixed citations - """ - if not text: - return text - - # More specific pattern to match only numeric citations in markdown-style links - # This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...)) - pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)' - - # Replace with just [X] where X is the number - text = re.sub(pattern, r'[\1]', text) - - # Also match other incorrect formats like ([1]) and convert to [1] - # Only match if the content inside brackets is a number - text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text) - - return text \ No newline at end of file diff --git a/surfsense_backend/app/utils/streaming_service.py b/surfsense_backend/app/utils/streaming_service.py index 4f2b7d9..08a47a9 100644 --- a/surfsense_backend/app/utils/streaming_service.py +++ b/surfsense_backend/app/utils/streaming_service.py @@ -1,5 +1,6 @@ import json -from typing import List, Dict, Any, Generator +from typing import Any, Dict, List + class StreamingService: def __init__(self): @@ -18,55 +19,7 @@ class StreamingService: "content": [] } ] - - def add_terminal_message(self, text: str, message_type: str = "info") -> str: - """ - Add a terminal message to the annotations and return the formatted response - - Args: - text: The message text - message_type: The message type (info, success, error) - - Returns: - str: The formatted response string - """ - self.message_annotations[0]["content"].append({ - "id": self.terminal_idx, - "text": text, - "type": message_type - }) - self.terminal_idx += 1 - return self._format_annotations() - - def update_sources(self, sources: List[Dict[str, Any]]) -> str: - """ - Update the sources in the annotations and return the formatted response - - Args: - sources: List of source objects - - Returns: - str: The formatted response string - """ - self.message_annotations[1]["content"] = sources - return self._format_annotations() - - def update_answer(self, answer_content: List[str]) -> str: - """ - Update the answer in the annotations and return the formatted response - - Args: - answer_content: The answer content as a list of strings - - Returns: - str: The formatted response string - """ - self.message_annotations[2] = { - "type": "ANSWER", - "content": answer_content - } - return self._format_annotations() - + # It is used to send annotations to the frontend def _format_annotations(self) -> str: """ Format the annotations as a string @@ -76,6 +29,7 @@ class StreamingService: """ return f'8:{json.dumps(self.message_annotations)}\n' + # It is used to end Streaming def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str: """ Format a completion message @@ -96,4 +50,23 @@ class StreamingService: "totalTokens": total_tokens } } - return f'd:{json.dumps(completion_data)}\n' \ No newline at end of file + return f'd:{json.dumps(completion_data)}\n' + + def only_update_terminal(self, text: str, message_type: str = "info") -> str: + self.message_annotations[0]["content"].append({ + "id": self.terminal_idx, + "text": text, + "type": message_type + }) + self.terminal_idx += 1 + return self.message_annotations + + def only_update_sources(self, sources: List[Dict[str, Any]]) -> str: + self.message_annotations[1]["content"] = sources + return self.message_annotations + + def only_update_answer(self, answer: List[str]) -> str: + self.message_annotations[2]["content"] = answer + return self.message_annotations + + \ No newline at end of file diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index 0d8682e..95b111e 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -11,7 +11,6 @@ dependencies = [ "fastapi>=0.115.8", "fastapi-users[oauth,sqlalchemy]>=14.0.1", "firecrawl-py>=1.12.0", - "gpt-researcher>=0.12.12", "github3.py==4.0.1", "langchain-community>=0.3.17", "langchain-unstructured>=0.1.6", diff --git a/surfsense_web/components/markdown-viewer.tsx b/surfsense_web/components/markdown-viewer.tsx index 4398d10..03fe7ba 100644 --- a/surfsense_web/components/markdown-viewer.tsx +++ b/surfsense_web/components/markdown-viewer.tsx @@ -117,13 +117,21 @@ function processCitationsInReactChildren(children: React.ReactNode, getCitationS // Process citation references in text content function processCitationsInText(text: string, getCitationSource: (id: number) => Source | null): React.ReactNode[] { + // Use improved regex to catch citation numbers more reliably + // This will match patterns like [1], [42], etc. including when they appear at the end of a line or sentence const citationRegex = /\[(\d+)\]/g; const parts: React.ReactNode[] = []; let lastIndex = 0; let match; let position = 0; + // Debug log for troubleshooting + console.log("Processing citations in text:", text); + while ((match = citationRegex.exec(text)) !== null) { + // Log each match for debugging + console.log("Citation match found:", match[0], "at index", match.index); + // Add text before the citation if (match.index > lastIndex) { parts.push(text.substring(lastIndex, match.index)); @@ -131,13 +139,18 @@ function processCitationsInText(text: string, getCitationSource: (id: number) => // Add the citation component const citationId = parseInt(match[1], 10); + const source = getCitationSource(citationId); + + // Log the citation details + console.log("Citation ID:", citationId, "Source:", source ? "found" : "not found"); + parts.push( );