"""SSE parsing and Anthropic stream assertions for smoke tests.""" from __future__ import annotations import json from collections.abc import Iterable from dataclasses import dataclass from typing import Any @dataclass(frozen=True, slots=True) class SSEEvent: event: str data: dict[str, Any] raw: str def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]: events: list[SSEEvent] = [] current_event = "" data_parts: list[str] = [] raw_parts: list[str] = [] for line in lines: stripped = line.rstrip("\r\n") if stripped == "": _append_event(events, current_event, data_parts, raw_parts) current_event = "" data_parts = [] raw_parts = [] continue raw_parts.append(stripped) if stripped.startswith("event:"): current_event = stripped.split(":", 1)[1].strip() elif stripped.startswith("data:"): data_parts.append(stripped.split(":", 1)[1].strip()) _append_event(events, current_event, data_parts, raw_parts) return events def parse_sse_text(text: str) -> list[SSEEvent]: return parse_sse_lines(text.splitlines()) def _append_event( events: list[SSEEvent], current_event: str, data_parts: list[str], raw_parts: list[str], ) -> None: if not current_event and not data_parts: return data_text = "\n".join(data_parts) data: dict[str, Any] try: parsed = json.loads(data_text) if data_text else {} data = parsed if isinstance(parsed, dict) else {"value": parsed} except json.JSONDecodeError: data = {"raw": data_text} events.append(SSEEvent(current_event, data, "\n".join(raw_parts))) def assert_anthropic_stream_contract( events: list[SSEEvent], *, allow_error: bool = False ) -> None: assert events, "stream produced no SSE events" event_names = [event.event for event in events] assert "message_start" in event_names, event_names assert event_names[-1] == "message_stop", event_names open_blocks: dict[int, str] = {} seen_blocks: set[int] = set() for event in events: if event.event == "error" and not allow_error: raise AssertionError(f"unexpected SSE error event: {event.data}") if event.event == "content_block_start": index = _event_index(event) block = event.data.get("content_block", {}) assert isinstance(block, dict), event.data block_type = str(block.get("type", "")) assert block_type in {"text", "thinking", "tool_use"}, event.data assert index not in open_blocks, f"block {index} started twice" assert index not in seen_blocks, f"block {index} reused after stop" open_blocks[index] = block_type seen_blocks.add(index) continue if event.event == "content_block_delta": index = _event_index(event) assert index in open_blocks, f"delta for unopened block {index}" delta = event.data.get("delta", {}) assert isinstance(delta, dict), event.data delta_type = str(delta.get("type", "")) expected = { "text": "text_delta", "thinking": "thinking_delta", "tool_use": "input_json_delta", }[open_blocks[index]] assert delta_type == expected, ( f"block {index} is {open_blocks[index]}, got {delta_type}" ) continue if event.event == "content_block_stop": index = _event_index(event) assert index in open_blocks, f"stop for unopened block {index}" open_blocks.pop(index) assert not open_blocks, f"unclosed blocks: {open_blocks}" assert seen_blocks, "stream did not emit any content blocks" def event_names(events: list[SSEEvent]) -> list[str]: return [event.event for event in events] def text_content(events: list[SSEEvent]) -> str: parts: list[str] = [] for event in events: delta = event.data.get("delta", {}) if isinstance(delta, dict) and delta.get("type") == "text_delta": parts.append(str(delta.get("text", ""))) return "".join(parts) def thinking_content(events: list[SSEEvent]) -> str: parts: list[str] = [] for event in events: delta = event.data.get("delta", {}) if isinstance(delta, dict) and delta.get("type") == "thinking_delta": parts.append(str(delta.get("thinking", ""))) return "".join(parts) def has_tool_use(events: list[SSEEvent]) -> bool: for event in events: block = event.data.get("content_block", {}) if isinstance(block, dict) and block.get("type") == "tool_use": return True return False def _event_index(event: SSEEvent) -> int: value = event.data.get("index") assert isinstance(value, int), event.data return value