Merge pull request #303 from MODSetter/dev

fix: added basic context window check for summarization
This commit is contained in:
Rohan Verma 2025-08-28 23:01:33 -07:00 committed by GitHub
commit 4fa0777094
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,10 +1,99 @@
import hashlib
from litellm import get_model_info, token_counter
from app.config import config
from app.db import Chunk
from app.prompts import SUMMARY_PROMPT_TEMPLATE
def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
try:
model_info = get_model_info(model_name)
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
return context_window
except Exception as e:
print(
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
)
return 4096 # Conservative fallback
def optimize_content_for_context_window(
content: str, document_metadata: dict | None, model_name: str
) -> str:
"""
Optimize content length to fit within model context window using binary search.
Args:
content: Original document content
document_metadata: Optional metadata dictionary
model_name: Model name for token counting
Returns:
Optimized content that fits within context window
"""
if not content:
return content
# Get model context window
context_window = get_model_context_window(model_name)
# Reserve tokens for: system prompt, metadata, template overhead, and output
# Conservative estimate: 2000 tokens for prompt + metadata + output buffer
# TODO: Calculate Summary System Prompt Token Count Here
reserved_tokens = 2000
# Add metadata token cost if present
if document_metadata:
metadata_text = (
f"<DOCUMENT_METADATA>\n\n{document_metadata}\n\n</DOCUMENT_METADATA>"
)
metadata_tokens = token_counter(
messages=[{"role": "user", "content": metadata_text}], model=model_name
)
reserved_tokens += metadata_tokens
available_tokens = context_window - reserved_tokens
if available_tokens <= 100: # Minimum viable content
print(f"Warning: Very limited tokens available for content: {available_tokens}")
return content[:500] # Fallback to first 500 chars
# Binary search to find optimal content length
left, right = 0, len(content)
optimal_length = 0
while left <= right:
mid = (left + right) // 2
test_content = content[:mid]
# Test token count for this content length
test_document = f"<DOCUMENT_CONTENT>\n\n{test_content}\n\n</DOCUMENT_CONTENT>"
test_tokens = token_counter(
messages=[{"role": "user", "content": test_document}], model=model_name
)
if test_tokens <= available_tokens:
optimal_length = mid
left = mid + 1
else:
right = mid - 1
optimized_content = (
content[:optimal_length] if optimal_length > 0 else content[:500]
)
if optimal_length < len(content):
print(
f"Content optimized: {len(content)} -> {optimal_length} chars "
f"to fit in {available_tokens} available tokens"
)
return optimized_content
async def generate_document_summary(
content: str,
user_llm,
@ -21,8 +110,16 @@ async def generate_document_summary(
Returns:
Tuple of (enhanced_summary_content, summary_embedding)
"""
# Get model name from user_llm for token counting
model_name = getattr(user_llm, "model", "gpt-3.5-turbo") # Fallback to default
# Optimize content to fit within context window
optimized_content = optimize_content_for_context_window(
content, document_metadata, model_name
)
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
content_with_metadata = f"<DOCUMENT><DOCUMENT_METADATA>\n\n{document_metadata}\n\n</DOCUMENT_METADATA>\n\n<DOCUMENT_CONTENT>\n\n{content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
content_with_metadata = f"<DOCUMENT><DOCUMENT_METADATA>\n\n{document_metadata}\n\n</DOCUMENT_METADATA>\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
summary_content = summary_result.content