Merge pull request #141 from MODSetter/dev

feat: Stabilized Citation Logic
This commit is contained in:
Rohan Verma 2025-06-04 17:43:54 -07:00 committed by GitHub
commit e8a19c496b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 54 additions and 49 deletions

View file

@ -572,12 +572,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
TOP_K = 20 TOP_K = 20
elif configuration.num_sections == 6: elif configuration.num_sections == 6:
TOP_K = 30 TOP_K = 30
else:
TOP_K = 10
relevant_documents = [] relevant_documents = []
async with async_session_maker() as db_session: async with async_session_maker() as db_session:
try: try:
# Create connector service inside the db_session scope # Create connector service inside the db_session scope
connector_service = ConnectorService(db_session) connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
relevant_documents = await fetch_relevant_documents( relevant_documents = await fetch_relevant_documents(
research_questions=all_questions, research_questions=all_questions,
@ -875,7 +878,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
async with async_session_maker() as db_session: async with async_session_maker() as db_session:
try: try:
# Create connector service inside the db_session scope # Create connector service inside the db_session scope
connector_service = ConnectorService(db_session) connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
# Use the reformulated query as a single research question # Use the reformulated query as a single research question
research_questions = [reformulated_query] research_questions = [reformulated_query]

View file

@ -4,21 +4,45 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document
from tavily import TavilyClient from tavily import TavilyClient
from linkup import LinkupClient from linkup import LinkupClient
from sqlalchemy import func
from app.agents.researcher.configuration import SearchMode from app.agents.researcher.configuration import SearchMode
class ConnectorService: class ConnectorService:
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession, user_id: str = None):
self.session = session self.session = session
self.chunk_retriever = ChucksHybridSearchRetriever(session) self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session)
self.source_id_counter = 1 self.user_id = user_id
self.source_id_counter = 100000 # High starting value to avoid collisions with existing IDs
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
async def initialize_counter(self):
"""
Initialize the source_id_counter based on the total number of chunks for the user.
This ensures unique IDs across different sessions.
"""
if self.user_id:
try:
# Count total chunks for documents belonging to this user
result = await self.session.execute(
select(func.count(Chunk.id))
.join(Document)
.filter(Document.user_id == self.user_id)
)
chunk_count = result.scalar() or 0
self.source_id_counter = chunk_count + 1
print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}")
except Exception as e:
print(f"Error initializing source_id_counter: {str(e)}")
# Fallback to default value
self.source_id_counter = 1
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
""" """
Search for crawled URLs and return both the source information and langchain documents Search for crawled URLs and return both the source information and langchain documents
@ -57,16 +81,14 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(crawled_urls_chunks): for _i, chunk in enumerate(crawled_urls_chunks):
# Fix for UI
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
# Create a source entry # Create a source entry
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'), "title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '') "url": metadata.get('url', '')
@ -123,16 +145,14 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(files_chunks): for _i, chunk in enumerate(files_chunks):
# Fix for UI
files_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
# Create a source entry # Create a source entry
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'), "title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '') "url": metadata.get('url', '')
@ -337,9 +357,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(slack_chunks): for _i, chunk in enumerate(slack_chunks):
# Fix for UI
slack_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -365,7 +383,7 @@ class ConnectorService:
url = f"https://slack.com/app_redirect?channel={channel_id}" url = f"https://slack.com/app_redirect?channel={channel_id}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -428,10 +446,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(notion_chunks): for _i, chunk in enumerate(notion_chunks):
# Fix for UI
notion_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -458,7 +473,7 @@ class ConnectorService:
url = f"https://notion.so/{page_id.replace('-', '')}" url = f"https://notion.so/{page_id.replace('-', '')}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -522,9 +537,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(extension_chunks): for i, chunk in enumerate(extension_chunks):
# Fix for UI
extension_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -569,7 +581,7 @@ class ConnectorService:
pass pass
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": webpage_url "url": webpage_url
@ -632,10 +644,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(youtube_chunks): for _i, chunk in enumerate(youtube_chunks):
# Fix for UI
youtube_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -644,7 +653,7 @@ class ConnectorService:
video_title = metadata.get('video_title', 'Untitled Video') video_title = metadata.get('video_title', 'Untitled Video')
video_id = metadata.get('video_id', '') video_id = metadata.get('video_id', '')
channel_name = metadata.get('channel_name', '') channel_name = metadata.get('channel_name', '')
published_date = metadata.get('published_date', '') # published_date = metadata.get('published_date', '')
# Create a more descriptive title for YouTube videos # Create a more descriptive title for YouTube videos
title = video_title title = video_title
@ -660,7 +669,7 @@ class ConnectorService:
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else "" url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -719,17 +728,14 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(github_chunks): for _i, chunk in enumerate(github_chunks):
# Fix for UI - assign a unique ID for citation/source tracking
github_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
# Create a source entry # Create a source entry
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": document.get('title', 'GitHub Document'), # Use specific title if available "title": document.get('title', 'GitHub Document'), # Use specific title if available
"description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview "description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview
"url": metadata.get('url', '') # Use URL if available in metadata "url": metadata.get('url', '') # Use URL if available in metadata
@ -792,10 +798,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication # Process each chunk and create sources directly without deduplication
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(linear_chunks): for _i, chunk in enumerate(linear_chunks):
# Fix for UI
linear_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -831,7 +834,7 @@ class ConnectorService:
url = f"https://linear.app/issue/{issue_identifier}" url = f"https://linear.app/issue/{issue_identifier}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,
@ -1004,8 +1007,6 @@ class ConnectorService:
sources_list = [] sources_list = []
async with self.counter_lock: async with self.counter_lock:
for i, chunk in enumerate(discord_chunks): for i, chunk in enumerate(discord_chunks):
# Fix for UI
discord_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata # Extract document metadata
document = chunk.get('document', {}) document = chunk.get('document', {})
metadata = document.get('metadata', {}) metadata = document.get('metadata', {})
@ -1034,7 +1035,7 @@ class ConnectorService:
url = f"https://discord.com/channels/@me/{channel_id}" url = f"https://discord.com/channels/@me/{channel_id}"
source = { source = {
"id": self.source_id_counter, "id": document.get('id', self.source_id_counter),
"title": title, "title": title,
"description": description, "description": description,
"url": url, "url": url,

View file

@ -183,8 +183,8 @@ const SourcesDialogContent = ({
</div> </div>
<div className="space-y-3 mt-4"> <div className="space-y-3 mt-4">
{paginatedSources.map((source: any) => ( {paginatedSources.map((source: any, index: number) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer"> <Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3"> <div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center"> <div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)} {getConnectorIcon(connector.type)}
@ -845,8 +845,8 @@ const ChatPage = () => {
{messageConnectorSources.map(connector => ( {messageConnectorSources.map(connector => (
<TabsContent key={connector.id} value={connector.type} className="mt-0"> <TabsContent key={connector.id} value={connector.type} className="mt-0">
<div className="space-y-3"> <div className="space-y-3">
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any) => ( {connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any, index: number) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer"> <Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3"> <div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center"> <div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)} {getConnectorIcon(connector.type)}