Merge pull request #276 from MODSetter/dev
Some checks failed
pre-commit / pre-commit (push) Has been cancelled

fix: citations for manual user selected docs.
This commit is contained in:
Rohan Verma 2025-08-20 12:06:57 -07:00 committed by GitHub
commit 0db3c32144
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,6 +26,77 @@ from .sub_section_writer.graph import graph as sub_section_writer_graph
from .utils import AnswerOutline, get_connector_emoji, get_connector_friendly_name
def extract_sources_from_documents(
all_documents: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""
Extract sources from all_documents and group them by document type.
Args:
all_documents: List of document chunks from user-selected documents and connector-fetched documents
Returns:
List of source objects grouped by type for streaming
"""
# Group documents by their source type
documents_by_type = {}
for doc in all_documents:
# Get source type from the document
source_type = doc.get("source", "UNKNOWN")
document_info = doc.get("document", {})
document_type = document_info.get("document_type", source_type)
# Use document_type if available, otherwise use source
group_type = document_type if document_type != "UNKNOWN" else source_type
if group_type not in documents_by_type:
documents_by_type[group_type] = []
documents_by_type[group_type].append(doc)
# Create source objects for each document type
source_objects = []
source_id_counter = 1
for doc_type, docs in documents_by_type.items():
sources_list = []
for doc in docs:
document_info = doc.get("document", {})
metadata = document_info.get("metadata", {})
# Create source entry based on document structure
source = {
"id": doc.get("chunk_id", source_id_counter),
"title": document_info.get("title", "Untitled Document"),
"description": doc.get("content", "")[:100] + "..."
if len(doc.get("content", "")) > 100
else doc.get("content", ""),
"url": metadata.get("url", metadata.get("page_url", "")),
}
source_id_counter += 1
sources_list.append(source)
# Create group object
group_name = (
get_connector_friendly_name(doc_type)
if doc_type != "UNKNOWN"
else "Unknown Sources"
)
source_object = {
"id": len(source_objects) + 1,
"name": group_name,
"type": doc_type,
"sources": sources_list,
}
source_objects.append(source_object)
return source_objects
async def fetch_documents_by_ids(
document_ids: list[int], user_id: str, db_session: AsyncSession
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
@ -1169,16 +1240,6 @@ async def fetch_relevant_documents(
}
)
# After all sources are collected and deduplicated, stream them
if streaming_service and writer:
writer(
{
"yield_value": streaming_service.format_sources_delta(
deduplicated_sources
)
}
)
# Deduplicate raw documents based on chunk_id or content
seen_chunk_ids = set()
seen_content_hashes = set()
@ -1355,6 +1416,13 @@ async def process_sections(
)
print(f"Total documents for sections: {len(all_documents)}")
# Extract and stream sources from all_documents
if all_documents:
sources_to_stream = extract_sources_from_documents(all_documents)
writer(
{"yield_value": streaming_service.format_sources_delta(sources_to_stream)}
)
writer(
{
"yield_value": streaming_service.format_terminal_info_delta(
@ -1781,6 +1849,13 @@ async def handle_qna_workflow(
print(f"Added {len(user_selected_documents)} user-selected documents for QNA")
print(f"Total documents for QNA: {len(all_documents)}")
# Extract and stream sources from all_documents
if all_documents:
sources_to_stream = extract_sources_from_documents(all_documents)
writer(
{"yield_value": streaming_service.format_sources_delta(sources_to_stream)}
)
writer(
{
"yield_value": streaming_service.format_terminal_info_delta(