feat: Stabilized Citation Logic

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-06-04 17:19:31 -07:00
parent 7c8b84a46c
commit 52a9ad04bd
3 changed files with 44 additions and 40 deletions

View file

@ -572,12 +572,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
TOP_K = 20
elif configuration.num_sections == 6:
TOP_K = 30
else:
TOP_K = 10
relevant_documents = []
async with async_session_maker() as db_session:
try:
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session)
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
@ -875,7 +878,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
async with async_session_maker() as db_session:
try:
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session)
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
# Use the reformulated query as a single research question
research_questions = [reformulated_query]

View file

@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document
from tavily import TavilyClient
from linkup import LinkupClient
@ -12,13 +12,36 @@ from app.agents.researcher.configuration import SearchMode
class ConnectorService:
def __init__(self, session: AsyncSession):
def __init__(self, session: AsyncSession, user_id: str = None):
self.session = session
self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session)
self.user_id = user_id
self.source_id_counter = 1
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
async def initialize_counter(self):
"""
Initialize the source_id_counter based on the total number of chunks for the user.
This ensures unique IDs across different sessions.
"""
if self.user_id:
try:
# Count total chunks for documents belonging to this user
from sqlalchemy import func
result = await self.session.execute(
select(func.count(Chunk.id))
.join(Document)
.filter(Document.user_id == self.user_id)
)
chunk_count = result.scalar() or 0
self.source_id_counter = chunk_count + 1
print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}")
except Exception as e:
print(f"Error initializing source_id_counter: {str(e)}")
# Fallback to default value
self.source_id_counter = 1
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for crawled URLs and return both the source information and langchain documents
@ -58,15 +81,13 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(crawled_urls_chunks):
# Fix for UI
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a source entry
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
@ -124,15 +145,13 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(files_chunks):
# Fix for UI
files_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a source entry
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
@ -338,8 +357,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(slack_chunks):
# Fix for UI
slack_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -365,7 +382,7 @@ class ConnectorService:
url = f"https://slack.com/app_redirect?channel={channel_id}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -429,9 +446,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(notion_chunks):
# Fix for UI
notion_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -458,7 +472,7 @@ class ConnectorService:
url = f"https://notion.so/{page_id.replace('-', '')}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -522,9 +536,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(extension_chunks):
# Fix for UI
extension_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -569,7 +580,7 @@ class ConnectorService:
pass
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": webpage_url
@ -633,9 +644,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(youtube_chunks):
# Fix for UI
youtube_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -660,7 +668,7 @@ class ConnectorService:
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -720,16 +728,13 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(github_chunks):
# Fix for UI - assign a unique ID for citation/source tracking
github_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a source entry
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": document.get('title', 'GitHub Document'), # Use specific title if available
"description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview
"url": metadata.get('url', '') # Use URL if available in metadata
@ -793,9 +798,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(linear_chunks):
# Fix for UI
linear_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -831,7 +833,7 @@ class ConnectorService:
url = f"https://linear.app/issue/{issue_identifier}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -1004,8 +1006,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(discord_chunks):
# Fix for UI
discord_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -1034,7 +1034,7 @@ class ConnectorService:
url = f"https://discord.com/channels/@me/{channel_id}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,

View file

@ -183,8 +183,8 @@ const SourcesDialogContent = ({
</div>
<div className="space-y-3 mt-4">
{paginatedSources.map((source: any) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
{paginatedSources.map((source: any, index: number) => (
<Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)}
@ -845,8 +845,8 @@ const ChatPage = () => {
{messageConnectorSources.map(connector => (
<TabsContent key={connector.id} value={connector.type} className="mt-0">
<div className="space-y-3">
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any, index: number) => (
<Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)}