mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-02 18:49:09 +00:00
feat: Integrate query reformulation in stream_connector_search_results
This commit is contained in:
parent
613b13b33b
commit
2e702902e4
2 changed files with 94 additions and 13 deletions
|
@ -8,6 +8,7 @@ from app.utils.connector_service import ConnectorService
|
||||||
from app.utils.research_service import ResearchService
|
from app.utils.research_service import ResearchService
|
||||||
from app.utils.streaming_service import StreamingService
|
from app.utils.streaming_service import StreamingService
|
||||||
from app.utils.reranker_service import RerankerService
|
from app.utils.reranker_service import RerankerService
|
||||||
|
from app.utils.query_service import QueryService
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
||||||
|
|
||||||
|
@ -37,6 +38,10 @@ async def stream_connector_search_results(
|
||||||
connector_service = ConnectorService(session)
|
connector_service = ConnectorService(session)
|
||||||
streaming_service = StreamingService()
|
streaming_service = StreamingService()
|
||||||
|
|
||||||
|
# Reformulate the user query using the strategic LLM
|
||||||
|
yield streaming_service.add_terminal_message("Reformulating your query for better results...", "info")
|
||||||
|
reformulated_query = await QueryService.reformulate_query(user_query)
|
||||||
|
yield streaming_service.add_terminal_message(f"Searching for: {reformulated_query}", "success")
|
||||||
|
|
||||||
reranker_service = RerankerService.get_reranker_instance(config)
|
reranker_service = RerankerService.get_reranker_instance(config)
|
||||||
|
|
||||||
|
@ -59,9 +64,9 @@ async def stream_connector_search_results(
|
||||||
# Send terminal message about starting search
|
# Send terminal message about starting search
|
||||||
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
||||||
|
|
||||||
# Search for crawled URLs
|
# Search for crawled URLs using reformulated query
|
||||||
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||||
user_query=user_query,
|
user_query=reformulated_query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
top_k=TOP_K
|
top_k=TOP_K
|
||||||
|
@ -86,9 +91,9 @@ async def stream_connector_search_results(
|
||||||
# Send terminal message about starting search
|
# Send terminal message about starting search
|
||||||
yield streaming_service.add_terminal_message("Starting to search for files...")
|
yield streaming_service.add_terminal_message("Starting to search for files...")
|
||||||
|
|
||||||
# Search for files
|
# Search for files using reformulated query
|
||||||
result_object, files_chunks = await connector_service.search_files(
|
result_object, files_chunks = await connector_service.search_files(
|
||||||
user_query=user_query,
|
user_query=reformulated_query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
top_k=TOP_K
|
top_k=TOP_K
|
||||||
|
@ -112,9 +117,9 @@ async def stream_connector_search_results(
|
||||||
# Send terminal message about starting search
|
# Send terminal message about starting search
|
||||||
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
||||||
|
|
||||||
# Search using Tavily API
|
# Search using Tavily API with reformulated query
|
||||||
result_object, tavily_chunks = await connector_service.search_tavily(
|
result_object, tavily_chunks = await connector_service.search_tavily(
|
||||||
user_query=user_query,
|
user_query=reformulated_query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
top_k=TOP_K
|
top_k=TOP_K
|
||||||
)
|
)
|
||||||
|
@ -137,9 +142,9 @@ async def stream_connector_search_results(
|
||||||
# Send terminal message about starting search
|
# Send terminal message about starting search
|
||||||
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
||||||
|
|
||||||
# Search using Slack API
|
# Search using Slack API with reformulated query
|
||||||
result_object, slack_chunks = await connector_service.search_slack(
|
result_object, slack_chunks = await connector_service.search_slack(
|
||||||
user_query=user_query,
|
user_query=reformulated_query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
top_k=TOP_K
|
top_k=TOP_K
|
||||||
|
@ -164,9 +169,9 @@ async def stream_connector_search_results(
|
||||||
# Send terminal message about starting search
|
# Send terminal message about starting search
|
||||||
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
||||||
|
|
||||||
# Search using Notion API
|
# Search using Notion API with reformulated query
|
||||||
result_object, notion_chunks = await connector_service.search_notion(
|
result_object, notion_chunks = await connector_service.search_notion(
|
||||||
user_query=user_query,
|
user_query=reformulated_query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
top_k=TOP_K
|
top_k=TOP_K
|
||||||
|
@ -209,8 +214,8 @@ async def stream_connector_search_results(
|
||||||
} for i, doc in enumerate(all_raw_documents)
|
} for i, doc in enumerate(all_raw_documents)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Rerank documents
|
# Rerank documents using the reformulated query
|
||||||
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
|
reranked_docs = reranker_service.rerank_documents(reformulated_query, reranker_input_docs)
|
||||||
|
|
||||||
# Sort by score in descending order
|
# Sort by score in descending order
|
||||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||||
|
@ -301,7 +306,7 @@ async def stream_connector_search_results(
|
||||||
# Start the research process in a separate task
|
# Start the research process in a separate task
|
||||||
research_task = asyncio.create_task(
|
research_task = asyncio.create_task(
|
||||||
ResearchService.stream_research(
|
ResearchService.stream_research(
|
||||||
user_query=user_query,
|
user_query=reformulated_query,
|
||||||
documents=all_langchain_documents_to_research,
|
documents=all_langchain_documents_to_research,
|
||||||
on_progress=stream_handler.handle_progress,
|
on_progress=stream_handler.handle_progress,
|
||||||
research_mode=research_mode
|
research_mode=research_mode
|
||||||
|
|
76
surfsense_backend/app/utils/query_service.py
Normal file
76
surfsense_backend/app/utils/query_service.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
from typing import Dict, Any
|
||||||
|
from langchain.schema import LLMResult, HumanMessage, SystemMessage
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
class QueryService:
|
||||||
|
"""
|
||||||
|
Service for query-related operations, including reformulation and processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def reformulate_query(user_query: str) -> str:
|
||||||
|
"""
|
||||||
|
Reformulate the user query using the STRATEGIC_LLM to make it more
|
||||||
|
effective for information retrieval and research purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_query: The original user query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The reformulated query
|
||||||
|
"""
|
||||||
|
if not user_query or not user_query.strip():
|
||||||
|
return user_query
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the strategic LLM instance from config
|
||||||
|
llm = config.strategic_llm_instance
|
||||||
|
|
||||||
|
# Create system message with instructions
|
||||||
|
system_message = SystemMessage(
|
||||||
|
content="""
|
||||||
|
You are an expert at reformulating user queries to optimize information retrieval.
|
||||||
|
Your job is to take a user query and reformulate it to:
|
||||||
|
|
||||||
|
1. Make it more specific and detailed
|
||||||
|
2. Expand ambiguous terms
|
||||||
|
3. Include relevant synonyms and alternative phrasings
|
||||||
|
4. Break down complex questions into their core components
|
||||||
|
5. Ensure it's comprehensive for research purposes
|
||||||
|
|
||||||
|
The query will be used with the following data sources/connectors:
|
||||||
|
- SERPER_API: Web search for retrieving current information from the internet
|
||||||
|
- TAVILY_API: Research-focused search API for comprehensive information
|
||||||
|
- SLACK_CONNECTOR: Retrieves information from indexed Slack workspace conversations
|
||||||
|
- NOTION_CONNECTOR: Retrieves information from indexed Notion documents and databases
|
||||||
|
- FILE: Searches through user's uploaded files
|
||||||
|
- CRAWLED_URL: Searches through previously crawled web pages
|
||||||
|
|
||||||
|
Please optimize the query to work effectively across these different data sources.
|
||||||
|
|
||||||
|
Return ONLY the reformulated query without explanations, prefixes, or commentary.
|
||||||
|
Do not include phrases like "Reformulated query:" or any other text except the query itself.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create human message with the user query
|
||||||
|
human_message = HumanMessage(
|
||||||
|
content=f"Reformulate this query for better research results: {user_query}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the response from the LLM
|
||||||
|
response = await llm.agenerate(messages=[[system_message, human_message]])
|
||||||
|
|
||||||
|
# Extract the reformulated query from the response
|
||||||
|
reformulated_query = response.generations[0][0].text.strip()
|
||||||
|
|
||||||
|
# Return the original query if the reformulation is empty
|
||||||
|
if not reformulated_query:
|
||||||
|
return user_query
|
||||||
|
|
||||||
|
return reformulated_query
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error and return the original query
|
||||||
|
print(f"Error reformulating query: {e}")
|
||||||
|
return user_query
|
Loading…
Add table
Reference in a new issue