fix: max tokens max is 8192 now

This commit is contained in:
LUIS NOVO 2025-10-18 13:21:53 -03:00
parent 059ee29e18
commit 8b5daa86bc
3 changed files with 70 additions and 53 deletions

View file

@ -26,10 +26,12 @@ class SourceChatState(TypedDict):
context_indicators: Optional[Dict[str, List[str]]]
def call_model_with_source_context(state: SourceChatState, config: RunnableConfig) -> dict:
def call_model_with_source_context(
state: SourceChatState, config: RunnableConfig
) -> dict:
"""
Main function that builds source context and calls the model.
This function:
1. Uses ContextBuilder to build source-specific context
2. Applies the source_chat Jinja2 prompt template
@ -39,7 +41,7 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
source_id = state.get("source_id")
if not source_id:
raise ValueError("source_id is required in state")
# Build source context using ContextBuilder (run async code in new loop)
def build_context():
"""Build context in a new event loop"""
@ -50,57 +52,66 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
source_id=source_id,
include_insights=True,
include_notes=False, # Focus on source-specific content
max_tokens=50000 # Reasonable limit for source context
max_tokens=50000, # Reasonable limit for source context
)
return new_loop.run_until_complete(context_builder.build())
finally:
new_loop.close()
asyncio.set_event_loop(None)
# Get the built context
try:
# Try to get the current event loop
asyncio.get_running_loop()
# If we're in an event loop, run in a thread with a new loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(build_context)
context_data = future.result()
except RuntimeError:
# No event loop running, safe to create a new one
context_data = build_context()
# Extract source and insights from context
source = None
insights = []
context_indicators: dict[str, list[str | None]] = {"sources": [], "insights": [], "notes": []}
context_indicators: dict[str, list[str | None]] = {
"sources": [],
"insights": [],
"notes": [],
}
if context_data.get("sources"):
source_info = context_data["sources"][0] # First source
source = Source(**source_info) if isinstance(source_info, dict) else source_info
context_indicators["sources"].append(source.id)
if context_data.get("insights"):
for insight_data in context_data["insights"]:
insight = SourceInsight(**insight_data) if isinstance(insight_data, dict) else insight_data
insight = (
SourceInsight(**insight_data)
if isinstance(insight_data, dict)
else insight_data
)
insights.append(insight)
context_indicators["insights"].append(insight.id)
# Format context for the prompt
formatted_context = _format_source_context(context_data)
# Build prompt data for the template
prompt_data = {
"source": source.model_dump() if source else None,
"insights": [insight.model_dump() for insight in insights] if insights else [],
"context": formatted_context,
"context_indicators": context_indicators
"context_indicators": context_indicators,
}
# Apply the source_chat prompt template
system_prompt = Prompter(prompt_template="source_chat").render(data=prompt_data)
payload = [SystemMessage(content=system_prompt)] + state.get("messages", [])
# Handle async model provisioning from sync context
def run_in_new_loop():
"""Run the async function in a new event loop"""
@ -110,20 +121,22 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
return new_loop.run_until_complete(
provision_langchain_model(
str(payload),
config.get("configurable", {}).get("model_id") or state.get("model_override"),
config.get("configurable", {}).get("model_id")
or state.get("model_override"),
"chat",
max_tokens=10000,
max_tokens=8192,
)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# Try to get the current event loop
asyncio.get_running_loop()
# If we're in an event loop, run in a thread with a new loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_new_loop)
model = future.result()
@ -132,36 +145,37 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
model = asyncio.run(
provision_langchain_model(
str(payload),
config.get("configurable", {}).get("model_id") or state.get("model_override"),
config.get("configurable", {}).get("model_id")
or state.get("model_override"),
"chat",
max_tokens=10000,
max_tokens=8192,
)
)
ai_message = model.invoke(payload)
# Update state with context information
return {
"messages": ai_message,
"source": source,
"insights": insights,
"context": formatted_context,
"context_indicators": context_indicators
"context_indicators": context_indicators,
}
def _format_source_context(context_data: Dict) -> str:
"""
Format the context data into a readable string for the prompt.
Args:
context_data: Context data from ContextBuilder
Returns:
Formatted context string
"""
context_parts = []
# Add source information
if context_data.get("sources"):
context_parts.append("## SOURCE CONTENT")
@ -176,17 +190,21 @@ def _format_source_context(context_data: Dict) -> str:
full_text = full_text[:5000] + "...\n[Content truncated]"
context_parts.append(f"**Content:**\n{full_text}")
context_parts.append("") # Empty line for separation
# Add insights
if context_data.get("insights"):
context_parts.append("## SOURCE INSIGHTS")
for insight in context_data["insights"]:
if isinstance(insight, dict):
context_parts.append(f"**Insight ID:** {insight.get('id', 'Unknown')}")
context_parts.append(f"**Type:** {insight.get('insight_type', 'Unknown')}")
context_parts.append(f"**Content:** {insight.get('content', 'No content')}")
context_parts.append(
f"**Type:** {insight.get('insight_type', 'Unknown')}"
)
context_parts.append(
f"**Content:** {insight.get('content', 'No content')}"
)
context_parts.append("") # Empty line for separation
# Add metadata
if context_data.get("metadata"):
metadata = context_data["metadata"]
@ -195,7 +213,7 @@ def _format_source_context(context_data: Dict) -> str:
context_parts.append(f"- Insight count: {metadata.get('insight_count', 0)}")
context_parts.append(f"- Total tokens: {context_data.get('total_tokens', 0)}")
context_parts.append("")
return "\n".join(context_parts)
@ -211,4 +229,4 @@ source_chat_state = StateGraph(SourceChatState)
source_chat_state.add_node("source_chat_agent", call_model_with_source_context)
source_chat_state.add_edge(START, "source_chat_agent")
source_chat_state.add_edge("source_chat_agent", END)
source_chat_graph = source_chat_state.compile(checkpointer=memory)
source_chat_graph = source_chat_state.compile(checkpointer=memory)