diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index c5ec229..5f153d6 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -26,6 +26,77 @@ from .sub_section_writer.graph import graph as sub_section_writer_graph from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name +def extract_sources_from_documents( + all_documents: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """ + Extract sources from all_documents and group them by document type. + + Args: + all_documents: List of document chunks from user-selected documents and connector-fetched documents + + Returns: + List of source objects grouped by type for streaming + """ + # Group documents by their source type + documents_by_type = {} + + for doc in all_documents: + # Get source type from the document + source_type = doc.get("source", "UNKNOWN") + document_info = doc.get("document", {}) + document_type = document_info.get("document_type", source_type) + + # Use document_type if available, otherwise use source + group_type = document_type if document_type != "UNKNOWN" else source_type + + if group_type not in documents_by_type: + documents_by_type[group_type] = [] + documents_by_type[group_type].append(doc) + + # Create source objects for each document type + source_objects = [] + source_id_counter = 1 + + for doc_type, docs in documents_by_type.items(): + sources_list = [] + + for doc in docs: + document_info = doc.get("document", {}) + metadata = document_info.get("metadata", {}) + + # Create source entry based on document structure + source = { + "id": doc.get("chunk_id", source_id_counter), + "title": document_info.get("title", "Untitled Document"), + "description": doc.get("content", "")[:100] + "..." + if len(doc.get("content", "")) > 100 + else doc.get("content", ""), + "url": metadata.get("url", metadata.get("page_url", "")), + } + + source_id_counter += 1 + sources_list.append(source) + + # Create group object + group_name = ( + get_connector_friendly_name(doc_type) + if doc_type != "UNKNOWN" + else "Unknown Sources" + ) + + source_object = { + "id": len(source_objects) + 1, + "name": group_name, + "type": doc_type, + "sources": sources_list, + } + + source_objects.append(source_object) + + return source_objects + + async def fetch_documents_by_ids( document_ids: list[int], user_id: str, db_session: AsyncSession ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: @@ -1169,16 +1240,6 @@ async def fetch_relevant_documents( } ) - # After all sources are collected and deduplicated, stream them - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_sources_delta( - deduplicated_sources - ) - } - ) - # Deduplicate raw documents based on chunk_id or content seen_chunk_ids = set() seen_content_hashes = set() @@ -1355,6 +1416,13 @@ async def process_sections( ) print(f"Total documents for sections: {len(all_documents)}") + # Extract and stream sources from all_documents + if all_documents: + sources_to_stream = extract_sources_from_documents(all_documents) + writer( + {"yield_value": streaming_service.format_sources_delta(sources_to_stream)} + ) + writer( { "yield_value": streaming_service.format_terminal_info_delta( @@ -1781,6 +1849,13 @@ async def handle_qna_workflow( print(f"Added {len(user_selected_documents)} user-selected documents for QNA") print(f"Total documents for QNA: {len(all_documents)}") + # Extract and stream sources from all_documents + if all_documents: + sources_to_stream = extract_sources_from_documents(all_documents) + writer( + {"yield_value": streaming_service.format_sources_delta(sources_to_stream)} + ) + writer( { "yield_value": streaming_service.format_terminal_info_delta(