mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
340 lines
14 KiB
Python
340 lines
14 KiB
Python
import json
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List, AsyncGenerator, Dict, Any
|
|
import asyncio
|
|
import re
|
|
|
|
from app.utils.connector_service import ConnectorService
|
|
from app.utils.research_service import ResearchService
|
|
from app.utils.streaming_service import StreamingService
|
|
from app.utils.reranker_service import RerankerService
|
|
from app.config import config
|
|
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
|
|
|
async def stream_connector_search_results(
|
|
user_query: str,
|
|
user_id: int,
|
|
search_space_id: int,
|
|
session: AsyncSession,
|
|
research_mode: str,
|
|
selected_connectors: List[str]
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Stream connector search results to the client
|
|
|
|
Args:
|
|
user_query: The user's query
|
|
user_id: The user's ID
|
|
search_space_id: The search space ID
|
|
session: The database session
|
|
research_mode: The research mode
|
|
selected_connectors: List of selected connectors
|
|
|
|
Yields:
|
|
str: Formatted response strings
|
|
"""
|
|
# Initialize services
|
|
connector_service = ConnectorService(session)
|
|
streaming_service = StreamingService()
|
|
|
|
|
|
reranker_service = RerankerService.get_reranker_instance(config)
|
|
|
|
all_raw_documents = [] # Store all raw documents before reranking
|
|
all_sources = []
|
|
TOP_K = 20
|
|
|
|
if research_mode == "GENERAL":
|
|
TOP_K = 20
|
|
elif research_mode == "DEEP":
|
|
TOP_K = 40
|
|
elif research_mode == "DEEPER":
|
|
TOP_K = 60
|
|
|
|
|
|
# Process each selected connector
|
|
for connector in selected_connectors:
|
|
# Crawled URLs
|
|
if connector == "CRAWLED_URL":
|
|
# Send terminal message about starting search
|
|
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
|
|
|
# Search for crawled URLs
|
|
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
|
user_query=user_query,
|
|
user_id=user_id,
|
|
search_space_id=search_space_id,
|
|
top_k=TOP_K
|
|
)
|
|
|
|
# Send terminal message about search results
|
|
yield streaming_service.add_terminal_message(
|
|
f"Found {len(result_object['sources'])} relevant crawled URLs",
|
|
"success"
|
|
)
|
|
|
|
# Update sources
|
|
all_sources.append(result_object)
|
|
yield streaming_service.update_sources(all_sources)
|
|
|
|
# Add documents to collection
|
|
all_raw_documents.extend(crawled_urls_chunks)
|
|
|
|
|
|
# Files
|
|
if connector == "FILE":
|
|
# Send terminal message about starting search
|
|
yield streaming_service.add_terminal_message("Starting to search for files...")
|
|
|
|
# Search for files
|
|
result_object, files_chunks = await connector_service.search_files(
|
|
user_query=user_query,
|
|
user_id=user_id,
|
|
search_space_id=search_space_id,
|
|
top_k=TOP_K
|
|
)
|
|
|
|
# Send terminal message about search results
|
|
yield streaming_service.add_terminal_message(
|
|
f"Found {len(result_object['sources'])} relevant files",
|
|
"success"
|
|
)
|
|
|
|
# Update sources
|
|
all_sources.append(result_object)
|
|
yield streaming_service.update_sources(all_sources)
|
|
|
|
# Add documents to collection
|
|
all_raw_documents.extend(files_chunks)
|
|
|
|
# Tavily Connector
|
|
if connector == "TAVILY_API":
|
|
# Send terminal message about starting search
|
|
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
|
|
|
# Search using Tavily API
|
|
result_object, tavily_chunks = await connector_service.search_tavily(
|
|
user_query=user_query,
|
|
user_id=user_id,
|
|
top_k=TOP_K
|
|
)
|
|
|
|
# Send terminal message about search results
|
|
yield streaming_service.add_terminal_message(
|
|
f"Found {len(result_object['sources'])} relevant results from Tavily",
|
|
"success"
|
|
)
|
|
|
|
# Update sources
|
|
all_sources.append(result_object)
|
|
yield streaming_service.update_sources(all_sources)
|
|
|
|
# Add documents to collection
|
|
all_raw_documents.extend(tavily_chunks)
|
|
|
|
# Slack Connector
|
|
if connector == "SLACK_CONNECTOR":
|
|
# Send terminal message about starting search
|
|
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
|
|
|
# Search using Slack API
|
|
result_object, slack_chunks = await connector_service.search_slack(
|
|
user_query=user_query,
|
|
user_id=user_id,
|
|
search_space_id=search_space_id,
|
|
top_k=TOP_K
|
|
)
|
|
|
|
# Send terminal message about search results
|
|
yield streaming_service.add_terminal_message(
|
|
f"Found {len(result_object['sources'])} relevant results from Slack",
|
|
"success"
|
|
)
|
|
|
|
# Update sources
|
|
all_sources.append(result_object)
|
|
yield streaming_service.update_sources(all_sources)
|
|
|
|
# Add documents to collection
|
|
all_raw_documents.extend(slack_chunks)
|
|
|
|
|
|
# Notion Connector
|
|
if connector == "NOTION_CONNECTOR":
|
|
# Send terminal message about starting search
|
|
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
|
|
|
# Search using Notion API
|
|
result_object, notion_chunks = await connector_service.search_notion(
|
|
user_query=user_query,
|
|
user_id=user_id,
|
|
search_space_id=search_space_id,
|
|
top_k=TOP_K
|
|
)
|
|
|
|
# Send terminal message about search results
|
|
yield streaming_service.add_terminal_message(
|
|
f"Found {len(result_object['sources'])} relevant results from Notion",
|
|
"success"
|
|
)
|
|
|
|
# Update sources
|
|
all_sources.append(result_object)
|
|
yield streaming_service.update_sources(all_sources)
|
|
|
|
# Add documents to collection
|
|
all_raw_documents.extend(notion_chunks)
|
|
|
|
|
|
|
|
|
|
# If we have documents to research
|
|
if all_raw_documents:
|
|
# Rerank all documents if reranker is available
|
|
if reranker_service:
|
|
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
|
|
|
|
# Convert documents to format expected by reranker
|
|
reranker_input_docs = [
|
|
{
|
|
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
|
"content": doc.get("content", ""),
|
|
"score": doc.get("score", 0.0),
|
|
"document": {
|
|
"id": doc.get("document", {}).get("id", ""),
|
|
"title": doc.get("document", {}).get("title", ""),
|
|
"document_type": doc.get("document", {}).get("document_type", ""),
|
|
"metadata": doc.get("document", {}).get("metadata", {})
|
|
}
|
|
} for i, doc in enumerate(all_raw_documents)
|
|
]
|
|
|
|
# Rerank documents
|
|
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
|
|
|
|
# Sort by score in descending order
|
|
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
|
|
|
|
|
|
# Convert back to langchain documents format
|
|
from langchain.schema import Document as LangchainDocument
|
|
all_langchain_documents_to_research = [
|
|
LangchainDocument(
|
|
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
|
|
metadata={
|
|
# **doc.get("document", {}).get("metadata", {}),
|
|
# "score": doc.get("score", 0.0),
|
|
# "rank": doc.get("rank", 0),
|
|
# "document_id": doc.get("document", {}).get("id", ""),
|
|
# "document_title": doc.get("document", {}).get("title", ""),
|
|
# "document_type": doc.get("document", {}).get("document_type", ""),
|
|
# # Explicitly set source_id for citation purposes
|
|
"source_id": str(doc.get("document", {}).get("id", ""))
|
|
}
|
|
) for doc in reranked_docs
|
|
]
|
|
|
|
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
|
|
else:
|
|
# Use raw documents if no reranker is available
|
|
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
|
|
|
|
# Send terminal message about starting research
|
|
yield streaming_service.add_terminal_message("Starting to research...", "info")
|
|
|
|
# Create a buffer to collect report content
|
|
report_buffer = []
|
|
|
|
|
|
# Use the streaming research method
|
|
yield streaming_service.add_terminal_message("Generating report...", "info")
|
|
|
|
# Create a wrapper to handle the streaming
|
|
class StreamHandler:
|
|
def __init__(self):
|
|
self.queue = asyncio.Queue()
|
|
|
|
async def handle_progress(self, data):
|
|
result = None
|
|
if data.get("type") == "logs":
|
|
# Handle log messages
|
|
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
|
|
elif data.get("type") == "report":
|
|
# Handle report content
|
|
content = data.get("output", "")
|
|
|
|
# Fix incorrect citation formats using regex
|
|
|
|
# More specific pattern to match only numeric citations in markdown-style links
|
|
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
|
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
|
|
|
# Replace with just [X] where X is the number
|
|
content = re.sub(pattern, r'[\1]', content)
|
|
|
|
# Also match other incorrect formats like ([1]) and convert to [1]
|
|
# Only match if the content inside brackets is a number
|
|
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
|
|
|
|
report_buffer.append(content)
|
|
# Update the answer with the accumulated content
|
|
result = streaming_service.update_answer(report_buffer)
|
|
|
|
if result:
|
|
await self.queue.put(result)
|
|
return result
|
|
|
|
async def get_next(self):
|
|
try:
|
|
return await self.queue.get()
|
|
except Exception as e:
|
|
print(f"Error getting next item from queue: {e}")
|
|
return None
|
|
|
|
def task_done(self):
|
|
self.queue.task_done()
|
|
|
|
# Create the stream handler
|
|
stream_handler = StreamHandler()
|
|
|
|
# Start the research process in a separate task
|
|
research_task = asyncio.create_task(
|
|
ResearchService.stream_research(
|
|
user_query=user_query,
|
|
documents=all_langchain_documents_to_research,
|
|
on_progress=stream_handler.handle_progress,
|
|
research_mode=research_mode
|
|
)
|
|
)
|
|
|
|
# Stream results as they become available
|
|
while not research_task.done() or not stream_handler.queue.empty():
|
|
try:
|
|
# Get the next result with a timeout
|
|
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
|
|
stream_handler.task_done()
|
|
yield result
|
|
except asyncio.TimeoutError:
|
|
# No result available yet, check if the research task is done
|
|
if research_task.done():
|
|
# If the queue is empty and the task is done, we're finished
|
|
if stream_handler.queue.empty():
|
|
break
|
|
|
|
# Get the final report
|
|
try:
|
|
final_report = await research_task
|
|
|
|
# Send terminal message about research completion
|
|
yield streaming_service.add_terminal_message("Research completed", "success")
|
|
|
|
# Update the answer with the final report
|
|
final_report_lines = final_report.split('\n')
|
|
yield streaming_service.update_answer(final_report_lines)
|
|
except Exception as e:
|
|
# Handle any exceptions
|
|
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
|
|
|
|
# Send completion message
|
|
yield streaming_service.format_completion()
|