feat: Integrate query reformulation in stream_connector_search_results

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-03-19 23:57:02 -07:00
parent 613b13b33b
commit 2e702902e4
2 changed files with 94 additions and 13 deletions

View file

@ -8,6 +8,7 @@ from app.utils.connector_service import ConnectorService
from app.utils.research_service import ResearchService from app.utils.research_service import ResearchService
from app.utils.streaming_service import StreamingService from app.utils.streaming_service import StreamingService
from app.utils.reranker_service import RerankerService from app.utils.reranker_service import RerankerService
from app.utils.query_service import QueryService
from app.config import config from app.config import config
from app.utils.document_converters import convert_chunks_to_langchain_documents from app.utils.document_converters import convert_chunks_to_langchain_documents
@ -37,6 +38,10 @@ async def stream_connector_search_results(
connector_service = ConnectorService(session) connector_service = ConnectorService(session)
streaming_service = StreamingService() streaming_service = StreamingService()
# Reformulate the user query using the strategic LLM
yield streaming_service.add_terminal_message("Reformulating your query for better results...", "info")
reformulated_query = await QueryService.reformulate_query(user_query)
yield streaming_service.add_terminal_message(f"Searching for: {reformulated_query}", "success")
reranker_service = RerankerService.get_reranker_instance(config) reranker_service = RerankerService.get_reranker_instance(config)
@ -59,9 +64,9 @@ async def stream_connector_search_results(
# Send terminal message about starting search # Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...") yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
# Search for crawled URLs # Search for crawled URLs using reformulated query
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls( result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=user_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=TOP_K top_k=TOP_K
@ -86,9 +91,9 @@ async def stream_connector_search_results(
# Send terminal message about starting search # Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for files...") yield streaming_service.add_terminal_message("Starting to search for files...")
# Search for files # Search for files using reformulated query
result_object, files_chunks = await connector_service.search_files( result_object, files_chunks = await connector_service.search_files(
user_query=user_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=TOP_K top_k=TOP_K
@ -112,9 +117,9 @@ async def stream_connector_search_results(
# Send terminal message about starting search # Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search with Tavily API...") yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
# Search using Tavily API # Search using Tavily API with reformulated query
result_object, tavily_chunks = await connector_service.search_tavily( result_object, tavily_chunks = await connector_service.search_tavily(
user_query=user_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
top_k=TOP_K top_k=TOP_K
) )
@ -137,9 +142,9 @@ async def stream_connector_search_results(
# Send terminal message about starting search # Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for slack connector...") yield streaming_service.add_terminal_message("Starting to search for slack connector...")
# Search using Slack API # Search using Slack API with reformulated query
result_object, slack_chunks = await connector_service.search_slack( result_object, slack_chunks = await connector_service.search_slack(
user_query=user_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=TOP_K top_k=TOP_K
@ -164,9 +169,9 @@ async def stream_connector_search_results(
# Send terminal message about starting search # Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for notion connector...") yield streaming_service.add_terminal_message("Starting to search for notion connector...")
# Search using Notion API # Search using Notion API with reformulated query
result_object, notion_chunks = await connector_service.search_notion( result_object, notion_chunks = await connector_service.search_notion(
user_query=user_query, user_query=reformulated_query,
user_id=user_id, user_id=user_id,
search_space_id=search_space_id, search_space_id=search_space_id,
top_k=TOP_K top_k=TOP_K
@ -209,8 +214,8 @@ async def stream_connector_search_results(
} for i, doc in enumerate(all_raw_documents) } for i, doc in enumerate(all_raw_documents)
] ]
# Rerank documents # Rerank documents using the reformulated query
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs) reranked_docs = reranker_service.rerank_documents(reformulated_query, reranker_input_docs)
# Sort by score in descending order # Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
@ -301,7 +306,7 @@ async def stream_connector_search_results(
# Start the research process in a separate task # Start the research process in a separate task
research_task = asyncio.create_task( research_task = asyncio.create_task(
ResearchService.stream_research( ResearchService.stream_research(
user_query=user_query, user_query=reformulated_query,
documents=all_langchain_documents_to_research, documents=all_langchain_documents_to_research,
on_progress=stream_handler.handle_progress, on_progress=stream_handler.handle_progress,
research_mode=research_mode research_mode=research_mode

View file

@ -0,0 +1,76 @@
from typing import Dict, Any
from langchain.schema import LLMResult, HumanMessage, SystemMessage
from app.config import config
class QueryService:
"""
Service for query-related operations, including reformulation and processing.
"""
@staticmethod
async def reformulate_query(user_query: str) -> str:
"""
Reformulate the user query using the STRATEGIC_LLM to make it more
effective for information retrieval and research purposes.
Args:
user_query: The original user query
Returns:
str: The reformulated query
"""
if not user_query or not user_query.strip():
return user_query
try:
# Get the strategic LLM instance from config
llm = config.strategic_llm_instance
# Create system message with instructions
system_message = SystemMessage(
content="""
You are an expert at reformulating user queries to optimize information retrieval.
Your job is to take a user query and reformulate it to:
1. Make it more specific and detailed
2. Expand ambiguous terms
3. Include relevant synonyms and alternative phrasings
4. Break down complex questions into their core components
5. Ensure it's comprehensive for research purposes
The query will be used with the following data sources/connectors:
- SERPER_API: Web search for retrieving current information from the internet
- TAVILY_API: Research-focused search API for comprehensive information
- SLACK_CONNECTOR: Retrieves information from indexed Slack workspace conversations
- NOTION_CONNECTOR: Retrieves information from indexed Notion documents and databases
- FILE: Searches through user's uploaded files
- CRAWLED_URL: Searches through previously crawled web pages
Please optimize the query to work effectively across these different data sources.
Return ONLY the reformulated query without explanations, prefixes, or commentary.
Do not include phrases like "Reformulated query:" or any other text except the query itself.
"""
)
# Create human message with the user query
human_message = HumanMessage(
content=f"Reformulate this query for better research results: {user_query}"
)
# Get the response from the LLM
response = await llm.agenerate(messages=[[system_message, human_message]])
# Extract the reformulated query from the response
reformulated_query = response.generations[0][0].text.strip()
# Return the original query if the reformulation is empty
if not reformulated_query:
return user_query
return reformulated_query
except Exception as e:
# Log the error and return the original query
print(f"Error reformulating query: {e}")
return user_query