diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index d8c5ac1..cbcd44f 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -266,27 +266,36 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str from app.db import get_async_session streaming_service = state.streaming_service - - streaming_service.only_update_terminal("๐Ÿ” Generating answer outline...") - writer({"yeild_value": streaming_service._format_annotations()}) + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿ” Generating answer outline..." + ) + } + ) # Get configuration from runnable config configuration = Configuration.from_runnable_config(config) reformulated_query = state.reformulated_query user_query = configuration.user_query num_sections = configuration.num_sections user_id = configuration.user_id - - streaming_service.only_update_terminal(f"๐Ÿค” Planning research approach for: \"{user_query[:100]}...\"") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f'๐Ÿค” Planning research approach for: "{user_query[:100]}..."' + ) + } + ) + # Get user's strategic LLM llm = await get_user_strategic_llm(state.db_session, user_id) if not llm: error_message = f"No strategic LLM configured for user {user_id}" - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_error(error_message)}) raise RuntimeError(error_message) - + # Create the human message content human_message_content = f""" Now Please create an answer outline for the following query: @@ -310,10 +319,15 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation. """ - - streaming_service.only_update_terminal("๐Ÿ“ Designing structured outline with AI...") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿ“ Designing structured outline with AI..." + ) + } + ) + # Create messages for the LLM messages = [ SystemMessage(content=get_answer_outline_system_prompt()), @@ -321,9 +335,14 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str ] # Call the LLM directly without using structured output - streaming_service.only_update_terminal("โš™๏ธ Processing answer structure...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "โš™๏ธ Processing answer structure..." + ) + } + ) + response = await llm.ainvoke(messages) # Parse the JSON response manually @@ -344,26 +363,34 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str answer_outline = AnswerOutline(**parsed_data) total_questions = sum(len(section.questions) for section in answer_outline.answer_outline) - streaming_service.only_update_terminal(f"โœ… Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions!") - writer({"yeild_value": streaming_service._format_annotations()}) - - print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections") + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โœ… Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions!" + ) + } + ) + + print( + f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections" + ) + # Return state update return {"answer_outline": answer_outline} else: # If JSON structure not found, raise a clear error - error_message = f"Could not find valid JSON in LLM response. Raw response: {content}" - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) + error_message = ( + f"Could not find valid JSON in LLM response. Raw response: {content}" + ) + writer({"yield_value": streaming_service.format_error(error_message)}) raise ValueError(error_message) - + except (json.JSONDecodeError, ValueError) as e: # Log the error and re-raise it error_message = f"Error parsing LLM response: {str(e)}" - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer({"yield_value": streaming_service.format_error(error_message)}) + print(f"Error parsing LLM response: {str(e)}") print(f"Raw response: {response.content}") raise @@ -414,8 +441,13 @@ async def fetch_relevant_documents( if streaming_service and writer: connector_names = [get_connector_friendly_name(connector) for connector in connectors_to_search] connector_names_str = ", ".join(connector_names) - streaming_service.only_update_terminal(f"๐Ÿ”Ž Starting research on {len(research_questions)} questions using {connector_names_str} data sources") - writer({"yeild_value": streaming_service._format_annotations()}) + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ”Ž Starting research on {len(research_questions)} questions using {connector_names_str} data sources" + ) + } + ) all_raw_documents = [] # Store all raw documents all_sources = [] # Store all sources @@ -423,9 +455,14 @@ async def fetch_relevant_documents( for i, user_query in enumerate(research_questions): # Stream question being researched if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿง  Researching question {i+1}/{len(research_questions)}: \"{user_query[:100]}...\"") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f'๐Ÿง  Researching question {i + 1}/{len(research_questions)}: "{user_query[:100]}..."' + ) + } + ) + # Use original research question as the query reformulated_query = user_query @@ -435,9 +472,14 @@ async def fetch_relevant_documents( if streaming_service and writer: connector_emoji = get_connector_emoji(connector) friendly_name = get_connector_friendly_name(connector) - streaming_service.only_update_terminal(f"{connector_emoji} Searching {friendly_name} for relevant information...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"{connector_emoji} Searching {friendly_name} for relevant information..." + ) + } + ) + try: if connector == "YOUTUBE_VIDEO": source_object, youtube_chunks = await connector_service.search_youtube( @@ -455,9 +497,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ“น Found {len(youtube_chunks)} YouTube chunks related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“น Found {len(youtube_chunks)} YouTube chunks related to your query" + ) + } + ) + elif connector == "EXTENSION": source_object, extension_chunks = await connector_service.search_extension( user_query=reformulated_query, @@ -474,9 +521,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿงฉ Found {len(extension_chunks)} Browser Extension chunks related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿงฉ Found {len(extension_chunks)} Browser Extension chunks related to your query" + ) + } + ) + elif connector == "CRAWLED_URL": source_object, crawled_urls_chunks = await connector_service.search_crawled_urls( user_query=reformulated_query, @@ -493,9 +545,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐ŸŒ Found {len(crawled_urls_chunks)} Web Pages chunks related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐ŸŒ Found {len(crawled_urls_chunks)} Web Pages chunks related to your query" + ) + } + ) + elif connector == "FILE": source_object, files_chunks = await connector_service.search_files( user_query=reformulated_query, @@ -512,10 +569,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ“„ Found {len(files_chunks)} Files chunks related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“„ Found {len(files_chunks)} Files chunks related to your query" + ) + } + ) + elif connector == "SLACK_CONNECTOR": source_object, slack_chunks = await connector_service.search_slack( user_query=reformulated_query, @@ -532,9 +593,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ’ฌ Found {len(slack_chunks)} Slack messages related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ’ฌ Found {len(slack_chunks)} Slack messages related to your query" + ) + } + ) + elif connector == "NOTION_CONNECTOR": source_object, notion_chunks = await connector_service.search_notion( user_query=reformulated_query, @@ -551,9 +617,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ“˜ Found {len(notion_chunks)} Notion pages/blocks related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“˜ Found {len(notion_chunks)} Notion pages/blocks related to your query" + ) + } + ) + elif connector == "GITHUB_CONNECTOR": source_object, github_chunks = await connector_service.search_github( user_query=reformulated_query, @@ -570,9 +641,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ™ Found {len(github_chunks)} GitHub files/issues related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ™ Found {len(github_chunks)} GitHub files/issues related to your query" + ) + } + ) + elif connector == "LINEAR_CONNECTOR": source_object, linear_chunks = await connector_service.search_linear( user_query=reformulated_query, @@ -589,9 +665,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ“Š Found {len(linear_chunks)} Linear issues related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“Š Found {len(linear_chunks)} Linear issues related to your query" + ) + } + ) + elif connector == "TAVILY_API": source_object, tavily_chunks = await connector_service.search_tavily( user_query=reformulated_query, @@ -606,9 +687,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ” Found {len(tavily_chunks)} Web Search results related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ” Found {len(tavily_chunks)} Web Search results related to your query" + ) + } + ) + elif connector == "LINKUP_API": if top_k > 10: linkup_mode = "deep" @@ -628,9 +714,14 @@ async def fetch_relevant_documents( # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ”— Found {len(linkup_chunks)} Linkup results related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ”— Found {len(linkup_chunks)} Linkup results related to your query" + ) + } + ) + elif connector == "DISCORD_CONNECTOR": source_object, discord_chunks = await connector_service.search_discord( user_query=reformulated_query, @@ -645,9 +736,13 @@ async def fetch_relevant_documents( all_raw_documents.extend(discord_chunks) # Stream found document count if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿ—จ๏ธ Found {len(discord_chunks)} Discord messages related to your query") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ—จ๏ธ Found {len(discord_chunks)} Discord messages related to your query" + ) + } + ) except Exception as e: error_message = f"Error searching connector {connector}: {str(e)}" @@ -656,9 +751,14 @@ async def fetch_relevant_documents( # Stream error message if streaming_service and writer: friendly_name = get_connector_friendly_name(connector) - streaming_service.only_update_terminal(f"โš ๏ธ Error searching {friendly_name}: {str(e)}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_error( + f"Error searching {friendly_name}: {str(e)}" + ) + } + ) + # Continue with other connectors on error continue @@ -700,14 +800,19 @@ async def fetch_relevant_documents( if streaming_service and writer: user_source_count = len(user_selected_sources) if user_selected_sources else 0 connector_source_count = len(deduplicated_sources) - user_source_count - streaming_service.only_update_terminal(f"๐Ÿ“š Collected {len(deduplicated_sources)} total sources ({user_source_count} user-selected + {connector_source_count} from connectors)") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“š Collected {len(deduplicated_sources)} total sources ({user_source_count} user-selected + {connector_source_count} from connectors)" + ) + } + ) + # After all sources are collected and deduplicated, stream them if streaming_service and writer: streaming_service.only_update_sources(deduplicated_sources) - writer({"yeild_value": streaming_service._format_annotations()}) - + writer({"yield_value": streaming_service._format_annotations()}) + # Deduplicate raw documents based on chunk_id or content seen_chunk_ids = set() seen_content_hashes = set() @@ -730,9 +835,14 @@ async def fetch_relevant_documents( # Stream info about deduplicated documents if streaming_service and writer: - streaming_service.only_update_terminal(f"๐Ÿงน Found {len(deduplicated_docs)} unique document chunks after removing duplicates") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿงน Found {len(deduplicated_docs)} unique document chunks after removing duplicates" + ) + } + ) + # Return deduplicated documents return deduplicated_docs @@ -756,15 +866,20 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW # Initialize a dictionary to track content for all sections # This is used to maintain section content while streaming multiple sections section_contents = {} - - streaming_service.only_update_terminal(f"๐Ÿš€ Starting to process research sections...") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿš€ Starting to process research sections..." + ) + } + ) + print(f"Processing sections from outline: {answer_outline is not None}") if not answer_outline: - streaming_service.only_update_terminal("โŒ Error: No answer outline was provided. Cannot generate report.", "error") - writer({"yeild_value": streaming_service._format_annotations()}) + error_message = "No answer outline was provided. Cannot generate report." + writer({"yield_value": streaming_service.format_error(error_message)}) return { "final_written_report": "No answer outline was provided. Cannot generate final report." } @@ -775,13 +890,23 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW all_questions.extend(section.questions) print(f"Collected {len(all_questions)} questions from all sections") - streaming_service.only_update_terminal(f"๐Ÿงฉ Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿงฉ Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections" + ) + } + ) + # Fetch relevant documents once for all questions - streaming_service.only_update_terminal("๐Ÿ” Searching for relevant information across all connectors...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿ” Searching for relevant information across all connectors..." + ) + } + ) + if configuration.num_sections == 1: TOP_K = 10 elif configuration.num_sections == 3: @@ -798,9 +923,14 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW try: # First, fetch user-selected documents if any if configuration.document_ids_to_add_in_context: - streaming_service.only_update_terminal(f"๐Ÿ“‹ Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“‹ Including {len(configuration.document_ids_to_add_in_context)} user-selected documents..." + ) + } + ) + user_selected_sources, user_selected_documents = await fetch_documents_by_ids( document_ids=configuration.document_ids_to_add_in_context, user_id=configuration.user_id, @@ -808,9 +938,14 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW ) if user_selected_documents: - streaming_service.only_update_terminal(f"โœ… Successfully added {len(user_selected_documents)} user-selected documents to context") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โœ… Successfully added {len(user_selected_documents)} user-selected documents to context" + ) + } + ) + # Create connector service using state db_session connector_service = ConnectorService(state.db_session, user_id=configuration.user_id) await connector_service.initialize_counter() @@ -831,8 +966,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW except Exception as e: error_message = f"Error fetching relevant documents: {str(e)}" print(error_message) - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_error(error_message)}) # Log the error and continue with an empty list of documents # This allows the process to continue, but the report might lack information relevant_documents = [] @@ -843,15 +977,25 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW print(f"Fetched {len(relevant_documents)} relevant documents for all sections") print(f"Added {len(user_selected_documents)} user-selected documents for all sections") print(f"Total documents for sections: {len(all_documents)}") - - streaming_service.only_update_terminal(f"โœจ Starting to draft {len(answer_outline.answer_outline)} sections using {len(all_documents)} total document chunks ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โœจ Starting to draft {len(answer_outline.answer_outline)} sections using {len(all_documents)} total document chunks ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)" + ) + } + ) + # Create tasks to process each section in parallel with the same document set section_tasks = [] - streaming_service.only_update_terminal("โš™๏ธ Creating processing tasks for each section...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "โš™๏ธ Creating processing tasks for each section..." + ) + } + ) + for i, section in enumerate(answer_outline.answer_outline): if i == 0: sub_section_type = SubSectionType.START @@ -885,23 +1029,32 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW # Run all section processing tasks in parallel print(f"Running {len(section_tasks)} section processing tasks in parallel") - streaming_service.only_update_terminal(f"โณ Processing {len(section_tasks)} sections simultaneously...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โณ Processing {len(section_tasks)} sections simultaneously..." + ) + } + ) + section_results = await asyncio.gather(*section_tasks, return_exceptions=True) # Handle any exceptions in the results - streaming_service.only_update_terminal("๐Ÿงต Combining section results into final report...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿงต Combining section results into final report..." + ) + } + ) + processed_results = [] for i, result in enumerate(section_results): if isinstance(result, Exception): section_title = answer_outline.answer_outline[i].section_title error_message = f"Error processing section '{section_title}': {str(result)}" print(error_message) - streaming_service.only_update_terminal(f"โš ๏ธ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_error(error_message)}) processed_results.append(error_message) else: processed_results.append(result) @@ -912,18 +1065,27 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW # Skip adding the section header since the content already contains the title final_report.append(content) final_report.append("\n") + + # Stream each section with its title + writer( + { + "yield_value": state.streaming_service.format_text_chunk(f"# {section.section_title}\n\n{content}") + } + ) # Join all sections with newlines final_written_report = "\n".join(final_report) print(f"Generated final report with {len(final_report)} parts") - - streaming_service.only_update_terminal("๐ŸŽ‰ Final research report generated successfully!") - writer({"yeild_value": streaming_service._format_annotations()}) - - # Skip the final update since we've been streaming incremental updates - # The final answer from each section is already shown in the UI - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐ŸŽ‰ Final research report generated successfully!" + ) + } + ) + # Use the shared documents for further question generation # Since all sections used the same document pool, we can use it directly return { @@ -969,16 +1131,26 @@ async def process_section_with_documents( # Send status update via streaming if available if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"๐Ÿ“ Writing section: \"{section_title}\" with {len(section_questions)} research questions") - writer({"yeild_value": state.streaming_service._format_annotations()}) - + writer( + { + "yield_value": state.streaming_service.format_terminal_info_delta( + f'๐Ÿ“ Writing section: "{section_title}" with {len(section_questions)} research questions' + ) + } + ) + # Fallback if no documents found if not documents_to_use: print(f"No relevant documents found for section: {section_title}") if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"โš ๏ธ Warning: No relevant documents found for section: \"{section_title}\"", "warning") - writer({"yeild_value": state.streaming_service._format_annotations()}) - + writer( + { + "yield_value": state.streaming_service.format_error( + f'Warning: No relevant documents found for section: "{section_title}"' + ) + } + ) + documents_to_use = [ {"content": f"No specific information was found for: {question}"} for question in section_questions @@ -993,7 +1165,7 @@ async def process_section_with_documents( "user_query": user_query, "relevant_documents": documents_to_use, "user_id": user_id, - "search_space_id": search_space_id + "search_space_id": search_space_id, } } @@ -1006,9 +1178,14 @@ async def process_section_with_documents( # Invoke the sub-section writer graph with streaming print(f"Invoking sub_section_writer for: {section_title}") if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"๐Ÿง  Analyzing information and drafting content for section: \"{section_title}\"") - writer({"yeild_value": state.streaming_service._format_annotations()}) - + writer( + { + "yield_value": state.streaming_service.format_terminal_info_delta( + f'๐Ÿง  Analyzing information and drafting content for section: "{section_title}"' + ) + } + ) + # Variables to track streaming state complete_content = "" # Tracks the complete content received so far @@ -1025,8 +1202,14 @@ async def process_section_with_documents( # Only stream if there's actual new content if delta and state and state.streaming_service and writer: # Update terminal with real-time progress indicator - state.streaming_service.only_update_terminal(f"โœ๏ธ Writing section {section_id+1}... ({len(complete_content.split())} words)") - + writer( + { + "yield_value": state.streaming_service.format_terminal_info_delta( + f"โœ๏ธ Writing section {section_id + 1}... ({len(complete_content.split())} words)" + ) + } + ) + # Update section_contents with just the new delta section_contents[section_id]["content"] += delta @@ -1043,10 +1226,7 @@ async def process_section_with_documents( complete_answer.extend(content_lines) complete_answer.append("") # Empty line after content - # Update answer in UI in real-time - state.streaming_service.only_update_answer(complete_answer) - writer({"yeild_value": state.streaming_service._format_annotations()}) - + # Set default if no content was received if not complete_content: complete_content = "No content was generated for this section." @@ -1054,18 +1234,28 @@ async def process_section_with_documents( # Final terminal update if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"โœ… Completed section: \"{section_title}\"") - writer({"yeild_value": state.streaming_service._format_annotations()}) - + writer( + { + "yield_value": state.streaming_service.format_terminal_info_delta( + f'โœ… Completed section: "{section_title}"' + ) + } + ) + return complete_content except Exception as e: print(f"Error processing section '{section_title}': {str(e)}") # Send error update via streaming if available if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"โŒ Error processing section \"{section_title}\": {str(e)}", "error") - writer({"yeild_value": state.streaming_service._format_annotations()}) - + writer( + { + "yield_value": state.streaming_service.format_error( + f'Error processing section "{section_title}": {str(e)}' + ) + } + ) + return f"Error processing section: {section_title}. Details: {str(e)}" @@ -1102,17 +1292,32 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre reformulated_query = state.reformulated_query user_query = configuration.user_query - - streaming_service.only_update_terminal("๐Ÿค” Starting Q&A research workflow...") - writer({"yeild_value": streaming_service._format_annotations()}) - - streaming_service.only_update_terminal(f"๐Ÿ” Researching: \"{user_query[:100]}...\"") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿค” Starting Q&A research workflow..." + ) + } + ) + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f'๐Ÿ” Researching: "{user_query[:100]}..."' + ) + } + ) + # Fetch relevant documents for the QNA query - streaming_service.only_update_terminal("๐Ÿ” Searching for relevant information across all connectors...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿ” Searching for relevant information across all connectors..." + ) + } + ) + # Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM TOP_K = 15 @@ -1123,9 +1328,14 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre try: # First, fetch user-selected documents if any if configuration.document_ids_to_add_in_context: - streaming_service.only_update_terminal(f"๐Ÿ“‹ Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿ“‹ Including {len(configuration.document_ids_to_add_in_context)} user-selected documents..." + ) + } + ) + user_selected_sources, user_selected_documents = await fetch_documents_by_ids( document_ids=configuration.document_ids_to_add_in_context, user_id=configuration.user_id, @@ -1133,9 +1343,14 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre ) if user_selected_documents: - streaming_service.only_update_terminal(f"โœ… Successfully added {len(user_selected_documents)} user-selected documents to context") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โœ… Successfully added {len(user_selected_documents)} user-selected documents to context" + ) + } + ) + # Create connector service using state db_session connector_service = ConnectorService(state.db_session, user_id=configuration.user_id) await connector_service.initialize_counter() @@ -1159,8 +1374,7 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre except Exception as e: error_message = f"Error fetching relevant documents for QNA: {str(e)}" print(error_message) - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_error(error_message)}) # Continue with empty documents - the QNA agent will handle this gracefully relevant_documents = [] @@ -1170,10 +1384,15 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre print(f"Fetched {len(relevant_documents)} relevant documents for QNA") print(f"Added {len(user_selected_documents)} user-selected documents for QNA") print(f"Total documents for QNA: {len(all_documents)}") - - streaming_service.only_update_terminal(f"๐Ÿง  Generating comprehensive answer using {len(all_documents)} total sources ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)...") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"๐Ÿง  Generating comprehensive answer using {len(all_documents)} total sources ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)..." + ) + } + ) + # Prepare configuration for the QNA agent qna_config = { "configurable": { @@ -1192,9 +1411,14 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre } try: - streaming_service.only_update_terminal("โœ๏ธ Writing comprehensive answer with citations...") - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "โœ๏ธ Writing comprehensive answer with citations..." + ) + } + ) + # Track streaming content for real-time updates complete_content = "" captured_reranked_documents = [] @@ -1212,13 +1436,18 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre if delta: # Update terminal with progress word_count = len(complete_content.split()) - streaming_service.only_update_terminal(f"โœ๏ธ Writing answer... ({word_count} words)") - - # Update the answer in real-time - answer_lines = complete_content.split("\n") - streaming_service.only_update_answer(answer_lines) - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โœ๏ธ Writing answer... ({word_count} words)" + ) + } + ) + + writer( + {"yield_value": streaming_service.format_text_chunk(delta)} + ) + # Capture reranked documents from QNA agent for further question generation if "reranked_documents" in chunk: captured_reranked_documents = chunk["reranked_documents"] @@ -1226,10 +1455,15 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre # Set default if no content was received if not complete_content: complete_content = "I couldn't find relevant information in your knowledge base to answer this question." - - streaming_service.only_update_terminal("๐ŸŽ‰ Q&A answer generated successfully!") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐ŸŽ‰ Q&A answer generated successfully!" + ) + } + ) + # Return the final answer and captured reranked documents for further question generation return { "final_written_report": complete_content, @@ -1239,12 +1473,9 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre except Exception as e: error_message = f"Error generating QNA answer: {str(e)}" print(error_message) - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - writer({"yeild_value": streaming_service._format_annotations()}) - - return { - "final_written_report": f"Error generating answer: {str(e)}" - } + writer({"yield_value": streaming_service.format_error(error_message)}) + + return {"final_written_report": f"Error generating answer: {str(e)}"} async def generate_further_questions(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]: @@ -1268,20 +1499,24 @@ async def generate_further_questions(state: State, config: RunnableConfig, write # Get reranked documents from the state (will be populated by sub-agents) reranked_documents = getattr(state, 'reranked_documents', None) or [] - - streaming_service.only_update_terminal("๐Ÿค” Generating follow-up questions...") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿค” Generating follow-up questions..." + ) + } + ) + # Get user's fast LLM llm = await get_user_fast_llm(state.db_session, user_id) if not llm: error_message = f"No fast LLM configured for user {user_id}" print(error_message) - streaming_service.only_update_terminal(f"โŒ {error_message}", "error") - + writer({"yield_value": streaming_service.format_error(error_message)}) + # Stream empty further questions to UI - streaming_service.only_update_further_questions([]) - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_further_questions_delta([])}) return {"further_questions": []} # Format chat history for the prompt @@ -1338,10 +1573,15 @@ async def generate_further_questions(state: State, config: RunnableConfig, write Do not include any other text or explanation. Only return the JSON. """ - - streaming_service.only_update_terminal("๐Ÿง  Analyzing conversation context to suggest relevant questions...") - writer({"yeild_value": streaming_service._format_annotations()}) - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + "๐Ÿง  Analyzing conversation context to suggest relevant questions..." + ) + } + ) + # Create messages for the LLM messages = [ SystemMessage(content=get_further_questions_system_prompt()), @@ -1366,47 +1606,67 @@ async def generate_further_questions(state: State, config: RunnableConfig, write # Extract the further_questions array further_questions = parsed_data.get("further_questions", []) - - streaming_service.only_update_terminal(f"โœ… Generated {len(further_questions)} contextual follow-up questions!") - + + writer( + { + "yield_value": streaming_service.format_terminal_info_delta( + f"โœ… Generated {len(further_questions)} contextual follow-up questions!" + ) + } + ) + # Stream the further questions to the UI - streaming_service.only_update_further_questions(further_questions) - writer({"yeild_value": streaming_service._format_annotations()}) - + writer( + { + "yield_value": streaming_service.format_further_questions_delta( + further_questions + ) + } + ) + print(f"Successfully generated {len(further_questions)} further questions") return {"further_questions": further_questions} else: # If JSON structure not found, return empty list - error_message = "Could not find valid JSON in LLM response for further questions" + error_message = ( + "Could not find valid JSON in LLM response for further questions" + ) print(error_message) - streaming_service.only_update_terminal(f"โš ๏ธ {error_message}", "warning") - + writer( + { + "yield_value": streaming_service.format_error( + f"Warning: {error_message}" + ) + } + ) + # Stream empty further questions to UI - streaming_service.only_update_further_questions([]) - writer({"yeild_value": streaming_service._format_annotations()}) + writer( + {"yield_value": streaming_service.format_further_questions_delta([])} + ) return {"further_questions": []} except (json.JSONDecodeError, ValueError) as e: # Log the error and return empty list error_message = f"Error parsing further questions response: {str(e)}" print(error_message) - streaming_service.only_update_terminal(f"โš ๏ธ {error_message}", "warning") - + writer( + {"yield_value": streaming_service.format_error(f"Warning: {error_message}")} + ) + # Stream empty further questions to UI - streaming_service.only_update_further_questions([]) - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_further_questions_delta([])}) return {"further_questions": []} except Exception as e: # Handle any other errors error_message = f"Error generating further questions: {str(e)}" print(error_message) - streaming_service.only_update_terminal(f"โš ๏ธ {error_message}", "warning") - + writer( + {"yield_value": streaming_service.format_error(f"Warning: {error_message}")} + ) + # Stream empty further questions to UI - streaming_service.only_update_further_questions([]) - writer({"yeild_value": streaming_service._format_annotations()}) + writer({"yield_value": streaming_service.format_further_questions_delta([])}) return {"further_questions": []} - - diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py index 9db77f4..dc7c126 100644 --- a/surfsense_backend/app/routes/chats_routes.py +++ b/surfsense_backend/app/routes/chats_routes.py @@ -54,32 +54,23 @@ async def handle_chat_data( if message['role'] == "user": langchain_chat_history.append(HumanMessage(content=message['content'])) elif message['role'] == "assistant": - # Find the last "ANSWER" annotation specifically - answer_annotation = None - for annotation in reversed(message['annotations']): - if annotation['type'] == "ANSWER": - answer_annotation = annotation - break - - if answer_annotation: - answer_text = answer_annotation['content'] - # If content is a list, join it into a single string - if isinstance(answer_text, list): - answer_text = "\n".join(answer_text) - langchain_chat_history.append(AIMessage(content=answer_text)) + langchain_chat_history.append(AIMessage(content=message['content'])) - response = StreamingResponse(stream_connector_search_results( - user_query, - user.id, - search_space_id, # Already converted to int in lines 32-37 - session, - research_mode, - selected_connectors, - langchain_chat_history, - search_mode_str, - document_ids_to_add_in_context - )) - response.headers['x-vercel-ai-data-stream'] = 'v1' + response = StreamingResponse( + stream_connector_search_results( + user_query, + user.id, + search_space_id, + session, + research_mode, + selected_connectors, + langchain_chat_history, + search_mode_str, + document_ids_to_add_in_context, + ) + ) + + response.headers["x-vercel-ai-data-stream"] = "v1" return response diff --git a/surfsense_backend/app/services/streaming_service.py b/surfsense_backend/app/services/streaming_service.py index 514e76b..ce1188a 100644 --- a/surfsense_backend/app/services/streaming_service.py +++ b/surfsense_backend/app/services/streaming_service.py @@ -23,17 +23,138 @@ class StreamingService: "content": [] } ] - # It is used to send annotations to the frontend + + # DEPRECATED: This sends the full annotation array every time (inefficient) def _format_annotations(self) -> str: """ Format the annotations as a string - + + DEPRECATED: This method sends the full annotation state every time. + Use the delta formatters instead for optimal streaming. + Returns: str: The formatted annotations string """ return f'8:{json.dumps(self.message_annotations)}\n' - - # It is used to end Streaming + + def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str: + """ + Format a single terminal info message as a delta annotation + + Args: + text: The terminal message text + message_type: The message type (info, error, success, etc.) + + Returns: + str: The formatted annotation delta string + """ + message = {"id": self.terminal_idx, "text": text, "type": message_type} + self.terminal_idx += 1 + + # Update internal state for reference + self.message_annotations[0]["content"].append(message) + + # Return only the delta annotation + annotation = {"type": "TERMINAL_INFO", "content": [message]} + return f"8:[{json.dumps(annotation)}]\n" + + def format_sources_delta(self, sources: List[Dict[str, Any]]) -> str: + """ + Format sources as a delta annotation + + Args: + sources: List of source objects + + Returns: + str: The formatted annotation delta string + """ + # Update internal state + self.message_annotations[1]["content"] = sources + + # Return only the delta annotation + annotation = {"type": "SOURCES", "content": sources} + return f"8:[{json.dumps(annotation)}]\n" + + def format_answer_delta(self, answer_chunk: str) -> str: + """ + Format a single answer chunk as a delta annotation + + Args: + answer_chunk: The new answer chunk to add + + Returns: + str: The formatted annotation delta string + """ + # Update internal state by appending the chunk + if isinstance(self.message_annotations[2]["content"], list): + self.message_annotations[2]["content"].append(answer_chunk) + else: + self.message_annotations[2]["content"] = [answer_chunk] + + # Return only the delta annotation with the new chunk + annotation = {"type": "ANSWER", "content": [answer_chunk]} + return f"8:[{json.dumps(annotation)}]\n" + + def format_answer_annotation(self, answer_lines: List[str]) -> str: + """ + Format the complete answer as a replacement annotation + + Args: + answer_lines: Complete list of answer lines + + Returns: + str: The formatted annotation string + """ + # Update internal state + self.message_annotations[2]["content"] = answer_lines + + # Return the full answer annotation + annotation = {"type": "ANSWER", "content": answer_lines} + return f"8:[{json.dumps(annotation)}]\n" + + def format_further_questions_delta( + self, further_questions: List[Dict[str, Any]] + ) -> str: + """ + Format further questions as a delta annotation + + Args: + further_questions: List of further question objects + + Returns: + str: The formatted annotation delta string + """ + # Update internal state + self.message_annotations[3]["content"] = further_questions + + # Return only the delta annotation + annotation = {"type": "FURTHER_QUESTIONS", "content": further_questions} + return f"8:[{json.dumps(annotation)}]\n" + + def format_text_chunk(self, text: str) -> str: + """ + Format a text chunk using the text stream part + + Args: + text: The text chunk to stream + + Returns: + str: The formatted text part string + """ + return f"0:{json.dumps(text)}\n" + + def format_error(self, error_message: str) -> str: + """ + Format an error using the error stream part + + Args: + error_message: The error message + + Returns: + str: The formatted error part string + """ + return f"3:{json.dumps(error_message)}\n" + def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str: """ Format a completion message @@ -56,7 +177,12 @@ class StreamingService: } return f'd:{json.dumps(completion_data)}\n' + + # DEPRECATED METHODS: Keep for backward compatibility but mark as deprecated def only_update_terminal(self, text: str, message_type: str = "info") -> str: + """ + DEPRECATED: Use format_terminal_info_delta() instead for optimal streaming + """ self.message_annotations[0]["content"].append({ "id": self.terminal_idx, "text": text, @@ -66,17 +192,23 @@ class StreamingService: return self.message_annotations def only_update_sources(self, sources: List[Dict[str, Any]]) -> str: + """ + DEPRECATED: Use format_sources_delta() instead for optimal streaming + """ self.message_annotations[1]["content"] = sources return self.message_annotations def only_update_answer(self, answer: List[str]) -> str: + """ + DEPRECATED: Use format_answer_delta() or format_answer_annotation() instead for optimal streaming + """ self.message_annotations[2]["content"] = answer return self.message_annotations - + def only_update_further_questions(self, further_questions: List[Dict[str, Any]]) -> str: """ - Update the further questions annotation - + DEPRECATED: Use format_further_questions_delta() instead for optimal streaming + Args: further_questions: List of further question objects with id and question fields diff --git a/surfsense_backend/app/tasks/stream_connector_search_results.py b/surfsense_backend/app/tasks/stream_connector_search_results.py index 145214c..f66bf1a 100644 --- a/surfsense_backend/app/tasks/stream_connector_search_results.py +++ b/surfsense_backend/app/tasks/stream_connector_search_results.py @@ -83,9 +83,8 @@ async def stream_connector_search_results( config=config, stream_mode="custom", ): - # If the chunk contains a 'yeild_value' key, print its value - # Note: there's a typo in 'yeild_value' in the code, but we need to match it - if isinstance(chunk, dict) and 'yeild_value' in chunk: - yield chunk['yeild_value'] - - yield streaming_service.format_completion() \ No newline at end of file + if isinstance(chunk, dict): + if "yield_value" in chunk: + yield chunk["yield_value"] + + yield streaming_service.format_completion() diff --git a/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx index 3c0f20d..8a0bde7 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx @@ -981,19 +981,16 @@ const ChatPage = () => { const renderTerminalContent = (message: any) => { if (!message.annotations) return null; - // Get all TERMINAL_INFO annotations - const terminalInfoAnnotations = (message.annotations as any[]).filter( - (a) => a.type === "TERMINAL_INFO", - ); - - // Get the latest TERMINAL_INFO annotation - const latestTerminalInfo = - terminalInfoAnnotations.length > 0 - ? terminalInfoAnnotations[terminalInfoAnnotations.length - 1] - : null; + // Get all TERMINAL_INFO annotations content + const terminalInfoAnnotations = (message.annotations as any[]).map(item => { + if(item.type === "TERMINAL_INFO") { + return item.content.map((a: any) => a.text) + + } + }).flat().filter(Boolean) // Render the content of the latest TERMINAL_INFO annotation - return latestTerminalInfo?.content.map((item: any, idx: number) => ( + return terminalInfoAnnotations.map((item: any, idx: number) => (
[{String(idx).padStart(2, "0")}: @@ -1008,7 +1005,7 @@ const ChatPage = () => { ${item.type === "warning" ? "text-yellow-300" : ""} `} > - {item.text} + {item}
));