from typing import Any, NamedTuple
from langchain_core.messages import BaseMessage
from litellm import get_model_info, token_counter
from pydantic import BaseModel, Field
class Section(BaseModel):
"""A section in the answer outline."""
section_id: int = Field(..., description="The zero-based index of the section")
section_title: str = Field(..., description="The title of the section")
questions: list[str] = Field(
..., description="Questions to research for this section"
)
class AnswerOutline(BaseModel):
"""The complete answer outline with all sections."""
answer_outline: list[Section] = Field(
..., description="List of sections in the answer outline"
)
class DocumentTokenInfo(NamedTuple):
"""Information about a document and its token cost."""
index: int
document: dict[str, Any]
formatted_content: str
token_count: int
def get_connector_emoji(connector_name: str) -> str:
"""Get an appropriate emoji for a connector type."""
connector_emojis = {
"YOUTUBE_VIDEO": "📹",
"EXTENSION": "🧩",
"CRAWLED_URL": "🌐",
"FILE": "📄",
"SLACK_CONNECTOR": "💬",
"NOTION_CONNECTOR": "📘",
"GITHUB_CONNECTOR": "🐙",
"LINEAR_CONNECTOR": "📊",
"JIRA_CONNECTOR": "🎫",
"DISCORD_CONNECTOR": "🗨️",
"TAVILY_API": "🔍",
"LINKUP_API": "🔗",
"GOOGLE_CALENDAR_CONNECTOR": "📅",
}
return connector_emojis.get(connector_name, "🔎")
def get_connector_friendly_name(connector_name: str) -> str:
"""Convert technical connector IDs to user-friendly names."""
connector_friendly_names = {
"YOUTUBE_VIDEO": "YouTube",
"EXTENSION": "Browser Extension",
"CRAWLED_URL": "Web Pages",
"FILE": "Files",
"SLACK_CONNECTOR": "Slack",
"NOTION_CONNECTOR": "Notion",
"GITHUB_CONNECTOR": "GitHub",
"LINEAR_CONNECTOR": "Linear",
"JIRA_CONNECTOR": "Jira",
"CONFLUENCE_CONNECTOR": "Confluence",
"GOOGLE_CALENDAR_CONNECTOR": "Google Calendar",
"DISCORD_CONNECTOR": "Discord",
"TAVILY_API": "Tavily Search",
"LINKUP_API": "Linkup Search",
}
return connector_friendly_names.get(connector_name, connector_name)
def convert_langchain_messages_to_dict(
messages: list[BaseMessage],
) -> list[dict[str, str]]:
"""Convert LangChain messages to format expected by token_counter."""
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
converted_messages = []
for msg in messages:
role = role_mapping.get(getattr(msg, "type", None), "user")
converted_messages.append({"role": role, "content": str(msg.content)})
return converted_messages
def format_document_for_citation(document: dict[str, Any]) -> str:
"""Format a single document for citation in the standard XML format."""
content = document.get("content", "")
doc_info = document.get("document", {})
document_id = document.get("chunk_id", "")
document_type = doc_info.get("document_type", "CRAWLED_URL")
return f"""
{document_id}
{document_type}
{content}
"""
def format_documents_section(
documents: list[dict[str, Any]], section_title: str = "Source material"
) -> str:
"""Format multiple documents into a complete documents section."""
if not documents:
return ""
formatted_docs = [format_document_for_citation(doc) for doc in documents]
return f"""{section_title}:
{chr(10).join(formatted_docs)}
"""
def calculate_document_token_costs(
documents: list[dict[str, Any]], model: str
) -> list[DocumentTokenInfo]:
"""Pre-calculate token costs for each document."""
document_token_info = []
for i, doc in enumerate(documents):
formatted_doc = format_document_for_citation(doc)
# Calculate token count for this document
token_count = token_counter(
messages=[{"role": "user", "content": formatted_doc}], model=model
)
document_token_info.append(
DocumentTokenInfo(
index=i,
document=doc,
formatted_content=formatted_doc,
token_count=token_count,
)
)
return document_token_info
def find_optimal_documents_with_binary_search(
document_tokens: list[DocumentTokenInfo], available_tokens: int
) -> list[DocumentTokenInfo]:
"""Use binary search to find the maximum number of documents that fit within token limit."""
if not document_tokens or available_tokens <= 0:
return []
left, right = 0, len(document_tokens)
optimal_docs = []
while left <= right:
mid = (left + right) // 2
current_docs = document_tokens[:mid]
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
if current_token_sum <= available_tokens:
optimal_docs = current_docs
left = mid + 1
else:
right = mid - 1
return optimal_docs
def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
try:
model_info = get_model_info(model_name)
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
return context_window
except Exception as e:
print(
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
)
return 4096 # Conservative fallback
def optimize_documents_for_token_limit(
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
) -> tuple[list[dict[str, Any]], bool]:
"""
Optimize documents to fit within token limits using binary search.
Args:
documents: List of documents with content and metadata
base_messages: Base messages without documents (chat history + system + human message template)
model_name: Model name for token counting (required)
output_token_buffer: Number of tokens to reserve for model output
Returns:
Tuple of (optimized_documents, has_documents_remaining)
"""
if not documents:
return [], False
model = model_name
context_window = get_model_context_window(model)
# Calculate base token cost
base_messages_dict = convert_langchain_messages_to_dict(base_messages)
base_tokens = token_counter(messages=base_messages_dict, model=model)
available_tokens_for_docs = context_window - base_tokens
print(
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
)
if available_tokens_for_docs <= 0:
print("No tokens available for documents after base content and output buffer")
return [], False
# Calculate token costs for all documents
document_token_info = calculate_document_token_costs(documents, model)
# Find optimal number of documents using binary search
optimal_doc_info = find_optimal_documents_with_binary_search(
document_token_info, available_tokens_for_docs
)
# Extract the original document objects
optimized_documents = [doc_info.document for doc_info in optimal_doc_info]
has_documents_remaining = len(optimized_documents) > 0
print(
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
)
return optimized_documents, has_documents_remaining
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
"""Calculate token count for a list of LangChain messages."""
model = model_name
messages_dict = convert_langchain_messages_to_dict(messages)
return token_counter(messages=messages_dict, model=model)