Merge pull request #141 from MODSetter/dev

feat: Stabilized Citation Logic
This commit is contained in:
Rohan Verma 2025-06-04 17:43:54 -07:00 committed by GitHub
commit e8a19c496b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 54 additions and 49 deletions

View file

@ -572,12 +572,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
TOP_K = 20
elif configuration.num_sections == 6:
TOP_K = 30
else:
TOP_K = 10
relevant_documents = []
async with async_session_maker() as db_session:
try:
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session)
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
@ -875,7 +878,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
async with async_session_maker() as db_session:
try:
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session)
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
# Use the reformulated query as a single research question
research_questions = [reformulated_query]

View file

@ -4,21 +4,45 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document
from tavily import TavilyClient
from linkup import LinkupClient
from sqlalchemy import func
from app.agents.researcher.configuration import SearchMode
class ConnectorService:
def __init__(self, session: AsyncSession):
def __init__(self, session: AsyncSession, user_id: str = None):
self.session = session
self.chunk_retriever = ChucksHybridSearchRetriever(session)
self.document_retriever = DocumentHybridSearchRetriever(session)
self.source_id_counter = 1
self.user_id = user_id
self.source_id_counter = 100000 # High starting value to avoid collisions with existing IDs
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
async def initialize_counter(self):
"""
Initialize the source_id_counter based on the total number of chunks for the user.
This ensures unique IDs across different sessions.
"""
if self.user_id:
try:
# Count total chunks for documents belonging to this user
result = await self.session.execute(
select(func.count(Chunk.id))
.join(Document)
.filter(Document.user_id == self.user_id)
)
chunk_count = result.scalar() or 0
self.source_id_counter = chunk_count + 1
print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}")
except Exception as e:
print(f"Error initializing source_id_counter: {str(e)}")
# Fallback to default value
self.source_id_counter = 1
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
"""
Search for crawled URLs and return both the source information and langchain documents
@ -57,16 +81,14 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(crawled_urls_chunks):
# Fix for UI
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(crawled_urls_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a source entry
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
@ -123,16 +145,14 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(files_chunks):
# Fix for UI
files_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(files_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a source entry
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": document.get('title', 'Untitled Document'),
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
"url": metadata.get('url', '')
@ -337,9 +357,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(slack_chunks):
# Fix for UI
slack_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(slack_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -365,7 +383,7 @@ class ConnectorService:
url = f"https://slack.com/app_redirect?channel={channel_id}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -428,10 +446,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(notion_chunks):
# Fix for UI
notion_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(notion_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -458,7 +473,7 @@ class ConnectorService:
url = f"https://notion.so/{page_id.replace('-', '')}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -522,9 +537,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(extension_chunks):
# Fix for UI
extension_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -569,7 +581,7 @@ class ConnectorService:
pass
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": webpage_url
@ -632,10 +644,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(youtube_chunks):
# Fix for UI
youtube_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(youtube_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -644,7 +653,7 @@ class ConnectorService:
video_title = metadata.get('video_title', 'Untitled Video')
video_id = metadata.get('video_id', '')
channel_name = metadata.get('channel_name', '')
published_date = metadata.get('published_date', '')
# published_date = metadata.get('published_date', '')
# Create a more descriptive title for YouTube videos
title = video_title
@ -660,7 +669,7 @@ class ConnectorService:
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -719,17 +728,14 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(github_chunks):
# Fix for UI - assign a unique ID for citation/source tracking
github_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(github_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
# Create a source entry
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": document.get('title', 'GitHub Document'), # Use specific title if available
"description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview
"url": metadata.get('url', '') # Use URL if available in metadata
@ -792,10 +798,7 @@ class ConnectorService:
# Process each chunk and create sources directly without deduplication
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(linear_chunks):
# Fix for UI
linear_chunks[i]['document']['id'] = self.source_id_counter
for _i, chunk in enumerate(linear_chunks):
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -831,7 +834,7 @@ class ConnectorService:
url = f"https://linear.app/issue/{issue_identifier}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,
@ -1004,8 +1007,6 @@ class ConnectorService:
sources_list = []
async with self.counter_lock:
for i, chunk in enumerate(discord_chunks):
# Fix for UI
discord_chunks[i]['document']['id'] = self.source_id_counter
# Extract document metadata
document = chunk.get('document', {})
metadata = document.get('metadata', {})
@ -1034,7 +1035,7 @@ class ConnectorService:
url = f"https://discord.com/channels/@me/{channel_id}"
source = {
"id": self.source_id_counter,
"id": document.get('id', self.source_id_counter),
"title": title,
"description": description,
"url": url,

View file

@ -183,8 +183,8 @@ const SourcesDialogContent = ({
</div>
<div className="space-y-3 mt-4">
{paginatedSources.map((source: any) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
{paginatedSources.map((source: any, index: number) => (
<Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)}
@ -845,8 +845,8 @@ const ChatPage = () => {
{messageConnectorSources.map(connector => (
<TabsContent key={connector.id} value={connector.type} className="mt-0">
<div className="space-y-3">
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any) => (
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any, index: number) => (
<Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
<div className="flex items-start gap-3">
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
{getConnectorIcon(connector.type)}