mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
fix: citations for user selected documents.
This commit is contained in:
parent
3d93fe8186
commit
9dba1930de
1 changed files with 85 additions and 10 deletions
|
@ -26,6 +26,77 @@ from .sub_section_writer.graph import graph as sub_section_writer_graph
|
|||
from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name
|
||||
|
||||
|
||||
def extract_sources_from_documents(
|
||||
all_documents: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract sources from all_documents and group them by document type.
|
||||
|
||||
Args:
|
||||
all_documents: List of document chunks from user-selected documents and connector-fetched documents
|
||||
|
||||
Returns:
|
||||
List of source objects grouped by type for streaming
|
||||
"""
|
||||
# Group documents by their source type
|
||||
documents_by_type = {}
|
||||
|
||||
for doc in all_documents:
|
||||
# Get source type from the document
|
||||
source_type = doc.get("source", "UNKNOWN")
|
||||
document_info = doc.get("document", {})
|
||||
document_type = document_info.get("document_type", source_type)
|
||||
|
||||
# Use document_type if available, otherwise use source
|
||||
group_type = document_type if document_type != "UNKNOWN" else source_type
|
||||
|
||||
if group_type not in documents_by_type:
|
||||
documents_by_type[group_type] = []
|
||||
documents_by_type[group_type].append(doc)
|
||||
|
||||
# Create source objects for each document type
|
||||
source_objects = []
|
||||
source_id_counter = 1
|
||||
|
||||
for doc_type, docs in documents_by_type.items():
|
||||
sources_list = []
|
||||
|
||||
for doc in docs:
|
||||
document_info = doc.get("document", {})
|
||||
metadata = document_info.get("metadata", {})
|
||||
|
||||
# Create source entry based on document structure
|
||||
source = {
|
||||
"id": doc.get("chunk_id", source_id_counter),
|
||||
"title": document_info.get("title", "Untitled Document"),
|
||||
"description": doc.get("content", "")[:100] + "..."
|
||||
if len(doc.get("content", "")) > 100
|
||||
else doc.get("content", ""),
|
||||
"url": metadata.get("url", metadata.get("page_url", "")),
|
||||
}
|
||||
|
||||
source_id_counter += 1
|
||||
sources_list.append(source)
|
||||
|
||||
# Create group object
|
||||
group_name = (
|
||||
get_connector_friendly_name(doc_type)
|
||||
if doc_type != "UNKNOWN"
|
||||
else "Unknown Sources"
|
||||
)
|
||||
|
||||
source_object = {
|
||||
"id": len(source_objects) + 1,
|
||||
"name": group_name,
|
||||
"type": doc_type,
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
source_objects.append(source_object)
|
||||
|
||||
return source_objects
|
||||
|
||||
|
||||
async def fetch_documents_by_ids(
|
||||
document_ids: list[int], user_id: str, db_session: AsyncSession
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
|
@ -1169,16 +1240,6 @@ async def fetch_relevant_documents(
|
|||
}
|
||||
)
|
||||
|
||||
# After all sources are collected and deduplicated, stream them
|
||||
if streaming_service and writer:
|
||||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_sources_delta(
|
||||
deduplicated_sources
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Deduplicate raw documents based on chunk_id or content
|
||||
seen_chunk_ids = set()
|
||||
seen_content_hashes = set()
|
||||
|
@ -1355,6 +1416,13 @@ async def process_sections(
|
|||
)
|
||||
print(f"Total documents for sections: {len(all_documents)}")
|
||||
|
||||
# Extract and stream sources from all_documents
|
||||
if all_documents:
|
||||
sources_to_stream = extract_sources_from_documents(all_documents)
|
||||
writer(
|
||||
{"yield_value": streaming_service.format_sources_delta(sources_to_stream)}
|
||||
)
|
||||
|
||||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_terminal_info_delta(
|
||||
|
@ -1781,6 +1849,13 @@ async def handle_qna_workflow(
|
|||
print(f"Added {len(user_selected_documents)} user-selected documents for QNA")
|
||||
print(f"Total documents for QNA: {len(all_documents)}")
|
||||
|
||||
# Extract and stream sources from all_documents
|
||||
if all_documents:
|
||||
sources_to_stream = extract_sources_from_documents(all_documents)
|
||||
writer(
|
||||
{"yield_value": streaming_service.format_sources_delta(sources_to_stream)}
|
||||
)
|
||||
|
||||
writer(
|
||||
{
|
||||
"yield_value": streaming_service.format_terminal_info_delta(
|
||||
|
|
Loading…
Add table
Reference in a new issue