mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
Merge pull request #303 from MODSetter/dev
fix: added basic context window check for summarization
This commit is contained in:
commit
4fa0777094
1 changed files with 98 additions and 1 deletions
|
@ -1,10 +1,99 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from litellm import get_model_info, token_counter
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import Chunk
|
from app.db import Chunk
|
||||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_context_window(model_name: str) -> int:
|
||||||
|
"""Get the total context window size for a model (input + output tokens)."""
|
||||||
|
try:
|
||||||
|
model_info = get_model_info(model_name)
|
||||||
|
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
|
||||||
|
return context_window
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
|
||||||
|
)
|
||||||
|
return 4096 # Conservative fallback
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_content_for_context_window(
|
||||||
|
content: str, document_metadata: dict | None, model_name: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Optimize content length to fit within model context window using binary search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Original document content
|
||||||
|
document_metadata: Optional metadata dictionary
|
||||||
|
model_name: Model name for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized content that fits within context window
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Get model context window
|
||||||
|
context_window = get_model_context_window(model_name)
|
||||||
|
|
||||||
|
# Reserve tokens for: system prompt, metadata, template overhead, and output
|
||||||
|
# Conservative estimate: 2000 tokens for prompt + metadata + output buffer
|
||||||
|
# TODO: Calculate Summary System Prompt Token Count Here
|
||||||
|
reserved_tokens = 2000
|
||||||
|
|
||||||
|
# Add metadata token cost if present
|
||||||
|
if document_metadata:
|
||||||
|
metadata_text = (
|
||||||
|
f"<DOCUMENT_METADATA>\n\n{document_metadata}\n\n</DOCUMENT_METADATA>"
|
||||||
|
)
|
||||||
|
metadata_tokens = token_counter(
|
||||||
|
messages=[{"role": "user", "content": metadata_text}], model=model_name
|
||||||
|
)
|
||||||
|
reserved_tokens += metadata_tokens
|
||||||
|
|
||||||
|
available_tokens = context_window - reserved_tokens
|
||||||
|
|
||||||
|
if available_tokens <= 100: # Minimum viable content
|
||||||
|
print(f"Warning: Very limited tokens available for content: {available_tokens}")
|
||||||
|
return content[:500] # Fallback to first 500 chars
|
||||||
|
|
||||||
|
# Binary search to find optimal content length
|
||||||
|
left, right = 0, len(content)
|
||||||
|
optimal_length = 0
|
||||||
|
|
||||||
|
while left <= right:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
test_content = content[:mid]
|
||||||
|
|
||||||
|
# Test token count for this content length
|
||||||
|
test_document = f"<DOCUMENT_CONTENT>\n\n{test_content}\n\n</DOCUMENT_CONTENT>"
|
||||||
|
test_tokens = token_counter(
|
||||||
|
messages=[{"role": "user", "content": test_document}], model=model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if test_tokens <= available_tokens:
|
||||||
|
optimal_length = mid
|
||||||
|
left = mid + 1
|
||||||
|
else:
|
||||||
|
right = mid - 1
|
||||||
|
|
||||||
|
optimized_content = (
|
||||||
|
content[:optimal_length] if optimal_length > 0 else content[:500]
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimal_length < len(content):
|
||||||
|
print(
|
||||||
|
f"Content optimized: {len(content)} -> {optimal_length} chars "
|
||||||
|
f"to fit in {available_tokens} available tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimized_content
|
||||||
|
|
||||||
|
|
||||||
async def generate_document_summary(
|
async def generate_document_summary(
|
||||||
content: str,
|
content: str,
|
||||||
user_llm,
|
user_llm,
|
||||||
|
@ -21,8 +110,16 @@ async def generate_document_summary(
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (enhanced_summary_content, summary_embedding)
|
Tuple of (enhanced_summary_content, summary_embedding)
|
||||||
"""
|
"""
|
||||||
|
# Get model name from user_llm for token counting
|
||||||
|
model_name = getattr(user_llm, "model", "gpt-3.5-turbo") # Fallback to default
|
||||||
|
|
||||||
|
# Optimize content to fit within context window
|
||||||
|
optimized_content = optimize_content_for_context_window(
|
||||||
|
content, document_metadata, model_name
|
||||||
|
)
|
||||||
|
|
||||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||||
content_with_metadata = f"<DOCUMENT><DOCUMENT_METADATA>\n\n{document_metadata}\n\n</DOCUMENT_METADATA>\n\n<DOCUMENT_CONTENT>\n\n{content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
content_with_metadata = f"<DOCUMENT><DOCUMENT_METADATA>\n\n{document_metadata}\n\n</DOCUMENT_METADATA>\n\n<DOCUMENT_CONTENT>\n\n{optimized_content}\n\n</DOCUMENT_CONTENT></DOCUMENT>"
|
||||||
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
summary_result = await summary_chain.ainvoke({"document": content_with_metadata})
|
||||||
summary_content = summary_result.content
|
summary_content = summary_result.content
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue