mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
fix: citations for user selected documents.
This commit is contained in:
parent
3d93fe8186
commit
9dba1930de
1 changed files with 85 additions and 10 deletions
|
@ -26,6 +26,77 @@ from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||||
from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name
|
from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name
|
||||||
|
|
||||||
|
|
||||||
|
def extract_sources_from_documents(
|
||||||
|
all_documents: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract sources from all_documents and group them by document type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_documents: List of document chunks from user-selected documents and connector-fetched documents
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of source objects grouped by type for streaming
|
||||||
|
"""
|
||||||
|
# Group documents by their source type
|
||||||
|
documents_by_type = {}
|
||||||
|
|
||||||
|
for doc in all_documents:
|
||||||
|
# Get source type from the document
|
||||||
|
source_type = doc.get("source", "UNKNOWN")
|
||||||
|
document_info = doc.get("document", {})
|
||||||
|
document_type = document_info.get("document_type", source_type)
|
||||||
|
|
||||||
|
# Use document_type if available, otherwise use source
|
||||||
|
group_type = document_type if document_type != "UNKNOWN" else source_type
|
||||||
|
|
||||||
|
if group_type not in documents_by_type:
|
||||||
|
documents_by_type[group_type] = []
|
||||||
|
documents_by_type[group_type].append(doc)
|
||||||
|
|
||||||
|
# Create source objects for each document type
|
||||||
|
source_objects = []
|
||||||
|
source_id_counter = 1
|
||||||
|
|
||||||
|
for doc_type, docs in documents_by_type.items():
|
||||||
|
sources_list = []
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
document_info = doc.get("document", {})
|
||||||
|
metadata = document_info.get("metadata", {})
|
||||||
|
|
||||||
|
# Create source entry based on document structure
|
||||||
|
source = {
|
||||||
|
"id": doc.get("chunk_id", source_id_counter),
|
||||||
|
"title": document_info.get("title", "Untitled Document"),
|
||||||
|
"description": doc.get("content", "")[:100] + "..."
|
||||||
|
if len(doc.get("content", "")) > 100
|
||||||
|
else doc.get("content", ""),
|
||||||
|
"url": metadata.get("url", metadata.get("page_url", "")),
|
||||||
|
}
|
||||||
|
|
||||||
|
source_id_counter += 1
|
||||||
|
sources_list.append(source)
|
||||||
|
|
||||||
|
# Create group object
|
||||||
|
group_name = (
|
||||||
|
get_connector_friendly_name(doc_type)
|
||||||
|
if doc_type != "UNKNOWN"
|
||||||
|
else "Unknown Sources"
|
||||||
|
)
|
||||||
|
|
||||||
|
source_object = {
|
||||||
|
"id": len(source_objects) + 1,
|
||||||
|
"name": group_name,
|
||||||
|
"type": doc_type,
|
||||||
|
"sources": sources_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
source_objects.append(source_object)
|
||||||
|
|
||||||
|
return source_objects
|
||||||
|
|
||||||
|
|
||||||
async def fetch_documents_by_ids(
|
async def fetch_documents_by_ids(
|
||||||
document_ids: list[int], user_id: str, db_session: AsyncSession
|
document_ids: list[int], user_id: str, db_session: AsyncSession
|
||||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
@ -1169,16 +1240,6 @@ async def fetch_relevant_documents(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# After all sources are collected and deduplicated, stream them
|
|
||||||
if streaming_service and writer:
|
|
||||||
writer(
|
|
||||||
{
|
|
||||||
"yield_value": streaming_service.format_sources_delta(
|
|
||||||
deduplicated_sources
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deduplicate raw documents based on chunk_id or content
|
# Deduplicate raw documents based on chunk_id or content
|
||||||
seen_chunk_ids = set()
|
seen_chunk_ids = set()
|
||||||
seen_content_hashes = set()
|
seen_content_hashes = set()
|
||||||
|
@ -1355,6 +1416,13 @@ async def process_sections(
|
||||||
)
|
)
|
||||||
print(f"Total documents for sections: {len(all_documents)}")
|
print(f"Total documents for sections: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Extract and stream sources from all_documents
|
||||||
|
if all_documents:
|
||||||
|
sources_to_stream = extract_sources_from_documents(all_documents)
|
||||||
|
writer(
|
||||||
|
{"yield_value": streaming_service.format_sources_delta(sources_to_stream)}
|
||||||
|
)
|
||||||
|
|
||||||
writer(
|
writer(
|
||||||
{
|
{
|
||||||
"yield_value": streaming_service.format_terminal_info_delta(
|
"yield_value": streaming_service.format_terminal_info_delta(
|
||||||
|
@ -1781,6 +1849,13 @@ async def handle_qna_workflow(
|
||||||
print(f"Added {len(user_selected_documents)} user-selected documents for QNA")
|
print(f"Added {len(user_selected_documents)} user-selected documents for QNA")
|
||||||
print(f"Total documents for QNA: {len(all_documents)}")
|
print(f"Total documents for QNA: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Extract and stream sources from all_documents
|
||||||
|
if all_documents:
|
||||||
|
sources_to_stream = extract_sources_from_documents(all_documents)
|
||||||
|
writer(
|
||||||
|
{"yield_value": streaming_service.format_sources_delta(sources_to_stream)}
|
||||||
|
)
|
||||||
|
|
||||||
writer(
|
writer(
|
||||||
{
|
{
|
||||||
"yield_value": streaming_service.format_terminal_info_delta(
|
"yield_value": streaming_service.format_terminal_info_delta(
|
||||||
|
|
Loading…
Add table
Reference in a new issue