mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
refactor: optimized document handling and added token management in Q&A and sub-section writing agents
This commit is contained in:
parent
051580d145
commit
a22228f36e
3 changed files with 323 additions and 89 deletions
|
@ -5,6 +5,11 @@ from typing import Any, Dict
|
|||
from app.config import config as app_config
|
||||
from .prompts import get_qna_citation_system_prompt, get_qna_no_documents_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from ..utils import (
|
||||
optimize_documents_for_token_limit,
|
||||
calculate_token_count,
|
||||
format_documents_section
|
||||
)
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -82,48 +87,61 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = configuration.relevant_documents
|
||||
documents = state.reranked_documents
|
||||
user_query = configuration.user_query
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.fast_llm_instance
|
||||
|
||||
# Check if we have documents to determine which prompt to use
|
||||
has_documents = documents and len(documents) > 0
|
||||
# Determine if we have documents and optimize for token limits
|
||||
has_documents_initially = documents and len(documents) > 0
|
||||
|
||||
# Prepare documents for citation formatting (if any)
|
||||
documents_text = ""
|
||||
if has_documents:
|
||||
formatted_documents = []
|
||||
for _i, doc in enumerate(documents):
|
||||
# Extract content and metadata
|
||||
content = doc.get("content", "")
|
||||
doc_info = doc.get("document", {})
|
||||
document_id = doc_info.get("id") # Use document ID
|
||||
|
||||
# Format document according to the citation system prompt's expected format
|
||||
formatted_doc = f"""
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>{document_id}</source_id>
|
||||
<source_type>{doc_info.get("document_type", "CRAWLED_URL")}</source_type>
|
||||
</metadata>
|
||||
<content>
|
||||
{content}
|
||||
</content>
|
||||
</document>
|
||||
"""
|
||||
formatted_documents.append(formatted_doc)
|
||||
if has_documents_initially:
|
||||
# Create base message template for token calculation (without documents)
|
||||
base_human_message_template = f"""
|
||||
|
||||
# Create the formatted documents text
|
||||
documents_text = f"""
|
||||
Source material from your personal knowledge base:
|
||||
<documents>
|
||||
{"\n".join(formatted_documents)}
|
||||
</documents>
|
||||
User's question:
|
||||
<user_query>
|
||||
{user_query}
|
||||
</user_query>
|
||||
|
||||
Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner.
|
||||
"""
|
||||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = get_qna_citation_system_prompt()
|
||||
base_messages = state.chat_history + [
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template)
|
||||
]
|
||||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
||||
documents, base_messages, app_config.FAST_LLM
|
||||
)
|
||||
|
||||
# Update state based on optimization result
|
||||
documents = optimized_documents
|
||||
has_documents = has_optimized_documents
|
||||
else:
|
||||
has_documents = False
|
||||
|
||||
# Choose system prompt based on final document availability
|
||||
system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt()
|
||||
|
||||
# Generate documents section
|
||||
documents_text = format_documents_section(
|
||||
documents,
|
||||
"Source material from your personal knowledge base"
|
||||
) if has_documents else ""
|
||||
|
||||
# Create final human message content
|
||||
instruction_text = (
|
||||
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
|
||||
if has_documents else
|
||||
"Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
|
||||
)
|
||||
|
||||
# Construct a clear, structured query for the LLM
|
||||
human_message_content = f"""
|
||||
{documents_text}
|
||||
|
||||
|
@ -132,18 +150,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
{user_query}
|
||||
</user_query>
|
||||
|
||||
{"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner." if has_documents else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."}
|
||||
{instruction_text}
|
||||
"""
|
||||
|
||||
# Choose the appropriate system prompt based on document availability
|
||||
system_prompt = get_qna_citation_system_prompt() if has_documents else get_qna_no_documents_system_prompt()
|
||||
|
||||
# Create messages for the LLM, including chat history for context
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = state.chat_history + [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content)
|
||||
]
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
|
||||
# Call the LLM and get the response
|
||||
response = await llm.ainvoke(messages_with_chat_history)
|
||||
final_answer = response.content
|
||||
|
|
|
@ -6,6 +6,11 @@ from app.config import config as app_config
|
|||
from .prompts import get_citation_system_prompt, get_no_documents_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from .configuration import SubSectionType
|
||||
from ..utils import (
|
||||
optimize_documents_for_token_limit,
|
||||
calculate_token_count,
|
||||
format_documents_section
|
||||
)
|
||||
|
||||
async def rerank_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -89,64 +94,87 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = configuration.relevant_documents
|
||||
documents = state.reranked_documents
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.fast_llm_instance
|
||||
|
||||
# Check if we have documents to determine which prompt to use
|
||||
has_documents = documents and len(documents) > 0
|
||||
|
||||
# Prepare documents for citation formatting (if any)
|
||||
documents_text = ""
|
||||
if has_documents:
|
||||
formatted_documents = []
|
||||
for i, doc in enumerate(documents):
|
||||
# Extract content and metadata
|
||||
content = doc.get("content", "")
|
||||
doc_info = doc.get("document", {})
|
||||
document_id = doc_info.get("id") # Use document ID
|
||||
|
||||
# Format document according to the citation system prompt's expected format
|
||||
formatted_doc = f"""
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>{document_id}</source_id>
|
||||
<source_type>{doc_info.get("document_type", "CRAWLED_URL")}</source_type>
|
||||
</metadata>
|
||||
<content>
|
||||
{content}
|
||||
</content>
|
||||
</document>
|
||||
"""
|
||||
formatted_documents.append(formatted_doc)
|
||||
|
||||
documents_text = f"""
|
||||
Source material:
|
||||
<documents>
|
||||
{"\n".join(formatted_documents)}
|
||||
</documents>
|
||||
"""
|
||||
|
||||
# Create the query that uses the section title and questions
|
||||
# Extract configuration data
|
||||
section_title = configuration.sub_section_title
|
||||
sub_section_questions = configuration.sub_section_questions
|
||||
user_query = configuration.user_query # Get the original user query
|
||||
user_query = configuration.user_query
|
||||
sub_section_type = configuration.sub_section_type
|
||||
|
||||
# Format the questions as bullet points for clarity
|
||||
questions_text = "\n".join([f"- {question}" for question in sub_section_questions])
|
||||
|
||||
# Provide more context based on the subsection type
|
||||
section_position_context = ""
|
||||
if sub_section_type == SubSectionType.START:
|
||||
section_position_context = "This is the INTRODUCTION section. "
|
||||
elif sub_section_type == SubSectionType.MIDDLE:
|
||||
section_position_context = "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section."
|
||||
elif sub_section_type == SubSectionType.END:
|
||||
section_position_context = "This is the CONCLUSION section. Focus on summarizing key points, providing closure."
|
||||
# Provide context based on the subsection type
|
||||
section_position_context_map = {
|
||||
SubSectionType.START: "This is the INTRODUCTION section.",
|
||||
SubSectionType.MIDDLE: "This is a MIDDLE section. Ensure this content flows naturally from previous sections and into subsequent ones. This could be any middle section in the document, so maintain coherence with the overall structure while addressing the specific topic of this section. Do not provide any conclusions in this section, as conclusions should only appear in the final section.",
|
||||
SubSectionType.END: "This is the CONCLUSION section. Focus on summarizing key points, providing closure."
|
||||
}
|
||||
section_position_context = section_position_context_map.get(sub_section_type, "")
|
||||
|
||||
# Determine if we have documents and optimize for token limits
|
||||
has_documents_initially = documents and len(documents) > 0
|
||||
|
||||
if has_documents_initially:
|
||||
# Create base message template for token calculation (without documents)
|
||||
base_human_message_template = f"""
|
||||
|
||||
Now user's query is:
|
||||
<user_query>
|
||||
{user_query}
|
||||
</user_query>
|
||||
|
||||
The sub-section title is:
|
||||
<sub_section_title>
|
||||
{section_title}
|
||||
</sub_section_title>
|
||||
|
||||
<section_position>
|
||||
{section_position_context}
|
||||
</section_position>
|
||||
|
||||
<guiding_questions>
|
||||
{questions_text}
|
||||
</guiding_questions>
|
||||
|
||||
Please write content for this sub-section using the provided source material and cite all information appropriately.
|
||||
"""
|
||||
|
||||
# Use initial system prompt for token calculation
|
||||
initial_system_prompt = get_citation_system_prompt()
|
||||
base_messages = state.chat_history + [
|
||||
SystemMessage(content=initial_system_prompt),
|
||||
HumanMessage(content=base_human_message_template)
|
||||
]
|
||||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
||||
documents, base_messages, app_config.FAST_LLM
|
||||
)
|
||||
|
||||
# Update state based on optimization result
|
||||
documents = optimized_documents
|
||||
has_documents = has_optimized_documents
|
||||
else:
|
||||
has_documents = False
|
||||
|
||||
# Choose system prompt based on final document availability
|
||||
system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt()
|
||||
|
||||
# Generate documents section
|
||||
documents_text = format_documents_section(documents, "Source material") if has_documents else ""
|
||||
|
||||
# Create final human message content
|
||||
instruction_text = (
|
||||
"Please write content for this sub-section using the provided source material and cite all information appropriately."
|
||||
if has_documents else
|
||||
"Please write content for this sub-section based on our conversation history and your general knowledge."
|
||||
)
|
||||
|
||||
# Construct a clear, structured query for the LLM
|
||||
human_message_content = f"""
|
||||
{documents_text}
|
||||
|
||||
|
@ -168,18 +196,19 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
{questions_text}
|
||||
</guiding_questions>
|
||||
|
||||
{"Please write content for this sub-section using the provided source material and cite all information appropriately." if has_documents else "Please write content for this sub-section based on our conversation history and your general knowledge."}
|
||||
{instruction_text}
|
||||
"""
|
||||
|
||||
# Choose the appropriate system prompt based on document availability
|
||||
system_prompt = get_citation_system_prompt() if has_documents else get_no_documents_system_prompt()
|
||||
|
||||
# Create messages for the LLM
|
||||
# Create final messages for the LLM
|
||||
messages_with_chat_history = state.chat_history + [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=human_message_content)
|
||||
]
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
# Call the LLM and get the response
|
||||
response = await llm.ainvoke(messages_with_chat_history)
|
||||
final_answer = response.content
|
||||
|
|
185
surfsense_backend/app/agents/researcher/utils.py
Normal file
185
surfsense_backend/app/agents/researcher/utils.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from langchain_core.messages import BaseMessage
|
||||
from litellm import token_counter, get_model_info
|
||||
from app.config import config as app_config
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
"""Information about a document and its token cost."""
|
||||
index: int
|
||||
document: Dict[str, Any]
|
||||
formatted_content: str
|
||||
token_count: int
|
||||
|
||||
|
||||
def convert_langchain_messages_to_dict(messages: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
"""Convert LangChain messages to format expected by token_counter."""
|
||||
role_mapping = {
|
||||
'system': 'system',
|
||||
'human': 'user',
|
||||
'ai': 'assistant'
|
||||
}
|
||||
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = role_mapping.get(getattr(msg, 'type', None), 'user')
|
||||
converted_messages.append({
|
||||
"role": role,
|
||||
"content": str(msg.content)
|
||||
})
|
||||
|
||||
return converted_messages
|
||||
|
||||
|
||||
def format_document_for_citation(document: Dict[str, Any]) -> str:
|
||||
"""Format a single document for citation in the standard XML format."""
|
||||
content = document.get("content", "")
|
||||
doc_info = document.get("document", {})
|
||||
document_id = doc_info.get("id", "")
|
||||
document_type = doc_info.get("document_type", "CRAWLED_URL")
|
||||
|
||||
return f"""<document>
|
||||
<metadata>
|
||||
<source_id>{document_id}</source_id>
|
||||
<source_type>{document_type}</source_type>
|
||||
</metadata>
|
||||
<content>
|
||||
{content}
|
||||
</content>
|
||||
</document>"""
|
||||
|
||||
|
||||
def format_documents_section(documents: List[Dict[str, Any]], section_title: str = "Source material") -> str:
|
||||
"""Format multiple documents into a complete documents section."""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
formatted_docs = [format_document_for_citation(doc) for doc in documents]
|
||||
|
||||
return f"""{section_title}:
|
||||
<documents>
|
||||
{chr(10).join(formatted_docs)}
|
||||
</documents>"""
|
||||
|
||||
|
||||
def calculate_document_token_costs(documents: List[Dict[str, Any]], model: str) -> List[DocumentTokenInfo]:
|
||||
"""Pre-calculate token costs for each document."""
|
||||
document_token_info = []
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
formatted_doc = format_document_for_citation(doc)
|
||||
|
||||
# Calculate token count for this document
|
||||
token_count = token_counter(
|
||||
messages=[{"role": "user", "content": formatted_doc}],
|
||||
model=model
|
||||
)
|
||||
|
||||
document_token_info.append(DocumentTokenInfo(
|
||||
index=i,
|
||||
document=doc,
|
||||
formatted_content=formatted_doc,
|
||||
token_count=token_count
|
||||
))
|
||||
|
||||
return document_token_info
|
||||
|
||||
|
||||
def find_optimal_documents_with_binary_search(
|
||||
document_tokens: List[DocumentTokenInfo],
|
||||
available_tokens: int
|
||||
) -> List[DocumentTokenInfo]:
|
||||
"""Use binary search to find the maximum number of documents that fit within token limit."""
|
||||
if not document_tokens or available_tokens <= 0:
|
||||
return []
|
||||
|
||||
left, right = 0, len(document_tokens)
|
||||
optimal_docs = []
|
||||
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
current_docs = document_tokens[:mid]
|
||||
current_token_sum = sum(
|
||||
doc_info.token_count for doc_info in current_docs)
|
||||
|
||||
if current_token_sum <= available_tokens:
|
||||
optimal_docs = current_docs
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
|
||||
return optimal_docs
|
||||
|
||||
|
||||
def get_model_context_window(model_name: str) -> int:
|
||||
"""Get the total context window size for a model (input + output tokens)."""
|
||||
try:
|
||||
model_info = get_model_info(model_name)
|
||||
context_window = model_info.get(
|
||||
'max_input_tokens', 4096) # Default fallback
|
||||
return context_window
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}")
|
||||
return 4096 # Conservative fallback
|
||||
|
||||
|
||||
def optimize_documents_for_token_limit(
|
||||
documents: List[Dict[str, Any]],
|
||||
base_messages: List[BaseMessage],
|
||||
model_name: str = None
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
Optimize documents to fit within token limits using binary search.
|
||||
|
||||
Args:
|
||||
documents: List of documents with content and metadata
|
||||
base_messages: Base messages without documents (chat history + system + human message template)
|
||||
model_name: Model name for token counting (defaults to app_config.FAST_LLM)
|
||||
output_token_buffer: Number of tokens to reserve for model output
|
||||
|
||||
Returns:
|
||||
Tuple of (optimized_documents, has_documents_remaining)
|
||||
"""
|
||||
if not documents:
|
||||
return [], False
|
||||
|
||||
model = model_name or app_config.FAST_LLM
|
||||
context_window = get_model_context_window(model)
|
||||
|
||||
# Calculate base token cost
|
||||
base_messages_dict = convert_langchain_messages_to_dict(base_messages)
|
||||
base_tokens = token_counter(messages=base_messages_dict, model=model)
|
||||
available_tokens_for_docs = context_window - base_tokens
|
||||
|
||||
print(
|
||||
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}")
|
||||
|
||||
if available_tokens_for_docs <= 0:
|
||||
print("No tokens available for documents after base content and output buffer")
|
||||
return [], False
|
||||
|
||||
# Calculate token costs for all documents
|
||||
document_token_info = calculate_document_token_costs(documents, model)
|
||||
|
||||
# Find optimal number of documents using binary search
|
||||
optimal_doc_info = find_optimal_documents_with_binary_search(
|
||||
document_token_info,
|
||||
available_tokens_for_docs
|
||||
)
|
||||
|
||||
# Extract the original document objects
|
||||
optimized_documents = [doc_info.document for doc_info in optimal_doc_info]
|
||||
has_documents_remaining = len(optimized_documents) > 0
|
||||
|
||||
print(
|
||||
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents")
|
||||
|
||||
return optimized_documents, has_documents_remaining
|
||||
|
||||
|
||||
def calculate_token_count(messages: List[BaseMessage], model_name: str = None) -> int:
|
||||
"""Calculate token count for a list of LangChain messages."""
|
||||
model = model_name or app_config.FAST_LLM
|
||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||
return token_counter(messages=messages_dict, model=model)
|
Loading…
Add table
Reference in a new issue