mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
minor refactor using minimax m2.5
This commit is contained in:
parent
f9e8226120
commit
79a1ae0c54
9 changed files with 170 additions and 156 deletions
|
|
@ -6,6 +6,11 @@ from loguru import logger
|
|||
from config.settings import NVIDIA_NIM_BASE_URL, Settings
|
||||
from config.settings import get_settings as _get_settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.exceptions import AuthenticationError
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from providers.open_router import OpenRouterProvider
|
||||
from providers.open_router.client import OPENROUTER_BASE_URL
|
||||
|
||||
# Global provider instance (singleton)
|
||||
_provider: BaseProvider | None = None
|
||||
|
|
@ -20,15 +25,10 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
"""Construct and return a new provider instance from settings."""
|
||||
if settings.provider_type == "nvidia_nim":
|
||||
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://build.nvidia.com/settings/api-keys"
|
||||
),
|
||||
raise AuthenticationError(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://build.nvidia.com/settings/api-keys"
|
||||
)
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key=settings.nvidia_nim_api_key,
|
||||
base_url=NVIDIA_NIM_BASE_URL,
|
||||
|
|
@ -42,18 +42,13 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
elif settings.provider_type == "open_router":
|
||||
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://openrouter.ai/keys"
|
||||
),
|
||||
raise AuthenticationError(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://openrouter.ai/keys"
|
||||
)
|
||||
from providers.open_router import OpenRouterProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key=settings.open_router_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
|
|
@ -63,8 +58,6 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
)
|
||||
provider = OpenRouterProvider(config)
|
||||
elif settings.provider_type == "lmstudio":
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url=settings.lm_studio_base_url,
|
||||
|
|
@ -93,7 +86,10 @@ def get_provider() -> BaseProvider:
|
|||
"""Get or create the provider instance based on settings.provider_type."""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
_provider = _create_provider(get_settings())
|
||||
try:
|
||||
_provider = _create_provider(get_settings())
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(status_code=503, detail=str(e)) from e
|
||||
return _provider
|
||||
|
||||
|
||||
|
|
@ -101,13 +97,6 @@ async def cleanup_provider():
|
|||
"""Cleanup provider resources."""
|
||||
global _provider
|
||||
if _provider:
|
||||
client = getattr(_provider, "_client", None)
|
||||
if client and hasattr(client, "aclose"):
|
||||
await client.aclose()
|
||||
elif client:
|
||||
logger.warning(
|
||||
"Provider client %r has no aclose(); skipping async cleanup",
|
||||
type(client).__name__,
|
||||
)
|
||||
await _provider.cleanup()
|
||||
_provider = None
|
||||
logger.debug("Provider cleanup completed")
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ class _SnapshotQueue(asyncio.Queue[str]):
|
|||
"""Remove item from queue if present. Returns True if removed."""
|
||||
if item not in self._queue:
|
||||
return False
|
||||
object.__setattr__(self, "_queue", deque(x for x in self._queue if x != item))
|
||||
items = [x for x in self._queue if x != item]
|
||||
self._queue.clear()
|
||||
self._queue.extend(items)
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -335,6 +337,12 @@ class MessageTree:
|
|||
return True
|
||||
return False
|
||||
|
||||
def _set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
|
||||
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = error_message
|
||||
node.completed_at = datetime.now(UTC)
|
||||
|
||||
def drain_queue_and_mark_cancelled(
|
||||
self, error_message: str = "Cancelled by user"
|
||||
) -> list[MessageNode]:
|
||||
|
|
@ -350,8 +358,7 @@ class MessageTree:
|
|||
break
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = error_message
|
||||
self._set_node_error_sync(node, error_message)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
|
|
@ -372,6 +379,11 @@ class MessageTree:
|
|||
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
|
||||
}
|
||||
|
||||
def _add_node_from_dict(self, node: MessageNode) -> None:
|
||||
"""Register a deserialized node into the tree's internal indices."""
|
||||
self._nodes[node.node_id] = node
|
||||
self._status_to_node[node.status_message_id] = node.node_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> MessageTree:
|
||||
"""Deserialize tree from dictionary."""
|
||||
|
|
@ -386,8 +398,7 @@ class MessageTree:
|
|||
for node_id, node_data in nodes_data.items():
|
||||
if node_id != root_id:
|
||||
node = MessageNode.from_dict(node_data)
|
||||
tree._nodes[node_id] = node
|
||||
tree._status_to_node[node.status_message_id] = node_id
|
||||
tree._add_node_from_dict(node)
|
||||
|
||||
return tree
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ class BaseProvider(ABC):
|
|||
"""Base class for all providers. Extend this to add your own."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
self.config = config
|
||||
self._config = config
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Release any resources held by this provider."""
|
||||
|
||||
@abstractmethod
|
||||
async def stream_response(
|
||||
|
|
|
|||
|
|
@ -28,18 +28,18 @@ class HeuristicToolParser:
|
|||
instead of using the structured API.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.state = ParserState.TEXT
|
||||
self.buffer = ""
|
||||
self.current_tool_id = None
|
||||
self.current_function_name = None
|
||||
self.current_parameters = {}
|
||||
# Class-level compiled patterns (compiled once, not per instance)
|
||||
_FUNC_START_PATTERN = re.compile(r"●\s*<function=([^>]+)>")
|
||||
_PARAM_PATTERN = re.compile(
|
||||
r"<parameter=([^>]+)>(.*?)(?:</parameter>|$)", re.DOTALL
|
||||
)
|
||||
|
||||
# Regex patterns
|
||||
self.func_start_pattern = re.compile(r"●\s*<function=([^>]+)>")
|
||||
self.param_pattern = re.compile(
|
||||
r"<parameter=([^>]+)>(.*?)(?:</parameter>|$)", re.DOTALL
|
||||
)
|
||||
def __init__(self):
|
||||
self._state = ParserState.TEXT
|
||||
self._buffer = ""
|
||||
self._current_tool_id = None
|
||||
self._current_function_name = None
|
||||
self._current_parameters = {}
|
||||
|
||||
def _strip_control_tokens(self, text: str) -> str:
|
||||
# Remove complete sentinel tokens. If a token is split across chunks it
|
||||
|
|
@ -53,15 +53,15 @@ class HeuristicToolParser:
|
|||
|
||||
This prevents leaking raw sentinel fragments to the user when streaming.
|
||||
"""
|
||||
start = self.buffer.rfind(_CONTROL_TOKEN_START)
|
||||
start = self._buffer.rfind(_CONTROL_TOKEN_START)
|
||||
if start == -1:
|
||||
return ""
|
||||
end = self.buffer.find(_CONTROL_TOKEN_END, start)
|
||||
end = self._buffer.find(_CONTROL_TOKEN_END, start)
|
||||
if end != -1:
|
||||
return ""
|
||||
|
||||
prefix = self.buffer[:start]
|
||||
self.buffer = self.buffer[start:]
|
||||
prefix = self._buffer[:start]
|
||||
self._buffer = self._buffer[start:]
|
||||
return prefix
|
||||
|
||||
def feed(self, text: str) -> tuple[str, list[dict[str, Any]]]:
|
||||
|
|
@ -72,58 +72,59 @@ class HeuristicToolParser:
|
|||
filtered_text: Text that should be passed through as normal message content.
|
||||
detected_tools: List of Anthropic-format tool_use blocks.
|
||||
"""
|
||||
self.buffer += text
|
||||
self.buffer = self._strip_control_tokens(self.buffer)
|
||||
self._buffer += text
|
||||
self._buffer = self._strip_control_tokens(self._buffer)
|
||||
detected_tools = []
|
||||
filtered_output = ""
|
||||
filtered_output_parts: list[str] = []
|
||||
|
||||
while True:
|
||||
if self.state == ParserState.TEXT:
|
||||
if self._state == ParserState.TEXT:
|
||||
# Look for the trigger character
|
||||
if "●" in self.buffer:
|
||||
idx = self.buffer.find("●")
|
||||
filtered_output += self.buffer[:idx]
|
||||
self.buffer = self.buffer[idx:]
|
||||
self.state = ParserState.MATCHING_FUNCTION
|
||||
if "●" in self._buffer:
|
||||
idx = self._buffer.find("●")
|
||||
filtered_output_parts.append(self._buffer[:idx])
|
||||
self._buffer = self._buffer[idx:]
|
||||
self._state = ParserState.MATCHING_FUNCTION
|
||||
else:
|
||||
# Avoid emitting an incomplete "<|...|>" sentinel fragment if the
|
||||
# token got split across streaming chunks.
|
||||
safe_prefix = self._split_incomplete_control_token_tail()
|
||||
if safe_prefix:
|
||||
filtered_output += safe_prefix
|
||||
filtered_output_parts.append(safe_prefix)
|
||||
break
|
||||
|
||||
filtered_output += self.buffer
|
||||
self.buffer = ""
|
||||
filtered_output_parts.append(self._buffer)
|
||||
self._buffer = ""
|
||||
break
|
||||
|
||||
if self.state == ParserState.MATCHING_FUNCTION:
|
||||
if self._state == ParserState.MATCHING_FUNCTION:
|
||||
# We need enough buffer to match the function tag
|
||||
# e.g. "● <function=Grep>"
|
||||
match = self.func_start_pattern.search(self.buffer)
|
||||
match = self._FUNC_START_PATTERN.search(self._buffer)
|
||||
if match:
|
||||
self.current_function_name = match.group(1).strip()
|
||||
self.current_tool_id = f"toolu_heuristic_{uuid.uuid4().hex[:8]}"
|
||||
self.current_parameters = {}
|
||||
self._current_function_name = match.group(1).strip()
|
||||
self._current_tool_id = f"toolu_heuristic_{uuid.uuid4().hex[:8]}"
|
||||
self._current_parameters = {}
|
||||
|
||||
# Consume the function start from buffer
|
||||
self.buffer = self.buffer[match.end() :]
|
||||
self.state = ParserState.PARSING_PARAMETERS
|
||||
self._buffer = self._buffer[match.end() :]
|
||||
self._state = ParserState.PARSING_PARAMETERS
|
||||
logger.debug(
|
||||
f"Heuristic bypass: Detected start of tool call '{self.current_function_name}'"
|
||||
"Heuristic bypass: Detected start of tool call '{}'",
|
||||
self._current_function_name,
|
||||
)
|
||||
else:
|
||||
# If we have "●" but not the full tag yet, wait for more data
|
||||
# Unless the buffer has grown too large without a match
|
||||
if len(self.buffer) > 100:
|
||||
if len(self._buffer) > 100:
|
||||
# Probably not a tool call, treat as text
|
||||
filtered_output += self.buffer[0]
|
||||
self.buffer = self.buffer[1:]
|
||||
self.state = ParserState.TEXT
|
||||
filtered_output_parts.append(self._buffer[0])
|
||||
self._buffer = self._buffer[1:]
|
||||
self._state = ParserState.TEXT
|
||||
else:
|
||||
break
|
||||
|
||||
if self.state == ParserState.PARSING_PARAMETERS:
|
||||
if self._state == ParserState.PARSING_PARAMETERS:
|
||||
# Look for parameters. We look for </parameter> to know a param is complete.
|
||||
# Or wait for another <parameter or the end of the text if it seems complete.
|
||||
|
||||
|
|
@ -134,22 +135,17 @@ class HeuristicToolParser:
|
|||
|
||||
# Check if we have any complete parameters
|
||||
while True:
|
||||
param_match = self.param_pattern.search(self.buffer)
|
||||
param_match = self._PARAM_PATTERN.search(self._buffer)
|
||||
if param_match and "</parameter>" in param_match.group(0):
|
||||
# Detect any content before the parameter match and preserve it
|
||||
pre_match_text = self.buffer[: param_match.start()]
|
||||
if pre_match_text.strip():
|
||||
# If there's non-whitespace text, we should probably treat it as content
|
||||
# However, purely whitespace might be formatting
|
||||
filtered_output += pre_match_text
|
||||
elif pre_match_text:
|
||||
# Preserve whitespace too just in case
|
||||
filtered_output += pre_match_text
|
||||
pre_match_text = self._buffer[: param_match.start()]
|
||||
if pre_match_text:
|
||||
filtered_output_parts.append(pre_match_text)
|
||||
|
||||
key = param_match.group(1).strip()
|
||||
val = param_match.group(2).strip()
|
||||
self.current_parameters[key] = val
|
||||
self.buffer = self.buffer[param_match.end() :]
|
||||
self._current_parameters[key] = val
|
||||
self._buffer = self._buffer[param_match.end() :]
|
||||
else:
|
||||
break
|
||||
|
||||
|
|
@ -158,27 +154,27 @@ class HeuristicToolParser:
|
|||
# 2. Significant pause (not handled here, handled by caller via flush if needed)
|
||||
# 3. Another ● character (start of NEXT tool call)
|
||||
|
||||
if "●" in self.buffer:
|
||||
if "●" in self._buffer:
|
||||
# Next tool call starting or something else, close current
|
||||
# But first, capture any text before the ●
|
||||
idx = self.buffer.find("●")
|
||||
idx = self._buffer.find("●")
|
||||
if idx > 0:
|
||||
filtered_output += self.buffer[:idx]
|
||||
self.buffer = self.buffer[idx:]
|
||||
filtered_output_parts.append(self._buffer[:idx])
|
||||
self._buffer = self._buffer[idx:]
|
||||
finished_tool_call = True
|
||||
elif (
|
||||
len(self.buffer) > 0
|
||||
and not self.buffer.strip().startswith("<")
|
||||
and not self.buffer.lstrip().startswith("<")
|
||||
len(self._buffer) > 0
|
||||
and not self._buffer.strip().startswith("<")
|
||||
and not self._buffer.lstrip().startswith("<")
|
||||
):
|
||||
# We have text that doesn't look like a tag, and we already parsed some or are in param state
|
||||
# Let's see if we have trailing param starts
|
||||
if "<parameter=" not in self.buffer:
|
||||
if "<parameter=" not in self._buffer:
|
||||
# Treat the buffer as text (it's not a parameter)
|
||||
# But wait, we are in PARSING_PARAMETERS.
|
||||
# If we have " some text", we should emit it and finish tool call.
|
||||
filtered_output += self.buffer
|
||||
self.buffer = ""
|
||||
filtered_output_parts.append(self._buffer)
|
||||
self._buffer = ""
|
||||
finished_tool_call = True
|
||||
|
||||
if finished_tool_call:
|
||||
|
|
@ -186,47 +182,49 @@ class HeuristicToolParser:
|
|||
detected_tools.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": self.current_tool_id,
|
||||
"name": self.current_function_name,
|
||||
"input": self.current_parameters,
|
||||
"id": self._current_tool_id,
|
||||
"name": self._current_function_name,
|
||||
"input": self._current_parameters,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Heuristic bypass: Emitting tool call '{self.current_function_name}' with {len(self.current_parameters)} params"
|
||||
"Heuristic bypass: Emitting tool call '{}' with {} params",
|
||||
self._current_function_name,
|
||||
len(self._current_parameters),
|
||||
)
|
||||
self.state = ParserState.TEXT
|
||||
self._state = ParserState.TEXT
|
||||
# Continue loop to process remaining buffer (which is empty or starts with ●)
|
||||
else:
|
||||
break
|
||||
|
||||
return filtered_output, detected_tools
|
||||
return "".join(filtered_output_parts), detected_tools
|
||||
|
||||
def flush(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Flush any remaining tool calls in the buffer.
|
||||
"""
|
||||
self.buffer = self._strip_control_tokens(self.buffer)
|
||||
self._buffer = self._strip_control_tokens(self._buffer)
|
||||
detected_tools = []
|
||||
if self.state == ParserState.PARSING_PARAMETERS:
|
||||
if self._state == ParserState.PARSING_PARAMETERS:
|
||||
# Try to extract any partial parameters remaining in buffer
|
||||
# Even without </parameter>
|
||||
partial_matches = re.finditer(
|
||||
r"<parameter=([^>]+)>(.*)$", self.buffer, re.DOTALL
|
||||
r"<parameter=([^>]+)>(.*)$", self._buffer, re.DOTALL
|
||||
)
|
||||
for m in partial_matches:
|
||||
key = m.group(1).strip()
|
||||
val = m.group(2).strip()
|
||||
self.current_parameters[key] = val
|
||||
self._current_parameters[key] = val
|
||||
|
||||
detected_tools.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": self.current_tool_id,
|
||||
"name": self.current_function_name,
|
||||
"input": self.current_parameters,
|
||||
"id": self._current_tool_id,
|
||||
"name": self._current_function_name,
|
||||
"input": self._current_parameters,
|
||||
}
|
||||
)
|
||||
self.state = ParserState.TEXT
|
||||
self.buffer = ""
|
||||
self._state = ParserState.TEXT
|
||||
self._buffer = ""
|
||||
|
||||
return detected_tools
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class ContentBlockManager:
|
|||
thinking_started: bool = False
|
||||
text_started: bool = False
|
||||
tool_indices: dict[int, int] = field(default_factory=dict)
|
||||
tool_contents: dict[int, str] = field(default_factory=dict)
|
||||
tool_contents: dict[int, list[str]] = field(default_factory=dict)
|
||||
tool_names: dict[int, str] = field(default_factory=dict)
|
||||
tool_ids: dict[int, str] = field(default_factory=dict)
|
||||
tool_started: dict[int, bool] = field(default_factory=dict)
|
||||
|
|
@ -134,13 +134,13 @@ class SSEBuilder:
|
|||
self.model = model
|
||||
self.input_tokens = input_tokens
|
||||
self.blocks = ContentBlockManager()
|
||||
self._accumulated_text = ""
|
||||
self._accumulated_reasoning = ""
|
||||
self._accumulated_text_parts: list[str] = []
|
||||
self._accumulated_reasoning_parts: list[str] = []
|
||||
|
||||
def _format_event(self, event_type: str, data: dict[str, Any]) -> str:
|
||||
"""Format as SSE string."""
|
||||
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
logger.debug(f"SSE_EVENT: {event_type} - {event_str.strip()}")
|
||||
logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
|
||||
return event_str
|
||||
|
||||
# Message lifecycle events
|
||||
|
|
@ -161,7 +161,6 @@ class SSEBuilder:
|
|||
"stop_sequence": None,
|
||||
"usage": usage,
|
||||
},
|
||||
"usage": usage,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -247,7 +246,7 @@ class SSEBuilder:
|
|||
|
||||
def emit_thinking_delta(self, content: str) -> str:
|
||||
"""Emit thinking content delta."""
|
||||
self._accumulated_reasoning += content
|
||||
self._accumulated_reasoning_parts.append(content)
|
||||
return self.content_block_delta(
|
||||
self.blocks.thinking_index, "thinking_delta", content
|
||||
)
|
||||
|
|
@ -266,7 +265,7 @@ class SSEBuilder:
|
|||
|
||||
def emit_text_delta(self, content: str) -> str:
|
||||
"""Emit text content delta."""
|
||||
self._accumulated_text += content
|
||||
self._accumulated_text_parts.append(content)
|
||||
return self.content_block_delta(self.blocks.text_index, "text_delta", content)
|
||||
|
||||
def stop_text_block(self) -> str:
|
||||
|
|
@ -279,14 +278,14 @@ class SSEBuilder:
|
|||
"""Start a tool_use block."""
|
||||
block_idx = self.blocks.allocate_index()
|
||||
self.blocks.tool_indices[tool_index] = block_idx
|
||||
self.blocks.tool_contents[tool_index] = ""
|
||||
self.blocks.tool_contents[tool_index] = []
|
||||
self.blocks.tool_ids[tool_index] = tool_id
|
||||
self.blocks.task_args_emitted.setdefault(tool_index, False)
|
||||
return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name)
|
||||
|
||||
def emit_tool_delta(self, tool_index: int, partial_json: str) -> str:
|
||||
"""Emit tool input delta."""
|
||||
self.blocks.tool_contents[tool_index] += partial_json
|
||||
self.blocks.tool_contents[tool_index].append(partial_json)
|
||||
block_idx = self.blocks.tool_indices[tool_index]
|
||||
return self.content_block_delta(block_idx, "input_json_delta", partial_json)
|
||||
|
||||
|
|
@ -338,38 +337,40 @@ class SSEBuilder:
|
|||
@property
|
||||
def accumulated_text(self) -> str:
|
||||
"""Get accumulated text content."""
|
||||
return self._accumulated_text
|
||||
return "".join(self._accumulated_text_parts)
|
||||
|
||||
@property
|
||||
def accumulated_reasoning(self) -> str:
|
||||
"""Get accumulated reasoning content."""
|
||||
return self._accumulated_reasoning
|
||||
return "".join(self._accumulated_reasoning_parts)
|
||||
|
||||
def estimate_output_tokens(self) -> int:
|
||||
"""Estimate output tokens from accumulated content."""
|
||||
accumulated_text = self.accumulated_text
|
||||
accumulated_reasoning = self.accumulated_reasoning
|
||||
if ENCODER:
|
||||
text_tokens = len(ENCODER.encode(self._accumulated_text))
|
||||
reasoning_tokens = len(ENCODER.encode(self._accumulated_reasoning))
|
||||
text_tokens = len(ENCODER.encode(accumulated_text))
|
||||
reasoning_tokens = len(ENCODER.encode(accumulated_reasoning))
|
||||
# Tool calls are harder to tokenize exactly without reconstruction, but we can approximate
|
||||
# by tokenizing the json dumps of tool contents
|
||||
tool_tokens = 0
|
||||
for idx, content in self.blocks.tool_contents.items():
|
||||
for idx, content_parts in self.blocks.tool_contents.items():
|
||||
name = self.blocks.tool_names.get(idx, "")
|
||||
tool_tokens += len(ENCODER.encode(name))
|
||||
tool_tokens += len(ENCODER.encode(content))
|
||||
tool_tokens += len(ENCODER.encode("".join(content_parts)))
|
||||
tool_tokens += 15 # Control tokens overhead per tool
|
||||
|
||||
# Per-block overhead (~4 tokens per content block)
|
||||
block_count = (
|
||||
(1 if self._accumulated_reasoning else 0)
|
||||
+ (1 if self._accumulated_text else 0)
|
||||
(1 if accumulated_reasoning else 0)
|
||||
+ (1 if accumulated_text else 0)
|
||||
+ len(self.blocks.tool_indices)
|
||||
)
|
||||
block_overhead = block_count * 4
|
||||
|
||||
return text_tokens + reasoning_tokens + tool_tokens + block_overhead
|
||||
|
||||
text_tokens = len(self._accumulated_text) // 4
|
||||
reasoning_tokens = len(self._accumulated_reasoning) // 4
|
||||
text_tokens = len(accumulated_text) // 4
|
||||
reasoning_tokens = len(accumulated_reasoning) // 4
|
||||
tool_tokens = len(self.blocks.tool_indices) * 50
|
||||
return text_tokens + reasoning_tokens + tool_tokens
|
||||
|
|
|
|||
|
|
@ -46,22 +46,23 @@ class ThinkTagParser:
|
|||
Feed content and yield parsed chunks.
|
||||
|
||||
Handles partial tags by buffering content near potential tag boundaries.
|
||||
Uses an iterative loop instead of mutual recursion to avoid stack overflow
|
||||
on inputs with many consecutive think tags.
|
||||
"""
|
||||
self._buffer += content
|
||||
|
||||
while self._buffer:
|
||||
prev_len = len(self._buffer)
|
||||
if not self._in_think_tag:
|
||||
chunk = self._parse_outside_think()
|
||||
if chunk:
|
||||
yield chunk
|
||||
else:
|
||||
break
|
||||
else:
|
||||
chunk = self._parse_inside_think()
|
||||
if chunk:
|
||||
yield chunk
|
||||
else:
|
||||
break
|
||||
|
||||
if chunk:
|
||||
yield chunk
|
||||
elif len(self._buffer) == prev_len:
|
||||
# No progress: waiting for more data
|
||||
break
|
||||
|
||||
def _parse_outside_think(self) -> ContentChunk | None:
|
||||
"""Parse content outside think tags."""
|
||||
|
|
@ -75,8 +76,8 @@ class ThinkTagParser:
|
|||
self._buffer = self._buffer[orphan_close + self.CLOSE_TAG_LEN :]
|
||||
if pre_orphan:
|
||||
return ContentChunk(ContentType.TEXT, pre_orphan)
|
||||
# Continue parsing after stripping orphan tag
|
||||
return self._parse_outside_think()
|
||||
# Buffer shrunk; the feed() loop will continue parsing
|
||||
return None
|
||||
|
||||
if think_start == -1:
|
||||
# No tag found - check for partial tag at end
|
||||
|
|
@ -112,8 +113,9 @@ class ThinkTagParser:
|
|||
self._in_think_tag = True
|
||||
if pre_think:
|
||||
return ContentChunk(ContentType.TEXT, pre_think)
|
||||
# Continue parsing inside think tag
|
||||
return self._parse_inside_think()
|
||||
# Buffer shrunk (consumed <think>); the feed() loop will continue
|
||||
# parsing inside the think tag on the next iteration
|
||||
return None
|
||||
|
||||
def _parse_inside_think(self) -> ContentChunk | None:
|
||||
"""Parse content inside think tags."""
|
||||
|
|
@ -147,8 +149,9 @@ class ThinkTagParser:
|
|||
self._in_think_tag = False
|
||||
if thinking_content:
|
||||
return ContentChunk(ContentType.THINKING, thinking_content)
|
||||
# Continue parsing outside think tag
|
||||
return self._parse_outside_think()
|
||||
# Buffer shrunk (consumed </think>); the feed() loop will continue
|
||||
# parsing outside the think tag on the next iteration
|
||||
return None
|
||||
|
||||
def flush(self) -> ContentChunk | None:
|
||||
"""Flush any remaining buffered content."""
|
||||
|
|
@ -160,8 +163,3 @@ class ThinkTagParser:
|
|||
self._buffer = ""
|
||||
return ContentChunk(chunk_type, content)
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
"""Reset parser state."""
|
||||
self._buffer = ""
|
||||
self._in_think_tag = False
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -9,6 +10,7 @@ import httpx
|
|||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from config.nim import NimSettings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.common import (
|
||||
ContentType,
|
||||
|
|
@ -31,7 +33,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
provider_name: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
nim_settings: Any | None = None,
|
||||
nim_settings: NimSettings | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self._provider_name = provider_name
|
||||
|
|
@ -55,15 +57,26 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
),
|
||||
)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Release HTTP client resources."""
|
||||
client = getattr(self, "_client", None)
|
||||
if client and hasattr(client, "aclose"):
|
||||
await client.aclose()
|
||||
elif client:
|
||||
logger.warning(
|
||||
"Provider client %r has no aclose(); skipping async cleanup",
|
||||
type(client).__name__,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _build_request_body(self, request: Any) -> dict:
|
||||
"""Build request body. Override in subclasses."""
|
||||
raise NotImplementedError
|
||||
"""Build request body. Must be implemented by subclasses."""
|
||||
|
||||
def _handle_extra_reasoning(self, delta: Any, sse: SSEBuilder) -> Iterator[str]:
|
||||
"""Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details)."""
|
||||
return iter(())
|
||||
|
||||
def _process_tool_call(self, tc: dict, sse: Any) -> Iterator[str]:
|
||||
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
|
||||
"""Process a single tool call delta and yield SSE events."""
|
||||
tc_index = tc.get("index", 0)
|
||||
if tc_index < 0:
|
||||
|
|
@ -105,7 +118,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
|
||||
yield sse.emit_tool_delta(tc_index, args)
|
||||
|
||||
def _flush_task_arg_buffers(self, sse: Any) -> Iterator[str]:
|
||||
def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]:
|
||||
"""Emit buffered Task args as a single JSON delta (best-effort)."""
|
||||
for tool_index, out in sse.blocks.flush_task_arg_buffers():
|
||||
yield sse.emit_tool_delta(tool_index, out)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import time
|
|||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
import openai
|
||||
from loguru import logger
|
||||
|
|
@ -28,7 +28,7 @@ class GlobalRateLimiter:
|
|||
Concurrency limit - caps simultaneously open streams.
|
||||
"""
|
||||
|
||||
_instance: GlobalRateLimiter | None = None
|
||||
_instance: ClassVar[GlobalRateLimiter | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ class TestSSEBuilderHighLevelHelpers:
|
|||
sse = builder.emit_tool_delta(0, '{"pattern":')
|
||||
data = _parse_sse(sse)
|
||||
assert data["delta"]["partial_json"] == '{"pattern":'
|
||||
assert builder.blocks.tool_contents[0] == '{"pattern":'
|
||||
assert "".join(builder.blocks.tool_contents[0]) == '{"pattern":'
|
||||
|
||||
def test_stop_tool_block(self):
|
||||
builder = SSEBuilder("msg_1", "model")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue