mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-01 21:00:43 +00:00
fix: max tokens max is 8192 now
This commit is contained in:
parent
059ee29e18
commit
8b5daa86bc
3 changed files with 70 additions and 53 deletions
|
|
@ -42,8 +42,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
|
|||
str(payload),
|
||||
model_id,
|
||||
"chat",
|
||||
max_tokens=10000,
|
||||
)
|
||||
max_tokens=8192
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
|
@ -64,7 +63,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
|
|||
str(payload),
|
||||
model_id,
|
||||
"chat",
|
||||
max_tokens=10000,
|
||||
max_tokens=8192,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,10 +26,12 @@ class SourceChatState(TypedDict):
|
|||
context_indicators: Optional[Dict[str, List[str]]]
|
||||
|
||||
|
||||
def call_model_with_source_context(state: SourceChatState, config: RunnableConfig) -> dict:
|
||||
def call_model_with_source_context(
|
||||
state: SourceChatState, config: RunnableConfig
|
||||
) -> dict:
|
||||
"""
|
||||
Main function that builds source context and calls the model.
|
||||
|
||||
|
||||
This function:
|
||||
1. Uses ContextBuilder to build source-specific context
|
||||
2. Applies the source_chat Jinja2 prompt template
|
||||
|
|
@ -39,7 +41,7 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
|
|||
source_id = state.get("source_id")
|
||||
if not source_id:
|
||||
raise ValueError("source_id is required in state")
|
||||
|
||||
|
||||
# Build source context using ContextBuilder (run async code in new loop)
|
||||
def build_context():
|
||||
"""Build context in a new event loop"""
|
||||
|
|
@ -50,57 +52,66 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
|
|||
source_id=source_id,
|
||||
include_insights=True,
|
||||
include_notes=False, # Focus on source-specific content
|
||||
max_tokens=50000 # Reasonable limit for source context
|
||||
max_tokens=50000, # Reasonable limit for source context
|
||||
)
|
||||
return new_loop.run_until_complete(context_builder.build())
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
|
||||
# Get the built context
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
asyncio.get_running_loop()
|
||||
# If we're in an event loop, run in a thread with a new loop
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(build_context)
|
||||
context_data = future.result()
|
||||
except RuntimeError:
|
||||
# No event loop running, safe to create a new one
|
||||
context_data = build_context()
|
||||
|
||||
|
||||
# Extract source and insights from context
|
||||
source = None
|
||||
insights = []
|
||||
context_indicators: dict[str, list[str | None]] = {"sources": [], "insights": [], "notes": []}
|
||||
|
||||
context_indicators: dict[str, list[str | None]] = {
|
||||
"sources": [],
|
||||
"insights": [],
|
||||
"notes": [],
|
||||
}
|
||||
|
||||
if context_data.get("sources"):
|
||||
source_info = context_data["sources"][0] # First source
|
||||
source = Source(**source_info) if isinstance(source_info, dict) else source_info
|
||||
context_indicators["sources"].append(source.id)
|
||||
|
||||
|
||||
if context_data.get("insights"):
|
||||
for insight_data in context_data["insights"]:
|
||||
insight = SourceInsight(**insight_data) if isinstance(insight_data, dict) else insight_data
|
||||
insight = (
|
||||
SourceInsight(**insight_data)
|
||||
if isinstance(insight_data, dict)
|
||||
else insight_data
|
||||
)
|
||||
insights.append(insight)
|
||||
context_indicators["insights"].append(insight.id)
|
||||
|
||||
|
||||
# Format context for the prompt
|
||||
formatted_context = _format_source_context(context_data)
|
||||
|
||||
|
||||
# Build prompt data for the template
|
||||
prompt_data = {
|
||||
"source": source.model_dump() if source else None,
|
||||
"insights": [insight.model_dump() for insight in insights] if insights else [],
|
||||
"context": formatted_context,
|
||||
"context_indicators": context_indicators
|
||||
"context_indicators": context_indicators,
|
||||
}
|
||||
|
||||
|
||||
# Apply the source_chat prompt template
|
||||
system_prompt = Prompter(prompt_template="source_chat").render(data=prompt_data)
|
||||
payload = [SystemMessage(content=system_prompt)] + state.get("messages", [])
|
||||
|
||||
|
||||
# Handle async model provisioning from sync context
|
||||
def run_in_new_loop():
|
||||
"""Run the async function in a new event loop"""
|
||||
|
|
@ -110,20 +121,22 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
|
|||
return new_loop.run_until_complete(
|
||||
provision_langchain_model(
|
||||
str(payload),
|
||||
config.get("configurable", {}).get("model_id") or state.get("model_override"),
|
||||
config.get("configurable", {}).get("model_id")
|
||||
or state.get("model_override"),
|
||||
"chat",
|
||||
max_tokens=10000,
|
||||
max_tokens=8192,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
asyncio.get_running_loop()
|
||||
# If we're in an event loop, run in a thread with a new loop
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
model = future.result()
|
||||
|
|
@ -132,36 +145,37 @@ def call_model_with_source_context(state: SourceChatState, config: RunnableConfi
|
|||
model = asyncio.run(
|
||||
provision_langchain_model(
|
||||
str(payload),
|
||||
config.get("configurable", {}).get("model_id") or state.get("model_override"),
|
||||
config.get("configurable", {}).get("model_id")
|
||||
or state.get("model_override"),
|
||||
"chat",
|
||||
max_tokens=10000,
|
||||
max_tokens=8192,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
ai_message = model.invoke(payload)
|
||||
|
||||
|
||||
# Update state with context information
|
||||
return {
|
||||
"messages": ai_message,
|
||||
"source": source,
|
||||
"insights": insights,
|
||||
"context": formatted_context,
|
||||
"context_indicators": context_indicators
|
||||
"context_indicators": context_indicators,
|
||||
}
|
||||
|
||||
|
||||
def _format_source_context(context_data: Dict) -> str:
|
||||
"""
|
||||
Format the context data into a readable string for the prompt.
|
||||
|
||||
|
||||
Args:
|
||||
context_data: Context data from ContextBuilder
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
context_parts = []
|
||||
|
||||
|
||||
# Add source information
|
||||
if context_data.get("sources"):
|
||||
context_parts.append("## SOURCE CONTENT")
|
||||
|
|
@ -176,17 +190,21 @@ def _format_source_context(context_data: Dict) -> str:
|
|||
full_text = full_text[:5000] + "...\n[Content truncated]"
|
||||
context_parts.append(f"**Content:**\n{full_text}")
|
||||
context_parts.append("") # Empty line for separation
|
||||
|
||||
|
||||
# Add insights
|
||||
if context_data.get("insights"):
|
||||
context_parts.append("## SOURCE INSIGHTS")
|
||||
for insight in context_data["insights"]:
|
||||
if isinstance(insight, dict):
|
||||
context_parts.append(f"**Insight ID:** {insight.get('id', 'Unknown')}")
|
||||
context_parts.append(f"**Type:** {insight.get('insight_type', 'Unknown')}")
|
||||
context_parts.append(f"**Content:** {insight.get('content', 'No content')}")
|
||||
context_parts.append(
|
||||
f"**Type:** {insight.get('insight_type', 'Unknown')}"
|
||||
)
|
||||
context_parts.append(
|
||||
f"**Content:** {insight.get('content', 'No content')}"
|
||||
)
|
||||
context_parts.append("") # Empty line for separation
|
||||
|
||||
|
||||
# Add metadata
|
||||
if context_data.get("metadata"):
|
||||
metadata = context_data["metadata"]
|
||||
|
|
@ -195,7 +213,7 @@ def _format_source_context(context_data: Dict) -> str:
|
|||
context_parts.append(f"- Insight count: {metadata.get('insight_count', 0)}")
|
||||
context_parts.append(f"- Total tokens: {context_data.get('total_tokens', 0)}")
|
||||
context_parts.append("")
|
||||
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
|
|
@ -211,4 +229,4 @@ source_chat_state = StateGraph(SourceChatState)
|
|||
source_chat_state.add_node("source_chat_agent", call_model_with_source_context)
|
||||
source_chat_state.add_edge(START, "source_chat_agent")
|
||||
source_chat_state.add_edge("source_chat_agent", END)
|
||||
source_chat_graph = source_chat_state.compile(checkpointer=memory)
|
||||
source_chat_graph = source_chat_state.compile(checkpointer=memory)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
from .token_utils import token_count
|
||||
|
||||
# Pattern for matching thinking content in AI responses
|
||||
THINK_PATTERN = re.compile(r'<think>(.*?)</think>', re.DOTALL)
|
||||
THINK_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
def split_text(txt: str, chunk_size=500):
|
||||
|
|
@ -76,66 +76,66 @@ def remove_non_printable(text: str) -> str:
|
|||
def parse_thinking_content(content: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse message content to extract thinking content from <think> tags.
|
||||
|
||||
|
||||
Args:
|
||||
content (str): The original message content
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (thinking_content, cleaned_content)
|
||||
- thinking_content: Content from within <think> tags
|
||||
- cleaned_content: Original content with <think> blocks removed
|
||||
|
||||
|
||||
Example:
|
||||
>>> content = "<think>Let me analyze this</think>Here's my answer"
|
||||
>>> thinking, cleaned = parse_thinking_content(content)
|
||||
>>> print(thinking)
|
||||
"Let me analyze this"
|
||||
>>> print(cleaned)
|
||||
>>> print(cleaned)
|
||||
"Here's my answer"
|
||||
"""
|
||||
# Input validation
|
||||
if not isinstance(content, str):
|
||||
return "", str(content) if content is not None else ""
|
||||
|
||||
|
||||
# Limit processing for very large content (100KB limit)
|
||||
if len(content) > 100000:
|
||||
return "", content
|
||||
|
||||
|
||||
# Find all thinking blocks
|
||||
thinking_matches = THINK_PATTERN.findall(content)
|
||||
|
||||
|
||||
if not thinking_matches:
|
||||
return "", content
|
||||
|
||||
|
||||
# Join all thinking content with double newlines
|
||||
thinking_content = "\n\n".join(match.strip() for match in thinking_matches)
|
||||
|
||||
|
||||
# Remove all <think>...</think> blocks from the original content
|
||||
cleaned_content = THINK_PATTERN.sub("", content)
|
||||
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content).strip()
|
||||
|
||||
cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content).strip()
|
||||
|
||||
return thinking_content, cleaned_content
|
||||
|
||||
|
||||
def clean_thinking_content(content: str) -> str:
|
||||
"""
|
||||
Remove thinking content from AI responses, returning only the cleaned content.
|
||||
|
||||
|
||||
This is a convenience function for cases where you only need the cleaned
|
||||
content and don't need access to the thinking process.
|
||||
|
||||
|
||||
Args:
|
||||
content (str): The original message content with potential <think> tags
|
||||
|
||||
|
||||
Returns:
|
||||
str: Content with <think> blocks removed and whitespace cleaned
|
||||
|
||||
|
||||
Example:
|
||||
>>> content = "<think>Let me think...</think>Here's the answer"
|
||||
>>> clean_thinking_content(content)
|
||||
"Here's the answer"
|
||||
"""
|
||||
_, cleaned_content = parse_thinking_content(content)
|
||||
return cleaned_content
|
||||
return cleaned_content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue