free-claude-code/core/anthropic/sse.py
Alishahryar1 abae61d85b Fix null usage in SSE for OpenAI-compatible streams (#209, #123)
- Only use provider completion_tokens when it is an int; otherwise estimate
- Coerce message_start/message_delta usage fields to safe integers in SSEBuilder
- Add regression tests for null upstream completion_tokens and builder edge cases

Claude Code could crash (e.g. undefined access on usage) when NIM/GLM or
similar sent usage with null token fields in streamed message_delta.
2026-04-27 18:20:33 -07:00

416 lines
14 KiB
Python

"""SSE event builder for Anthropic-format streaming responses."""
import hashlib
import json
from collections.abc import Iterator
from dataclasses import dataclass, field
from loguru import logger
try:
import tiktoken
ENCODER = tiktoken.get_encoding("cl100k_base")
except Exception:
ENCODER = None
# Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy.
ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = {
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
STOP_REASON_MAP = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
}
def map_stop_reason(openai_reason: str | None) -> str:
"""Map OpenAI finish_reason to Anthropic stop_reason."""
return (
STOP_REASON_MAP.get(openai_reason, "end_turn") if openai_reason else "end_turn"
)
def _safe_usage_int(value: object) -> int:
"""Coerce streamed usage counters to int; non-integers become 0."""
return value if isinstance(value, int) else 0
def format_sse_event(event_type: str, data: dict) -> str:
"""Format one Anthropic-style SSE event (no logging)."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
@dataclass
class ToolCallState:
"""State for a single streaming tool call."""
block_index: int
tool_id: str
name: str
contents: list[str] = field(default_factory=list)
started: bool = False
task_arg_buffer: str = ""
task_args_emitted: bool = False
pre_start_args: str = ""
@dataclass
class ContentBlockManager:
"""Manage content block indices and state."""
next_index: int = 0
thinking_index: int = -1
text_index: int = -1
thinking_started: bool = False
text_started: bool = False
tool_states: dict[int, ToolCallState] = field(default_factory=dict)
def allocate_index(self) -> int:
idx = self.next_index
self.next_index += 1
return idx
def ensure_tool_state(self, index: int) -> ToolCallState:
"""Create tool stream state for ``index`` when the first tool delta arrives."""
if index not in self.tool_states:
self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="")
return self.tool_states[index]
def set_stream_tool_id(self, index: int, tool_id: str | None) -> None:
"""Record OpenAI tool call id before ``content_block_start`` (split-stream providers)."""
if not tool_id:
return
state = self.ensure_tool_state(index)
state.tool_id = str(tool_id)
def register_tool_name(self, index: int, name: str) -> None:
"""Record tool name fragments as they arrive from chunked OpenAI streams.
Names may be split across deltas; later chunks can extend (``ab`` + ``c``)
or repeat prefixes, so we merge conservatively.
"""
if index not in self.tool_states:
self.tool_states[index] = ToolCallState(
block_index=-1, tool_id="", name=name
)
return
state = self.tool_states[index]
prev = state.name
if not prev or name.startswith(prev):
state.name = name
elif not prev.startswith(name):
state.name = prev + name
def buffer_task_args(self, index: int, args: str) -> dict | None:
state = self.tool_states.get(index)
if state is None or state.task_args_emitted:
return None
state.task_arg_buffer += args
try:
args_json = json.loads(state.task_arg_buffer)
except Exception:
return None
_normalize_task_run_in_background(args_json)
state.task_args_emitted = True
state.task_arg_buffer = ""
return args_json
def has_emitted_tool_block(self) -> bool:
"""True when native OpenAI tool streaming has started a ``tool_use`` block."""
return any(s.started for s in self.tool_states.values())
def flush_task_arg_buffers(self) -> list[tuple[int, str]]:
results: list[tuple[int, str]] = []
for tool_index, state in list(self.tool_states.items()):
if not state.task_arg_buffer or state.task_args_emitted:
continue
out = "{}"
try:
args_json = json.loads(state.task_arg_buffer)
_normalize_task_run_in_background(args_json)
out = json.dumps(args_json)
except (json.JSONDecodeError, TypeError, ValueError) as e:
digest = hashlib.sha256(
state.task_arg_buffer.encode("utf-8", errors="replace")
).hexdigest()[:16]
logger.warning(
"Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}",
state.tool_id or "unknown",
len(state.task_arg_buffer),
digest,
e,
)
state.task_args_emitted = True
state.task_arg_buffer = ""
results.append((tool_index, out))
return results
def _normalize_task_run_in_background(args_json: dict) -> None:
"""Force Claude Code Task subagents to run in foreground (single shared rule)."""
if args_json.get("run_in_background") is not False:
args_json["run_in_background"] = False
class SSEBuilder:
"""Builder for Anthropic SSE streaming events."""
def __init__(
self,
message_id: str,
model: str,
input_tokens: int = 0,
*,
log_raw_events: bool = False,
):
self.message_id = message_id
self.model = model
self.input_tokens = input_tokens
self._log_raw_events = log_raw_events
self.blocks = ContentBlockManager()
self._accumulated_text_parts: list[str] = []
self._accumulated_reasoning_parts: list[str] = []
def _format_event(self, event_type: str, data: dict) -> str:
event_str = format_sse_event(event_type, data)
if self._log_raw_events:
logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
else:
logger.debug(
"SSE_EVENT: event_type={} serialized_bytes={}",
event_type,
len(event_str.encode("utf-8")),
)
return event_str
def message_start(self) -> str:
safe_input = _safe_usage_int(self.input_tokens)
usage = {"input_tokens": safe_input, "output_tokens": 1}
return self._format_event(
"message_start",
{
"type": "message_start",
"message": {
"id": self.message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": self.model,
"stop_reason": None,
"stop_sequence": None,
"usage": usage,
},
},
)
def message_delta(self, stop_reason: str, output_tokens: int | None) -> str:
safe_in = _safe_usage_int(self.input_tokens)
safe_out = output_tokens if isinstance(output_tokens, int) else 0
return self._format_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
"usage": {
"input_tokens": safe_in,
"output_tokens": safe_out,
},
},
)
def message_stop(self) -> str:
return self._format_event("message_stop", {"type": "message_stop"})
def content_block_start(self, index: int, block_type: str, **kwargs) -> str:
content_block: dict = {"type": block_type}
if block_type == "thinking":
content_block["thinking"] = kwargs.get("thinking", "")
elif block_type == "text":
content_block["text"] = kwargs.get("text", "")
elif block_type == "tool_use":
content_block["id"] = kwargs.get("id", "")
content_block["name"] = kwargs.get("name", "")
content_block["input"] = kwargs.get("input", {})
return self._format_event(
"content_block_start",
{
"type": "content_block_start",
"index": index,
"content_block": content_block,
},
)
def content_block_delta(self, index: int, delta_type: str, content: str) -> str:
delta: dict = {"type": delta_type}
if delta_type == "thinking_delta":
delta["thinking"] = content
elif delta_type == "text_delta":
delta["text"] = content
elif delta_type == "input_json_delta":
delta["partial_json"] = content
return self._format_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": index,
"delta": delta,
},
)
def content_block_stop(self, index: int) -> str:
return self._format_event(
"content_block_stop",
{
"type": "content_block_stop",
"index": index,
},
)
def start_thinking_block(self) -> str:
self.blocks.thinking_index = self.blocks.allocate_index()
self.blocks.thinking_started = True
return self.content_block_start(self.blocks.thinking_index, "thinking")
def emit_thinking_delta(self, content: str) -> str:
self._accumulated_reasoning_parts.append(content)
return self.content_block_delta(
self.blocks.thinking_index, "thinking_delta", content
)
def stop_thinking_block(self) -> str:
self.blocks.thinking_started = False
return self.content_block_stop(self.blocks.thinking_index)
def start_text_block(self) -> str:
self.blocks.text_index = self.blocks.allocate_index()
self.blocks.text_started = True
return self.content_block_start(self.blocks.text_index, "text")
def emit_text_delta(self, content: str) -> str:
self._accumulated_text_parts.append(content)
return self.content_block_delta(self.blocks.text_index, "text_delta", content)
def stop_text_block(self) -> str:
self.blocks.text_started = False
return self.content_block_stop(self.blocks.text_index)
def start_tool_block(self, tool_index: int, tool_id: str, name: str) -> str:
block_idx = self.blocks.allocate_index()
if tool_index in self.blocks.tool_states:
state = self.blocks.tool_states[tool_index]
state.block_index = block_idx
state.tool_id = tool_id
state.started = True
else:
self.blocks.tool_states[tool_index] = ToolCallState(
block_index=block_idx,
tool_id=tool_id,
name=name,
started=True,
)
return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name)
def emit_tool_delta(self, tool_index: int, partial_json: str) -> str:
state = self.blocks.tool_states[tool_index]
state.contents.append(partial_json)
return self.content_block_delta(
state.block_index, "input_json_delta", partial_json
)
def stop_tool_block(self, tool_index: int) -> str:
block_idx = self.blocks.tool_states[tool_index].block_index
return self.content_block_stop(block_idx)
def ensure_thinking_block(self) -> Iterator[str]:
if self.blocks.text_started:
yield self.stop_text_block()
if not self.blocks.thinking_started:
yield self.start_thinking_block()
def ensure_text_block(self) -> Iterator[str]:
if self.blocks.thinking_started:
yield self.stop_thinking_block()
if not self.blocks.text_started:
yield self.start_text_block()
def close_content_blocks(self) -> Iterator[str]:
if self.blocks.thinking_started:
yield self.stop_thinking_block()
if self.blocks.text_started:
yield self.stop_text_block()
def close_all_blocks(self) -> Iterator[str]:
yield from self.close_content_blocks()
for tool_index, state in list(self.blocks.tool_states.items()):
if state.started:
yield self.stop_tool_block(tool_index)
def emit_error(self, error_message: str) -> Iterator[str]:
error_index = self.blocks.allocate_index()
yield self.content_block_start(error_index, "text")
yield self.content_block_delta(error_index, "text_delta", error_message)
yield self.content_block_stop(error_index)
def emit_top_level_error(self, error_message: str) -> str:
"""Emit a top-level ``event: error`` (not assistant text) for transport failures."""
return self._format_event(
"error",
{
"type": "error",
"error": {
"type": "api_error",
"message": error_message,
},
},
)
@property
def accumulated_text(self) -> str:
return "".join(self._accumulated_text_parts)
@property
def accumulated_reasoning(self) -> str:
return "".join(self._accumulated_reasoning_parts)
def estimate_output_tokens(self) -> int:
accumulated_text = self.accumulated_text
accumulated_reasoning = self.accumulated_reasoning
if ENCODER:
text_tokens = len(ENCODER.encode(accumulated_text))
reasoning_tokens = len(ENCODER.encode(accumulated_reasoning))
tool_tokens = 0
started_tool_count = 0
for state in self.blocks.tool_states.values():
tool_tokens += len(ENCODER.encode(state.name))
tool_tokens += len(ENCODER.encode("".join(state.contents)))
tool_tokens += 15
if state.started:
started_tool_count += 1
block_count = (
(1 if accumulated_reasoning else 0)
+ (1 if accumulated_text else 0)
+ started_tool_count
)
return text_tokens + reasoning_tokens + tool_tokens + (block_count * 4)
text_tokens = len(accumulated_text) // 4
reasoning_tokens = len(accumulated_reasoning) // 4
tool_tokens = (
sum(1 for state in self.blocks.tool_states.values() if state.started) * 50
)
return text_tokens + reasoning_tokens + tool_tokens