mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
feat: Stabilized Citation Logic
This commit is contained in:
parent
7c8b84a46c
commit
52a9ad04bd
3 changed files with 44 additions and 40 deletions
|
@ -572,12 +572,15 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
TOP_K = 20
|
||||
elif configuration.num_sections == 6:
|
||||
TOP_K = 30
|
||||
else:
|
||||
TOP_K = 10
|
||||
|
||||
relevant_documents = []
|
||||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
# Create connector service inside the db_session scope
|
||||
connector_service = ConnectorService(db_session)
|
||||
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
||||
relevant_documents = await fetch_relevant_documents(
|
||||
research_questions=all_questions,
|
||||
|
@ -875,7 +878,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
|
|||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
# Create connector service inside the db_session scope
|
||||
connector_service = ConnectorService(db_session)
|
||||
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
||||
# Use the reformulated query as a single research question
|
||||
research_questions = [reformulated_query]
|
||||
|
|
|
@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType, Chunk, Document
|
||||
from tavily import TavilyClient
|
||||
from linkup import LinkupClient
|
||||
|
||||
|
@ -12,13 +12,36 @@ from app.agents.researcher.configuration import SearchMode
|
|||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession):
|
||||
def __init__(self, session: AsyncSession, user_id: str = None):
|
||||
self.session = session
|
||||
self.chunk_retriever = ChucksHybridSearchRetriever(session)
|
||||
self.document_retriever = DocumentHybridSearchRetriever(session)
|
||||
self.user_id = user_id
|
||||
self.source_id_counter = 1
|
||||
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
|
||||
|
||||
async def initialize_counter(self):
|
||||
"""
|
||||
Initialize the source_id_counter based on the total number of chunks for the user.
|
||||
This ensures unique IDs across different sessions.
|
||||
"""
|
||||
if self.user_id:
|
||||
try:
|
||||
# Count total chunks for documents belonging to this user
|
||||
from sqlalchemy import func
|
||||
result = await self.session.execute(
|
||||
select(func.count(Chunk.id))
|
||||
.join(Document)
|
||||
.filter(Document.user_id == self.user_id)
|
||||
)
|
||||
chunk_count = result.scalar() or 0
|
||||
self.source_id_counter = chunk_count + 1
|
||||
print(f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}")
|
||||
except Exception as e:
|
||||
print(f"Error initializing source_id_counter: {str(e)}")
|
||||
# Fallback to default value
|
||||
self.source_id_counter = 1
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
@ -58,15 +81,13 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(crawled_urls_chunks):
|
||||
# Fix for UI
|
||||
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
|
@ -124,15 +145,13 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(files_chunks):
|
||||
# Fix for UI
|
||||
files_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
|
@ -338,8 +357,6 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(slack_chunks):
|
||||
# Fix for UI
|
||||
slack_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
@ -365,7 +382,7 @@ class ConnectorService:
|
|||
url = f"https://slack.com/app_redirect?channel={channel_id}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
|
@ -429,9 +446,6 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(notion_chunks):
|
||||
# Fix for UI
|
||||
notion_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
@ -458,7 +472,7 @@ class ConnectorService:
|
|||
url = f"https://notion.so/{page_id.replace('-', '')}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
|
@ -522,9 +536,6 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(extension_chunks):
|
||||
# Fix for UI
|
||||
extension_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
@ -569,7 +580,7 @@ class ConnectorService:
|
|||
pass
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": webpage_url
|
||||
|
@ -633,9 +644,6 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(youtube_chunks):
|
||||
# Fix for UI
|
||||
youtube_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
@ -660,7 +668,7 @@ class ConnectorService:
|
|||
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
|
@ -720,16 +728,13 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(github_chunks):
|
||||
# Fix for UI - assign a unique ID for citation/source tracking
|
||||
github_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": document.get('title', 'GitHub Document'), # Use specific title if available
|
||||
"description": metadata.get('description', chunk.get('content', '')[:100]), # Use description or content preview
|
||||
"url": metadata.get('url', '') # Use URL if available in metadata
|
||||
|
@ -793,9 +798,6 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(linear_chunks):
|
||||
# Fix for UI
|
||||
linear_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
@ -831,7 +833,7 @@ class ConnectorService:
|
|||
url = f"https://linear.app/issue/{issue_identifier}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
|
@ -1004,8 +1006,6 @@ class ConnectorService:
|
|||
sources_list = []
|
||||
async with self.counter_lock:
|
||||
for i, chunk in enumerate(discord_chunks):
|
||||
# Fix for UI
|
||||
discord_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
@ -1034,7 +1034,7 @@ class ConnectorService:
|
|||
url = f"https://discord.com/channels/@me/{channel_id}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"id": document.get('id', self.source_id_counter),
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
|
|
|
@ -183,8 +183,8 @@ const SourcesDialogContent = ({
|
|||
</div>
|
||||
|
||||
<div className="space-y-3 mt-4">
|
||||
{paginatedSources.map((source: any) => (
|
||||
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
|
||||
{paginatedSources.map((source: any, index: number) => (
|
||||
<Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
|
||||
{getConnectorIcon(connector.type)}
|
||||
|
@ -845,8 +845,8 @@ const ChatPage = () => {
|
|||
{messageConnectorSources.map(connector => (
|
||||
<TabsContent key={connector.id} value={connector.type} className="mt-0">
|
||||
<div className="space-y-3">
|
||||
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any) => (
|
||||
<Card key={source.id} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
|
||||
{connector.sources?.slice(0, INITIAL_SOURCES_DISPLAY)?.map((source: any, index: number) => (
|
||||
<Card key={`${connector.type}-${source.id}-${index}`} className="p-3 hover:bg-gray-50 dark:hover:bg-gray-800 cursor-pointer">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0 w-6 h-6 flex items-center justify-center">
|
||||
{getConnectorIcon(connector.type)}
|
||||
|
|
Loading…
Add table
Reference in a new issue