mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-10 14:28:57 +00:00
feat: Introduce the RAPTOR Search.
This commit is contained in:
parent
fc937edf03
commit
a9db0a8ceb
11 changed files with 318 additions and 127 deletions
|
@ -4,32 +4,47 @@ import asyncio
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from tavily import TavilyClient
|
||||
from linkup import LinkupClient
|
||||
|
||||
from app.agents.researcher.configuration import SearchMode
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.retriever = ChucksHybridSearchRetriever(session)
|
||||
self.chunk_retriever = ChucksHybridSearchRetriever(session)
|
||||
self.document_retriever = DocumentHybridSearchRetriever(session)
|
||||
self.source_id_counter = 1
|
||||
self.counter_lock = asyncio.Lock() # Lock to protect counter in multithreaded environments
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
crawled_urls_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
crawled_urls_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
crawled_urls_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
crawled_urls_chunks = self._transform_document_results(crawled_urls_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not crawled_urls_chunks:
|
||||
|
@ -71,20 +86,31 @@ class ConnectorService:
|
|||
|
||||
return result_object, crawled_urls_chunks
|
||||
|
||||
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for files and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
files_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
files_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
files_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
files_chunks = self._transform_document_results(files_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not files_chunks:
|
||||
|
@ -126,6 +152,31 @@ class ConnectorService:
|
|||
|
||||
return result_object, files_chunks
|
||||
|
||||
def _transform_document_results(self, document_results: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Transform results from document_retriever.hybrid_search() to match the format
|
||||
expected by the processing code.
|
||||
|
||||
Args:
|
||||
document_results: Results from document_retriever.hybrid_search()
|
||||
|
||||
Returns:
|
||||
List of transformed results in the format expected by the processing code
|
||||
"""
|
||||
transformed_results = []
|
||||
for doc in document_results:
|
||||
transformed_results.append({
|
||||
'document': {
|
||||
'id': doc.get('document_id'),
|
||||
'title': doc.get('title', 'Untitled Document'),
|
||||
'document_type': doc.get('document_type'),
|
||||
'metadata': doc.get('metadata', {}),
|
||||
},
|
||||
'content': doc.get('chunks_content', doc.get('content', '')),
|
||||
'score': doc.get('score', 0.0)
|
||||
})
|
||||
return transformed_results
|
||||
|
||||
async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
@ -249,20 +300,31 @@ class ConnectorService:
|
|||
"sources": [],
|
||||
}, []
|
||||
|
||||
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for slack and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
slack_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
slack_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
slack_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
slack_chunks = self._transform_document_results(slack_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not slack_chunks:
|
||||
|
@ -323,7 +385,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, slack_chunks
|
||||
|
||||
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for Notion pages and return both the source information and langchain documents
|
||||
|
||||
|
@ -336,14 +398,25 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
notion_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
notion_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
notion_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
notion_chunks = self._transform_document_results(notion_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not notion_chunks:
|
||||
return {
|
||||
|
@ -405,7 +478,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, notion_chunks
|
||||
|
||||
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for extension data and return both the source information and langchain documents
|
||||
|
||||
|
@ -418,14 +491,25 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
extension_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="EXTENSION"
|
||||
)
|
||||
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
extension_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="EXTENSION"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
extension_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="EXTENSION"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
extension_chunks = self._transform_document_results(extension_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not extension_chunks:
|
||||
return {
|
||||
|
@ -505,7 +589,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, extension_chunks
|
||||
|
||||
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for YouTube videos and return both the source information and langchain documents
|
||||
|
||||
|
@ -518,13 +602,24 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
youtube_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
youtube_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
youtube_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
youtube_chunks = self._transform_document_results(youtube_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not youtube_chunks:
|
||||
|
@ -587,20 +682,31 @@ class ConnectorService:
|
|||
|
||||
return result_object, youtube_chunks
|
||||
|
||||
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_github(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for GitHub documents and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
github_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="GITHUB_CONNECTOR"
|
||||
)
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
github_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="GITHUB_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
github_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="GITHUB_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
github_chunks = self._transform_document_results(github_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not github_chunks:
|
||||
|
@ -643,7 +749,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, github_chunks
|
||||
|
||||
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_linear(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS) -> tuple:
|
||||
"""
|
||||
Search for Linear issues and comments and return both the source information and langchain documents
|
||||
|
||||
|
@ -656,14 +762,25 @@ class ConnectorService:
|
|||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
linear_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="LINEAR_CONNECTOR"
|
||||
)
|
||||
|
||||
if search_mode == SearchMode.CHUNKS:
|
||||
linear_chunks = await self.chunk_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="LINEAR_CONNECTOR"
|
||||
)
|
||||
elif search_mode == SearchMode.DOCUMENTS:
|
||||
linear_chunks = await self.document_retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="LINEAR_CONNECTOR"
|
||||
)
|
||||
# Transform document retriever results to match expected format
|
||||
linear_chunks = self._transform_document_results(linear_chunks)
|
||||
|
||||
# Early return if no results
|
||||
if not linear_chunks:
|
||||
return {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue