feat: Stabilized Citation Logic

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-06-04 17:19:31 -07:00
parent 7c8b84a46c
commit 52a9ad04bd
3 changed files with 44 additions and 40 deletions

View file

@ -572,12 +572,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
TOP_K = 20 TOP_K = 20
elif configuration.num_sections == 6: elif configuration.num_sections == 6:
TOP_K = 30 TOP_K = 30
else:
TOP_K = 10
relevant_documents = [] relevant_documents = []
async with async_session_maker() as db_session: async with async_session_maker() as db_session:
try: try:
# Create connector service inside the db_session scope # Create connector service inside the db_session scope
connector_service = ConnectorService(db_session) connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
relevant_documents = await fetch_relevant_documents( relevant_documents = await fetch_relevant_documents(
research_questions=all_questions, research_questions=all_questions,
@ -875,7 +878,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
async with async_session_maker() as db_session: async with async_session_maker() as db_session:
try: try:
# Create connector service inside the db_session scope # Create connector service inside the db_session scope
connector_service = ConnectorService(db_session) connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
# Use the reformulated query as a single research question # Use the reformulated query as a single research question
research_questions = [reformulated_query] research_questions = [reformulated_query]

View file

@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document
from tavily import TavilyClient from tavily import TavilyClient
from linkup import LinkupClient from linkup import LinkupClient
@ -12,13 +12,36 @@ from app.agents.researcher.configuration import SearchMode
class ConnectorService: class ConnectorService:
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession, user_id: str = None):
self.session = session self.session = session
self.chunk_retriever = ChucksHybridSearchRetriever(session) self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session)
self.user_id = user_id
self.source_id_counter = 1 self.source_id_counter = 1
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
async def initialize_counter(self):
"""
Initialize the source_id_counter based on the total number of chunks for the user.
This ensures unique IDs across different sessions.
"""
if self.user_id:
try:
# Count total chunks for documents belonging to this user
from sqlalchemy import func
result = await self.session.execute(
select(func.count(Chunk.id))
.join(Document)
.filter(Document.user_id == self.user_id)
)
chunk_count = result.scalar() or 0
self.source_id_counter = chunk_count + 1
print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}")
except Exception as e:
print(f"Error initializing source_id_counter: {str(e)}")
# Fallback to default value
self.source_id_counter = 1
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for crawled URLs and return both the source information and langchain documents Search for crawled URLs and return both the source information and langchain documents
@ -58,15 +81,13 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(crawled_urls_chunks): for i, chunk in enumerate(crawled_urls_chunks):
# Fix for UI
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
# Create a source entry # Create a source entry
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'), "title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '') "url": metadata.get('url', '')
@ -124,15 +145,13 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(files_chunks): for i, chunk in enumerate(files_chunks):
# Fix for UI
files_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
# Create a source entry # Create a source entry
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'), "title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '') "url": metadata.get('url', '')
@ -338,8 +357,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(slack_chunks): for i, chunk in enumerate(slack_chunks):
# Fix for UI
slack_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -365,7 +382,7 @@ class ConnectorService:
url = f"https://slack.com/app_redirect?channel={channel_id}" url = f"https://slack.com/app_redirect?channel={channel_id}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -429,9 +446,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(notion_chunks): for i, chunk in enumerate(notion_chunks):
# Fix for UI
notion_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -458,7 +472,7 @@ class ConnectorService:
url = f"https://notion.so/{page_id.replace('-', '')}" url = f"https://notion.so/{page_id.replace('-', '')}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -522,9 +536,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(extension_chunks): for i, chunk in enumerate(extension_chunks):
# Fix for UI
extension_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -569,7 +580,7 @@ class ConnectorService:
pass pass
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": webpage_url "url": webpage_url
@ -633,9 +644,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(youtube_chunks): for i, chunk in enumerate(youtube_chunks):
# Fix for UI
youtube_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -660,7 +668,7 @@ class ConnectorService:
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else "" url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -720,16 +728,13 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(github_chunks): for i, chunk in enumerate(github_chunks):
# Fix for UI - assign a unique ID for citation/source tracking
github_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
# Create a source entry # Create a source entry
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": document.get('title', 'GitHub Document'), # Use specific title if available "title": document.get('title', 'GitHub Document'), # Use specific title if available
"description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview "description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview
"url": metadata.get('url', '') # Use URL if available in metadata "url": metadata.get('url', '') # Use URL if available in metadata
@ -793,9 +798,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(linear_chunks): for i, chunk in enumerate(linear_chunks):
# Fix for UI
linear_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -831,7 +833,7 @@ class ConnectorService:
url = f"https://linear.app/issue/{issue_identifier}" url = f"https://linear.app/issue/{issue_identifier}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -1004,8 +1006,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(discord_chunks): for i, chunk in enumerate(discord_chunks):
# Fix for UI
discord_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -1034,7 +1034,7 @@ class ConnectorService:
url = f"https://discord.com/channels/@me/{channel_id}" url = f"https://discord.com/channels/@me/{channel_id}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,

View file

@ -183,8 +183,8 @@ const SourcesDialogContent = ({
</div> </div>
<div className="space-y-3 mt-4"> <div className="space-y-3 mt-4">
{paginatedSources.map((source: any) => ( {paginatedSources.map((source: any, index: number) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer"> <Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3"> <div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center"> <div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)} {getConnectorIcon(connector.type)}
@ -845,8 +845,8 @@ const ChatPage = () => {
{messageConnectorSources.map(connector => ( {messageConnectorSources.map(connector => (
<TabsContent key={connector.id} value={connector.type} className="mt-0"> <TabsContent key={connector.id} value={connector.type} className="mt-0">
<div className="space-y-3"> <div className="space-y-3">
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any) => ( {connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any, index: number) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer"> <Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3"> <div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center"> <div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)} {getConnectorIcon(connector.type)}