From 2e702902e4bf1ff436ffa9588143792a10bc0a74 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Wed, 19 Mar 2025 23:57:02 -0700 Subject: [PATCH] feat: Integrate query reformulation in stream_connector_search_results --- .../tasks/stream_connector_search_results.py | 31 ++++---- surfsense_backend/app/utils/query_service.py | 76 +++++++++++++++++++ 2 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 surfsense_backend/app/utils/query_service.py diff --git a/surfsense_backend/app/tasks/stream_connector_search_results.py b/surfsense_backend/app/tasks/stream_connector_search_results.py index 35171b5..4fc9181 100644 --- a/surfsense_backend/app/tasks/stream_connector_search_results.py +++ b/surfsense_backend/app/tasks/stream_connector_search_results.py @@ -8,6 +8,7 @@ from app.utils.connector_service import ConnectorService from app.utils.research_service import ResearchService from app.utils.streaming_service import StreamingService from app.utils.reranker_service import RerankerService +from app.utils.query_service import QueryService from app.config import config from app.utils.document_converters import convert_chunks_to_langchain_documents @@ -37,6 +38,10 @@ async def stream_connector_search_results( connector_service = ConnectorService(session) streaming_service = StreamingService() + # Reformulate the user query using the strategic LLM + yield streaming_service.add_terminal_message("Reformulating your query for better results...", "info") + reformulated_query = await QueryService.reformulate_query(user_query) + yield streaming_service.add_terminal_message(f"Searching for: {reformulated_query}", "success") reranker_service = RerankerService.get_reranker_instance(config) @@ -59,9 +64,9 @@ async def stream_connector_search_results( # Send terminal message about starting search yield streaming_service.add_terminal_message("Starting to search for crawled URLs...") - # Search for crawled URLs + # Search for crawled URLs using reformulated query result_object, crawled_urls_chunks = await connector_service.search_crawled_urls( - user_query=user_query, + user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=TOP_K @@ -86,9 +91,9 @@ async def stream_connector_search_results( # Send terminal message about starting search yield streaming_service.add_terminal_message("Starting to search for files...") - # Search for files + # Search for files using reformulated query result_object, files_chunks = await connector_service.search_files( - user_query=user_query, + user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=TOP_K @@ -112,9 +117,9 @@ async def stream_connector_search_results( # Send terminal message about starting search yield streaming_service.add_terminal_message("Starting to search with Tavily API...") - # Search using Tavily API + # Search using Tavily API with reformulated query result_object, tavily_chunks = await connector_service.search_tavily( - user_query=user_query, + user_query=reformulated_query, user_id=user_id, top_k=TOP_K ) @@ -137,9 +142,9 @@ async def stream_connector_search_results( # Send terminal message about starting search yield streaming_service.add_terminal_message("Starting to search for slack connector...") - # Search using Slack API + # Search using Slack API with reformulated query result_object, slack_chunks = await connector_service.search_slack( - user_query=user_query, + user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=TOP_K @@ -164,9 +169,9 @@ async def stream_connector_search_results( # Send terminal message about starting search yield streaming_service.add_terminal_message("Starting to search for notion connector...") - # Search using Notion API + # Search using Notion API with reformulated query result_object, notion_chunks = await connector_service.search_notion( - user_query=user_query, + user_query=reformulated_query, user_id=user_id, search_space_id=search_space_id, top_k=TOP_K @@ -209,8 +214,8 @@ async def stream_connector_search_results( } for i, doc in enumerate(all_raw_documents) ] - # Rerank documents - reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs) + # Rerank documents using the reformulated query + reranked_docs = reranker_service.rerank_documents(reformulated_query, reranker_input_docs) # Sort by score in descending order reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) @@ -301,7 +306,7 @@ async def stream_connector_search_results( # Start the research process in a separate task research_task = asyncio.create_task( ResearchService.stream_research( - user_query=user_query, + user_query=reformulated_query, documents=all_langchain_documents_to_research, on_progress=stream_handler.handle_progress, research_mode=research_mode diff --git a/surfsense_backend/app/utils/query_service.py b/surfsense_backend/app/utils/query_service.py new file mode 100644 index 0000000..9c295f4 --- /dev/null +++ b/surfsense_backend/app/utils/query_service.py @@ -0,0 +1,76 @@ +from typing import Dict, Any +from langchain.schema import LLMResult, HumanMessage, SystemMessage +from app.config import config + +class QueryService: + """ + Service for query-related operations, including reformulation and processing. + """ + + @staticmethod + async def reformulate_query(user_query: str) -> str: + """ + Reformulate the user query using the STRATEGIC_LLM to make it more + effective for information retrieval and research purposes. + + Args: + user_query: The original user query + + Returns: + str: The reformulated query + """ + if not user_query or not user_query.strip(): + return user_query + + try: + # Get the strategic LLM instance from config + llm = config.strategic_llm_instance + + # Create system message with instructions + system_message = SystemMessage( + content=""" + You are an expert at reformulating user queries to optimize information retrieval. + Your job is to take a user query and reformulate it to: + + 1. Make it more specific and detailed + 2. Expand ambiguous terms + 3. Include relevant synonyms and alternative phrasings + 4. Break down complex questions into their core components + 5. Ensure it's comprehensive for research purposes + + The query will be used with the following data sources/connectors: + - SERPER_API: Web search for retrieving current information from the internet + - TAVILY_API: Research-focused search API for comprehensive information + - SLACK_CONNECTOR: Retrieves information from indexed Slack workspace conversations + - NOTION_CONNECTOR: Retrieves information from indexed Notion documents and databases + - FILE: Searches through user's uploaded files + - CRAWLED_URL: Searches through previously crawled web pages + + Please optimize the query to work effectively across these different data sources. + + Return ONLY the reformulated query without explanations, prefixes, or commentary. + Do not include phrases like "Reformulated query:" or any other text except the query itself. + """ + ) + + # Create human message with the user query + human_message = HumanMessage( + content=f"Reformulate this query for better research results: {user_query}" + ) + + # Get the response from the LLM + response = await llm.agenerate(messages=[[system_message, human_message]]) + + # Extract the reformulated query from the response + reformulated_query = response.generations[0][0].text.strip() + + # Return the original query if the reformulation is empty + if not reformulated_query: + return user_query + + return reformulated_query + + except Exception as e: + # Log the error and return the original query + print(f"Error reformulating query: {e}") + return user_query \ No newline at end of file