mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-05 03:59:06 +00:00
feat: Added chat_history to researcher agent
This commit is contained in:
parent
eda1d43935
commit
a58550818b
10 changed files with 129 additions and 64 deletions
|
@ -1 +0,0 @@
|
||||||
"""This is upcoming research agent. Work in progress."""
|
|
|
@ -1,6 +1,6 @@
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from .state import State
|
from .state import State
|
||||||
from .nodes import write_answer_outline, process_sections
|
from .nodes import reformulate_user_query, write_answer_outline, process_sections
|
||||||
from .configuration import Configuration
|
from .configuration import Configuration
|
||||||
from typing import TypedDict, List, Dict, Any, Optional
|
from typing import TypedDict, List, Dict, Any, Optional
|
||||||
|
|
||||||
|
@ -25,11 +25,13 @@ def build_graph():
|
||||||
workflow = StateGraph(State, config_schema=Configuration)
|
workflow = StateGraph(State, config_schema=Configuration)
|
||||||
|
|
||||||
# Add nodes to the graph
|
# Add nodes to the graph
|
||||||
|
workflow.add_node("reformulate_user_query", reformulate_user_query)
|
||||||
workflow.add_node("write_answer_outline", write_answer_outline)
|
workflow.add_node("write_answer_outline", write_answer_outline)
|
||||||
workflow.add_node("process_sections", process_sections)
|
workflow.add_node("process_sections", process_sections)
|
||||||
|
|
||||||
# Define the edges - create a linear flow
|
# Define the edges - create a linear flow
|
||||||
workflow.add_edge("__start__", "write_answer_outline")
|
workflow.add_edge("__start__", "reformulate_user_query")
|
||||||
|
workflow.add_edge("reformulate_user_query", "write_answer_outline")
|
||||||
workflow.add_edge("write_answer_outline", "process_sections")
|
workflow.add_edge("write_answer_outline", "process_sections")
|
||||||
workflow.add_edge("process_sections", "__end__")
|
workflow.add_edge("process_sections", "__end__")
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ from .state import State
|
||||||
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
from .sub_section_writer.graph import graph as sub_section_writer_graph
|
||||||
from .sub_section_writer.configuration import SubSectionType
|
from .sub_section_writer.configuration import SubSectionType
|
||||||
|
|
||||||
|
from app.utils.query_service import QueryService
|
||||||
|
|
||||||
|
|
||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
|
|
||||||
|
@ -47,6 +49,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
|
||||||
writer({"yeild_value": streaming_service._format_annotations()})
|
writer({"yeild_value": streaming_service._format_annotations()})
|
||||||
# Get configuration from runnable config
|
# Get configuration from runnable config
|
||||||
configuration = Configuration.from_runnable_config(config)
|
configuration = Configuration.from_runnable_config(config)
|
||||||
|
reformulated_query = state.reformulated_query
|
||||||
user_query = configuration.user_query
|
user_query = configuration.user_query
|
||||||
num_sections = configuration.num_sections
|
num_sections = configuration.num_sections
|
||||||
|
|
||||||
|
@ -60,7 +63,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
|
||||||
human_message_content = f"""
|
human_message_content = f"""
|
||||||
Now Please create an answer outline for the following query:
|
Now Please create an answer outline for the following query:
|
||||||
|
|
||||||
User Query: {user_query}
|
User Query: {reformulated_query}
|
||||||
Number of Sections: {num_sections}
|
Number of Sections: {num_sections}
|
||||||
|
|
||||||
Remember to format your response as valid JSON exactly matching this structure:
|
Remember to format your response as valid JSON exactly matching this structure:
|
||||||
|
@ -719,8 +722,11 @@ async def process_section_with_documents(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create the initial state with db_session
|
# Create the initial state with db_session and chat_history
|
||||||
sub_state = {"db_session": db_session}
|
sub_state = {
|
||||||
|
"db_session": db_session,
|
||||||
|
"chat_history": state.chat_history
|
||||||
|
}
|
||||||
|
|
||||||
# Invoke the sub-section writer graph
|
# Invoke the sub-section writer graph
|
||||||
print(f"Invoking sub_section_writer for: {section_title}")
|
print(f"Invoking sub_section_writer for: {section_title}")
|
||||||
|
@ -749,3 +755,23 @@ async def process_section_with_documents(
|
||||||
|
|
||||||
return f"Error processing section: {section_title}. Details: {str(e)}"
|
return f"Error processing section: {section_title}. Details: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def reformulate_user_query(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Reforms the user query based on the chat history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
configuration = Configuration.from_runnable_config(config)
|
||||||
|
user_query = configuration.user_query
|
||||||
|
chat_history_str = await QueryService.langchain_chat_history_to_str(state.chat_history)
|
||||||
|
if len(state.chat_history) == 0:
|
||||||
|
reformulated_query = user_query
|
||||||
|
else:
|
||||||
|
reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query, chat_history_str)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"reformulated_query": reformulated_query
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,9 @@ class State:
|
||||||
# Streaming service
|
# Streaming service
|
||||||
streaming_service: StreamingService
|
streaming_service: StreamingService
|
||||||
|
|
||||||
# chat_history: Optional[List[Any]] = field(default=None)
|
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
reformulated_query: Optional[str] = field(default=None)
|
||||||
# Using field to explicitly mark as part of state
|
# Using field to explicitly mark as part of state
|
||||||
answer_outline: Optional[Any] = field(default=None)
|
answer_outline: Optional[Any] = field(default=None)
|
||||||
|
|
||||||
|
|
|
@ -164,13 +164,13 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create messages for the LLM
|
# Create messages for the LLM
|
||||||
messages = [
|
messages_with_chat_history = state.chat_history + [
|
||||||
SystemMessage(content=get_citation_system_prompt()),
|
SystemMessage(content=get_citation_system_prompt()),
|
||||||
HumanMessage(content=human_message_content)
|
HumanMessage(content=human_message_content)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Call the LLM and get the response
|
# Call the LLM and get the response
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages_with_chat_history)
|
||||||
final_answer = response.content
|
final_answer = response.content
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ class State:
|
||||||
# Runtime context
|
# Runtime context
|
||||||
db_session: AsyncSession
|
db_session: AsyncSession
|
||||||
|
|
||||||
|
chat_history: Optional[List[Any]] = field(default_factory=list)
|
||||||
# OUTPUT: Populated by agent nodes
|
# OUTPUT: Populated by agent nodes
|
||||||
reranked_documents: Optional[List[Any]] = None
|
reranked_documents: Optional[List[Any]] = None
|
||||||
final_answer: Optional[str] = None
|
final_answer: Optional[str] = None
|
||||||
|
|
|
@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
from langchain.schema import HumanMessage, AIMessage
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
|
@ -20,11 +20,11 @@ async def handle_chat_data(
|
||||||
user: User = Depends(current_active_user)
|
user: User = Depends(current_active_user)
|
||||||
):
|
):
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if messages[-1].role != "user":
|
if messages[-1]['role'] != "user":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Last message must be a user message")
|
status_code=400, detail="Last message must be a user message")
|
||||||
|
|
||||||
user_query = messages[-1].content
|
user_query = messages[-1]['content']
|
||||||
search_space_id = request.data.get('search_space_id')
|
search_space_id = request.data.get('search_space_id')
|
||||||
research_mode: str = request.data.get('research_mode')
|
research_mode: str = request.data.get('research_mode')
|
||||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
selected_connectors: List[str] = request.data.get('selected_connectors')
|
||||||
|
@ -43,6 +43,21 @@ async def handle_chat_data(
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="You don't have access to this search space")
|
status_code=403, detail="You don't have access to this search space")
|
||||||
|
|
||||||
|
langchain_chat_history = []
|
||||||
|
for message in messages[:-1]:
|
||||||
|
if message['role'] == "user":
|
||||||
|
langchain_chat_history.append(HumanMessage(content=message['content']))
|
||||||
|
elif message['role'] == "assistant":
|
||||||
|
# Last annotation type will always be "ANSWER" here
|
||||||
|
answer_annotation = message['annotations'][-1]
|
||||||
|
answer_text = ""
|
||||||
|
if answer_annotation['type'] == "ANSWER":
|
||||||
|
answer_text = answer_annotation['content']
|
||||||
|
# If content is a list, join it into a single string
|
||||||
|
if isinstance(answer_text, list):
|
||||||
|
answer_text = "\n".join(answer_text)
|
||||||
|
langchain_chat_history.append(AIMessage(content=answer_text))
|
||||||
|
|
||||||
response = StreamingResponse(stream_connector_search_results(
|
response = StreamingResponse(stream_connector_search_results(
|
||||||
user_query,
|
user_query,
|
||||||
|
@ -50,7 +65,8 @@ async def handle_chat_data(
|
||||||
search_space_id, # Already converted to int in lines 32-37
|
search_space_id, # Already converted to int in lines 32-37
|
||||||
session,
|
session,
|
||||||
research_mode,
|
research_mode,
|
||||||
selected_connectors
|
selected_connectors,
|
||||||
|
langchain_chat_history
|
||||||
))
|
))
|
||||||
response.headers['x-vercel-ai-data-stream'] = 'v1'
|
response.headers['x-vercel-ai-data-stream'] = 'v1'
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -27,14 +27,14 @@ class ToolInvocation(BaseModel):
|
||||||
result: dict
|
result: dict
|
||||||
|
|
||||||
|
|
||||||
class ClientMessage(BaseModel):
|
# class ClientMessage(BaseModel):
|
||||||
role: str
|
# role: str
|
||||||
content: str
|
# content: str
|
||||||
experimental_attachments: Optional[List[ClientAttachment]] = None
|
# experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||||
toolInvocations: Optional[List[ToolInvocation]] = None
|
# toolInvocations: Optional[List[ToolInvocation]] = None
|
||||||
|
|
||||||
class AISDKChatRequest(BaseModel):
|
class AISDKChatRequest(BaseModel):
|
||||||
messages: List[ClientMessage]
|
messages: List[Any]
|
||||||
data: Optional[Dict[str, Any]] = None
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
class ChatCreate(ChatBase):
|
class ChatCreate(ChatBase):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import AsyncGenerator, List, Union
|
from typing import Any, AsyncGenerator, List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from app.agents.researcher.graph import graph as researcher_graph
|
from app.agents.researcher.graph import graph as researcher_graph
|
||||||
|
@ -13,7 +13,8 @@ async def stream_connector_search_results(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
research_mode: str,
|
research_mode: str,
|
||||||
selected_connectors: List[str]
|
selected_connectors: List[str],
|
||||||
|
langchain_chat_history: List[Any]
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Stream connector search results to the client
|
Stream connector search results to the client
|
||||||
|
@ -53,7 +54,8 @@ async def stream_connector_search_results(
|
||||||
# Initialize state with database session and streaming service
|
# Initialize state with database session and streaming service
|
||||||
initial_state = State(
|
initial_state = State(
|
||||||
db_session=session,
|
db_session=session,
|
||||||
streaming_service=streaming_service
|
streaming_service=streaming_service,
|
||||||
|
chat_history=langchain_chat_history
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the graph directly
|
# Run the graph directly
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
"""
|
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
||||||
NOTE: This is not used anymore. Might be removed in the future.
|
|
||||||
"""
|
|
||||||
from langchain.schema import HumanMessage, SystemMessage
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class QueryService:
|
class QueryService:
|
||||||
"""
|
"""
|
||||||
|
@ -10,72 +9,90 @@ class QueryService:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def reformulate_query(user_query: str) -> str:
|
async def reformulate_query_with_chat_history(user_query: str, chat_history_str: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Reformulate the user query using the STRATEGIC_LLM to make it more
|
Reformulate the user query using the STRATEGIC_LLM to make it more
|
||||||
effective for information retrieval and research purposes.
|
effective for information retrieval and research purposes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_query: The original user query
|
user_query: The original user query
|
||||||
|
chat_history: Optional list of previous chat messages
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The reformulated query
|
str: The reformulated query
|
||||||
"""
|
"""
|
||||||
if not user_query or not user_query.strip():
|
if not user_query or not user_query.strip():
|
||||||
return user_query
|
return user_query
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the strategic LLM instance from config
|
# Get the strategic LLM instance from config
|
||||||
llm = config.strategic_llm_instance
|
llm = config.strategic_llm_instance
|
||||||
|
|
||||||
# Create system message with instructions
|
# Create system message with instructions
|
||||||
system_message = SystemMessage(
|
system_message = SystemMessage(
|
||||||
content="""
|
content=f"""
|
||||||
You are an expert at reformulating user queries to optimize information retrieval.
|
You are a highly skilled AI assistant specializing in query optimization for advanced research.
|
||||||
Your job is to take a user query and reformulate it to:
|
Your primary objective is to transform a user's initial query into a highly effective search query.
|
||||||
|
This reformulated query will be used to retrieve information from diverse data sources.
|
||||||
1. Make it more specific and detailed
|
|
||||||
2. Expand ambiguous terms
|
**Chat History Context:**
|
||||||
3. Include relevant synonyms and alternative phrasings
|
{chat_history_str if chat_history_str else "No prior conversation history is available."}
|
||||||
4. Break down complex questions into their core components
|
If chat history is provided, analyze it to understand the user's evolving information needs and the broader context of their request. Use this understanding to refine the current query, ensuring it builds upon or clarifies previous interactions.
|
||||||
5. Ensure it's comprehensive for research purposes
|
|
||||||
|
**Query Reformulation Guidelines:**
|
||||||
The query will be used with the following data sources/connectors:
|
Your reformulated query should:
|
||||||
- SERPER_API: Web search for retrieving current information from the internet
|
1. **Enhance Specificity and Detail:** Add precision to narrow the search focus effectively, making the query less ambiguous and more targeted.
|
||||||
- TAVILY_API: Research-focused search API for comprehensive information
|
2. **Resolve Ambiguities:** Identify and clarify vague terms or phrases. If a term has multiple meanings, orient the query towards the most likely one given the context.
|
||||||
- SLACK_CONNECTOR: Retrieves information from indexed Slack workspace conversations
|
3. **Expand Key Concepts:** Incorporate relevant synonyms, related terms, and alternative phrasings for core concepts. This helps capture a wider range of relevant documents.
|
||||||
- NOTION_CONNECTOR: Retrieves information from indexed Notion documents and databases
|
4. **Deconstruct Complex Questions:** If the original query is multifaceted, break it down into its core searchable components or rephrase it to address each aspect clearly. The final output must still be a single, coherent query string.
|
||||||
- FILE: Searches through user's uploaded files
|
5. **Optimize for Comprehensiveness:** Ensure the query is structured to uncover all essential facets of the original request, aiming for thorough information retrieval suitable for research.
|
||||||
- CRAWLED_URL: Searches through previously crawled web pages
|
6. **Maintain User Intent:** The reformulated query must stay true to the original intent of the user's query. Do not introduce new topics or shift the focus significantly.
|
||||||
|
|
||||||
IMPORTANT: Keep the reformulated query as concise as possible while still being effective.
|
**Crucial Constraints:**
|
||||||
Avoid unnecessary verbosity and limit the query to only essential terms and concepts.
|
* **Conciseness and Effectiveness:** While aiming for comprehensiveness, the reformulated query MUST be as concise as possible. Eliminate all unnecessary verbosity. Focus on essential keywords, entities, and concepts that directly contribute to effective retrieval.
|
||||||
|
* **Single, Direct Output:** Return ONLY the reformulated query itself. Do NOT include any explanations, introductory phrases (e.g., "Reformulated query:", "Here is the optimized query:"), or any other surrounding text or markdown formatting.
|
||||||
Please optimize the query to work effectively across these different data sources.
|
|
||||||
|
Your output should be a single, optimized query string, ready for immediate use in a search system.
|
||||||
Return ONLY the reformulated query without explanations, prefixes, or commentary.
|
|
||||||
Do not include phrases like "Reformulated query:" or any other text except the query itself.
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create human message with the user query
|
# Create human message with the user query
|
||||||
human_message = HumanMessage(
|
human_message = HumanMessage(
|
||||||
content=f"Reformulate this query for better research results: {user_query}"
|
content=f"Reformulate this query for better research results: {user_query}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the response from the LLM
|
# Get the response from the LLM
|
||||||
response = await llm.agenerate(messages=[[system_message, human_message]])
|
response = await llm.agenerate(messages=[[system_message, human_message]])
|
||||||
|
|
||||||
# Extract the reformulated query from the response
|
# Extract the reformulated query from the response
|
||||||
reformulated_query = response.generations[0][0].text.strip()
|
reformulated_query = response.generations[0][0].text.strip()
|
||||||
|
|
||||||
# Return the original query if the reformulation is empty
|
# Return the original query if the reformulation is empty
|
||||||
if not reformulated_query:
|
if not reformulated_query:
|
||||||
return user_query
|
return user_query
|
||||||
|
|
||||||
return reformulated_query
|
return reformulated_query
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the error and return the original query
|
# Log the error and return the original query
|
||||||
print(f"Error reformulating query: {e}")
|
print(f"Error reformulating query: {e}")
|
||||||
return user_query
|
return user_query
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def langchain_chat_history_to_str(chat_history: List[Any]) -> str:
|
||||||
|
"""
|
||||||
|
Convert a list of chat history messages to a string.
|
||||||
|
"""
|
||||||
|
chat_history_str = "<chat_history>\n"
|
||||||
|
|
||||||
|
for chat_message in chat_history:
|
||||||
|
if isinstance(chat_message, HumanMessage):
|
||||||
|
chat_history_str += f"<user>{chat_message.content}</user>\n"
|
||||||
|
elif isinstance(chat_message, AIMessage):
|
||||||
|
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
|
||||||
|
elif isinstance(chat_message, SystemMessage):
|
||||||
|
chat_history_str += f"<system>{chat_message.content}</system>\n"
|
||||||
|
|
||||||
|
chat_history_str += "</chat_history>"
|
||||||
|
return chat_history_str
|
||||||
|
|
Loading…
Add table
Reference in a new issue