From 52a9ad04bdb43c25927316186386afff04e28305 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Wed, 4 Jun 2025 17:19:31 -0700 Subject: [PATCH 1/3] feat: Stabilized Citation Logic --- .../app/agents/researcher/nodes.py | 8 ++- .../app/utils/connector_service.py | 68 +++++++++---------- .../researcher/[chat_id]/page.tsx | 8 +-- 3 files changed, 44 insertions(+), 40 deletions(-) diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index 2d4584f..e73aefa 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -572,12 +572,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW TOP_K = 20 elif configuration.num_sections == 6: TOP_K = 30 + else: + TOP_K = 10 relevant_documents = [] async with async_session_maker() as db_session: try: # Create connector service inside the db_session scope - connector_service = ConnectorService(db_session) + connector_service = ConnectorService(db_session, user_id=configuration.user_id) + await connector_service.initialize_counter() relevant_documents = await fetch_relevant_documents( research_questions=all_questions, @@ -875,7 +878,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre async with async_session_maker() as db_session: try: # Create connector service inside the db_session scope - connector_service = ConnectorService(db_session) + connector_service = ConnectorService(db_session, user_id=configuration.user_id) + await connector_service.initialize_counter() # Use the reformulated query as a single research question research_questions = [reformulated_query] diff --git a/surfsense_backend/app/utils/connector_service.py b/surfsense_backend/app/utils/connector_service.py index 650259b..1c5f42a 100644 --- a/surfsense_backend/app/utils/connector_service.py +++ b/surfsense_backend/app/utils/connector_service.py @@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever -from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document from tavily import TavilyClient from linkup import LinkupClient @@ -12,13 +12,36 @@ from app.agents.researcher.configuration import SearchMode class ConnectorService: - def __init__(self, session: AsyncSession): + def __init__(self, session: AsyncSession, user_id: str = None): self.session = session self.chunk_retriever = ChucksHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session) + self.user_id = user_id self.source_id_counter = 1 self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments + async def initialize_counter(self): + """ + Initialize the source_id_counter based on the total number of chunks for the user. + This ensures unique IDs across different sessions. + """ + if self.user_id: + try: + # Count total chunks for documents belonging to this user + from sqlalchemy import func + result = await self.session.execute( + select(func.count(Chunk.id)) + .join(Document) + .filter(Document.user_id == self.user_id) + ) + chunk_count = result.scalar() or 0 + self.source_id_counter = chunk_count + 1 + print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}") + except Exception as e: + print(f"Error initializing source_id_counter: {str(e)}") + # Fallback to default value + self.source_id_counter = 1 + async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple: """ Search for crawled URLs and return both the source information and langchain documents @@ -58,15 +81,13 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(crawled_urls_chunks): - # Fix for UI - crawled_urls_chunks[i]['document']['id'] = self.source_id_counter # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) # Create a source entry source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": document.get('title', 'Untitled Document'), "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), "url": metadata.get('url', '') @@ -124,15 +145,13 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(files_chunks): - # Fix for UI - files_chunks[i]['document']['id'] = self.source_id_counter # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) # Create a source entry source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": document.get('title', 'Untitled Document'), "description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])), "url": metadata.get('url', '') @@ -338,8 +357,6 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(slack_chunks): - # Fix for UI - slack_chunks[i]['document']['id'] = self.source_id_counter # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -365,7 +382,7 @@ class ConnectorService: url = f"https://slack.com/app_redirect?channel={channel_id}" source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": title, "description": description, "url": url, @@ -429,9 +446,6 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(notion_chunks): - # Fix for UI - notion_chunks[i]['document']['id'] = self.source_id_counter - # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -458,7 +472,7 @@ class ConnectorService: url = f"https://notion.so/{page_id.replace('-', '')}" source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": title, "description": description, "url": url, @@ -522,9 +536,6 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(extension_chunks): - # Fix for UI - extension_chunks[i]['document']['id'] = self.source_id_counter - # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -569,7 +580,7 @@ class ConnectorService: pass source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": title, "description": description, "url": webpage_url @@ -633,9 +644,6 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(youtube_chunks): - # Fix for UI - youtube_chunks[i]['document']['id'] = self.source_id_counter - # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -660,7 +668,7 @@ class ConnectorService: url = f"https://www.youtube.com/watch?v={video_id}" if video_id else "" source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": title, "description": description, "url": url, @@ -720,16 +728,13 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(github_chunks): - # Fix for UI - assign a unique ID for citation/source tracking - github_chunks[i]['document']['id'] = self.source_id_counter - # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) # Create a source entry source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": document.get('title', 'GitHub Document'), # Use specific title if available "description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview "url": metadata.get('url', '') # Use URL if available in metadata @@ -793,9 +798,6 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(linear_chunks): - # Fix for UI - linear_chunks[i]['document']['id'] = self.source_id_counter - # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -831,7 +833,7 @@ class ConnectorService: url = f"https://linear.app/issue/{issue_identifier}" source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": title, "description": description, "url": url, @@ -1004,8 +1006,6 @@ class ConnectorService: sources_list = [] async with self.counter_lock: for i, chunk in enumerate(discord_chunks): - # Fix for UI - discord_chunks[i]['document']['id'] = self.source_id_counter # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -1034,7 +1034,7 @@ class ConnectorService: url = f"https://discord.com/channels/@me/{channel_id}" source = { - "id": self.source_id_counter, + "id": document.get('id', self.source_id_counter), "title": title, "description": description, "url": url, diff --git a/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx index 56893b7..6a547a6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx @@ -183,8 +183,8 @@ const SourcesDialogContent = ({
- {paginatedSources.map((source: any) => ( - + {paginatedSources.map((source: any, index: number) => ( +
{getConnectorIcon(connector.type)} @@ -845,8 +845,8 @@ const ChatPage = () => { {messageConnectorSources.map(connector => (
- {connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any) => ( - + {connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any, index: number) => ( +
{getConnectorIcon(connector.type)} From ef252e821eaa69eb3057a3c67d27925373b4f4c9 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Wed, 4 Jun 2025 17:30:15 -0700 Subject: [PATCH 2/3] fix: to prevent collisions incase of fallback which should never happen --- surfsense_backend/app/utils/connector_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfsense_backend/app/utils/connector_service.py b/surfsense_backend/app/utils/connector_service.py index 1c5f42a..1c60766 100644 --- a/surfsense_backend/app/utils/connector_service.py +++ b/surfsense_backend/app/utils/connector_service.py @@ -17,7 +17,7 @@ class ConnectorService: self.chunk_retriever = ChucksHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session) self.user_id = user_id - self.source_id_counter = 1 + self.source_id_counter = 100000 # High starting value to avoid collisions with existing IDs self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments async def initialize_counter(self): From bc1642488f0c138b081dd910378661709fef0318 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Wed, 4 Jun 2025 17:38:44 -0700 Subject: [PATCH 3/3] fix: Coderabbit & Recurse --- .../app/utils/connector_service.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/surfsense_backend/app/utils/connector_service.py b/surfsense_backend/app/utils/connector_service.py index 1c60766..1052939 100644 --- a/surfsense_backend/app/utils/connector_service.py +++ b/surfsense_backend/app/utils/connector_service.py @@ -7,6 +7,7 @@ from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document from tavily import TavilyClient from linkup import LinkupClient +from sqlalchemy import func from app.agents.researcher.configuration import SearchMode @@ -28,7 +29,7 @@ class ConnectorService: if self.user_id: try: # Count total chunks for documents belonging to this user - from sqlalchemy import func + result = await self.session.execute( select(func.count(Chunk.id)) .join(Document) @@ -80,7 +81,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(crawled_urls_chunks): + for _i, chunk in enumerate(crawled_urls_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -144,7 +145,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(files_chunks): + for _i, chunk in enumerate(files_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -356,7 +357,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(slack_chunks): + for _i, chunk in enumerate(slack_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -445,7 +446,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(notion_chunks): + for _i, chunk in enumerate(notion_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -643,7 +644,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(youtube_chunks): + for _i, chunk in enumerate(youtube_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -652,7 +653,7 @@ class ConnectorService: video_title = metadata.get('video_title', 'Untitled Video') video_id = metadata.get('video_id', '') channel_name = metadata.get('channel_name', '') - published_date = metadata.get('published_date', '') + # published_date = metadata.get('published_date', '') # Create a more descriptive title for YouTube videos title = video_title @@ -727,7 +728,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(github_chunks): + for _i, chunk in enumerate(github_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {}) @@ -797,7 +798,7 @@ class ConnectorService: # Process each chunk and create sources directly without deduplication sources_list = [] async with self.counter_lock: - for i, chunk in enumerate(linear_chunks): + for _i, chunk in enumerate(linear_chunks): # Extract document metadata document = chunk.get('document', {}) metadata = document.get('metadata', {})