feat: Added chat_history to researcher agent

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-05-10 20:06:19 -07:00
parent eda1d43935
commit a58550818b
10 changed files with 129 additions and 64 deletions

View file

@ -1 +0,0 @@
"""This is upcoming research agent. Work in progress."""

View file

@ -1,6 +1,6 @@
from langgraph.graph import StateGraph
from .state import State
from .nodes import write_answer_outline, process_sections
from .nodes import reformulate_user_query, write_answer_outline, process_sections
from .configuration import Configuration
from typing import TypedDict, List, Dict, Any, Optional
@ -25,11 +25,13 @@ def build_graph():
workflow = StateGraph(State, config_schema=Configuration)
# Add nodes to the graph
workflow.add_node("reformulate_user_query", reformulate_user_query)
workflow.add_node("write_answer_outline", write_answer_outline)
workflow.add_node("process_sections", process_sections)
# Define the edges - create a linear flow
workflow.add_edge("__start__", "write_answer_outline")
workflow.add_edge("__start__", "reformulate_user_query")
workflow.add_edge("reformulate_user_query", "write_answer_outline")
workflow.add_edge("write_answer_outline", "process_sections")
workflow.add_edge("process_sections", "__end__")

View file

@ -16,6 +16,8 @@ from .state import State
from .sub_section_writer.graph import graph as sub_section_writer_graph
from .sub_section_writer.configuration import SubSectionType
from app.utils.query_service import QueryService
from langgraph.types import StreamWriter
@ -47,6 +49,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
writer({"yeild_value": streaming_service._format_annotations()})
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
reformulated_query = state.reformulated_query
user_query = configuration.user_query
num_sections = configuration.num_sections
@ -60,7 +63,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
human_message_content = f"""
Now Please create an answer outline for the following query:
User Query: {user_query}
User Query: {reformulated_query}
Number of Sections: {num_sections}
Remember to format your response as valid JSON exactly matching this structure:
@ -719,8 +722,11 @@ async def process_section_with_documents(
}
}
# Create the initial state with db_session
sub_state = {"db_session": db_session}
# Create the initial state with db_session and chat_history
sub_state = {
"db_session": db_session,
"chat_history": state.chat_history
}
# Invoke the sub-section writer graph
print(f"Invoking sub_section_writer for: {section_title}")
@ -749,3 +755,23 @@ async def process_section_with_documents(
return f"Error processing section: {section_title}. Details: {str(e)}"
async def reformulate_user_query(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
"""
Reforms the user query based on the chat history.
"""
configuration = Configuration.from_runnable_config(config)
user_query = configuration.user_query
chat_history_str = await QueryService.langchain_chat_history_to_str(state.chat_history)
if len(state.chat_history) == 0:
reformulated_query = user_query
else:
reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query, chat_history_str)
return {
"reformulated_query": reformulated_query
}

View file

@ -21,7 +21,9 @@ class State:
# Streaming service
streaming_service: StreamingService
# chat_history: Optional[List[Any]] = field(default=None)
chat_history: Optional[List[Any]] = field(default_factory=list)
reformulated_query: Optional[str] = field(default=None)
# Using field to explicitly mark as part of state
answer_outline: Optional[Any] = field(default=None)

View file

@ -164,13 +164,13 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
"""
# Create messages for the LLM
messages = [
messages_with_chat_history = state.chat_history + [
SystemMessage(content=get_citation_system_prompt()),
HumanMessage(content=human_message_content)
]
# Call the LLM and get the response
response = await llm.ainvoke(messages)
response = await llm.ainvoke(messages_with_chat_history)
final_answer = response.content
return {

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List, Optional, Any
from sqlalchemy.ext.asyncio import AsyncSession
@ -17,6 +17,7 @@ class State:
# Runtime context
db_session: AsyncSession
chat_history: Optional[List[Any]] = field(default_factory=list)
# OUTPUT: Populated by agent nodes
reranked_documents: Optional[List[Any]] = None
final_answer: Optional[str] = None

View file

@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain.schema import HumanMessage, AIMessage
router = APIRouter()
@router.post("/chat")
@ -20,11 +20,11 @@ async def handle_chat_data(
user: User = Depends(current_active_user)
):
messages = request.messages
if messages[-1].role != "user":
if messages[-1]['role'] != "user":
raise HTTPException(
status_code=400, detail="Last message must be a user message")
user_query = messages[-1].content
user_query = messages[-1]['content']
search_space_id = request.data.get('search_space_id')
research_mode: str = request.data.get('research_mode')
selected_connectors: List[str] = request.data.get('selected_connectors')
@ -43,6 +43,21 @@ async def handle_chat_data(
except HTTPException:
raise HTTPException(
status_code=403, detail="You don't have access to this search space")
langchain_chat_history = []
for message in messages[:-1]:
if message['role'] == "user":
langchain_chat_history.append(HumanMessage(content=message['content']))
elif message['role'] == "assistant":
# Last annotation type will always be "ANSWER" here
answer_annotation = message['annotations'][-1]
answer_text = ""
if answer_annotation['type'] == "ANSWER":
answer_text = answer_annotation['content']
# If content is a list, join it into a single string
if isinstance(answer_text, list):
answer_text = "\n".join(answer_text)
langchain_chat_history.append(AIMessage(content=answer_text))
response = StreamingResponse(stream_connector_search_results(
user_query,
@ -50,7 +65,8 @@ async def handle_chat_data(
search_space_id, # Already converted to int in lines 32-37
session,
research_mode,
selected_connectors
selected_connectors,
langchain_chat_history
))
response.headers['x-vercel-ai-data-stream'] = 'v1'
return response

View file

@ -27,14 +27,14 @@ class ToolInvocation(BaseModel):
result: dict
class ClientMessage(BaseModel):
role: str
content: str
experimental_attachments: Optional[List[ClientAttachment]] = None
toolInvocations: Optional[List[ToolInvocation]] = None
# class ClientMessage(BaseModel):
# role: str
# content: str
# experimental_attachments: Optional[List[ClientAttachment]] = None
# toolInvocations: Optional[List[ToolInvocation]] = None
class AISDKChatRequest(BaseModel):
messages: List[ClientMessage]
messages: List[Any]
data: Optional[Dict[str, Any]] = None
class ChatCreate(ChatBase):

View file

@ -1,4 +1,4 @@
from typing import AsyncGenerator, List, Union
from typing import Any, AsyncGenerator, List, Union
from uuid import UUID
from app.agents.researcher.graph import graph as researcher_graph
@ -13,7 +13,8 @@ async def stream_connector_search_results(
search_space_id: int,
session: AsyncSession,
research_mode: str,
selected_connectors: List[str]
selected_connectors: List[str],
langchain_chat_history: List[Any]
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
@ -53,7 +54,8 @@ async def stream_connector_search_results(
# Initialize state with database session and streaming service
initial_state = State(
db_session=session,
streaming_service=streaming_service
streaming_service=streaming_service,
chat_history=langchain_chat_history
)
# Run the graph directly

View file

@ -1,8 +1,7 @@
"""
NOTE: This is not used anymore. Might be removed in the future.
"""
from langchain.schema import HumanMessage, SystemMessage
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from app.config import config
from typing import Any, List, Optional
class QueryService:
"""
@ -10,72 +9,90 @@ class QueryService:
"""
@staticmethod
async def reformulate_query(user_query: str) -> str:
async def reformulate_query_with_chat_history(user_query: str, chat_history_str: Optional[str] = None) -> str:
"""
Reformulate the user query using the STRATEGIC_LLM to make it more
effective for information retrieval and research purposes.
Args:
user_query: The original user query
chat_history: Optional list of previous chat messages
Returns:
str: The reformulated query
"""
if not user_query or not user_query.strip():
return user_query
try:
# Get the strategic LLM instance from config
llm = config.strategic_llm_instance
# Create system message with instructions
system_message = SystemMessage(
content="""
You are an expert at reformulating user queries to optimize information retrieval.
Your job is to take a user query and reformulate it to:
1. Make it more specific and detailed
2. Expand ambiguous terms
3. Include relevant synonyms and alternative phrasings
4. Break down complex questions into their core components
5. Ensure it's comprehensive for research purposes
The query will be used with the following data sources/connectors:
- SERPER_API: Web search for retrieving current information from the internet
- TAVILY_API: Research-focused search API for comprehensive information
- SLACK_CONNECTOR: Retrieves information from indexed Slack workspace conversations
- NOTION_CONNECTOR: Retrieves information from indexed Notion documents and databases
- FILE: Searches through user's uploaded files
- CRAWLED_URL: Searches through previously crawled web pages
IMPORTANT: Keep the reformulated query as concise as possible while still being effective.
Avoid unnecessary verbosity and limit the query to only essential terms and concepts.
Please optimize the query to work effectively across these different data sources.
Return ONLY the reformulated query without explanations, prefixes, or commentary.
Do not include phrases like "Reformulated query:" or any other text except the query itself.
content=f"""
You are a highly skilled AI assistant specializing in query optimization for advanced research.
Your primary objective is to transform a user's initial query into a highly effective search query.
This reformulated query will be used to retrieve information from diverse data sources.
**Chat History Context:**
{chat_history_str if chat_history_str else "No prior conversation history is available."}
If chat history is provided, analyze it to understand the user's evolving information needs and the broader context of their request. Use this understanding to refine the current query, ensuring it builds upon or clarifies previous interactions.
**Query Reformulation Guidelines:**
Your reformulated query should:
1. **Enhance Specificity and Detail:** Add precision to narrow the search focus effectively, making the query less ambiguous and more targeted.
2. **Resolve Ambiguities:** Identify and clarify vague terms or phrases. If a term has multiple meanings, orient the query towards the most likely one given the context.
3. **Expand Key Concepts:** Incorporate relevant synonyms, related terms, and alternative phrasings for core concepts. This helps capture a wider range of relevant documents.
4. **Deconstruct Complex Questions:** If the original query is multifaceted, break it down into its core searchable components or rephrase it to address each aspect clearly. The final output must still be a single, coherent query string.
5. **Optimize for Comprehensiveness:** Ensure the query is structured to uncover all essential facets of the original request, aiming for thorough information retrieval suitable for research.
6. **Maintain User Intent:** The reformulated query must stay true to the original intent of the user's query. Do not introduce new topics or shift the focus significantly.
**Crucial Constraints:**
* **Conciseness and Effectiveness:** While aiming for comprehensiveness, the reformulated query MUST be as concise as possible. Eliminate all unnecessary verbosity. Focus on essential keywords, entities, and concepts that directly contribute to effective retrieval.
* **Single, Direct Output:** Return ONLY the reformulated query itself. Do NOT include any explanations, introductory phrases (e.g., "Reformulated query:", "Here is the optimized query:"), or any other surrounding text or markdown formatting.
Your output should be a single, optimized query string, ready for immediate use in a search system.
"""
)
# Create human message with the user query
human_message = HumanMessage(
content=f"Reformulate this query for better research results: {user_query}"
)
# Get the response from the LLM
response = await llm.agenerate(messages=[[system_message, human_message]])
# Extract the reformulated query from the response
reformulated_query = response.generations[0][0].text.strip()
# Return the original query if the reformulation is empty
if not reformulated_query:
return user_query
return reformulated_query
except Exception as e:
# Log the error and return the original query
print(f"Error reformulating query: {e}")
return user_query
return user_query
@staticmethod
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str:
"""
Convert a list of chat history messages to a string.
"""
chat_history_str = "<chat_history>\n"
for chat_message in chat_history:
if isinstance(chat_message, HumanMessage):
chat_history_str += f"<user>{chat_message.content}</user>\n"
elif isinstance(chat_message, AIMessage):
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
elif isinstance(chat_message, SystemMessage):
chat_history_str += f"<system>{chat_message.content}</system>\n"
chat_history_str += "</chat_history>"
return chat_history_str