diff --git a/.env.example b/.env.example
index 42d51a3..e9e1990 100644
--- a/.env.example
+++ b/.env.example
@@ -47,8 +47,8 @@ OPENROUTER_PROXY=""
LMSTUDIO_PROXY=""
LLAMACPP_PROXY=""
-PROVIDER_RATE_LIMIT=40
-PROVIDER_RATE_WINDOW=60
+PROVIDER_RATE_LIMIT=1
+PROVIDER_RATE_WINDOW=3
PROVIDER_MAX_CONCURRENCY=5
@@ -102,3 +102,25 @@ ENABLE_NETWORK_PROBE_MOCK=true
ENABLE_TITLE_GENERATION_SKIP=true
ENABLE_SUGGESTION_MODE_SKIP=true
ENABLE_FILEPATH_EXTRACTION_MOCK=true
+
+
+# Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default)
+ENABLE_WEB_SERVER_TOOLS=true
+WEB_FETCH_ALLOWED_SCHEMES=http,https
+WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false
+
+
+# Verbose diagnostics (avoid logging raw prompts / SSE bodies in production)
+DEBUG_PLATFORM_EDITS=false
+DEBUG_SUBAGENT_STACK=false
+# When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging).
+LOG_RAW_API_PAYLOADS=false
+LOG_RAW_SSE_EVENTS=false
+# When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data).
+LOG_API_ERROR_TRACEBACKS=false
+# When true, log message/transcription text previews in messaging adapters (may leak user content).
+LOG_RAW_MESSAGING_CONTENT=false
+# When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text.
+LOG_RAW_CLI_DIAGNOSTICS=false
+# When true, log full exception and CLI error message strings in messaging (may leak user content).
+LOG_MESSAGING_ERROR_DETAILS=false
diff --git a/AGENTS.md b/AGENTS.md
index f4335d2..265b3a8 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -32,13 +32,14 @@
- **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code.
- **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue.
- **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required.
+- **Maximum Test Coverage**: There should be maximum test coverage for everything, preferably live smoke test coverage to catch bugs early
## COGNITIVE WORKFLOW
1. **ANALYZE**: Read relevant files. Do not guess.
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
-4. **VERIFY**: Run ci checks. Confirm the fix via logs or output.
+4. **VERIFY**: Run ci checks and relevant smoke tests. Confirm the fix via logs or output.
5. **SPECIFICITY**: Do exactly as much as asked; nothing more, nothing less.
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.
diff --git a/PLAN.md b/PLAN.md
index 9f4607b..f5011f6 100644
--- a/PLAN.md
+++ b/PLAN.md
@@ -56,13 +56,20 @@ with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging
**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`,
and `providers.registry` from the providers package (not per-adapter modules).
`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`.
-`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract
-assertions for default CI live under `core/anthropic/stream_contracts.py`;
-`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers
-(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and
-unit tests; production HTTP handlers must use `resolve_provider` with
+`messaging/` does not import `api`, `cli`, or `smoke`, and may import `providers`
+only via `providers.nvidia_nim.voice` (NVIDIA/Riva offline ASR). Stream contract
+helpers live in `core/anthropic/stream_contracts.py`; live smoke imports that
+module directly (no dedicated smoke SSE shim). NVIDIA NIM chat tuning uses the
+canonical `config.nim.NimSettings` model on `Settings`; `providers.registry`
+passes `settings.nim` into `NvidiaNimProvider` without a duplicate schema.
+Default upstream base URLs use a single constant per endpoint in
+`providers/defaults.py` (e.g. `NVIDIA_NIM_DEFAULT_BASE`). Process-cached provider
+helpers (`api.dependencies.get_provider` / `get_provider_for_type`) exist for
+scripts and unit tests; production HTTP handlers must use `resolve_provider` with
`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package
-`__all__` exposes HTTP models and `create_app` only (not those helpers).
+`__all__` exposes HTTP models and `create_app` only (not `app`, not those helpers).
+`api.app:create_app` is the ASGI factory (e.g. `uvicorn api.app:create_app --factory`);
+`server.py` still exposes `server:app` as a module-level instance for convenience.
## Target Boundaries
diff --git a/api/__init__.py b/api/__init__.py
index 03094ba..5f6a6e4 100644
--- a/api/__init__.py
+++ b/api/__init__.py
@@ -1,6 +1,6 @@
"""API layer for Claude Code Proxy."""
-from .app import app, create_app
+from .app import create_app
from .models import (
MessagesRequest,
MessagesResponse,
@@ -13,6 +13,5 @@ __all__ = [
"MessagesResponse",
"TokenCountRequest",
"TokenCountResponse",
- "app",
"create_app",
]
diff --git a/api/app.py b/api/app.py
index b47deb3..5b70133 100644
--- a/api/app.py
+++ b/api/app.py
@@ -1,6 +1,6 @@
"""FastAPI application factory and configuration."""
-import os
+import traceback
from contextlib import asynccontextmanager
from typing import Any
@@ -16,13 +16,7 @@ from providers.exceptions import ProviderError
from .routes import router
from .runtime import AppRuntime
-
-# Opt-in to future behavior for python-telegram-bot
-os.environ["PTB_TIMEDELTA"] = "1"
-
-# Configure logging first (before any module logs)
-_settings = get_settings()
-configure_logging(_settings.log_file)
+from .validation_log import summarize_request_validation_body
@asynccontextmanager
@@ -38,6 +32,11 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
+ settings = get_settings()
+ configure_logging(
+ settings.log_file, verbose_third_party=settings.log_raw_api_payloads
+ )
+
app = FastAPI(
title="Claude Code Proxy",
version="2.0.0",
@@ -57,33 +56,7 @@ def create_app() -> FastAPI:
except Exception as e:
body = {"_json_error": type(e).__name__}
- messages = body.get("messages") if isinstance(body, dict) else None
- message_summary: list[dict[str, Any]] = []
- if isinstance(messages, list):
- for msg in messages:
- if not isinstance(msg, dict):
- message_summary.append({"message_kind": type(msg).__name__})
- continue
- content = msg.get("content")
- item: dict[str, Any] = {
- "role": msg.get("role"),
- "content_kind": type(content).__name__,
- }
- if isinstance(content, list):
- item["block_types"] = [
- block.get("type", "dict")
- if isinstance(block, dict)
- else type(block).__name__
- for block in content[:12]
- ]
- item["block_keys"] = [
- sorted(str(key) for key in block)[:12]
- for block in content[:5]
- if isinstance(block, dict)
- ]
- elif isinstance(content, str):
- item["content_length"] = len(content)
- message_summary.append(item)
+ message_summary, tool_names = summarize_request_validation_body(body)
logger.debug(
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
@@ -92,20 +65,27 @@ def create_app() -> FastAPI:
[list(error.get("loc", ())) for error in exc.errors()],
[str(error.get("type", "")) for error in exc.errors()],
message_summary,
- [
- str(tool.get("name", ""))
- for tool in body.get("tools", [])
- if isinstance(body, dict)
- and isinstance(body.get("tools"), list)
- and isinstance(tool, dict)
- ],
+ tool_names,
)
return await request_validation_exception_handler(request, exc)
@app.exception_handler(ProviderError)
async def provider_error_handler(request: Request, exc: ProviderError):
"""Handle provider-specific errors and return Anthropic format."""
- logger.error(f"Provider Error: {exc.error_type} - {exc.message}")
+ err_settings = get_settings()
+ if err_settings.log_api_error_tracebacks:
+ logger.error(
+ "Provider Error: error_type={} status_code={} message={}",
+ exc.error_type,
+ exc.status_code,
+ exc.message,
+ )
+ else:
+ logger.error(
+ "Provider Error: error_type={} status_code={}",
+ exc.error_type,
+ exc.status_code,
+ )
return JSONResponse(
status_code=exc.status_code,
content=exc.to_anthropic_format(),
@@ -114,10 +94,17 @@ def create_app() -> FastAPI:
@app.exception_handler(Exception)
async def general_error_handler(request: Request, exc: Exception):
"""Handle general errors and return Anthropic format."""
- logger.error(f"General Error: {exc!s}")
- import traceback
-
- logger.error(traceback.format_exc())
+ settings = get_settings()
+ if settings.log_api_error_tracebacks:
+ logger.error("General Error: {}", exc)
+ logger.error(traceback.format_exc())
+ else:
+ logger.error(
+ "General Error: path={} method={} exc_type={}",
+ request.url.path,
+ request.method,
+ type(exc).__name__,
+ )
return JSONResponse(
status_code=500,
content={
@@ -130,7 +117,3 @@ def create_app() -> FastAPI:
)
return app
-
-
-# Default app instance for uvicorn
-app = create_app()
diff --git a/api/command_utils.py b/api/command_utils.py
index ea5251b..76963cd 100644
--- a/api/command_utils.py
+++ b/api/command_utils.py
@@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str:
return "\n"
- except Exception:
+ except ValueError:
return "\n"
diff --git a/api/dependencies.py b/api/dependencies.py
index 3527eb8..c3e5743 100644
--- a/api/dependencies.py
+++ b/api/dependencies.py
@@ -8,7 +8,11 @@ from config.settings import Settings
from config.settings import get_settings as _get_settings
from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider
-from providers.exceptions import AuthenticationError, UnknownProviderTypeError
+from providers.exceptions import (
+ AuthenticationError,
+ ServiceUnavailableError,
+ UnknownProviderTypeError,
+)
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
@@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
- """Get application settings via dependency injection."""
+ """Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
return _get_settings()
@@ -31,10 +35,9 @@ def resolve_provider(
"""Resolve a provider using the app-scoped registry when ``app`` is set.
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
- is always used. If the registry is missing (e.g. a test app without
- :class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
- is installed on ``app.state`` so the process cache is never mixed with
- per-request app identity.
+ must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
+ Callers that construct a bare ``FastAPI`` without lifespan must set
+ ``app.state.provider_registry`` explicitly.
When ``app`` is ``None`` (no HTTP context), uses the process-level
:data:`_providers` cache only.
@@ -42,8 +45,10 @@ def resolve_provider(
if app is not None:
reg = getattr(app.state, "provider_registry", None)
if reg is None:
- reg = ProviderRegistry()
- app.state.provider_registry = reg
+ raise ServiceUnavailableError(
+ "Provider registry is not configured. Ensure AppRuntime startup ran "
+ "or assign app.state.provider_registry for test apps."
+ )
return _resolve_with_registry(reg, provider_type, settings)
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
@@ -55,9 +60,10 @@ def _resolve_with_registry(
try:
provider = registry.get(provider_type, settings)
except AuthenticationError as e:
- raise HTTPException(
- status_code=503, detail=get_user_facing_error_message(e)
- ) from e
+ # Provider :class:`~providers.exceptions.AuthenticationError` messages are
+ # curated configuration hints (env var names, docs links), not upstream noise.
+ detail = str(e).strip() or get_user_facing_error_message(e)
+ raise HTTPException(status_code=503, detail=detail) from e
except UnknownProviderTypeError:
logger.error(
"Unknown provider_type: '{}'. Supported: {}",
@@ -73,8 +79,9 @@ def _resolve_with_registry(
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider in the process-level cache (no ``app``/Request).
- For server requests, use :func:`resolve_provider` with the active
- :attr:`request.app` so the app-scoped provider registry is used.
+ HTTP route handlers should call :func:`resolve_provider` with the active
+ :attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
+ process-wide cache.
"""
return resolve_provider(provider_type, app=None, settings=get_settings())
diff --git a/api/detection.py b/api/detection.py
index a299151..7977c1b 100644
--- a/api/detection.py
+++ b/api/detection.py
@@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st
try:
cmd_start = content.rfind("Command:") + len("Command:")
return True, content[cmd_start:].strip()
- except Exception:
- pass
+ except TypeError:
+ return False, ""
return False, ""
@@ -121,19 +121,16 @@ def is_filepath_extraction_request(
if not user_has_filepaths and not system_has_extract:
return False, "", ""
- try:
- cmd_start = content.find("Command:") + len("Command:")
- output_marker = content.find("Output:", cmd_start)
- if output_marker == -1:
- return False, "", ""
-
- command = content[cmd_start:output_marker].strip()
- output = content[output_marker + len("Output:") :].strip()
-
- for marker in ["<", "\n\n"]:
- if marker in output:
- output = output.split(marker)[0].strip()
-
- return True, command, output
- except Exception:
+ cmd_start = content.find("Command:") + len("Command:")
+ output_marker = content.find("Output:", cmd_start)
+ if output_marker == -1:
return False, "", ""
+
+ command = content[cmd_start:output_marker].strip()
+ output = content[output_marker + len("Output:") :].strip()
+
+ for marker in ["<", "\n\n"]:
+ if marker in output:
+ output = output.split(marker)[0].strip()
+
+ return True, command, output
diff --git a/api/models/anthropic.py b/api/models/anthropic.py
index 12b6533..5bda060 100644
--- a/api/models/anthropic.py
+++ b/api/models/anthropic.py
@@ -3,7 +3,7 @@
from enum import StrEnum
from typing import Any, Literal
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict, Field
# =============================================================================
@@ -15,41 +15,68 @@ class Role(StrEnum):
system = "system"
-class ContentBlockText(BaseModel):
+class _AnthropicBlockBase(BaseModel):
+ """Pass through provider fields (e.g. ``cache_control``) for native transports."""
+
+ model_config = ConfigDict(extra="allow")
+
+
+class ContentBlockText(_AnthropicBlockBase):
type: Literal["text"]
text: str
-class ContentBlockImage(BaseModel):
+class ContentBlockImage(_AnthropicBlockBase):
type: Literal["image"]
source: dict[str, Any]
-class ContentBlockToolUse(BaseModel):
+class ContentBlockToolUse(_AnthropicBlockBase):
type: Literal["tool_use"]
id: str
name: str
input: dict[str, Any]
-class ContentBlockToolResult(BaseModel):
+class ContentBlockToolResult(_AnthropicBlockBase):
type: Literal["tool_result"]
tool_use_id: str
content: str | list[Any] | dict[str, Any]
-class ContentBlockThinking(BaseModel):
+class ContentBlockThinking(_AnthropicBlockBase):
type: Literal["thinking"]
thinking: str
signature: str | None = None
-class ContentBlockRedactedThinking(BaseModel):
+class ContentBlockRedactedThinking(_AnthropicBlockBase):
type: Literal["redacted_thinking"]
data: str
-class SystemContent(BaseModel):
+class ContentBlockServerToolUse(_AnthropicBlockBase):
+ """Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
+
+ type: Literal["server_tool_use"]
+ id: str
+ name: str
+ input: dict[str, Any]
+
+
+class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
+ type: Literal["web_search_tool_result"]
+ tool_use_id: str
+ content: Any
+
+
+class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
+ type: Literal["web_fetch_tool_result"]
+ tool_use_id: str
+ content: Any
+
+
+class SystemContent(_AnthropicBlockBase):
type: Literal["text"]
text: str
@@ -68,12 +95,15 @@ class Message(BaseModel):
| ContentBlockToolResult
| ContentBlockThinking
| ContentBlockRedactedThinking
+ | ContentBlockServerToolUse
+ | ContentBlockWebSearchToolResult
+ | ContentBlockWebFetchToolResult
]
)
reasoning_content: str | None = None
-class Tool(BaseModel):
+class Tool(_AnthropicBlockBase):
name: str
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
# may omit ``input_schema`` because the provider owns the schema.
@@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel):
# Request Models
# =============================================================================
class MessagesRequest(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
model: str
+ # Internal routing / debug: accepted on parse but not serialized to providers.
+ original_model: str | None = Field(default=None, exclude=True)
+ resolved_provider_model: str | None = Field(default=None, exclude=True)
max_tokens: int | None = None
messages: list[Message]
system: str | list[SystemContent] | None = None
@@ -105,13 +140,24 @@ class MessagesRequest(BaseModel):
tools: list[Tool] | None = None
tool_choice: dict[str, Any] | None = None
thinking: ThinkingConfig | None = None
+ # Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
+ context_management: dict[str, Any] | None = None
+ output_config: dict[str, Any] | None = None
+ mcp_servers: list[dict[str, Any]] | None = None
extra_body: dict[str, Any] | None = None
class TokenCountRequest(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
model: str
+ original_model: str | None = Field(default=None, exclude=True)
+ resolved_provider_model: str | None = Field(default=None, exclude=True)
messages: list[Message]
system: str | list[SystemContent] | None = None
tools: list[Tool] | None = None
thinking: ThinkingConfig | None = None
tool_choice: dict[str, Any] | None = None
+ context_management: dict[str, Any] | None = None
+ output_config: dict[str, Any] | None = None
+ mcp_servers: list[dict[str, Any]] | None = None
diff --git a/api/runtime.py b/api/runtime.py
index 56cfe07..56df000 100644
--- a/api/runtime.py
+++ b/api/runtime.py
@@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0
async def best_effort(
- name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S
+ name: str,
+ awaitable: Any,
+ timeout_s: float = _SHUTDOWN_TIMEOUT_S,
+ *,
+ log_verbose_errors: bool = False,
) -> None:
"""Run a shutdown step with timeout; never raise to callers."""
try:
await asyncio.wait_for(awaitable, timeout=timeout_s)
except TimeoutError:
- logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
+ logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
except Exception as e:
- logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
+ if log_verbose_errors:
+ logger.warning(
+ "Shutdown step failed: {}: {}: {}",
+ name,
+ type(e).__name__,
+ e,
+ )
+ else:
+ logger.warning(
+ "Shutdown step failed: {}: exc_type={}",
+ name,
+ type(e).__name__,
+ )
def warn_if_process_auth_token(settings: Settings) -> None:
@@ -73,20 +89,37 @@ class AppRuntime:
self._publish_state()
async def shutdown(self) -> None:
+ verbose = self.settings.log_api_error_tracebacks
if self.message_handler is not None:
try:
self.message_handler.session_store.flush_pending_save()
except Exception as e:
- logger.warning(f"Session store flush on shutdown: {e}")
+ if verbose:
+ logger.warning("Session store flush on shutdown: {}", e)
+ else:
+ logger.warning(
+ "Session store flush on shutdown: exc_type={}",
+ type(e).__name__,
+ )
logger.info("Shutdown requested, cleaning up...")
if self.messaging_platform:
- await best_effort("messaging_platform.stop", self.messaging_platform.stop())
+ await best_effort(
+ "messaging_platform.stop",
+ self.messaging_platform.stop(),
+ log_verbose_errors=verbose,
+ )
if self.cli_manager:
- await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
+ await best_effort(
+ "cli_manager.stop_all",
+ self.cli_manager.stop_all(),
+ log_verbose_errors=verbose,
+ )
if self._provider_registry is not None:
await best_effort(
- "provider_registry.cleanup", self._provider_registry.cleanup()
+ "provider_registry.cleanup",
+ self._provider_registry.cleanup(),
+ log_verbose_errors=verbose,
)
await self._shutdown_limiter()
logger.info("Server shut down cleanly")
@@ -110,6 +143,10 @@ class AppRuntime:
whisper_device=self.settings.whisper_device,
hf_token=self.settings.hf_token,
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
+ messaging_rate_limit=self.settings.messaging_rate_limit,
+ messaging_rate_window=self.settings.messaging_rate_window,
+ log_raw_messaging_content=self.settings.log_raw_messaging_content,
+ log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
),
)
@@ -117,12 +154,24 @@ class AppRuntime:
await self._start_message_handler()
except ImportError as e:
- logger.warning(f"Messaging module import error: {e}")
+ if self.settings.log_api_error_tracebacks:
+ logger.warning("Messaging module import error: {}", e)
+ else:
+ logger.warning(
+ "Messaging module import error: exc_type={}",
+ type(e).__name__,
+ )
except Exception as e:
- logger.error(f"Failed to start messaging platform: {e}")
- import traceback
+ if self.settings.log_api_error_tracebacks:
+ logger.error("Failed to start messaging platform: {}", e)
+ import traceback
- logger.error(traceback.format_exc())
+ logger.error(traceback.format_exc())
+ else:
+ logger.error(
+ "Failed to start messaging platform: exc_type={}",
+ type(e).__name__,
+ )
async def _start_message_handler(self) -> None:
from cli.manager import CLISessionManager
@@ -151,10 +200,13 @@ class AppRuntime:
allowed_dirs=allowed_dirs,
plans_directory=plans_directory,
claude_bin=self.settings.claude_cli_bin,
+ log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
+ log_messaging_error_details=self.settings.log_messaging_error_details,
)
session_store = SessionStore(
- storage_path=os.path.join(data_path, "sessions.json")
+ storage_path=os.path.join(data_path, "sessions.json"),
+ message_log_cap=self.settings.max_message_log_entries_per_chat,
)
platform = self.messaging_platform
assert platform is not None
@@ -162,6 +214,11 @@ class AppRuntime:
platform=platform,
cli_manager=self.cli_manager,
session_store=session_store,
+ debug_platform_edits=self.settings.debug_platform_edits,
+ debug_subagent_stack=self.settings.debug_subagent_stack,
+ log_raw_messaging_content=self.settings.log_raw_messaging_content,
+ log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
+ log_messaging_error_details=self.settings.log_messaging_error_details,
)
self._restore_tree_state(session_store)
@@ -201,18 +258,26 @@ class AppRuntime:
self.app.state.cli_manager = self.cli_manager
async def _shutdown_limiter(self) -> None:
+ verbose = self.settings.log_api_error_tracebacks
try:
from messaging.limiter import MessagingRateLimiter
except Exception as e:
- logger.debug(
- "Rate limiter shutdown skipped (import failed): {}: {}",
- type(e).__name__,
- e,
- )
+ if verbose:
+ logger.debug(
+ "Rate limiter shutdown skipped (import failed): {}: {}",
+ type(e).__name__,
+ e,
+ )
+ else:
+ logger.debug(
+ "Rate limiter shutdown skipped (import failed): exc_type={}",
+ type(e).__name__,
+ )
return
await best_effort(
"MessagingRateLimiter.shutdown_instance",
MessagingRateLimiter.shutdown_instance(),
timeout_s=2.0,
+ log_verbose_errors=verbose,
)
diff --git a/api/services.py b/api/services.py
index d0bbcaa..219287c 100644
--- a/api/services.py
+++ b/api/services.py
@@ -4,7 +4,7 @@ from __future__ import annotations
import traceback
import uuid
-from collections.abc import Callable
+from collections.abc import AsyncIterator, Callable
from typing import Any
from fastapi import HTTPException
@@ -13,6 +13,7 @@ from loguru import logger
from config.settings import Settings
from core.anthropic import get_token_count, get_user_facing_error_message
+from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS
from providers.base import BaseProvider
from providers.exceptions import InvalidRequestError, ProviderError
@@ -20,15 +21,67 @@ from .model_router import ModelRouter
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
-from .web_server_tools import (
+from .web_tools.egress import WebFetchEgressPolicy
+from .web_tools.request import (
is_web_server_tool_request,
- stream_web_server_tool_response,
+ openai_chat_upstream_server_tool_error,
)
+from .web_tools.streaming import stream_web_server_tool_response
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
ProviderGetter = Callable[[str], BaseProvider]
+# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
+_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"})
+
+
+def anthropic_sse_streaming_response(
+ body: AsyncIterator[str],
+) -> StreamingResponse:
+ """Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
+ return StreamingResponse(
+ body,
+ media_type="text/event-stream",
+ headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
+ )
+
+
+def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
+ """HTTP status for uncaught non-provider failures (stable client contract)."""
+ return 500
+
+
+def _log_unexpected_service_exception(
+ settings: Settings,
+ exc: BaseException,
+ *,
+ context: str,
+ request_id: str | None = None,
+) -> None:
+ """Log service-layer failures without echoing exception text unless opted in."""
+ if settings.log_api_error_tracebacks:
+ if request_id is not None:
+ logger.error("{} request_id={}: {}", context, request_id, exc)
+ else:
+ logger.error("{}: {}", context, exc)
+ logger.error(traceback.format_exc())
+ return
+ if request_id is not None:
+ logger.error(
+ "{} request_id={} exc_type={}",
+ context,
+ request_id,
+ type(exc).__name__,
+ )
+ else:
+ logger.error("{} exc_type={}", context, type(exc).__name__)
+
+
+def _require_non_empty_messages(messages: list[Any]) -> None:
+ if not messages:
+ raise InvalidRequestError("messages cannot be empty")
+
class ClaudeProxyService:
"""Coordinate request optimization, model routing, token count, and providers."""
@@ -48,25 +101,35 @@ class ClaudeProxyService:
def create_message(self, request_data: MessagesRequest) -> object:
"""Create a message response or streaming response."""
try:
- if not request_data.messages:
- raise InvalidRequestError("messages cannot be empty")
+ _require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_messages_request(request_data)
- if is_web_server_tool_request(routed.request):
+ if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
+ tool_err = openai_chat_upstream_server_tool_error(
+ routed.request,
+ web_tools_enabled=self._settings.enable_web_server_tools,
+ )
+ if tool_err is not None:
+ raise InvalidRequestError(tool_err)
+
+ if self._settings.enable_web_server_tools and is_web_server_tool_request(
+ routed.request
+ ):
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
logger.info("Optimization: Handling Anthropic web server tool")
- return StreamingResponse(
+ egress = WebFetchEgressPolicy(
+ allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
+ allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
+ )
+ return anthropic_sse_streaming_response(
stream_web_server_tool_response(
- routed.request, input_tokens=input_tokens
+ routed.request,
+ input_tokens=input_tokens,
+ web_fetch_egress=egress,
+ verbose_client_errors=self._settings.log_api_error_tracebacks,
),
- media_type="text/event-stream",
- headers={
- "X-Accel-Buffering": "no",
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- },
)
optimized = try_optimizations(routed.request, self._settings)
@@ -75,6 +138,10 @@ class ClaudeProxyService:
logger.debug("No optimization matched, routing to provider")
provider = self._provider_getter(routed.resolved.provider_id)
+ provider.preflight_stream(
+ routed.request,
+ thinking_enabled=routed.resolved.thinking_enabled,
+ )
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
@@ -83,34 +150,31 @@ class ClaudeProxyService:
routed.request.model,
len(routed.request.messages),
)
- logger.debug(
- "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
- )
+ if self._settings.log_raw_api_payloads:
+ logger.debug(
+ "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
+ )
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
- return StreamingResponse(
+ return anthropic_sse_streaming_response(
provider.stream_response(
routed.request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=routed.resolved.thinking_enabled,
),
- media_type="text/event-stream",
- headers={
- "X-Accel-Buffering": "no",
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- },
)
except ProviderError:
raise
except Exception as e:
- logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
+ _log_unexpected_service_exception(
+ self._settings, e, context="CREATE_MESSAGE_ERROR"
+ )
raise HTTPException(
- status_code=getattr(e, "status_code", 500),
+ status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e
@@ -119,6 +183,7 @@ class ClaudeProxyService:
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
+ _require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_token_count_request(request_data)
tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
@@ -131,13 +196,16 @@ class ClaudeProxyService:
tokens,
)
return TokenCountResponse(input_tokens=tokens)
+ except ProviderError:
+ raise
except Exception as e:
- logger.error(
- "COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
- request_id,
- get_user_facing_error_message(e),
- traceback.format_exc(),
+ _log_unexpected_service_exception(
+ self._settings,
+ e,
+ context="COUNT_TOKENS_ERROR",
+ request_id=request_id,
)
raise HTTPException(
- status_code=500, detail=get_user_facing_error_message(e)
+ status_code=_http_status_for_unexpected_service_exception(e),
+ detail=get_user_facing_error_message(e),
) from e
diff --git a/api/validation_log.py b/api/validation_log.py
new file mode 100644
index 0000000..9ccdff0
--- /dev/null
+++ b/api/validation_log.py
@@ -0,0 +1,48 @@
+"""Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
+
+from __future__ import annotations
+
+from typing import Any
+
+
+def summarize_request_validation_body(
+ body: Any,
+) -> tuple[list[dict[str, Any]], list[str]]:
+ """Return message shape summary and tool name list for debug logs."""
+ messages = body.get("messages") if isinstance(body, dict) else None
+ message_summary: list[dict[str, Any]] = []
+ if isinstance(messages, list):
+ for msg in messages:
+ if not isinstance(msg, dict):
+ message_summary.append({"message_kind": type(msg).__name__})
+ continue
+ content = msg.get("content")
+ item: dict[str, Any] = {
+ "role": msg.get("role"),
+ "content_kind": type(content).__name__,
+ }
+ if isinstance(content, list):
+ item["block_types"] = [
+ block.get("type", "dict")
+ if isinstance(block, dict)
+ else type(block).__name__
+ for block in content[:12]
+ ]
+ item["block_keys"] = [
+ sorted(str(key) for key in block)[:12]
+ for block in content[:5]
+ if isinstance(block, dict)
+ ]
+ elif isinstance(content, str):
+ item["content_length"] = len(content)
+ message_summary.append(item)
+
+ tool_names: list[str] = []
+ if isinstance(body, dict) and isinstance(body.get("tools"), list):
+ tool_names = [
+ str(tool.get("name", ""))
+ for tool in body["tools"]
+ if isinstance(tool, dict)
+ ]
+
+ return message_summary, tool_names
diff --git a/api/web_server_tools.py b/api/web_server_tools.py
index 5daced4..cedaf95 100644
--- a/api/web_server_tools.py
+++ b/api/web_server_tools.py
@@ -1,331 +1,22 @@
-"""Local handlers for Anthropic web server tools.
-
-OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
-web tools are server-side: the API response itself must include the tool result.
-"""
+"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
from __future__ import annotations
-import html
-import json
-import re
-import uuid
-from collections.abc import AsyncIterator
-from datetime import UTC, datetime
-from html.parser import HTMLParser
-from typing import Any
-from urllib.parse import parse_qs, unquote, urlparse
-
import httpx
-from .models.anthropic import MessagesRequest
+from api.web_tools.egress import (
+ WebFetchEgressPolicy,
+ WebFetchEgressViolation,
+ enforce_web_fetch_egress,
+)
+from api.web_tools.request import is_web_server_tool_request
+from api.web_tools.streaming import stream_web_server_tool_response
-_REQUEST_TIMEOUT_S = 20.0
-_MAX_SEARCH_RESULTS = 10
-_MAX_FETCH_CHARS = 24_000
-
-
-class _SearchResultParser(HTMLParser):
- def __init__(self) -> None:
- super().__init__()
- self.results: list[dict[str, str]] = []
- self._href: str | None = None
- self._title_parts: list[str] = []
-
- def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
- if tag != "a":
- return
- href = dict(attrs).get("href")
- if not href or "uddg=" not in href:
- return
- parsed = urlparse(href)
- query = parse_qs(parsed.query)
- uddg = query.get("uddg", [""])[0]
- if not uddg:
- return
- self._href = unquote(uddg)
- self._title_parts = []
-
- def handle_data(self, data: str) -> None:
- if self._href is not None:
- self._title_parts.append(data)
-
- def handle_endtag(self, tag: str) -> None:
- if tag != "a" or self._href is None:
- return
- title = " ".join("".join(self._title_parts).split())
- if title and not any(result["url"] == self._href for result in self.results):
- self.results.append({"title": html.unescape(title), "url": self._href})
- self._href = None
- self._title_parts = []
-
-
-class _HTMLTextParser(HTMLParser):
- def __init__(self) -> None:
- super().__init__()
- self.title = ""
- self.text_parts: list[str] = []
- self._in_title = False
- self._skip_depth = 0
-
- def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
- if tag in {"script", "style", "noscript"}:
- self._skip_depth += 1
- elif tag == "title":
- self._in_title = True
-
- def handle_endtag(self, tag: str) -> None:
- if tag in {"script", "style", "noscript"} and self._skip_depth:
- self._skip_depth -= 1
- elif tag == "title":
- self._in_title = False
-
- def handle_data(self, data: str) -> None:
- text = " ".join(data.split())
- if not text:
- return
- if self._in_title:
- self.title = f"{self.title} {text}".strip()
- elif not self._skip_depth:
- self.text_parts.append(text)
-
-
-def _format_event(event_type: str, data: dict[str, Any]) -> str:
- return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
-
-
-def _content_text(content: Any) -> str:
- if isinstance(content, str):
- return content
- if isinstance(content, list):
- parts = []
- for item in content:
- if isinstance(item, dict):
- parts.append(str(item.get("text", "")))
- else:
- parts.append(str(getattr(item, "text", "")))
- return "\n".join(part for part in parts if part)
- return str(content)
-
-
-def _request_text(request: MessagesRequest) -> str:
- return "\n".join(_content_text(message.content) for message in request.messages)
-
-
-def _web_tool_name(request: MessagesRequest) -> str | None:
- for tool in request.tools or []:
- name = tool.name
- tool_type = tool.type or ""
- if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
- return name
- return None
-
-
-def is_web_server_tool_request(request: MessagesRequest) -> bool:
- return _web_tool_name(request) in {"web_search", "web_fetch"}
-
-
-def _extract_query(text: str) -> str:
- match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
- if match:
- return match.group(1).strip().strip("\"'")
- return text.strip()
-
-
-def _extract_url(text: str) -> str:
- match = re.search(r"https?://\S+", text)
- return match.group(0).rstrip(").,]") if match else text.strip()
-
-
-async def _run_web_search(query: str) -> list[dict[str, str]]:
- async with httpx.AsyncClient(
- timeout=_REQUEST_TIMEOUT_S,
- follow_redirects=True,
- headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
- ) as client:
- response = await client.get(
- "https://lite.duckduckgo.com/lite/",
- params={"q": query},
- )
- response.raise_for_status()
-
- parser = _SearchResultParser()
- parser.feed(response.text)
- return parser.results[:_MAX_SEARCH_RESULTS]
-
-
-async def _run_web_fetch(url: str) -> dict[str, str]:
- async with httpx.AsyncClient(
- timeout=_REQUEST_TIMEOUT_S,
- follow_redirects=True,
- headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
- ) as client:
- response = await client.get(url)
- response.raise_for_status()
-
- content_type = response.headers.get("content-type", "text/plain")
- title = url
- data = response.text
- if "html" in content_type.lower():
- parser = _HTMLTextParser()
- parser.feed(response.text)
- title = parser.title or url
- data = "\n".join(parser.text_parts)
- return {
- "url": str(response.url),
- "title": title,
- "media_type": "text/plain",
- "data": data[:_MAX_FETCH_CHARS],
- }
-
-
-def _search_summary(query: str, results: list[dict[str, str]]) -> str:
- if not results:
- return f"No web search results found for: {query}"
- lines = [f"Search results for: {query}"]
- for index, result in enumerate(results, start=1):
- lines.append(f"{index}. {result['title']}\n{result['url']}")
- return "\n\n".join(lines)
-
-
-async def stream_web_server_tool_response(
- request: MessagesRequest, input_tokens: int
-) -> AsyncIterator[str]:
- tool_name = _web_tool_name(request)
- if tool_name is None:
- return
-
- text = _request_text(request)
- message_id = f"msg_{uuid.uuid4()}"
- tool_id = f"srvtoolu_{uuid.uuid4().hex}"
- output_tokens = 1
- usage_key = (
- "web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
- )
- tool_input = (
- {"query": _extract_query(text)}
- if tool_name == "web_search"
- else {"url": _extract_url(text)}
- )
-
- yield _format_event(
- "message_start",
- {
- "type": "message_start",
- "message": {
- "id": message_id,
- "type": "message",
- "role": "assistant",
- "content": [],
- "model": request.model,
- "stop_reason": None,
- "stop_sequence": None,
- "usage": {"input_tokens": input_tokens, "output_tokens": 1},
- },
- },
- )
- yield _format_event(
- "content_block_start",
- {
- "type": "content_block_start",
- "index": 0,
- "content_block": {
- "type": "server_tool_use",
- "id": tool_id,
- "name": tool_name,
- "input": tool_input,
- },
- },
- )
- yield _format_event(
- "content_block_stop", {"type": "content_block_stop", "index": 0}
- )
-
- try:
- if tool_name == "web_search":
- query = str(tool_input["query"])
- results = await _run_web_search(query)
- result_content: Any = [
- {
- "type": "web_search_result",
- "title": result["title"],
- "url": result["url"],
- }
- for result in results
- ]
- summary = _search_summary(query, results)
- result_block_type = "web_search_tool_result"
- else:
- fetched = await _run_web_fetch(str(tool_input["url"]))
- result_content = {
- "type": "web_fetch_result",
- "url": fetched["url"],
- "content": {
- "type": "document",
- "source": {
- "type": "text",
- "media_type": fetched["media_type"],
- "data": fetched["data"],
- },
- "title": fetched["title"],
- "citations": {"enabled": True},
- },
- "retrieved_at": datetime.now(UTC).isoformat(),
- }
- summary = fetched["data"][:_MAX_FETCH_CHARS]
- result_block_type = "web_fetch_tool_result"
- except Exception as error:
- result_block_type = (
- "web_search_tool_result"
- if tool_name == "web_search"
- else "web_fetch_tool_result"
- )
- error_type = (
- "web_search_tool_result_error"
- if tool_name == "web_search"
- else "web_fetch_tool_error"
- )
- result_content = {"type": error_type, "error_code": "unavailable"}
- summary = f"{tool_name} failed: {type(error).__name__}"
-
- output_tokens = max(1, len(summary) // 4)
-
- yield _format_event(
- "content_block_start",
- {
- "type": "content_block_start",
- "index": 1,
- "content_block": {
- "type": result_block_type,
- "tool_use_id": tool_id,
- "content": result_content,
- },
- },
- )
- yield _format_event(
- "content_block_stop", {"type": "content_block_stop", "index": 1}
- )
- yield _format_event(
- "content_block_start",
- {
- "type": "content_block_start",
- "index": 2,
- "content_block": {"type": "text", "text": summary},
- },
- )
- yield _format_event(
- "content_block_stop", {"type": "content_block_stop", "index": 2}
- )
- yield _format_event(
- "message_delta",
- {
- "type": "message_delta",
- "delta": {"stop_reason": "end_turn", "stop_sequence": None},
- "usage": {
- "input_tokens": input_tokens,
- "output_tokens": output_tokens,
- "server_tool_use": {usage_key: 1},
- },
- },
- )
- yield _format_event("message_stop", {"type": "message_stop"})
+__all__ = [
+ "WebFetchEgressPolicy",
+ "WebFetchEgressViolation",
+ "enforce_web_fetch_egress",
+ "httpx",
+ "is_web_server_tool_request",
+ "stream_web_server_tool_response",
+]
diff --git a/api/web_tools/__init__.py b/api/web_tools/__init__.py
new file mode 100644
index 0000000..e0fd14c
--- /dev/null
+++ b/api/web_tools/__init__.py
@@ -0,0 +1,17 @@
+"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
+
+from .egress import (
+ WebFetchEgressPolicy,
+ WebFetchEgressViolation,
+ enforce_web_fetch_egress,
+)
+from .request import is_web_server_tool_request
+from .streaming import stream_web_server_tool_response
+
+__all__ = [
+ "WebFetchEgressPolicy",
+ "WebFetchEgressViolation",
+ "enforce_web_fetch_egress",
+ "is_web_server_tool_request",
+ "stream_web_server_tool_response",
+]
diff --git a/api/web_tools/constants.py b/api/web_tools/constants.py
new file mode 100644
index 0000000..e7b2c01
--- /dev/null
+++ b/api/web_tools/constants.py
@@ -0,0 +1,15 @@
+"""Limits and defaults for outbound web server tool HTTP."""
+
+_REQUEST_TIMEOUT_S = 20.0
+_MAX_SEARCH_RESULTS = 10
+_MAX_FETCH_CHARS = 24_000
+# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
+_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
+# Drain at most this many bytes from redirect responses before following Location.
+_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
+_MAX_WEB_FETCH_REDIRECTS = 10
+_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
+
+_WEB_TOOL_HTTP_HEADERS = {
+ "User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
+}
diff --git a/api/web_tools/egress.py b/api/web_tools/egress.py
new file mode 100644
index 0000000..30b29d5
--- /dev/null
+++ b/api/web_tools/egress.py
@@ -0,0 +1,99 @@
+"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
+
+from __future__ import annotations
+
+import ipaddress
+import socket
+from dataclasses import dataclass
+from urllib.parse import urlparse
+
+
+@dataclass(frozen=True, slots=True)
+class WebFetchEgressPolicy:
+ """Egress rules for user-influenced web_fetch URLs."""
+
+ allow_private_network_targets: bool
+ allowed_schemes: frozenset[str]
+
+
+class WebFetchEgressViolation(ValueError):
+ """Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
+
+
+def _port_for_url(parsed) -> int:
+ if parsed.port is not None:
+ return parsed.port
+ return 443 if (parsed.scheme or "").lower() == "https" else 80
+
+
+def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
+ try:
+ return socket.getaddrinfo(
+ host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
+ )
+ except OSError as exc:
+ raise WebFetchEgressViolation(
+ f"Could not resolve host {host!r}: {exc}"
+ ) from exc
+
+
+def get_validated_stream_addrinfos_for_egress(
+ url: str, policy: WebFetchEgressPolicy
+) -> list[tuple]:
+ """Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
+
+ Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
+ server cannot rebind to a disallowed address between resolution and the TCP
+ connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
+ """
+ parsed = urlparse(url)
+ scheme = (parsed.scheme or "").lower()
+ if scheme not in policy.allowed_schemes:
+ raise WebFetchEgressViolation(
+ f"URL scheme {scheme!r} is not allowed for web_fetch"
+ )
+
+ host = parsed.hostname
+ if host is None or host == "":
+ raise WebFetchEgressViolation("web_fetch URL must include a host")
+
+ port = _port_for_url(parsed)
+
+ if policy.allow_private_network_targets:
+ return _stream_getaddrinfo_or_raise(host, port)
+
+ host_lower = host.lower()
+ if host_lower == "localhost" or host_lower.endswith(".localhost"):
+ raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
+ if host_lower.endswith(".local"):
+ raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
+
+ try:
+ parsed_ip = ipaddress.ip_address(host)
+ except ValueError:
+ parsed_ip = None
+
+ if parsed_ip is not None:
+ if not parsed_ip.is_global:
+ raise WebFetchEgressViolation(
+ f"Non-public IP host {host!r} is not allowed for web_fetch"
+ )
+ return _stream_getaddrinfo_or_raise(host, port)
+
+ infos = _stream_getaddrinfo_or_raise(host, port)
+ for *_, sockaddr in infos:
+ addr = sockaddr[0]
+ try:
+ resolved = ipaddress.ip_address(addr)
+ except ValueError:
+ continue
+ if not resolved.is_global:
+ raise WebFetchEgressViolation(
+ f"Host {host!r} resolves to a non-public address ({resolved})"
+ )
+ return infos
+
+
+def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
+ """Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
+ get_validated_stream_addrinfos_for_egress(url, policy)
diff --git a/api/web_tools/outbound.py b/api/web_tools/outbound.py
new file mode 100644
index 0000000..4f93f9b
--- /dev/null
+++ b/api/web_tools/outbound.py
@@ -0,0 +1,278 @@
+"""Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
+
+from __future__ import annotations
+
+import asyncio
+import socket
+from collections.abc import AsyncIterator
+from urllib.parse import urljoin, urlparse
+
+import aiohttp
+import httpx
+from aiohttp import ClientSession, ClientTimeout, TCPConnector
+from aiohttp.abc import AbstractResolver, ResolveResult
+from loguru import logger
+
+from . import constants
+from .constants import (
+ _MAX_FETCH_CHARS,
+ _MAX_SEARCH_RESULTS,
+ _REDIRECT_RESPONSE_BODY_CAP_BYTES,
+ _REQUEST_TIMEOUT_S,
+ _WEB_FETCH_REDIRECT_STATUSES,
+ _WEB_TOOL_HTTP_HEADERS,
+)
+from .egress import (
+ WebFetchEgressPolicy,
+ WebFetchEgressViolation,
+ get_validated_stream_addrinfos_for_egress,
+)
+from .parsers import HTMLTextParser, SearchResultParser
+
+
+def _safe_public_host_for_logs(url: str) -> str:
+ host = urlparse(url).hostname or ""
+ return host[:253]
+
+
+def _log_web_tool_failure(
+ tool_name: str,
+ error: BaseException,
+ *,
+ fetch_url: str | None = None,
+) -> None:
+ exc_type = type(error).__name__
+ if isinstance(error, WebFetchEgressViolation):
+ host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
+ logger.warning(
+ "web_tool_egress_rejected tool={} exc_type={} host={!r}",
+ tool_name,
+ exc_type,
+ host,
+ )
+ return
+ if tool_name == "web_fetch" and fetch_url:
+ logger.warning(
+ "web_tool_failure tool={} exc_type={} host={!r}",
+ tool_name,
+ exc_type,
+ _safe_public_host_for_logs(fetch_url),
+ )
+ else:
+ logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
+
+
+def _web_tool_client_error_summary(
+ tool_name: str,
+ error: BaseException,
+ *,
+ verbose: bool,
+) -> str:
+ if verbose:
+ return f"{tool_name} failed: {type(error).__name__}"
+ return "Web tool request failed."
+
+
+async def _iter_response_body_under_cap(
+ response: httpx.Response, max_bytes: int
+) -> AsyncIterator[bytes]:
+ if max_bytes <= 0:
+ return
+ received = 0
+ async for chunk in response.aiter_bytes(chunk_size=65_536):
+ if received >= max_bytes:
+ break
+ remaining = max_bytes - received
+ if len(chunk) <= remaining:
+ received += len(chunk)
+ yield chunk
+ if received >= max_bytes:
+ break
+ else:
+ yield chunk[:remaining]
+ break
+
+
+async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
+ async for _ in _iter_response_body_under_cap(response, max_bytes):
+ pass
+
+
+async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
+ return b"".join(
+ [piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
+ )
+
+
+_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
+_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
+
+
+def getaddrinfo_rows_to_resolve_results(
+ host: str, addrinfos: list[tuple]
+) -> list[ResolveResult]:
+ """Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
+ out: list[ResolveResult] = []
+ for family, _type, proto, _canon, sockaddr in addrinfos:
+ if family == socket.AF_INET6:
+ if len(sockaddr) < 3:
+ continue
+ if sockaddr[3]:
+ resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
+ else:
+ resolved_host, port = sockaddr[:2]
+ else:
+ assert family == socket.AF_INET, family
+ resolved_host, port = sockaddr[0], sockaddr[1]
+ resolved_host = str(resolved_host)
+ port = int(port)
+ out.append(
+ ResolveResult(
+ hostname=host,
+ host=resolved_host,
+ port=int(port),
+ family=family,
+ proto=proto,
+ flags=_NUMERIC_RESOLVE_FLAGS,
+ )
+ )
+ return out
+
+
+class _PinnedEgressStaticResolver(AbstractResolver):
+ """Return only pre-validated :class:`ResolveResult` for the outbound request."""
+
+ def __init__(self, results: list[ResolveResult]) -> None:
+ self._results = results
+
+ async def resolve(
+ self, host: str, port: int = 0, family: int = socket.AF_INET
+ ) -> list[ResolveResult]:
+ return self._results
+
+ async def close(self) -> None: # pragma: no cover - aiohttp contract
+ return
+
+
+async def _read_aiohttp_body_capped(
+ response: aiohttp.ClientResponse, max_bytes: int
+) -> bytes:
+ received = 0
+ parts: list[bytes] = []
+ async for chunk in response.content.iter_chunked(65_536):
+ if received >= max_bytes:
+ break
+ remaining = max_bytes - received
+ if len(chunk) <= remaining:
+ received += len(chunk)
+ parts.append(chunk)
+ else:
+ parts.append(chunk[:remaining])
+ break
+ return b"".join(parts)
+
+
+async def _drain_aiohttp_body_capped(
+ response: aiohttp.ClientResponse, max_bytes: int
+) -> None:
+ if max_bytes <= 0:
+ return
+ received = 0
+ async for chunk in response.content.iter_chunked(65_536):
+ received += len(chunk)
+ if received >= max_bytes:
+ break
+
+
+async def _run_web_search(query: str) -> list[dict[str, str]]:
+ async with (
+ httpx.AsyncClient(
+ timeout=_REQUEST_TIMEOUT_S,
+ follow_redirects=True,
+ headers=_WEB_TOOL_HTTP_HEADERS,
+ ) as client,
+ client.stream(
+ "GET",
+ "https://lite.duckduckgo.com/lite/",
+ params={"q": query},
+ ) as response,
+ ):
+ response.raise_for_status()
+ body_bytes = await _read_response_body_capped(
+ response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
+ )
+ text = body_bytes.decode("utf-8", errors="replace")
+ parser = SearchResultParser()
+ parser.feed(text)
+ return parser.results[:_MAX_SEARCH_RESULTS]
+
+
+async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
+ """Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
+ current_url = url
+ redirect_hops = 0
+ timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
+
+ while True:
+ addr_infos = await asyncio.to_thread(
+ get_validated_stream_addrinfos_for_egress, current_url, egress
+ )
+ host = urlparse(current_url).hostname or ""
+ results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
+ resolver = _PinnedEgressStaticResolver(results)
+ connector = TCPConnector(
+ resolver=resolver,
+ force_close=True,
+ )
+ try:
+ async with (
+ ClientSession(
+ timeout=timeout,
+ headers=_WEB_TOOL_HTTP_HEADERS,
+ connector=connector,
+ ) as session,
+ session.get(current_url, allow_redirects=False) as response,
+ ):
+ if response.status in _WEB_FETCH_REDIRECT_STATUSES:
+ await _drain_aiohttp_body_capped(
+ response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
+ )
+ if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
+ raise WebFetchEgressViolation(
+ "web_fetch exceeded maximum redirects "
+ f"({constants._MAX_WEB_FETCH_REDIRECTS})"
+ )
+ location = response.headers.get("location")
+ if not location or not location.strip():
+ raise WebFetchEgressViolation(
+ "web_fetch redirect response missing Location header"
+ )
+ current_url = urljoin(str(response.url), location.strip())
+ redirect_hops += 1
+ continue
+ response.raise_for_status()
+ content_type = response.headers.get("content-type", "text/plain")
+ final_url = str(response.url)
+ encoding = response.get_encoding() or "utf-8"
+ body_bytes = await _read_aiohttp_body_capped(
+ response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
+ )
+ finally:
+ await connector.close()
+
+ break
+
+ text = body_bytes.decode(encoding, errors="replace")
+ title = final_url
+ data = text
+ if "html" in content_type.lower():
+ parser = HTMLTextParser()
+ parser.feed(text)
+ title = parser.title or final_url
+ data = "\n".join(parser.text_parts)
+ return {
+ "url": final_url,
+ "title": title,
+ "media_type": "text/plain",
+ "data": data[:_MAX_FETCH_CHARS],
+ }
diff --git a/api/web_tools/parsers.py b/api/web_tools/parsers.py
new file mode 100644
index 0000000..198b412
--- /dev/null
+++ b/api/web_tools/parsers.py
@@ -0,0 +1,104 @@
+"""HTML parsing for web_search / web_fetch."""
+
+from __future__ import annotations
+
+import html
+import re
+from html.parser import HTMLParser
+from typing import Any
+from urllib.parse import parse_qs, unquote, urlparse
+
+
+class SearchResultParser(HTMLParser):
+ """DuckDuckGo lite HTML: extract result links and titles."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.results: list[dict[str, str]] = []
+ self._href: str | None = None
+ self._title_parts: list[str] = []
+
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
+ if tag != "a":
+ return
+ href = dict(attrs).get("href")
+ if not href or "uddg=" not in href:
+ return
+ parsed = urlparse(href)
+ query = parse_qs(parsed.query)
+ uddg = query.get("uddg", [""])[0]
+ if not uddg:
+ return
+ self._href = unquote(uddg)
+ self._title_parts = []
+
+ def handle_data(self, data: str) -> None:
+ if self._href is not None:
+ self._title_parts.append(data)
+
+ def handle_endtag(self, tag: str) -> None:
+ if tag != "a" or self._href is None:
+ return
+ title = " ".join("".join(self._title_parts).split())
+ if title and not any(result["url"] == self._href for result in self.results):
+ self.results.append({"title": html.unescape(title), "url": self._href})
+ self._href = None
+ self._title_parts = []
+
+
+class HTMLTextParser(HTMLParser):
+ """Strip scripts/styles and collect visible text + title for fetch previews."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.title = ""
+ self.text_parts: list[str] = []
+ self._in_title = False
+ self._skip_depth = 0
+
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
+ if tag in {"script", "style", "noscript"}:
+ self._skip_depth += 1
+ elif tag == "title":
+ self._in_title = True
+
+ def handle_endtag(self, tag: str) -> None:
+ if tag in {"script", "style", "noscript"} and self._skip_depth:
+ self._skip_depth -= 1
+ elif tag == "title":
+ self._in_title = False
+
+ def handle_data(self, data: str) -> None:
+ text = " ".join(data.split())
+ if not text:
+ return
+ if self._in_title:
+ self.title = f"{self.title} {text}".strip()
+ elif not self._skip_depth:
+ self.text_parts.append(text)
+
+
+def content_text(content: Any) -> str:
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts = []
+ for item in content:
+ if isinstance(item, dict):
+ parts.append(str(item.get("text", "")))
+ else:
+ parts.append(str(getattr(item, "text", "")))
+ return "\n".join(part for part in parts if part)
+ return str(content)
+
+
+def extract_query(text: str) -> str:
+ match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
+ if match:
+ return match.group(1).strip().strip("\"'")
+ return text.strip()
+
+
+def extract_url(text: str) -> str:
+ match = re.search(r"https?://\S+", text)
+ return match.group(0).rstrip(").,]") if match else text.strip()
diff --git a/api/web_tools/request.py b/api/web_tools/request.py
new file mode 100644
index 0000000..b794685
--- /dev/null
+++ b/api/web_tools/request.py
@@ -0,0 +1,87 @@
+"""Detect forced Anthropic web server tool requests."""
+
+from __future__ import annotations
+
+from api.models.anthropic import MessagesRequest, Tool
+
+
+def request_text(request: MessagesRequest) -> str:
+ """Join all user/assistant message content into one string for tool input parsing."""
+ from .parsers import content_text
+
+ return "\n".join(content_text(message.content) for message in request.messages)
+
+
+def forced_tool_turn_text(request: MessagesRequest) -> str:
+ """Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
+ if not request.messages:
+ return ""
+
+ from .parsers import content_text
+
+ for message in reversed(request.messages):
+ if message.role == "user":
+ return content_text(message.content)
+ return ""
+
+
+def forced_server_tool_name(request: MessagesRequest) -> str | None:
+ """Return web_search or web_fetch only when tool_choice forces that server tool."""
+ tc = request.tool_choice
+ if not isinstance(tc, dict):
+ return None
+ if tc.get("type") != "tool":
+ return None
+ name = tc.get("name")
+ if name in {"web_search", "web_fetch"}:
+ return str(name)
+ return None
+
+
+def has_tool_named(request: MessagesRequest, name: str) -> bool:
+ return any(tool.name == name for tool in request.tools or [])
+
+
+def is_web_server_tool_request(request: MessagesRequest) -> bool:
+ """True when the client forces a web server tool via tool_choice (not merely listed)."""
+ forced = forced_server_tool_name(request)
+ if forced is None:
+ return False
+ return has_tool_named(request, forced)
+
+
+def is_anthropic_server_tool_definition(tool: Tool) -> bool:
+ """Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
+ name = (tool.name or "").strip()
+ if name in ("web_search", "web_fetch"):
+ return True
+ typ = tool.type
+ if isinstance(typ, str):
+ return typ.startswith("web_search") or typ.startswith("web_fetch")
+ return False
+
+
+def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
+ """True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
+ return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
+
+
+def openai_chat_upstream_server_tool_error(
+ request: MessagesRequest, *, web_tools_enabled: bool
+) -> str | None:
+ """Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
+ forced = forced_server_tool_name(request)
+ if forced and not web_tools_enabled:
+ return (
+ f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
+ "disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic "
+ "Messages transport (e.g. open_router, ollama, lmstudio)."
+ )
+ if not forced and has_listed_anthropic_server_tools(request):
+ return (
+ "OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools "
+ "(web_search / web_fetch) without the local web server tool handler. Use a native "
+ "Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
+ "tool_choice, or remove these tools from the request."
+ )
+ return None
diff --git a/api/web_tools/streaming.py b/api/web_tools/streaming.py
new file mode 100644
index 0000000..eb9b414
--- /dev/null
+++ b/api/web_tools/streaming.py
@@ -0,0 +1,206 @@
+"""SSE streaming for local web_search / web_fetch server tool results."""
+
+from __future__ import annotations
+
+import uuid
+from collections.abc import AsyncIterator
+from datetime import UTC, datetime
+from typing import Any
+
+from api.models.anthropic import MessagesRequest
+from core.anthropic.server_tool_sse import (
+ SERVER_TOOL_USE,
+ WEB_FETCH_TOOL_ERROR,
+ WEB_FETCH_TOOL_RESULT,
+ WEB_SEARCH_TOOL_RESULT,
+ WEB_SEARCH_TOOL_RESULT_ERROR,
+)
+from core.anthropic.sse import format_sse_event
+
+from . import outbound
+from .constants import _MAX_FETCH_CHARS
+from .egress import WebFetchEgressPolicy
+from .parsers import extract_query, extract_url
+from .request import (
+ forced_server_tool_name,
+ forced_tool_turn_text,
+ has_tool_named,
+)
+
+
+def _search_summary(query: str, results: list[dict[str, str]]) -> str:
+ if not results:
+ return f"No web search results found for: {query}"
+ lines = [f"Search results for: {query}"]
+ for index, result in enumerate(results, start=1):
+ lines.append(f"{index}. {result['title']}\n{result['url']}")
+ return "\n\n".join(lines)
+
+
+async def stream_web_server_tool_response(
+ request: MessagesRequest,
+ input_tokens: int,
+ *,
+ web_fetch_egress: WebFetchEgressPolicy,
+ verbose_client_errors: bool = False,
+) -> AsyncIterator[str]:
+ """Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
+
+ When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full
+ hosted Anthropic citation or encrypted-content pipeline.
+ """
+ tool_name = forced_server_tool_name(request)
+ if tool_name is None or not has_tool_named(request, tool_name):
+ return
+
+ text = forced_tool_turn_text(request)
+ message_id = f"msg_{uuid.uuid4()}"
+ tool_id = f"srvtoolu_{uuid.uuid4().hex}"
+ usage_key = (
+ "web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
+ )
+ tool_input = (
+ {"query": extract_query(text)}
+ if tool_name == "web_search"
+ else {"url": extract_url(text)}
+ )
+ _result_block_for_tool = {
+ "web_search": WEB_SEARCH_TOOL_RESULT,
+ "web_fetch": WEB_FETCH_TOOL_RESULT,
+ }
+ _error_payload_type_for_tool = {
+ "web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
+ "web_fetch": WEB_FETCH_TOOL_ERROR,
+ }
+
+ yield format_sse_event(
+ "message_start",
+ {
+ "type": "message_start",
+ "message": {
+ "id": message_id,
+ "type": "message",
+ "role": "assistant",
+ "content": [],
+ "model": request.model,
+ "stop_reason": None,
+ "stop_sequence": None,
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1},
+ },
+ },
+ )
+ yield format_sse_event(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {
+ "type": SERVER_TOOL_USE,
+ "id": tool_id,
+ "name": tool_name,
+ "input": tool_input,
+ },
+ },
+ )
+ yield format_sse_event(
+ "content_block_stop", {"type": "content_block_stop", "index": 0}
+ )
+
+ try:
+ if tool_name == "web_search":
+ query = str(tool_input["query"])
+ results = await outbound._run_web_search(query)
+ result_content: Any = [
+ {
+ "type": "web_search_result",
+ "title": result["title"],
+ "url": result["url"],
+ }
+ for result in results
+ ]
+ summary = _search_summary(query, results)
+ result_block_type = WEB_SEARCH_TOOL_RESULT
+ else:
+ fetched = await outbound._run_web_fetch(
+ str(tool_input["url"]), web_fetch_egress
+ )
+ result_content = {
+ "type": "web_fetch_result",
+ "url": fetched["url"],
+ "content": {
+ "type": "document",
+ "source": {
+ "type": "text",
+ "media_type": fetched["media_type"],
+ "data": fetched["data"],
+ },
+ "title": fetched["title"],
+ "citations": {"enabled": True},
+ },
+ "retrieved_at": datetime.now(UTC).isoformat(),
+ }
+ summary = fetched["data"][:_MAX_FETCH_CHARS]
+ result_block_type = WEB_FETCH_TOOL_RESULT
+ except Exception as error:
+ fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
+ outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
+ result_block_type = _result_block_for_tool[tool_name]
+ result_content = {
+ "type": _error_payload_type_for_tool[tool_name],
+ "error_code": "unavailable",
+ }
+ summary = outbound._web_tool_client_error_summary(
+ tool_name, error, verbose=verbose_client_errors
+ )
+
+ output_tokens = max(1, len(summary) // 4)
+
+ yield format_sse_event(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": 1,
+ "content_block": {
+ "type": result_block_type,
+ "tool_use_id": tool_id,
+ "content": result_content,
+ },
+ },
+ )
+ yield format_sse_event(
+ "content_block_stop", {"type": "content_block_stop", "index": 1}
+ )
+ # Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
+ # not eager `text` on `content_block_start`).
+ yield format_sse_event(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": 2,
+ "content_block": {"type": "text", "text": ""},
+ },
+ )
+ yield format_sse_event(
+ "content_block_delta",
+ {
+ "type": "content_block_delta",
+ "index": 2,
+ "delta": {"type": "text_delta", "text": summary},
+ },
+ )
+ yield format_sse_event(
+ "content_block_stop", {"type": "content_block_stop", "index": 2}
+ )
+ yield format_sse_event(
+ "message_delta",
+ {
+ "type": "message_delta",
+ "delta": {"stop_reason": "end_turn", "stop_sequence": None},
+ "usage": {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "server_tool_use": {usage_key: 1},
+ },
+ },
+ )
+ yield format_sse_event("message_stop", {"type": "message_stop"})
diff --git a/cli/entrypoints.py b/cli/entrypoints.py
index ba61a9e..c161de9 100644
--- a/cli/entrypoints.py
+++ b/cli/entrypoints.py
@@ -30,7 +30,8 @@ def serve() -> None:
settings = get_settings()
try:
uvicorn.run(
- "api.app:app",
+ "api.app:create_app",
+ factory=True,
host=settings.host,
port=settings.port,
log_level="debug",
diff --git a/cli/manager.py b/cli/manager.py
index b267633..cb27419 100644
--- a/cli/manager.py
+++ b/cli/manager.py
@@ -29,6 +29,9 @@ class CLISessionManager:
allowed_dirs: list[str] | None = None,
plans_directory: str | None = None,
claude_bin: str = "claude",
+ *,
+ log_raw_cli_diagnostics: bool = False,
+ log_messaging_error_details: bool = False,
):
"""
Initialize the session manager.
@@ -44,6 +47,8 @@ class CLISessionManager:
self.allowed_dirs = allowed_dirs or []
self.plans_directory = plans_directory
self.claude_bin = claude_bin
+ self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
+ self._log_messaging_error_details = log_messaging_error_details
self._sessions: dict[str, CLISession] = {}
self._pending_sessions: dict[str, CLISession] = {}
@@ -79,6 +84,7 @@ class CLISessionManager:
allowed_dirs=self.allowed_dirs,
plans_directory=self.plans_directory,
claude_bin=self.claude_bin,
+ log_raw_cli_diagnostics=self._log_raw_cli_diagnostics,
)
self._pending_sessions[temp_id] = new_session
logger.info(f"Created new session: {temp_id}")
@@ -130,7 +136,17 @@ class CLISessionManager:
try:
await session.stop()
except Exception as e:
- logger.error(f"Error stopping session: {e}")
+ if self._log_messaging_error_details:
+ logger.error(
+ "Error stopping session: {}: {}",
+ type(e).__name__,
+ e,
+ )
+ else:
+ logger.error(
+ "Error stopping session: exc_type={}",
+ type(e).__name__,
+ )
self._sessions.clear()
self._pending_sessions.clear()
diff --git a/cli/session.py b/cli/session.py
index db18f6c..d007975 100644
--- a/cli/session.py
+++ b/cli/session.py
@@ -11,6 +11,9 @@ from loguru import logger
from .process_registry import register_pid, unregister_pid
+# Cap stderr capture so a runaway child cannot exhaust memory; pipe is still drained.
+_MAX_STDERR_CAPTURE_BYTES = 256 * 1024
+
@dataclass(frozen=True, slots=True)
class ClaudeCliConfig:
@@ -33,6 +36,8 @@ class CLISession:
allowed_dirs: list[str] | None = None,
plans_directory: str | None = None,
claude_bin: str = "claude",
+ *,
+ log_raw_cli_diagnostics: bool = False,
):
self.config = ClaudeCliConfig(
workspace_path=os.path.normpath(os.path.abspath(workspace_path)),
@@ -46,11 +51,40 @@ class CLISession:
self.allowed_dirs = self.config.allowed_dirs
self.plans_directory = self.config.plans_directory
self.claude_bin = self.config.claude_bin
+ self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
self.process: asyncio.subprocess.Process | None = None
self.current_session_id: str | None = None
self._is_busy = False
self._cli_lock = asyncio.Lock()
+ @staticmethod
+ async def _drain_stderr_bounded(
+ process: asyncio.subprocess.Process,
+ *,
+ max_bytes: int = _MAX_STDERR_CAPTURE_BYTES,
+ ) -> bytes:
+ """Read stderr concurrently with stdout to avoid subprocess pipe deadlocks.
+
+ Retains at most ``max_bytes`` for logging; any excess is discarded, but
+ the pipe is read until EOF so a noisy child cannot fill the buffer and
+ block forever.
+ """
+ if not process.stderr:
+ return b""
+ parts: list[bytes] = []
+ received = 0
+ while True:
+ chunk = await process.stderr.read(65_536)
+ if not chunk:
+ break
+ if received < max_bytes:
+ take = min(len(chunk), max_bytes - received)
+ if take:
+ parts.append(chunk[:take])
+ received += take
+ # If already at cap, keep reading and discarding until EOF.
+ return b"".join(parts)
+
@property
def is_busy(self) -> bool:
"""Check if a task is currently running."""
@@ -140,6 +174,11 @@ class CLISession:
session_id_extracted = False
buffer = bytearray()
+ stderr_task: asyncio.Task[bytes] | None = None
+ if self.process.stderr:
+ stderr_task = asyncio.create_task(
+ self._drain_stderr_bounded(self.process)
+ )
try:
while True:
@@ -179,23 +218,27 @@ class CLISession:
except asyncio.CancelledError:
# Cancelling the handler task should not leave a Claude CLI
# subprocess running in the background.
- try:
- await asyncio.shield(self.stop())
- finally:
- raise
+ await asyncio.shield(self.stop())
+ raise
+ finally:
+ stderr_bytes = b""
+ if stderr_task is not None:
+ stderr_bytes = await stderr_task
stderr_text = None
- if self.process.stderr:
- stderr_output = await self.process.stderr.read()
- if stderr_output:
- stderr_text = stderr_output.decode(
- "utf-8", errors="replace"
- ).strip()
- logger.error(f"Claude CLI Stderr: {stderr_text}")
- # Yield stderr as error event so it shows in UI
- if stderr_text:
- logger.info("CLI_SESSION: Yielding error event from stderr")
- yield {"type": "error", "error": {"message": stderr_text}}
+ if stderr_bytes:
+ stderr_text = stderr_bytes.decode("utf-8", errors="replace").strip()
+ if stderr_text:
+ if self._log_raw_cli_diagnostics:
+ logger.error("Claude CLI stderr: {}", stderr_text)
+ else:
+ logger.error(
+ "Claude CLI stderr: bytes={} text_chars={}",
+ len(stderr_bytes),
+ len(stderr_text),
+ )
+ logger.info("CLI_SESSION: Yielding error event from stderr")
+ yield {"type": "error", "error": {"message": stderr_text}}
return_code = await self.process.wait()
logger.info(
@@ -230,7 +273,10 @@ class CLISession:
yield event
except json.JSONDecodeError:
- logger.debug(f"Non-JSON output: {line_str}")
+ if self._log_raw_cli_diagnostics:
+ logger.debug("Non-JSON output: {}", line_str)
+ else:
+ logger.debug("Non-JSON CLI line: char_len={}", len(line_str))
yield {"type": "raw", "content": line_str}
def _extract_session_id(self, event: Any) -> str | None:
@@ -273,6 +319,16 @@ class CLISession:
unregister_pid(self.process.pid)
return True
except Exception as e:
- logger.error(f"Error stopping process: {e}")
+ if self._log_raw_cli_diagnostics:
+ logger.error(
+ "Error stopping process: {}: {}",
+ type(e).__name__,
+ e,
+ )
+ else:
+ logger.error(
+ "Error stopping process: exc_type={}",
+ type(e).__name__,
+ )
return False
return False
diff --git a/config/constants.py b/config/constants.py
new file mode 100644
index 0000000..987027a
--- /dev/null
+++ b/config/constants.py
@@ -0,0 +1,10 @@
+"""Shared defaults used by config models and provider adapters."""
+
+# HTTP client connect timeout (seconds). Keep aligned with README.md and .env.example.
+HTTP_CONNECT_TIMEOUT_DEFAULT = 10.0
+
+# Anthropic Messages API default when the client omits max_tokens.
+ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS = 81920
+
+# Max bytes read from a non-200 native messages response when verbose error logging is on.
+NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES = 4096
diff --git a/config/logging_config.py b/config/logging_config.py
index e1a31e4..7c5d8fd 100644
--- a/config/logging_config.py
+++ b/config/logging_config.py
@@ -8,6 +8,7 @@ included at top level for easy grep/filter.
import json
import logging
+import re
from pathlib import Path
from loguru import logger
@@ -17,6 +18,22 @@ _configured = False
# Context keys we promote to top-level JSON for traceability
_CONTEXT_KEYS = ("request_id", "node_id", "chat_id")
+_TELEGRAM_BOT_RE = re.compile(
+ r"(https?://api\.telegram\.org/)bot([0-9]+:[A-Za-z0-9_-]+)(/?)",
+ re.IGNORECASE,
+)
+# Authorization: Bearer (HTTP client / proxy debug lines)
+_AUTH_BEARER_RE = re.compile(
+ r"(\bAuthorization\s*:\s*Bearer\s+)([^\s'\"]+)",
+ re.IGNORECASE,
+)
+
+
+def _redact_sensitive_substrings(message: str) -> str:
+ """Remove obvious API tokens and secrets before JSON log line emission."""
+ text = _TELEGRAM_BOT_RE.sub(r"\1bot\3", message)
+ return _AUTH_BEARER_RE.sub(r"\1", text)
+
def _serialize_with_context(record) -> str:
"""Format record as JSON with context vars at top level.
@@ -26,7 +43,7 @@ def _serialize_with_context(record) -> str:
out = {
"time": str(record["time"]),
"level": record["level"].name,
- "message": record["message"],
+ "message": _redact_sensitive_substrings(str(record["message"])),
"module": record["name"],
"function": record["function"],
"line": record["line"],
@@ -57,11 +74,16 @@ class InterceptHandler(logging.Handler):
)
-def configure_logging(log_file: str, *, force: bool = False) -> None:
+def configure_logging(
+ log_file: str, *, force: bool = False, verbose_third_party: bool = False
+) -> None:
"""Configure loguru with JSON output to log_file and intercept stdlib logging.
Idempotent: skips if already configured (e.g. hot reload).
Use force=True to reconfigure (e.g. in tests with a different log path).
+
+ When ``verbose_third_party`` is false, noisy HTTP and Telegram loggers are capped
+ at WARNING unless explicitly configured otherwise.
"""
global _configured
if _configured and not force:
@@ -88,3 +110,16 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
intercept = InterceptHandler()
logging.root.handlers = [intercept]
logging.root.setLevel(logging.DEBUG)
+
+ third_party = (
+ "httpx",
+ "httpcore",
+ "httpcore.http11",
+ "httpcore.connection",
+ "telegram",
+ "telegram.ext",
+ )
+ for name in third_party:
+ logging.getLogger(name).setLevel(
+ logging.WARNING if not verbose_third_party else logging.NOTSET
+ )
diff --git a/config/nim.py b/config/nim.py
index b3a9c3b..5bf7761 100644
--- a/config/nim.py
+++ b/config/nim.py
@@ -2,6 +2,8 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
+from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
+
class NimSettings(BaseModel):
"""Fixed NVIDIA NIM settings (not configurable via env)."""
@@ -14,7 +16,9 @@ class NimSettings(BaseModel):
)
top_k: int = -1
max_tokens: int = Field(
- 81920, ge=1, description="Maximum number of tokens in output."
+ ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ ge=1,
+ description="Maximum number of tokens in output.",
)
presence_penalty: float = Field(0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
@@ -68,7 +72,7 @@ class NimSettings(BaseModel):
return field_defaults.get(key, 1.0)
try:
val = float(v)
- except Exception as err:
+ except (TypeError, ValueError) as err:
raise ValueError(
f"{info.field_name} must be a float. Got {type(v).__name__}."
) from err
@@ -78,15 +82,15 @@ class NimSettings(BaseModel):
@classmethod
def validate_int_fields(cls, v, info: ValidationInfo):
field_defaults = {
- "max_tokens": 81920,
+ "max_tokens": ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
"min_tokens": 0,
}
if v is None or v == "":
key = info.field_name or "max_tokens"
- return field_defaults.get(key, 81920)
+ return field_defaults.get(key, ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS)
try:
val = int(v)
- except Exception as err:
+ except (TypeError, ValueError) as err:
raise ValueError(
f"{info.field_name} must be an int. Got {type(v).__name__}."
) from err
@@ -99,7 +103,7 @@ class NimSettings(BaseModel):
return None
try:
return int(v)
- except Exception as err:
+ except (TypeError, ValueError) as err:
raise ValueError(
f"{info.field_name} must be an int or empty/None."
) from err
diff --git a/config/provider_catalog.py b/config/provider_catalog.py
new file mode 100644
index 0000000..b0b731e
--- /dev/null
+++ b/config/provider_catalog.py
@@ -0,0 +1,108 @@
+"""Neutral provider catalog: IDs, credentials, defaults, proxy and capability metadata.
+
+Adapter factories live in :mod:`providers.registry`; this module stays free of
+provider implementation imports (see contract tests).
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Literal
+
+TransportType = Literal["openai_chat", "anthropic_messages"]
+
+# Default upstream base URLs (also re-exported via :mod:`providers.defaults`)
+NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
+DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
+OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
+LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
+LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
+OLLAMA_DEFAULT_BASE = "http://localhost:11434"
+
+
+@dataclass(frozen=True, slots=True)
+class ProviderDescriptor:
+ """Metadata for building :class:`~providers.base.ProviderConfig` and factory wiring."""
+
+ provider_id: str
+ transport_type: TransportType
+ capabilities: tuple[str, ...]
+ credential_env: str | None = None
+ credential_url: str | None = None
+ credential_attr: str | None = None
+ static_credential: str | None = None
+ default_base_url: str | None = None
+ base_url_attr: str | None = None
+ proxy_attr: str | None = None
+
+
+PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
+ "nvidia_nim": ProviderDescriptor(
+ provider_id="nvidia_nim",
+ transport_type="openai_chat",
+ credential_env="NVIDIA_NIM_API_KEY",
+ credential_url="https://build.nvidia.com/settings/api-keys",
+ credential_attr="nvidia_nim_api_key",
+ default_base_url=NVIDIA_NIM_DEFAULT_BASE,
+ proxy_attr="nvidia_nim_proxy",
+ capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
+ ),
+ "open_router": ProviderDescriptor(
+ provider_id="open_router",
+ transport_type="anthropic_messages",
+ credential_env="OPENROUTER_API_KEY",
+ credential_url="https://openrouter.ai/keys",
+ credential_attr="open_router_api_key",
+ default_base_url=OPENROUTER_DEFAULT_BASE,
+ proxy_attr="open_router_proxy",
+ capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
+ ),
+ "deepseek": ProviderDescriptor(
+ provider_id="deepseek",
+ transport_type="openai_chat",
+ credential_env="DEEPSEEK_API_KEY",
+ credential_url="https://platform.deepseek.com/api_keys",
+ credential_attr="deepseek_api_key",
+ default_base_url=DEEPSEEK_DEFAULT_BASE,
+ capabilities=("chat", "streaming", "thinking"),
+ ),
+ "lmstudio": ProviderDescriptor(
+ provider_id="lmstudio",
+ transport_type="anthropic_messages",
+ static_credential="lm-studio",
+ default_base_url=LMSTUDIO_DEFAULT_BASE,
+ base_url_attr="lm_studio_base_url",
+ proxy_attr="lmstudio_proxy",
+ capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
+ ),
+ "llamacpp": ProviderDescriptor(
+ provider_id="llamacpp",
+ transport_type="anthropic_messages",
+ static_credential="llamacpp",
+ default_base_url=LLAMACPP_DEFAULT_BASE,
+ base_url_attr="llamacpp_base_url",
+ proxy_attr="llamacpp_proxy",
+ capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
+ ),
+ "ollama": ProviderDescriptor(
+ provider_id="ollama",
+ transport_type="anthropic_messages",
+ static_credential="ollama",
+ default_base_url=OLLAMA_DEFAULT_BASE,
+ base_url_attr="ollama_base_url",
+ capabilities=(
+ "chat",
+ "streaming",
+ "tools",
+ "thinking",
+ "native_anthropic",
+ "local",
+ ),
+ ),
+}
+
+# Order matches docs / historical error text; must match PROVIDER_CATALOG keys.
+SUPPORTED_PROVIDER_IDS: tuple[str, ...] = tuple(PROVIDER_CATALOG.keys())
+
+if len(set(SUPPORTED_PROVIDER_IDS)) != len(SUPPORTED_PROVIDER_IDS):
+ raise AssertionError("Duplicate provider ids in PROVIDER_CATALOG key order")
diff --git a/config/provider_ids.py b/config/provider_ids.py
index e3e50bf..fd08ab7 100644
--- a/config/provider_ids.py
+++ b/config/provider_ids.py
@@ -1,18 +1,7 @@
-"""Canonical list of model provider type prefixes (provider_id values).
-
-`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this
-module holds the id set for config validation and must stay in sync
-(registries assert in `providers.registry`).
-"""
+"""Canonical provider id tuple (re-exported from the provider catalog)."""
from __future__ import annotations
-# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys.
-SUPPORTED_PROVIDER_IDS: tuple[str, ...] = (
- "nvidia_nim",
- "open_router",
- "deepseek",
- "lmstudio",
- "llamacpp",
- "ollama",
-)
+from .provider_catalog import SUPPORTED_PROVIDER_IDS
+
+__all__ = ("SUPPORTED_PROVIDER_IDS",)
diff --git a/config/settings.py b/config/settings.py
index f7129fb..c4afb1b 100644
--- a/config/settings.py
+++ b/config/settings.py
@@ -10,6 +10,7 @@ from dotenv import dotenv_values
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
+from .constants import HTTP_CONNECT_TIMEOUT_DEFAULT
from .nim import NimSettings
from .provider_ids import SUPPORTED_PROVIDER_IDS
@@ -105,6 +106,12 @@ class Settings(BaseSettings):
messaging_platform: str = Field(
default="discord", validation_alias="MESSAGING_PLATFORM"
)
+ messaging_rate_limit: int = Field(
+ default=1, validation_alias="MESSAGING_RATE_LIMIT"
+ )
+ messaging_rate_window: float = Field(
+ default=1.0, validation_alias="MESSAGING_RATE_WINDOW"
+ )
# ==================== NVIDIA NIM Config ====================
nvidia_nim_api_key: str = ""
@@ -173,7 +180,8 @@ class Settings(BaseSettings):
default=10.0, validation_alias="HTTP_WRITE_TIMEOUT"
)
http_connect_timeout: float = Field(
- default=2.0, validation_alias="HTTP_CONNECT_TIMEOUT"
+ default=HTTP_CONNECT_TIMEOUT_DEFAULT,
+ validation_alias="HTTP_CONNECT_TIMEOUT",
)
# ==================== Fast Prefix Detection ====================
@@ -185,6 +193,51 @@ class Settings(BaseSettings):
enable_suggestion_mode_skip: bool = True
enable_filepath_extraction_mock: bool = True
+ # ==================== Local web server tools (web_search / web_fetch) ====================
+ # Off by default: these tools perform outbound HTTP from the proxy (SSRF risk).
+ enable_web_server_tools: bool = Field(
+ default=False, validation_alias="ENABLE_WEB_SERVER_TOOLS"
+ )
+ # Comma-separated URL schemes allowed for web_fetch (default: http,https).
+ web_fetch_allowed_schemes: str = Field(
+ default="http,https", validation_alias="WEB_FETCH_ALLOWED_SCHEMES"
+ )
+ # When true, skip private/loopback/link-local IP blocking for web_fetch (lab only).
+ web_fetch_allow_private_networks: bool = Field(
+ default=False, validation_alias="WEB_FETCH_ALLOW_PRIVATE_NETWORKS"
+ )
+
+ # ==================== Debug / diagnostic logging (avoid sensitive content) ====================
+ # When false (default), API and SSE helpers log only metadata (counts, lengths, ids).
+ log_raw_api_payloads: bool = Field(
+ default=False, validation_alias="LOG_RAW_API_PAYLOADS"
+ )
+ log_raw_sse_events: bool = Field(
+ default=False, validation_alias="LOG_RAW_SSE_EVENTS"
+ )
+ # When false (default), unhandled exceptions log only type + route metadata (no message/traceback).
+ log_api_error_tracebacks: bool = Field(
+ default=False, validation_alias="LOG_API_ERROR_TRACEBACKS"
+ )
+ # When false (default), messaging logs omit text/transcription previews (metadata only).
+ log_raw_messaging_content: bool = Field(
+ default=False, validation_alias="LOG_RAW_MESSAGING_CONTENT"
+ )
+ # When true, log full Claude CLI stderr, non-JSON lines, and parser error text.
+ log_raw_cli_diagnostics: bool = Field(
+ default=False, validation_alias="LOG_RAW_CLI_DIAGNOSTICS"
+ )
+ # When true, log exception text / CLI error strings in messaging (may leak user content).
+ log_messaging_error_details: bool = Field(
+ default=False, validation_alias="LOG_MESSAGING_ERROR_DETAILS"
+ )
+ debug_platform_edits: bool = Field(
+ default=False, validation_alias="DEBUG_PLATFORM_EDITS"
+ )
+ debug_subagent_stack: bool = Field(
+ default=False, validation_alias="DEBUG_SUBAGENT_STACK"
+ )
+
# ==================== NIM Settings ====================
nim: NimSettings = Field(default_factory=NimSettings)
@@ -215,6 +268,9 @@ class Settings(BaseSettings):
claude_workspace: str = "./agent_workspace"
allowed_dir: str = ""
claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN")
+ max_message_log_entries_per_chat: int | None = Field(
+ default=None, validation_alias="MAX_MESSAGE_LOG_ENTRIES_PER_CHAT"
+ )
# ==================== Server ====================
host: str = "0.0.0.0"
@@ -254,6 +310,13 @@ class Settings(BaseSettings):
return None
return v
+ @field_validator("max_message_log_entries_per_chat", mode="before")
+ @classmethod
+ def parse_optional_log_cap(cls, v: Any) -> Any:
+ if v == "" or v is None:
+ return None
+ return v
+
@field_validator("whisper_device")
@classmethod
def validate_whisper_device(cls, v: str) -> str:
@@ -272,6 +335,33 @@ class Settings(BaseSettings):
)
return v
+ @field_validator("messaging_rate_limit")
+ @classmethod
+ def validate_messaging_rate_limit(cls, v: int) -> int:
+ if v <= 0:
+ raise ValueError("messaging_rate_limit must be > 0")
+ return v
+
+ @field_validator("messaging_rate_window")
+ @classmethod
+ def validate_messaging_rate_window(cls, v: float) -> float:
+ if v <= 0:
+ raise ValueError("messaging_rate_window must be > 0")
+ return float(v)
+
+ @field_validator("web_fetch_allowed_schemes")
+ @classmethod
+ def validate_web_fetch_allowed_schemes(cls, v: str) -> str:
+ schemes = [part.strip().lower() for part in v.split(",") if part.strip()]
+ if not schemes:
+ raise ValueError("web_fetch_allowed_schemes must list at least one scheme")
+ for scheme in schemes:
+ if not scheme.isascii() or not scheme.isalpha():
+ raise ValueError(
+ f"Invalid URL scheme in web_fetch_allowed_schemes: {scheme!r}"
+ )
+ return ",".join(schemes)
+
@field_validator("ollama_base_url")
@classmethod
def validate_ollama_base_url(cls, v: str) -> str:
@@ -329,12 +419,12 @@ class Settings(BaseSettings):
@property
def provider_type(self) -> str:
"""Extract provider type from the default model string."""
- return self.model.split("/", 1)[0]
+ return Settings.parse_provider_type(self.model)
@property
def model_name(self) -> str:
"""Extract the actual model name from the default model string."""
- return self.model.split("/", 1)[1]
+ return Settings.parse_model_name(self.model)
def resolve_model(self, claude_model_name: str) -> str:
"""Resolve a Claude model name to the configured provider/model string.
@@ -362,6 +452,14 @@ class Settings(BaseSettings):
return self.enable_sonnet_thinking
return self.enable_model_thinking
+ def web_fetch_allowed_scheme_set(self) -> frozenset[str]:
+ """Return normalized schemes allowed for web_fetch."""
+ return frozenset(
+ part.strip().lower()
+ for part in self.web_fetch_allowed_schemes.split(",")
+ if part.strip()
+ )
+
@staticmethod
def parse_provider_type(model_string: str) -> str:
"""Extract provider type from any 'provider/model' string."""
diff --git a/core/anthropic/__init__.py b/core/anthropic/__init__.py
index 9af5987..eaea8a1 100644
--- a/core/anthropic/__init__.py
+++ b/core/anthropic/__init__.py
@@ -1,9 +1,19 @@
"""Anthropic protocol helpers shared across API, providers, and integrations."""
from .content import extract_text_from_content, get_block_attr, get_block_type
-from .conversion import AnthropicToOpenAIConverter, build_base_request_body
-from .errors import append_request_id, get_user_facing_error_message
-from .sse import ContentBlockManager, SSEBuilder, map_stop_reason
+from .conversion import (
+ AnthropicToOpenAIConverter,
+ OpenAIConversionError,
+ build_base_request_body,
+)
+from .errors import (
+ append_request_id,
+ format_user_error_preview,
+ get_user_facing_error_message,
+)
+from .native_messages_request import sanitize_native_messages_thinking_policy
+from .provider_stream_error import iter_provider_stream_error_sse_events
+from .sse import ContentBlockManager, SSEBuilder, format_sse_event, map_stop_reason
from .thinking import ContentChunk, ContentType, ThinkTagParser
from .tokens import get_token_count
from .tools import HeuristicToolParser
@@ -15,15 +25,20 @@ __all__ = [
"ContentChunk",
"ContentType",
"HeuristicToolParser",
+ "OpenAIConversionError",
"SSEBuilder",
"ThinkTagParser",
"append_request_id",
"build_base_request_body",
"extract_text_from_content",
+ "format_sse_event",
+ "format_user_error_preview",
"get_block_attr",
"get_block_type",
"get_token_count",
"get_user_facing_error_message",
+ "iter_provider_stream_error_sse_events",
"map_stop_reason",
+ "sanitize_native_messages_thinking_policy",
"set_if_not_none",
]
diff --git a/core/anthropic/conversion.py b/core/anthropic/conversion.py
index 5979ba1..c484ceb 100644
--- a/core/anthropic/conversion.py
+++ b/core/anthropic/conversion.py
@@ -3,10 +3,34 @@
import json
from typing import Any
+from pydantic import BaseModel
+
from .content import get_block_attr, get_block_type
from .utils import set_if_not_none
+class OpenAIConversionError(Exception):
+ """Raised when Anthropic content cannot be converted to OpenAI chat without data loss."""
+
+
+def _openai_reject_native_only_top_level_fields(request_data: Any) -> None:
+ """OpenAI chat providers may only convert known top-level request fields.
+
+ First-class model fields (e.g. ``context_management``) are not forwarded to
+ the OpenAI API but are allowed so clients do not hit spurious 400s.
+ Unknown extra keys (``__pydantic_extra__``) are still rejected.
+ """
+ if not isinstance(request_data, BaseModel):
+ return
+ extra = getattr(request_data, "__pydantic_extra__", None)
+ if not extra:
+ return
+ raise OpenAIConversionError(
+ "OpenAI chat conversion does not support these top-level request fields: "
+ f"{sorted(str(k) for k in extra)}. Use a native Anthropic transport provider."
+ )
+
+
def _tool_name(tool: Any) -> str:
return str(getattr(tool, "name", "") or "")
@@ -18,6 +42,27 @@ def _tool_input_schema(tool: Any) -> dict[str, Any]:
return {"type": "object", "properties": {}}
+def _serialize_tool_result_content(tool_content: Any) -> str:
+ """Serialize tool_result content for OpenAI ``role: tool`` messages (stable JSON for structured values)."""
+ if tool_content is None:
+ return ""
+ if isinstance(tool_content, str):
+ return tool_content
+ if isinstance(tool_content, dict):
+ return json.dumps(tool_content, ensure_ascii=False)
+ if isinstance(tool_content, list):
+ parts: list[str] = []
+ for item in tool_content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ parts.append(str(item.get("text", "")))
+ elif isinstance(item, dict):
+ parts.append(json.dumps(item, ensure_ascii=False))
+ else:
+ parts.append(str(item))
+ return "\n".join(parts)
+ return str(tool_content)
+
+
class AnthropicToOpenAIConverter:
"""Convert Anthropic message format to OpenAI-compatible format."""
@@ -64,20 +109,38 @@ class AnthropicToOpenAIConverter:
content_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
+ seen_tool_use = False
for block in content:
block_type = get_block_type(block)
if block_type == "text":
+ if seen_tool_use:
+ raise OpenAIConversionError(
+ "OpenAI chat conversion does not support assistant text after "
+ "tool_use in the same message; split the transcript or use a "
+ "native Anthropic provider."
+ )
content_parts.append(get_block_attr(block, "text", ""))
elif block_type == "thinking":
if not include_thinking:
continue
+ if seen_tool_use:
+ raise OpenAIConversionError(
+ "OpenAI chat conversion does not support assistant thinking after "
+ "tool_use in the same message; split the transcript or use a "
+ "native Anthropic provider."
+ )
thinking = get_block_attr(block, "thinking", "")
content_parts.append(f"\n{thinking}\n")
if include_reasoning_content:
thinking_parts.append(thinking)
+ elif block_type == "redacted_thinking":
+ # Opaque provider continuation data; do not materialize as model-visible text
+ # or reasoning_content for OpenAI chat upstreams.
+ continue
elif block_type == "tool_use":
+ seen_tool_use = True
tool_input = get_block_attr(block, "input", {})
tool_calls.append(
{
@@ -91,6 +154,19 @@ class AnthropicToOpenAIConverter:
},
}
)
+ elif block_type == "image":
+ raise OpenAIConversionError(
+ "Assistant image blocks are not supported for OpenAI chat conversion."
+ )
+ elif block_type in (
+ "server_tool_use",
+ "web_search_tool_result",
+ "web_fetch_tool_result",
+ ):
+ raise OpenAIConversionError(
+ "OpenAI chat conversion does not support Anthropic server tool blocks "
+ f"({block_type!r} in an assistant message). Use a native Anthropic transport provider."
+ )
content_str = "\n\n".join(content_parts)
if not content_str and not tool_calls:
@@ -122,21 +198,21 @@ class AnthropicToOpenAIConverter:
if block_type == "text":
text_parts.append(get_block_attr(block, "text", ""))
+ elif block_type == "image":
+ raise OpenAIConversionError(
+ "User message image blocks are not supported for OpenAI chat "
+ "conversion; use a vision-capable native Anthropic provider or "
+ "extend the converter."
+ )
elif block_type == "tool_result":
flush_text()
tool_content = get_block_attr(block, "content", "")
- if isinstance(tool_content, list):
- tool_content = "\n".join(
- item.get("text", str(item))
- if isinstance(item, dict)
- else str(item)
- for item in tool_content
- )
+ serialized = _serialize_tool_result_content(tool_content)
result.append(
{
"role": "tool",
"tool_call_id": get_block_attr(block, "tool_use_id"),
- "content": str(tool_content) if tool_content else "",
+ "content": serialized if serialized else "",
}
)
@@ -199,6 +275,7 @@ def build_base_request_body(
include_reasoning_content: bool = False,
) -> dict[str, Any]:
"""Build the common parts of an OpenAI-format request body."""
+ _openai_reject_native_only_top_level_fields(request_data)
messages = AnthropicToOpenAIConverter.convert_messages(
request_data.messages,
include_thinking=include_thinking,
diff --git a/core/anthropic/emitted_sse_tracker.py b/core/anthropic/emitted_sse_tracker.py
new file mode 100644
index 0000000..079dd51
--- /dev/null
+++ b/core/anthropic/emitted_sse_tracker.py
@@ -0,0 +1,97 @@
+"""Track content-block state for native Anthropic SSE strings we emit to clients."""
+
+from __future__ import annotations
+
+import uuid
+from collections.abc import Iterator
+from contextlib import suppress
+from typing import Any
+
+from core.anthropic.sse import SSEBuilder, format_sse_event
+from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines
+
+
+class EmittedNativeSseTracker:
+ """Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index."""
+
+ def __init__(self) -> None:
+ self._buf = ""
+ self._open_stack: list[int] = []
+ self._max_index = -1
+ self.message_id: str | None = None
+ self.model: str = ""
+
+ def feed(self, chunk: str) -> None:
+ """Record SSE frames completed by ``chunk`` (handles splitting across reads)."""
+ self._buf += chunk
+ while True:
+ sep = self._buf.find("\n\n")
+ if sep < 0:
+ break
+ frame = self._buf[:sep]
+ self._buf = self._buf[sep + 2 :]
+ if not frame.strip():
+ continue
+ for event in parse_sse_lines(frame.splitlines()):
+ self._observe(event)
+
+ def _observe(self, event: SSEEvent) -> None:
+ if event.event == "message_start":
+ message = event.data.get("message")
+ if isinstance(message, dict):
+ mid = message.get("id")
+ if isinstance(mid, str) and mid:
+ self.message_id = mid
+ model = message.get("model")
+ if isinstance(model, str) and model:
+ self.model = model
+ return
+
+ if event.event == "content_block_start":
+ idx = event_index(event)
+ self._max_index = max(self._max_index, idx)
+ self._open_stack.append(idx)
+ return
+
+ if event.event == "content_block_stop":
+ idx = event_index(event)
+ if self._open_stack and self._open_stack[-1] == idx:
+ self._open_stack.pop()
+ else:
+ with suppress(ValueError):
+ self._open_stack.remove(idx)
+
+ def next_content_index(self) -> int:
+ """Next unused content block index based on emitted starts."""
+ return self._max_index + 1
+
+ def iter_close_unclosed_blocks(self) -> Iterator[str]:
+ """Yield ``content_block_stop`` events for blocks that were started but not stopped."""
+ while self._open_stack:
+ idx = self._open_stack.pop()
+ yield format_sse_event(
+ "content_block_stop",
+ {"type": "content_block_stop", "index": idx},
+ )
+
+ def iter_midstream_error_tail(
+ self,
+ error_message: str,
+ *,
+ request: Any,
+ input_tokens: int,
+ log_raw_sse_events: bool,
+ ) -> Iterator[str]:
+ """Close dangling blocks, emit a text error block at a fresh index, then message tail."""
+ mid = self.message_id or f"msg_{uuid.uuid4()}"
+ model = self.model or (getattr(request, "model", "") or "")
+ sse = SSEBuilder(
+ mid,
+ model,
+ input_tokens,
+ log_raw_events=log_raw_sse_events,
+ )
+ sse.blocks.next_index = self.next_content_index()
+ yield from sse.emit_error(error_message)
+ yield sse.message_delta("end_turn", 1)
+ yield sse.message_stop()
diff --git a/core/anthropic/errors.py b/core/anthropic/errors.py
index 38e3d12..afe093d 100644
--- a/core/anthropic/errors.py
+++ b/core/anthropic/errors.py
@@ -9,11 +9,12 @@ def get_user_facing_error_message(
*,
read_timeout_s: float | None = None,
) -> str:
- """Return a readable, non-empty error message for users."""
- message = str(e).strip()
- if message:
- return message
+ """Return a readable, non-empty error message for users.
+ Known transport and OpenAI SDK exception types are mapped to stable wording
+ before falling back to ``str(e)``, so empty or noisy SDK messages do not skip
+ the mapped path.
+ """
if isinstance(e, httpx.ReadTimeout):
if read_timeout_s is not None:
return f"Provider request timed out after {read_timeout_s:g}s."
@@ -25,13 +26,20 @@ def get_user_facing_error_message(
return f"Provider request timed out after {read_timeout_s:g}s."
return "Request timed out."
+ if isinstance(e, openai.RateLimitError):
+ return "Provider rate limit reached. Please retry shortly."
+ if isinstance(e, openai.AuthenticationError):
+ return "Provider authentication failed. Check API key."
+ if isinstance(e, openai.BadRequestError):
+ return "Invalid request sent to provider."
+
name = type(e).__name__
status_code = getattr(e, "status_code", None)
- if isinstance(e, openai.RateLimitError) or name == "RateLimitError":
+ if name == "RateLimitError":
return "Provider rate limit reached. Please retry shortly."
- if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError":
+ if name == "AuthenticationError":
return "Provider authentication failed. Check API key."
- if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError":
+ if name == "InvalidRequestError":
return "Invalid request sent to provider."
if name == "OverloadedError":
return "Provider is currently overloaded. Please retry."
@@ -42,9 +50,18 @@ def get_user_facing_error_message(
if name.endswith("ProviderError") or name == "ProviderError":
return "Provider request failed."
+ message = str(e).strip()
+ if message:
+ return message
+
return "Provider request failed unexpectedly."
+def format_user_error_preview(exc: Exception, *, max_len: int = 200) -> str:
+ """Truncate a user-facing error string for short chat replies."""
+ return get_user_facing_error_message(exc)[:max_len]
+
+
def append_request_id(message: str, request_id: str | None) -> str:
"""Append request_id suffix when available."""
base = message.strip() or "Provider request failed unexpectedly."
diff --git a/core/anthropic/native_messages_request.py b/core/anthropic/native_messages_request.py
new file mode 100644
index 0000000..ed24c6c
--- /dev/null
+++ b/core/anthropic/native_messages_request.py
@@ -0,0 +1,260 @@
+"""Native Anthropic Messages request body construction (JSON-ready dicts).
+
+Provider adapters supply policy via parameters (defaults, OpenRouter post-steps).
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import Any
+
+from pydantic import BaseModel
+
+_REQUEST_FIELDS = (
+ "model",
+ "messages",
+ "system",
+ "max_tokens",
+ "stop_sequences",
+ "stream",
+ "temperature",
+ "top_p",
+ "top_k",
+ "metadata",
+ "tools",
+ "tool_choice",
+ "thinking",
+ "context_management",
+ "output_config",
+ "mcp_servers",
+ "extra_body",
+)
+
+# Keys that would override routed canonical request fields if merged from ``extra_body``.
+_OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS = frozenset(
+ {
+ "model",
+ "messages",
+ "system",
+ "tools",
+ "tool_choice",
+ "stream",
+ "max_tokens",
+ "temperature",
+ "top_p",
+ "top_k",
+ "metadata",
+ "stop_sequences",
+ "context_management",
+ "output_config",
+ "mcp_servers",
+ }
+)
+
+
+class OpenRouterExtraBodyError(ValueError):
+ """``extra_body`` contained reserved keys that would override canonical fields."""
+
+
+def validate_openrouter_extra_body(extra: Any) -> None:
+ """Reject ``extra_body`` keys that must not override routed request fields."""
+ if not isinstance(extra, dict) or not extra:
+ return
+ bad = _OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS & extra.keys()
+ if bad:
+ raise OpenRouterExtraBodyError(
+ f"extra_body must not override canonical request fields: {sorted(bad)}"
+ )
+
+
+_INTERNAL_FIELDS = {
+ "thinking",
+ "extra_body",
+}
+
+
+def _serialize_value(value: Any) -> Any:
+ """Convert Pydantic models and lightweight objects into JSON-ready values."""
+ if isinstance(value, BaseModel):
+ return value.model_dump(exclude_none=True)
+ if isinstance(value, dict):
+ return {
+ key: _serialize_value(item)
+ for key, item in value.items()
+ if item is not None
+ }
+ if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
+ return [_serialize_value(item) for item in value]
+ if value is None or isinstance(value, str | int | float | bool):
+ return value
+ if hasattr(value, "__dict__"):
+ return {
+ key: _serialize_value(item)
+ for key, item in vars(value).items()
+ if not key.startswith("_") and item is not None
+ }
+ return value
+
+
+def _dump_request_fields(request_data: Any) -> dict[str, Any]:
+ """Extract the public request fields (OpenRouter-style explicit field list)."""
+ if isinstance(request_data, BaseModel):
+ return request_data.model_dump(exclude_none=True)
+
+ dumped: dict[str, Any] = {}
+ for field in _REQUEST_FIELDS:
+ value = getattr(request_data, field, None)
+ if value is not None:
+ dumped[field] = _serialize_value(value)
+ return dumped
+
+
+def sanitize_native_messages_thinking_policy(
+ messages: Any, *, thinking_enabled: bool
+) -> Any:
+ """Filter assistant message thinking blocks for upstream native Anthropic JSON.
+
+ When ``thinking_enabled`` is false, remove ``thinking`` and ``redacted_thinking``
+ history so disabled policy is not undermined by prior turns.
+
+ When true, keep ``redacted_thinking`` and signed ``thinking``; remove only
+ unsigned plain ``thinking`` blocks (not replayable).
+ """
+ if not isinstance(messages, list):
+ return messages
+
+ sanitized_messages: list[Any] = []
+ for message in messages:
+ if not isinstance(message, dict):
+ sanitized_messages.append(message)
+ continue
+
+ if message.get("role") != "assistant":
+ sanitized_messages.append(message)
+ continue
+
+ content = message.get("content")
+ if not isinstance(content, list):
+ sanitized_messages.append(message)
+ continue
+
+ if not thinking_enabled:
+ sanitized_content = [
+ block
+ for block in content
+ if not (
+ isinstance(block, dict)
+ and block.get("type") in ("thinking", "redacted_thinking")
+ )
+ ]
+ else:
+ sanitized_content = [
+ block
+ for block in content
+ if not (
+ isinstance(block, dict)
+ and block.get("type") == "thinking"
+ and not isinstance(block.get("signature"), str)
+ )
+ ]
+
+ sanitized_message = dict(message)
+ sanitized_message["content"] = sanitized_content or ""
+ sanitized_messages.append(sanitized_message)
+
+ return sanitized_messages
+
+
+def _normalize_system_prompt_for_openrouter(system: Any) -> Any:
+ """Flatten Claude SDK system blocks for OpenRouter's native endpoint."""
+ if not isinstance(system, list):
+ return system
+
+ text_parts: list[str] = []
+ for block in system:
+ if not isinstance(block, dict):
+ continue
+ if block.get("type") == "text" and isinstance(block.get("text"), str):
+ text_parts.append(block["text"])
+ return "\n\n".join(text_parts).strip() if text_parts else system
+
+
+def _apply_openrouter_reasoning_policy(body: dict[str, Any], thinking_cfg: Any) -> None:
+ """Map Anthropic thinking controls onto OpenRouter reasoning controls."""
+ reasoning = body.setdefault("reasoning", {"enabled": True})
+ if not isinstance(reasoning, dict):
+ return
+ reasoning.setdefault("enabled", True)
+ if not isinstance(thinking_cfg, dict):
+ return
+ budget_tokens = thinking_cfg.get("budget_tokens")
+ if isinstance(budget_tokens, int):
+ reasoning.setdefault("max_tokens", budget_tokens)
+
+
+def build_base_native_anthropic_request_body(
+ request: Any,
+ *,
+ default_max_tokens: int,
+ thinking_enabled: bool,
+) -> dict[str, Any]:
+ """Serialize a Pydantic messages request to a generic native Anthropic body."""
+ body = request.model_dump(exclude_none=True)
+
+ body.pop("extra_body", None)
+
+ if "thinking" in body:
+ thinking_cfg = body.pop("thinking")
+ if thinking_enabled and isinstance(thinking_cfg, dict):
+ thinking_payload: dict[str, Any] = {"type": "enabled"}
+ budget_tokens = thinking_cfg.get("budget_tokens")
+ if isinstance(budget_tokens, int):
+ thinking_payload["budget_tokens"] = budget_tokens
+ body["thinking"] = thinking_payload
+
+ if "max_tokens" not in body:
+ body["max_tokens"] = default_max_tokens
+
+ if "messages" in body:
+ body["messages"] = sanitize_native_messages_thinking_policy(
+ body["messages"],
+ thinking_enabled=thinking_enabled,
+ )
+
+ return body
+
+
+def build_openrouter_native_request_body(
+ request_data: Any,
+ *,
+ thinking_enabled: bool,
+ default_max_tokens: int,
+) -> dict[str, Any]:
+ """Build an Anthropic-format request body for OpenRouter (policy hooks built-in)."""
+ dumped_request = _dump_request_fields(request_data)
+ request_extra = dumped_request.pop("extra_body", None)
+ thinking_cfg = dumped_request.get("thinking")
+ body: dict[str, Any] = {
+ key: value
+ for key, value in dumped_request.items()
+ if key not in _INTERNAL_FIELDS
+ }
+
+ if isinstance(request_extra, dict):
+ validate_openrouter_extra_body(request_extra)
+ body.update(request_extra)
+
+ body["messages"] = sanitize_native_messages_thinking_policy(
+ body.get("messages"),
+ thinking_enabled=thinking_enabled,
+ )
+ if "system" in body:
+ body["system"] = _normalize_system_prompt_for_openrouter(body["system"])
+ body["stream"] = True
+ if body.get("max_tokens") is None:
+ body["max_tokens"] = default_max_tokens
+
+ if thinking_enabled:
+ _apply_openrouter_reasoning_policy(body, thinking_cfg)
+
+ return body
diff --git a/core/anthropic/native_sse_block_policy.py b/core/anthropic/native_sse_block_policy.py
new file mode 100644
index 0000000..7a3fd81
--- /dev/null
+++ b/core/anthropic/native_sse_block_policy.py
@@ -0,0 +1,313 @@
+"""Shared native Anthropic SSE thinking policy, block remapping, and overlap repair.
+
+Used by :class:`OpenRouterProvider` and line-mode
+:class:`providers.anthropic_messages.AnthropicMessagesTransport` providers.
+"""
+
+from __future__ import annotations
+
+import copy
+import json
+from dataclasses import dataclass, field
+from typing import Any
+
+__all__ = [
+ "NativeSseBlockPolicyState",
+ "format_native_sse_event",
+ "is_terminal_openrouter_done_event",
+ "parse_native_sse_event",
+ "transform_native_sse_block_event",
+]
+
+
+@dataclass
+class _UpstreamBlockState:
+ """Per-upstream content block: segment index and liveness in the model stream."""
+
+ block_type: str
+ down_index: int
+ open: bool
+ last_start_block: dict[str, Any] | None = None
+
+
+@dataclass
+class NativeSseBlockPolicyState:
+ """Track per-upstream content blocks and remapped Anthropic ``index`` field."""
+
+ next_index: int = 0
+ by_upstream: dict[int, _UpstreamBlockState] = field(default_factory=dict)
+ dropped_indexes: set[int] = field(default_factory=set)
+ pending_suppressed_stops: set[int] = field(default_factory=set)
+ message_stopped: bool = False
+
+
+def format_native_sse_event(event_name: str | None, data_text: str) -> str:
+ """Format an SSE event from its event name and data payload."""
+ lines: list[str] = []
+ if event_name:
+ lines.append(f"event: {event_name}")
+ lines.extend(f"data: {line}" for line in data_text.splitlines())
+ return "\n".join(lines) + "\n\n"
+
+
+def parse_native_sse_event(event: str) -> tuple[str | None, str]:
+ """Extract the event name and raw data payload from an SSE event."""
+ event_name = None
+ data_lines: list[str] = []
+ for line in event.strip().splitlines():
+ if line.startswith("event:"):
+ event_name = line[6:].strip()
+ elif line.startswith("data:"):
+ data_lines.append(line[5:].lstrip())
+ return event_name, "\n".join(data_lines)
+
+
+def is_terminal_openrouter_done_event(event_name: str | None, data_text: str) -> bool:
+ """Return whether an event is OpenAI-style terminal noise (``[DONE]``)."""
+ return (event_name is None or event_name in {"data", "done"}) and (
+ data_text.strip().upper() == "[DONE]"
+ )
+
+
+def _delta_type_to_block_kind(delta_type: Any) -> str | None:
+ """Map a content_block_delta type to a content block kind (text/thinking/tool_use)."""
+ if not isinstance(delta_type, str):
+ return None
+ if delta_type in {"thinking_delta", "signature_delta"}:
+ return "thinking"
+ if delta_type == "text_delta":
+ return "text"
+ if delta_type == "input_json_delta":
+ return "tool_use"
+ return None
+
+
+def _synthetic_start_content_block(
+ block_kind: str,
+ *,
+ upstream_index: int,
+ stored_tool_block: dict[str, Any] | None = None,
+) -> dict[str, Any]:
+ """Build a `content_block` for a `content_block_start` with empty streaming fields."""
+ if block_kind == "tool_use":
+ if (
+ isinstance(stored_tool_block, dict)
+ and stored_tool_block.get("type") == "tool_use"
+ ):
+ tool_id = stored_tool_block.get("id")
+ name = stored_tool_block.get("name")
+ inp = stored_tool_block.get("input")
+ return {
+ "type": "tool_use",
+ "id": tool_id
+ if isinstance(tool_id, str) and tool_id
+ else f"toolu_or_{upstream_index}",
+ "name": name if isinstance(name, str) else "",
+ "input": inp if isinstance(inp, dict) else {},
+ }
+ return {
+ "type": "tool_use",
+ "id": f"toolu_or_{upstream_index}",
+ "name": "",
+ "input": {},
+ }
+ if block_kind == "thinking":
+ return {"type": "thinking", "thinking": ""}
+ if block_kind == "text":
+ return {"type": "text", "text": ""}
+ return {"type": "text", "text": ""}
+
+
+def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool:
+ if not isinstance(block_type, str):
+ return False
+ if block_type.startswith("redacted_thinking"):
+ return not thinking_enabled
+ return not thinking_enabled and "thinking" in block_type
+
+
+def _synthetic_close_other_open_blocks(
+ state: NativeSseBlockPolicyState, current_upstream: int
+) -> str:
+ """Close every open block except `current_upstream` and track duplicate upstream stops."""
+ out: list[str] = []
+ for upstream, seg in list(state.by_upstream.items()):
+ if upstream == current_upstream or not seg.open:
+ continue
+ out.append(
+ format_native_sse_event(
+ "content_block_stop",
+ json.dumps(
+ {
+ "type": "content_block_stop",
+ "index": seg.down_index,
+ }
+ ),
+ )
+ )
+ seg.open = False
+ state.pending_suppressed_stops.add(upstream)
+ return "".join(out)
+
+
+def _allocate_new_segment(
+ state: NativeSseBlockPolicyState,
+ upstream_index: int,
+ block_type: str,
+ *,
+ last_start_block: dict[str, Any] | None = None,
+) -> int:
+ """Assign a new downstream `index` for a segment and record upstream state."""
+ new_idx = state.next_index
+ state.next_index += 1
+ state.by_upstream[upstream_index] = _UpstreamBlockState(
+ block_type=block_type,
+ down_index=new_idx,
+ open=True,
+ last_start_block=last_start_block,
+ )
+ return new_idx
+
+
+def transform_native_sse_block_event(
+ event: str,
+ state: NativeSseBlockPolicyState,
+ *,
+ thinking_enabled: bool,
+) -> str | None:
+ """Normalize native Anthropic SSE events and enforce local thinking policy."""
+ event_name, data_text = parse_native_sse_event(event)
+ if not event_name or not data_text:
+ return event
+
+ try:
+ payload = json.loads(data_text)
+ except json.JSONDecodeError:
+ return event
+
+ if event_name == "content_block_start":
+ block = payload.get("content_block")
+ if not isinstance(block, dict):
+ return event
+ block_type = block.get("type")
+ upstream_index = payload.get("index")
+ if not isinstance(upstream_index, int):
+ return event
+ if _should_drop_block_type(block_type, thinking_enabled=thinking_enabled):
+ state.dropped_indexes.add(upstream_index)
+ return None
+
+ if not isinstance(block_type, str):
+ return event
+ prefix = _synthetic_close_other_open_blocks(state, upstream_index)
+ stored = copy.deepcopy(block)
+ new_idx = _allocate_new_segment(
+ state,
+ upstream_index,
+ block_type=block_type,
+ last_start_block=stored,
+ )
+ payload["index"] = new_idx
+ return prefix + format_native_sse_event(event_name, json.dumps(payload))
+
+ if event_name == "content_block_delta":
+ delta = payload.get("delta")
+ if not isinstance(delta, dict):
+ return event
+ delta_type = delta.get("type")
+ upstream_index = payload.get("index")
+ if not isinstance(upstream_index, int):
+ return event
+ if upstream_index in state.dropped_indexes:
+ return None
+ if _should_drop_block_type(delta_type, thinking_enabled=thinking_enabled):
+ return None
+
+ block_kind = _delta_type_to_block_kind(delta_type)
+ if block_kind is None:
+ return event
+
+ seg = state.by_upstream.get(upstream_index)
+ if seg and seg.open:
+ payload["index"] = seg.down_index
+ return format_native_sse_event(event_name, json.dumps(payload))
+
+ if seg is not None and not seg.open:
+ # More deltas for an upstream block after a synthetic (or other) close:
+ # reopen with a new downstream `index` and emit a synthetic `content_block_start` first.
+ state.pending_suppressed_stops.discard(upstream_index)
+ carry = seg.last_start_block
+ new_idx = _allocate_new_segment(
+ state,
+ upstream_index,
+ block_type=block_kind,
+ last_start_block=carry,
+ )
+ stored_tool = (
+ carry
+ if isinstance(carry, dict) and carry.get("type") == "tool_use"
+ else None
+ )
+ start_payload = {
+ "type": "content_block_start",
+ "index": new_idx,
+ "content_block": _synthetic_start_content_block(
+ block_kind,
+ upstream_index=upstream_index,
+ stored_tool_block=stored_tool,
+ ),
+ }
+ prefix = format_native_sse_event(
+ "content_block_start", json.dumps(start_payload)
+ )
+ payload["index"] = new_idx
+ return prefix + format_native_sse_event(event_name, json.dumps(payload))
+
+ # Delta with no prior `content_block_start` in this stream
+ if block_kind in ("text", "tool_use"):
+ synthetic_block = _synthetic_start_content_block(
+ block_kind,
+ upstream_index=upstream_index,
+ )
+ new_idx = _allocate_new_segment(
+ state,
+ upstream_index,
+ block_type=block_kind,
+ last_start_block=copy.deepcopy(synthetic_block),
+ )
+ start_payload = {
+ "type": "content_block_start",
+ "index": new_idx,
+ "content_block": synthetic_block,
+ }
+ prefix = format_native_sse_event(
+ "content_block_start", json.dumps(start_payload)
+ )
+ payload["index"] = new_idx
+ return prefix + format_native_sse_event(event_name, json.dumps(payload))
+ # thinking: pass through raw (unusual upstream shape)
+ return event
+
+ if event_name == "content_block_stop":
+ upstream_index = payload.get("index")
+ if not isinstance(upstream_index, int):
+ return event
+ if upstream_index in state.dropped_indexes:
+ return None
+ if upstream_index in state.pending_suppressed_stops:
+ state.pending_suppressed_stops.discard(upstream_index)
+ return None
+
+ seg = state.by_upstream.get(upstream_index)
+ if seg is not None and seg.open:
+ payload["index"] = seg.down_index
+ seg.open = False
+ return format_native_sse_event(event_name, json.dumps(payload))
+ if seg is not None:
+ # Spurious or duplicate `content_block_stop` for a closed block.
+ return None
+ if not thinking_enabled:
+ return None
+ return event
+
+ return event
diff --git a/core/anthropic/provider_stream_error.py b/core/anthropic/provider_stream_error.py
new file mode 100644
index 0000000..5534c50
--- /dev/null
+++ b/core/anthropic/provider_stream_error.py
@@ -0,0 +1,34 @@
+"""Canonical Anthropic-style SSE sequence for provider-side streaming errors."""
+
+from __future__ import annotations
+
+import uuid
+from collections.abc import Iterator
+from typing import Any
+
+from core.anthropic.sse import SSEBuilder
+
+
+def iter_provider_stream_error_sse_events(
+ *,
+ request: Any,
+ input_tokens: int,
+ error_message: str,
+ sent_any_event: bool,
+ log_raw_sse_events: bool,
+ message_id: str | None = None,
+) -> Iterator[str]:
+ """Yield message_start (if needed), a text block with the error, then message_delta/stop."""
+ mid = message_id or f"msg_{uuid.uuid4()}"
+ model = getattr(request, "model", "") or ""
+ sse = SSEBuilder(
+ mid,
+ model,
+ input_tokens,
+ log_raw_events=log_raw_sse_events,
+ )
+ if not sent_any_event:
+ yield sse.message_start()
+ yield from sse.emit_error(error_message)
+ yield sse.message_delta("end_turn", 1)
+ yield sse.message_stop()
diff --git a/core/anthropic/server_tool_sse.py b/core/anthropic/server_tool_sse.py
new file mode 100644
index 0000000..1c8727d
--- /dev/null
+++ b/core/anthropic/server_tool_sse.py
@@ -0,0 +1,14 @@
+"""SSE content_block ``type`` values for Anthropic web server tools (local handlers).
+
+Shared by :mod:`api.web_tools` and stream contract tests to avoid drift.
+"""
+
+from __future__ import annotations
+
+from typing import Final
+
+SERVER_TOOL_USE: Final = "server_tool_use"
+WEB_SEARCH_TOOL_RESULT: Final = "web_search_tool_result"
+WEB_FETCH_TOOL_RESULT: Final = "web_fetch_tool_result"
+WEB_SEARCH_TOOL_RESULT_ERROR: Final = "web_search_tool_result_error"
+WEB_FETCH_TOOL_ERROR: Final = "web_fetch_tool_error"
diff --git a/core/anthropic/sse.py b/core/anthropic/sse.py
index 7998839..3bf5ebe 100644
--- a/core/anthropic/sse.py
+++ b/core/anthropic/sse.py
@@ -1,5 +1,6 @@
"""SSE event builder for Anthropic-format streaming responses."""
+import hashlib
import json
from collections.abc import Iterator
from dataclasses import dataclass, field
@@ -14,6 +15,13 @@ except Exception:
ENCODER = None
+# Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy.
+ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = {
+ "X-Accel-Buffering": "no",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+}
+
STOP_REASON_MAP = {
"stop": "end_turn",
"length": "max_tokens",
@@ -29,6 +37,11 @@ def map_stop_reason(openai_reason: str | None) -> str:
)
+def format_sse_event(event_type: str, data: dict) -> str:
+ """Format one Anthropic-style SSE event (no logging)."""
+ return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
+
+
@dataclass
class ToolCallState:
"""State for a single streaming tool call."""
@@ -40,6 +53,7 @@ class ToolCallState:
started: bool = False
task_arg_buffer: str = ""
task_args_emitted: bool = False
+ pre_start_args: str = ""
@dataclass
@@ -58,7 +72,25 @@ class ContentBlockManager:
self.next_index += 1
return idx
+ def ensure_tool_state(self, index: int) -> ToolCallState:
+ """Create tool stream state for ``index`` when the first tool delta arrives."""
+ if index not in self.tool_states:
+ self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="")
+ return self.tool_states[index]
+
+ def set_stream_tool_id(self, index: int, tool_id: str | None) -> None:
+ """Record OpenAI tool call id before ``content_block_start`` (split-stream providers)."""
+ if not tool_id:
+ return
+ state = self.ensure_tool_state(index)
+ state.tool_id = str(tool_id)
+
def register_tool_name(self, index: int, name: str) -> None:
+ """Record tool name fragments as they arrive from chunked OpenAI streams.
+
+ Names may be split across deltas; later chunks can extend (``ab`` + ``c``)
+ or repeat prefixes, so we merge conservatively.
+ """
if index not in self.tool_states:
self.tool_states[index] = ToolCallState(
block_index=-1, tool_id="", name=name
@@ -82,8 +114,7 @@ class ContentBlockManager:
except Exception:
return None
- if args_json.get("run_in_background") is not False:
- args_json["run_in_background"] = False
+ _normalize_task_run_in_background(args_json)
state.task_args_emitted = True
state.task_arg_buffer = ""
@@ -98,16 +129,17 @@ class ContentBlockManager:
out = "{}"
try:
args_json = json.loads(state.task_arg_buffer)
- if args_json.get("run_in_background") is not False:
- args_json["run_in_background"] = False
+ _normalize_task_run_in_background(args_json)
out = json.dumps(args_json)
- except Exception as e:
- prefix = state.task_arg_buffer[:120]
+ except (json.JSONDecodeError, TypeError, ValueError) as e:
+ digest = hashlib.sha256(
+ state.task_arg_buffer.encode("utf-8", errors="replace")
+ ).hexdigest()[:16]
logger.warning(
- "Task args invalid JSON (id={} len={} prefix={!r}): {}",
+ "Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}",
state.tool_id or "unknown",
len(state.task_arg_buffer),
- prefix,
+ digest,
e,
)
@@ -117,20 +149,41 @@ class ContentBlockManager:
return results
+def _normalize_task_run_in_background(args_json: dict) -> None:
+ """Force Claude Code Task subagents to run in foreground (single shared rule)."""
+ if args_json.get("run_in_background") is not False:
+ args_json["run_in_background"] = False
+
+
class SSEBuilder:
"""Builder for Anthropic SSE streaming events."""
- def __init__(self, message_id: str, model: str, input_tokens: int = 0):
+ def __init__(
+ self,
+ message_id: str,
+ model: str,
+ input_tokens: int = 0,
+ *,
+ log_raw_events: bool = False,
+ ):
self.message_id = message_id
self.model = model
self.input_tokens = input_tokens
+ self._log_raw_events = log_raw_events
self.blocks = ContentBlockManager()
self._accumulated_text_parts: list[str] = []
self._accumulated_reasoning_parts: list[str] = []
def _format_event(self, event_type: str, data: dict) -> str:
- event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
- logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
+ event_str = format_sse_event(event_type, data)
+ if self._log_raw_events:
+ logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
+ else:
+ logger.debug(
+ "SSE_EVENT: event_type={} serialized_bytes={}",
+ event_type,
+ len(event_str.encode("utf-8")),
+ )
return event_str
def message_start(self) -> str:
@@ -289,10 +342,7 @@ class SSEBuilder:
yield self.stop_text_block()
def close_all_blocks(self) -> Iterator[str]:
- if self.blocks.thinking_started:
- yield self.stop_thinking_block()
- if self.blocks.text_started:
- yield self.stop_text_block()
+ yield from self.close_content_blocks()
for tool_index, state in list(self.blocks.tool_states.items()):
if state.started:
yield self.stop_tool_block(tool_index)
diff --git a/core/anthropic/stream_contracts.py b/core/anthropic/stream_contracts.py
index 8058392..ba2b4d8 100644
--- a/core/anthropic/stream_contracts.py
+++ b/core/anthropic/stream_contracts.py
@@ -1,7 +1,6 @@
"""Neutral SSE parsing and Anthropic stream shape assertions.
-Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for
-opt-in smoke scenarios.
+Used by default CI contract tests and by opt-in live smoke scenarios.
"""
from __future__ import annotations
@@ -11,6 +10,36 @@ from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
+from .server_tool_sse import (
+ SERVER_TOOL_USE,
+ WEB_FETCH_TOOL_RESULT,
+ WEB_SEARCH_TOOL_RESULT,
+)
+
+# Content blocks that only use content_block_start/stop (no deltas), including
+# Anthropic server tools and eager text emitted in a single start event.
+_NO_DELTA_BLOCK_KINDS = frozenset(
+ {
+ SERVER_TOOL_USE,
+ WEB_SEARCH_TOOL_RESULT,
+ WEB_FETCH_TOOL_RESULT,
+ "text_eager",
+ "redacted_thinking",
+ }
+)
+
+_ALLOWED_BLOCK_START_TYPES = frozenset(
+ {
+ "text",
+ "thinking",
+ "tool_use",
+ "redacted_thinking",
+ SERVER_TOOL_USE,
+ WEB_SEARCH_TOOL_RESULT,
+ WEB_FETCH_TOOL_RESULT,
+ }
+)
+
@dataclass(frozen=True, slots=True)
class SSEEvent:
@@ -86,35 +115,46 @@ def assert_anthropic_stream_contract(
raise AssertionError(f"unexpected SSE error event: {event.data}")
if event.event == "content_block_start":
- index = _event_index(event)
+ index = event_index(event)
block = event.data.get("content_block", {})
assert isinstance(block, dict), event.data
block_type = str(block.get("type", ""))
- assert block_type in {"text", "thinking", "tool_use"}, event.data
+ assert block_type in _ALLOWED_BLOCK_START_TYPES, event.data
assert index not in open_blocks, f"block {index} started twice"
assert index not in seen_blocks, f"block {index} reused after stop"
- open_blocks[index] = block_type
+ if block_type == "text" and str(block.get("text", "")).strip():
+ storage = "text_eager"
+ else:
+ storage = block_type
+ open_blocks[index] = storage
seen_blocks.add(index)
continue
if event.event == "content_block_delta":
- index = _event_index(event)
+ index = event_index(event)
assert index in open_blocks, f"delta for unopened block {index}"
+ kind = open_blocks[index]
+ assert kind not in _NO_DELTA_BLOCK_KINDS, (
+ f"unexpected delta for start/stop-only block {kind} at index {index}"
+ )
delta = event.data.get("delta", {})
assert isinstance(delta, dict), event.data
delta_type = str(delta.get("type", ""))
+ if kind == "thinking":
+ assert delta_type in (
+ "thinking_delta",
+ "signature_delta",
+ ), f"block {index} is {kind}, got {delta_type}"
+ continue
expected = {
"text": "text_delta",
- "thinking": "thinking_delta",
"tool_use": "input_json_delta",
- }[open_blocks[index]]
- assert delta_type == expected, (
- f"block {index} is {open_blocks[index]}, got {delta_type}"
- )
+ }[kind]
+ assert delta_type == expected, f"block {index} is {kind}, got {delta_type}"
continue
if event.event == "content_block_stop":
- index = _event_index(event)
+ index = event_index(event)
assert index in open_blocks, f"stop for unopened block {index}"
open_blocks.pop(index)
@@ -129,6 +169,12 @@ def event_names(events: list[SSEEvent]) -> list[str]:
def text_content(events: list[SSEEvent]) -> str:
parts: list[str] = []
for event in events:
+ if event.event == "content_block_start":
+ block = event.data.get("content_block", {})
+ if isinstance(block, dict) and block.get("type") == "text":
+ eager = str(block.get("text", ""))
+ if eager:
+ parts.append(eager)
delta = event.data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "text_delta":
parts.append(str(delta.get("text", "")))
@@ -152,7 +198,8 @@ def has_tool_use(events: list[SSEEvent]) -> bool:
return False
-def _event_index(event: SSEEvent) -> int:
+def event_index(event: SSEEvent) -> int:
+ """Return the content block ``index`` field from an SSE payload (strict)."""
value = event.data.get("index")
assert isinstance(value, int), event.data
return value
diff --git a/core/anthropic/tokens.py b/core/anthropic/tokens.py
index dfa1115..081d28b 100644
--- a/core/anthropic/tokens.py
+++ b/core/anthropic/tokens.py
@@ -68,6 +68,27 @@ def get_token_count(
total_tokens += len(ENCODER.encode(json.dumps(content)))
total_tokens += len(ENCODER.encode(str(tool_use_id)))
total_tokens += 8
+ elif b_type in (
+ "server_tool_use",
+ "web_search_tool_result",
+ "web_fetch_tool_result",
+ ):
+ if hasattr(block, "model_dump"):
+ blob: object = block.model_dump()
+ else:
+ blob = block
+ try:
+ total_tokens += len(
+ ENCODER.encode(
+ json.dumps(blob, default=str, ensure_ascii=False)
+ )
+ )
+ except (TypeError, ValueError, OverflowError) as e:
+ logger.debug(
+ "Block encode fallback b_type={} err={}", b_type, e
+ )
+ total_tokens += len(ENCODER.encode(str(blob)))
+ total_tokens += 12
else:
logger.debug(
"Unexpected block type %r, falling back to json/str encoding",
diff --git a/core/rate_limit.py b/core/rate_limit.py
new file mode 100644
index 0000000..8d9bd81
--- /dev/null
+++ b/core/rate_limit.py
@@ -0,0 +1,60 @@
+"""Shared strict sliding-window rate limiting primitives."""
+
+from __future__ import annotations
+
+import asyncio
+import time
+from collections import deque
+
+
+class StrictSlidingWindowLimiter:
+ """Strict sliding window limiter.
+
+ Guarantees: at most ``rate_limit`` acquisitions in any interval of length
+ ``rate_window`` (seconds).
+
+ Implemented as an async context manager so call sites can do::
+
+ async with limiter:
+ ...
+ """
+
+ def __init__(self, rate_limit: int, rate_window: float) -> None:
+ if rate_limit <= 0:
+ raise ValueError("rate_limit must be > 0")
+ if rate_window <= 0:
+ raise ValueError("rate_window must be > 0")
+
+ self._rate_limit = int(rate_limit)
+ self._rate_window = float(rate_window)
+ self._times: deque[float] = deque()
+ self._lock = asyncio.Lock()
+
+ async def acquire(self) -> None:
+ while True:
+ wait_time = 0.0
+ async with self._lock:
+ now = time.monotonic()
+ cutoff = now - self._rate_window
+
+ while self._times and self._times[0] <= cutoff:
+ self._times.popleft()
+
+ if len(self._times) < self._rate_limit:
+ self._times.append(now)
+ return
+
+ oldest = self._times[0]
+ wait_time = max(0.0, (oldest + self._rate_window) - now)
+
+ if wait_time > 0:
+ await asyncio.sleep(wait_time)
+ else:
+ await asyncio.sleep(0)
+
+ async def __aenter__(self) -> StrictSlidingWindowLimiter:
+ await self.acquire()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb) -> bool:
+ return False
diff --git a/messaging/cli_event_constants.py b/messaging/cli_event_constants.py
new file mode 100644
index 0000000..3d531ec
--- /dev/null
+++ b/messaging/cli_event_constants.py
@@ -0,0 +1,67 @@
+"""CLI event types and status-line mapping for transcript / UI updates."""
+
+from collections.abc import Callable
+from typing import Any
+
+# Status message prefixes used to filter our own messages (ignore echo)
+STATUS_MESSAGE_PREFIXES = (
+ "⏳",
+ "💭",
+ "🔧",
+ "✅",
+ "❌",
+ "🚀",
+ "🤖",
+ "📋",
+ "📊",
+ "🔄",
+)
+
+# Event types that update the transcript (frozenset for O(1) membership)
+TRANSCRIPT_EVENT_TYPES = frozenset(
+ {
+ "thinking_start",
+ "thinking_delta",
+ "thinking_chunk",
+ "thinking_stop",
+ "text_start",
+ "text_delta",
+ "text_chunk",
+ "text_stop",
+ "tool_use_start",
+ "tool_use_delta",
+ "tool_use_stop",
+ "tool_use",
+ "tool_result",
+ "block_stop",
+ "error",
+ }
+)
+
+# Event type -> (emoji, label) for status updates (O(1) lookup)
+_EVENT_STATUS_MAP: dict[str, tuple[str, str]] = {
+ "thinking_start": ("🧠", "Claude is thinking..."),
+ "thinking_delta": ("🧠", "Claude is thinking..."),
+ "thinking_chunk": ("🧠", "Claude is thinking..."),
+ "text_start": ("🧠", "Claude is working..."),
+ "text_delta": ("🧠", "Claude is working..."),
+ "text_chunk": ("🧠", "Claude is working..."),
+ "tool_result": ("⏳", "Executing tools..."),
+}
+
+
+def get_status_for_event(
+ ptype: str,
+ parsed: dict[str, Any],
+ format_status_fn: Callable[..., str],
+) -> str | None:
+ """Return status string for event type, or None if no status update needed."""
+ entry = _EVENT_STATUS_MAP.get(ptype)
+ if entry is not None:
+ emoji, label = entry
+ return format_status_fn(emoji, label)
+ if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
+ if parsed.get("name") == "Task":
+ return format_status_fn("🤖", "Subagent working...")
+ return format_status_fn("⏳", "Executing tools...")
+ return None
diff --git a/messaging/event_parser.py b/messaging/event_parser.py
index 2499248..87eb82a 100644
--- a/messaging/event_parser.py
+++ b/messaging/event_parser.py
@@ -9,12 +9,14 @@ from typing import Any
from loguru import logger
-def parse_cli_event(event: Any) -> list[dict]:
+def parse_cli_event(event: Any, *, log_raw_cli: bool = False) -> list[dict]:
"""
Parse a CLI event and return a structured result.
Args:
event: Raw event dictionary from CLI
+ log_raw_cli: When True, log full error text from the CLI. Default is
+ metadata-only (lengths / exit codes) to avoid leaking user content.
Returns:
List of parsed event dicts. Empty list if not recognized.
@@ -140,7 +142,11 @@ def parse_cli_event(event: Any) -> list[dict]:
if etype == "error":
err = event.get("error")
msg = err.get("message") if isinstance(err, dict) else str(err)
- logger.info(f"CLI_PARSER: Parsed error event: {msg}")
+ if log_raw_cli:
+ logger.info("CLI_PARSER: Parsed error event: {}", msg)
+ else:
+ mlen = len(msg) if isinstance(msg, str) else 0
+ logger.info("CLI_PARSER: Parsed error event: message_chars={}", mlen)
return [{"type": "error", "message": msg}]
elif etype == "exit":
code = event.get("code", 0)
@@ -151,7 +157,19 @@ def parse_cli_event(event: Any) -> list[dict]:
else:
# Non-zero exit is an error
error_msg = stderr if stderr else f"Process exited with code {code}"
- logger.warning(f"CLI_PARSER: Error exit (code={code}): {error_msg}")
+ if log_raw_cli:
+ logger.warning(
+ "CLI_PARSER: Error exit (code={}): {}",
+ code,
+ error_msg,
+ )
+ else:
+ em = error_msg if isinstance(error_msg, str) else str(error_msg)
+ logger.warning(
+ "CLI_PARSER: Error exit (code={}): message_chars={}",
+ code,
+ len(em),
+ )
return [
{"type": "error", "message": error_msg},
{"type": "complete", "status": "failed"},
diff --git a/messaging/handler.py b/messaging/handler.py
index 949cafb..7702e11 100644
--- a/messaging/handler.py
+++ b/messaging/handler.py
@@ -7,13 +7,12 @@ Uses tree-based queuing for message ordering.
"""
import asyncio
-import os
-import time
from loguru import logger
-from core.anthropic import get_user_facing_error_message
+from core.anthropic import format_user_error_preview, get_user_facing_error_message
+from .cli_event_constants import STATUS_MESSAGE_PREFIXES
from .command_dispatcher import (
dispatch_command,
message_kind_for_command,
@@ -21,8 +20,10 @@ from .command_dispatcher import (
)
from .event_parser import parse_cli_event
from .models import IncomingMessage
+from .node_event_pipeline import handle_session_info_event, process_parsed_cli_event
from .platforms.base import MessagingPlatform, SessionManagerInterface
from .rendering.profiles import build_rendering_profile
+from .safe_diagnostics import format_exception_for_log
from .session import SessionStore
from .transcript import RenderCtx, TranscriptBuffer
from .trees.queue_manager import (
@@ -31,54 +32,7 @@ from .trees.queue_manager import (
MessageTree,
TreeQueueManager,
)
-
-# Status message prefixes used to filter our own messages (ignore echo)
-STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄")
-
-# Event types that update the transcript (frozenset for O(1) membership)
-TRANSCRIPT_EVENT_TYPES = frozenset(
- {
- "thinking_start",
- "thinking_delta",
- "thinking_chunk",
- "thinking_stop",
- "text_start",
- "text_delta",
- "text_chunk",
- "text_stop",
- "tool_use_start",
- "tool_use_delta",
- "tool_use_stop",
- "tool_use",
- "tool_result",
- "block_stop",
- "error",
- }
-)
-
-# Event type -> (emoji, label) for status updates (O(1) lookup)
-_EVENT_STATUS_MAP = {
- "thinking_start": ("🧠", "Claude is thinking..."),
- "thinking_delta": ("🧠", "Claude is thinking..."),
- "thinking_chunk": ("🧠", "Claude is thinking..."),
- "text_start": ("🧠", "Claude is working..."),
- "text_delta": ("🧠", "Claude is working..."),
- "text_chunk": ("🧠", "Claude is working..."),
- "tool_result": ("⏳", "Executing tools..."),
-}
-
-
-def _get_status_for_event(ptype: str, parsed: dict, format_status_fn) -> str | None:
- """Return status string for event type, or None if no status update needed."""
- entry = _EVENT_STATUS_MAP.get(ptype)
- if entry is not None:
- emoji, label = entry
- return format_status_fn(emoji, label)
- if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
- if parsed.get("name") == "Task":
- return format_status_fn("🤖", "Subagent working...")
- return format_status_fn("⏳", "Executing tools...")
- return None
+from .ui_updates import ThrottledTranscriptEditor
class ClaudeMessageHandler:
@@ -97,10 +51,21 @@ class ClaudeMessageHandler:
platform: MessagingPlatform,
cli_manager: SessionManagerInterface,
session_store: SessionStore,
+ *,
+ debug_platform_edits: bool = False,
+ debug_subagent_stack: bool = False,
+ log_raw_messaging_content: bool = False,
+ log_raw_cli_diagnostics: bool = False,
+ log_messaging_error_details: bool = False,
):
self.platform = platform
self.cli_manager = cli_manager
self.session_store = session_store
+ self._debug_platform_edits = debug_platform_edits
+ self._debug_subagent_stack = debug_subagent_stack
+ self._log_raw_messaging_content = log_raw_messaging_content
+ self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
+ self._log_messaging_error_details = log_messaging_error_details
self._tree_queue = TreeQueueManager(
queue_update_callback=self.update_queue_positions,
node_started_callback=self.mark_node_processing,
@@ -137,16 +102,26 @@ class ClaudeMessageHandler:
Determines if this is a new conversation or reply,
creates/extends the message tree, and queues for processing.
"""
- text_preview = (incoming.text or "")[:80]
- if len(incoming.text or "") > 80:
- text_preview += "..."
- logger.info(
- "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}",
- incoming.chat_id,
- incoming.message_id,
- incoming.reply_to_message_id,
- text_preview,
- )
+ raw = incoming.text or ""
+ if self._log_raw_messaging_content:
+ text_preview = raw[:80]
+ if len(raw) > 80:
+ text_preview += "..."
+ logger.info(
+ "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}",
+ incoming.chat_id,
+ incoming.message_id,
+ incoming.reply_to_message_id,
+ text_preview,
+ )
+ else:
+ logger.info(
+ "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_len={}",
+ incoming.chat_id,
+ incoming.message_id,
+ incoming.reply_to_message_id,
+ len(raw),
+ )
with logger.contextualize(
chat_id=incoming.chat_id, node_id=incoming.message_id
@@ -169,7 +144,12 @@ class ClaudeMessageHandler:
kind=message_kind_for_command(cmd_base),
)
except Exception as e:
- logger.debug(f"Failed to record incoming message_id: {e}")
+ logger.debug(
+ "Failed to record incoming message_id: {}",
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
+ )
if await dispatch_command(self, incoming, cmd_base):
return
@@ -276,7 +256,12 @@ class ClaudeMessageHandler:
try:
queued_ids = await tree.get_queue_snapshot()
except Exception as e:
- logger.warning(f"Failed to read queue snapshot: {e}")
+ logger.warning(
+ "Failed to read queue snapshot: {}",
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
+ )
return
if not queued_ids:
@@ -317,85 +302,11 @@ class ClaudeMessageHandler:
self,
) -> tuple[TranscriptBuffer, RenderCtx]:
"""Create transcript buffer and render context for node processing."""
- transcript = TranscriptBuffer(show_tool_results=False)
- return transcript, self.get_render_ctx()
-
- async def _handle_session_info_event(
- self,
- event_data: dict,
- tree: MessageTree | None,
- node_id: str,
- captured_session_id: str | None,
- temp_session_id: str | None,
- ) -> tuple[str | None, str | None]:
- """Handle session_info event; return updated (captured_session_id, temp_session_id)."""
- if event_data.get("type") != "session_info":
- return captured_session_id, temp_session_id
-
- real_session_id = event_data.get("session_id")
- if not real_session_id or not temp_session_id:
- return captured_session_id, temp_session_id
-
- await self.cli_manager.register_real_session_id(
- temp_session_id, real_session_id
+ transcript = TranscriptBuffer(
+ show_tool_results=False,
+ debug_subagent_stack=self._debug_subagent_stack,
)
- if tree and real_session_id:
- await tree.update_state(
- node_id,
- MessageState.IN_PROGRESS,
- session_id=real_session_id,
- )
- self.session_store.save_tree(tree.root_id, tree.to_dict())
-
- return real_session_id, None
-
- async def _process_parsed_event(
- self,
- parsed: dict,
- transcript: TranscriptBuffer,
- update_ui,
- last_status: str | None,
- had_transcript_events: bool,
- tree: MessageTree | None,
- node_id: str,
- captured_session_id: str | None,
- ) -> tuple[str | None, bool]:
- """Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
- ptype = parsed.get("type") or ""
-
- if ptype in TRANSCRIPT_EVENT_TYPES:
- transcript.apply(parsed)
- had_transcript_events = True
-
- status = _get_status_for_event(ptype, parsed, self.format_status)
- if status is not None:
- await update_ui(status)
- last_status = status
- elif ptype == "block_stop":
- await update_ui(last_status, force=True)
- elif ptype == "complete":
- if not had_transcript_events:
- transcript.apply({"type": "text_chunk", "text": "Done."})
- logger.info("HANDLER: Task complete, updating UI")
- await update_ui(self.format_status("✅", "Complete"), force=True)
- if tree and captured_session_id:
- await tree.update_state(
- node_id,
- MessageState.COMPLETED,
- session_id=captured_session_id,
- )
- self.session_store.save_tree(tree.root_id, tree.to_dict())
- elif ptype == "error":
- error_msg = parsed.get("message", "Unknown error")
- logger.error(f"HANDLER: Error event received: {error_msg}")
- logger.info("HANDLER: Updating UI with error status")
- await update_ui(self.format_status("❌", "Error"), force=True)
- if tree:
- await self._propagate_error_to_children(
- node_id, error_msg, "Parent task failed"
- )
-
- return last_status, had_transcript_events
+ return transcript, self.get_render_ctx()
async def _process_node(
self,
@@ -426,8 +337,6 @@ class ClaudeMessageHandler:
transcript, render_ctx = self._create_transcript_and_render_ctx()
- last_ui_update = 0.0
- last_displayed_text = None
had_transcript_events = False
captured_session_id = None
temp_session_id = None
@@ -439,52 +348,21 @@ class ClaudeMessageHandler:
if parent_session_id:
logger.info(f"Will fork from parent session: {parent_session_id}")
- async def update_ui(status: str | None = None, force: bool = False) -> None:
- nonlocal last_ui_update, last_displayed_text, last_status
- now = time.time()
- if not force and now - last_ui_update < 1.0:
- return
+ editor = ThrottledTranscriptEditor(
+ platform=self.platform,
+ parse_mode=self._parse_mode(),
+ get_limit_chars=self._get_limit_chars,
+ transcript=transcript,
+ render_ctx=render_ctx,
+ node_id=node_id,
+ chat_id=chat_id,
+ status_msg_id=status_msg_id,
+ debug_platform_edits=self._debug_platform_edits,
+ log_messaging_error_details=self._log_messaging_error_details,
+ )
- last_ui_update = now
- if status is not None:
- last_status = status
- try:
- display = transcript.render(
- render_ctx,
- limit_chars=self._get_limit_chars(),
- status=status,
- )
- except Exception as e:
- logger.warning(f"Transcript render failed for node {node_id}: {e}")
- return
- if display and display != last_displayed_text:
- logger.debug(
- "PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
- node_id,
- chat_id,
- status_msg_id,
- bool(force),
- status,
- len(display),
- )
- if os.getenv("DEBUG_PLATFORM_EDITS") == "1":
- logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
- else:
- head = display[:500]
- tail = display[-500:] if len(display) > 500 else ""
- logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head)
- if tail:
- logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail)
- last_displayed_text = display
- try:
- await self.platform.queue_edit_message(
- chat_id,
- status_msg_id,
- display,
- parse_mode=self._parse_mode(),
- )
- except Exception as e:
- logger.warning(f"Failed to update platform for node {node_id}: {e}")
+ async def update_ui(status: str | None = None, force: bool = False) -> None:
+ await editor.update(status, force=force)
try:
try:
@@ -531,20 +409,28 @@ class ClaudeMessageHandler:
(
captured_session_id,
temp_session_id,
- ) = await self._handle_session_info_event(
- event_data, tree, node_id, captured_session_id, temp_session_id
+ ) = await handle_session_info_event(
+ event_data,
+ tree,
+ node_id,
+ captured_session_id,
+ temp_session_id,
+ cli_manager=self.cli_manager,
+ session_store=self.session_store,
)
if event_data.get("type") == "session_info":
continue
- parsed_list = parse_cli_event(event_data)
+ parsed_list = parse_cli_event(
+ event_data, log_raw_cli=self._log_raw_cli_diagnostics
+ )
logger.debug(f"HANDLER: Parsed {len(parsed_list)} events from CLI")
for parsed in parsed_list:
(
last_status,
had_transcript_events,
- ) = await self._process_parsed_event(
+ ) = await process_parsed_cli_event(
parsed,
transcript,
update_ui,
@@ -553,6 +439,10 @@ class ClaudeMessageHandler:
tree,
node_id,
captured_session_id,
+ session_store=self.session_store,
+ format_status=self.format_status,
+ propagate_error_to_children=self._propagate_error_to_children,
+ log_messaging_error_details=self._log_messaging_error_details,
)
except asyncio.CancelledError:
@@ -575,9 +465,12 @@ class ClaudeMessageHandler:
)
except Exception as e:
logger.error(
- f"HANDLER: Task failed with exception: {type(e).__name__}: {e}"
+ "HANDLER: Task failed with exception: {}",
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
)
- error_msg = get_user_facing_error_message(e)[:200]
+ error_msg = format_user_error_preview(e)
transcript.apply({"type": "error", "message": error_msg})
await update_ui(self.format_status("💥", "Task Failed"), force=True)
if tree:
@@ -595,7 +488,13 @@ class ClaudeMessageHandler:
elif temp_session_id:
await self.cli_manager.remove_session(temp_session_id)
except Exception as e:
- logger.debug(f"Failed to remove session for node {node_id}: {e}")
+ logger.debug(
+ "Failed to remove session for node {}: {}",
+ node_id,
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
+ )
async def _propagate_error_to_children(
self,
@@ -691,7 +590,12 @@ class ClaudeMessageHandler:
platform, chat_id, str(msg_id), direction="out", kind=kind
)
except Exception as e:
- logger.debug(f"Failed to record message_id: {e}")
+ logger.debug(
+ "Failed to record message_id: {}",
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
+ )
def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
"""Update status messages and persist tree state for cancelled nodes."""
diff --git a/messaging/limiter.py b/messaging/limiter.py
index 6367e9e..648d4a8 100644
--- a/messaging/limiter.py
+++ b/messaging/limiter.py
@@ -6,65 +6,16 @@ using a strict sliding window algorithm and a task queue.
"""
import asyncio
-import os
-import time
from collections import deque
from collections.abc import Awaitable, Callable
from typing import Any
from loguru import logger
+from config.settings import get_settings
+from core.rate_limit import StrictSlidingWindowLimiter as SlidingWindowLimiter
-class SlidingWindowLimiter:
- """Strict sliding window limiter.
-
- Guarantees: at most `rate_limit` acquisitions in any interval of length
- `rate_window` (seconds).
-
- Implemented as an async context manager so call sites can do:
- async with limiter:
- ...
- """
-
- def __init__(self, rate_limit: int, rate_window: float) -> None:
- if rate_limit <= 0:
- raise ValueError("rate_limit must be > 0")
- if rate_window <= 0:
- raise ValueError("rate_window must be > 0")
-
- self._rate_limit = int(rate_limit)
- self._rate_window = float(rate_window)
- self._times: deque[float] = deque()
- self._lock = asyncio.Lock()
-
- async def acquire(self) -> None:
- while True:
- wait_time = 0.0
- async with self._lock:
- now = time.monotonic()
- cutoff = now - self._rate_window
-
- while self._times and self._times[0] <= cutoff:
- self._times.popleft()
-
- if len(self._times) < self._rate_limit:
- self._times.append(now)
- return
-
- oldest = self._times[0]
- wait_time = max(0.0, (oldest + self._rate_window) - now)
-
- if wait_time > 0:
- await asyncio.sleep(wait_time)
- else:
- await asyncio.sleep(0)
-
- async def __aenter__(self) -> SlidingWindowLimiter:
- await self.acquire()
- return self
-
- async def __aexit__(self, exc_type, exc, tb) -> bool:
- return False
+from .safe_diagnostics import format_exception_for_log
class MessagingRateLimiter:
@@ -82,23 +33,29 @@ class MessagingRateLimiter:
return super().__new__(cls)
@classmethod
- async def get_instance(cls) -> MessagingRateLimiter:
- """Get the singleton instance of the limiter."""
+ async def get_instance(
+ cls,
+ *,
+ rate_limit: int = 1,
+ rate_window: float = 1.0,
+ ) -> MessagingRateLimiter:
+ """Get the singleton instance of the limiter.
+
+ ``rate_limit`` and ``rate_window`` apply only when the singleton is first
+ created. Call :meth:`shutdown_instance` before changing parameters.
+ """
async with cls._lock:
if cls._instance is None:
- cls._instance = cls()
+ cls._instance = cls(rate_limit=rate_limit, rate_window=rate_window)
# Start the background worker (tracked for graceful shutdown).
cls._instance._start_worker()
return cls._instance
- def __init__(self):
+ def __init__(self, *, rate_limit: int, rate_window: float) -> None:
# Prevent double initialization in singleton
if hasattr(self, "_initialized"):
return
- rate_limit = int(os.getenv("MESSAGING_RATE_LIMIT", "1"))
- rate_window = float(os.getenv("MESSAGING_RATE_WINDOW", "2.0"))
-
self.limiter = SlidingWindowLimiter(rate_limit, rate_window)
# Custom queue state - using deque for O(1) popleft
self._queue_list: deque[str] = deque() # Deque of dedup_keys in order
@@ -189,15 +146,27 @@ class MessagingRateLimiter:
asyncio.get_event_loop().time() + wait_secs
)
else:
+ d = get_settings().log_messaging_error_details
logger.error(
- f"Error in limiter worker for key {dedup_key}: {type(e).__name__}: {e}"
+ "Error in limiter worker for key {}: {}",
+ dedup_key,
+ format_exception_for_log(e, log_full_message=d),
)
except asyncio.CancelledError:
break
except Exception as e:
- logger.error(
- f"MessagingRateLimiter worker critical error: {e}", exc_info=True
- )
+ d = get_settings().log_messaging_error_details
+ if d:
+ logger.error(
+ "MessagingRateLimiter worker critical error: {}",
+ e,
+ exc_info=True,
+ )
+ else:
+ logger.error(
+ "MessagingRateLimiter worker critical error: exc_type={}",
+ type(e).__name__,
+ )
await asyncio.sleep(1)
async def shutdown(self, timeout: float = 2.0) -> None:
@@ -223,7 +192,11 @@ class MessagingRateLimiter:
except asyncio.CancelledError:
pass
except Exception as e:
- logger.debug(f"MessagingRateLimiter worker shutdown error: {e}")
+ d = get_settings().log_messaging_error_details
+ logger.debug(
+ "MessagingRateLimiter worker shutdown error: {}",
+ format_exception_for_log(e, log_full_message=d),
+ )
finally:
self._worker_task = None
@@ -296,14 +269,29 @@ class MessagingRateLimiter:
x in error_msg for x in ["connect", "timeout", "broken"]
):
wait = 2**attempt
- logger.warning(
- f"Limiter fire_and_forget transient error (attempt {attempt + 1}): {e}. Retrying in {wait}s..."
- )
+ d = get_settings().log_messaging_error_details
+ if d:
+ logger.warning(
+ "Limiter fire_and_forget transient error (attempt {}): {}. Retrying in {}s...",
+ attempt + 1,
+ e,
+ wait,
+ )
+ else:
+ logger.warning(
+ "Limiter fire_and_forget transient error (attempt {}): exc_type={}. Retrying in {}s...",
+ attempt + 1,
+ type(e).__name__,
+ wait,
+ )
await asyncio.sleep(wait)
continue
+ d = get_settings().log_messaging_error_details
logger.error(
- f"Final error in fire_and_forget for key {dedup_key}: {type(e).__name__}: {e}"
+ "Final error in fire_and_forget for key {}: {}",
+ dedup_key,
+ format_exception_for_log(e, log_full_message=d),
)
if not future.done():
future.set_exception(e)
diff --git a/messaging/node_event_pipeline.py b/messaging/node_event_pipeline.py
new file mode 100644
index 0000000..6f233e2
--- /dev/null
+++ b/messaging/node_event_pipeline.py
@@ -0,0 +1,103 @@
+"""CLI event handling for a single queued node (transcript + session + errors)."""
+
+from __future__ import annotations
+
+from collections.abc import Awaitable, Callable
+from typing import Any
+
+from loguru import logger
+
+from .cli_event_constants import TRANSCRIPT_EVENT_TYPES, get_status_for_event
+from .platforms.base import SessionManagerInterface
+from .safe_diagnostics import text_len_hint
+from .session import SessionStore
+from .transcript import TranscriptBuffer
+from .trees.queue_manager import MessageState, MessageTree
+
+
+async def handle_session_info_event(
+ event_data: dict[str, Any],
+ tree: MessageTree | None,
+ node_id: str,
+ captured_session_id: str | None,
+ temp_session_id: str | None,
+ *,
+ cli_manager: SessionManagerInterface,
+ session_store: SessionStore,
+) -> tuple[str | None, str | None]:
+ """Handle session_info event; return updated (captured_session_id, temp_session_id)."""
+ if event_data.get("type") != "session_info":
+ return captured_session_id, temp_session_id
+
+ real_session_id = event_data.get("session_id")
+ if not real_session_id or not temp_session_id:
+ return captured_session_id, temp_session_id
+
+ await cli_manager.register_real_session_id(temp_session_id, real_session_id)
+ if tree and real_session_id:
+ await tree.update_state(
+ node_id,
+ MessageState.IN_PROGRESS,
+ session_id=real_session_id,
+ )
+ session_store.save_tree(tree.root_id, tree.to_dict())
+
+ return real_session_id, None
+
+
+async def process_parsed_cli_event(
+ parsed: dict[str, Any],
+ transcript: TranscriptBuffer,
+ update_ui: Callable[..., Awaitable[None]],
+ last_status: str | None,
+ had_transcript_events: bool,
+ tree: MessageTree | None,
+ node_id: str,
+ captured_session_id: str | None,
+ *,
+ session_store: SessionStore,
+ format_status: Callable[..., str],
+ propagate_error_to_children: Callable[[str, str, str], Awaitable[None]],
+ log_messaging_error_details: bool = False,
+) -> tuple[str | None, bool]:
+ """Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
+ ptype = parsed.get("type") or ""
+
+ if ptype in TRANSCRIPT_EVENT_TYPES:
+ transcript.apply(parsed)
+ had_transcript_events = True
+
+ status = get_status_for_event(ptype, parsed, format_status)
+ if status is not None:
+ await update_ui(status)
+ last_status = status
+ elif ptype == "block_stop":
+ await update_ui(last_status, force=True)
+ elif ptype == "complete":
+ if not had_transcript_events:
+ transcript.apply({"type": "text_chunk", "text": "Done."})
+ logger.info("HANDLER: Task complete, updating UI")
+ await update_ui(format_status("✅", "Complete"), force=True)
+ if tree and captured_session_id:
+ await tree.update_state(
+ node_id,
+ MessageState.COMPLETED,
+ session_id=captured_session_id,
+ )
+ session_store.save_tree(tree.root_id, tree.to_dict())
+ elif ptype == "error":
+ error_msg = parsed.get("message", "Unknown error")
+ if log_messaging_error_details:
+ logger.error("HANDLER: Error event received: {}", error_msg)
+ else:
+ em = error_msg if isinstance(error_msg, str) else str(error_msg)
+ logger.error(
+ "HANDLER: Error event received: message_chars={}",
+ text_len_hint(em),
+ )
+ logger.info("HANDLER: Updating UI with error status")
+ await update_ui(format_status("❌", "Error"), force=True)
+ if tree:
+ await propagate_error_to_children(node_id, error_msg, "Parent task failed")
+
+ return last_status, had_transcript_events
diff --git a/messaging/platforms/discord.py b/messaging/platforms/discord.py
index 201682b..53aa7dc 100644
--- a/messaging/platforms/discord.py
+++ b/messaging/platforms/discord.py
@@ -6,7 +6,6 @@ Implements MessagingPlatform for Discord using discord.py.
import asyncio
import contextlib
-import os
import tempfile
from collections.abc import Awaitable, Callable
from pathlib import Path
@@ -14,7 +13,7 @@ from typing import Any, cast
from loguru import logger
-from core.anthropic import get_user_facing_error_message
+from core.anthropic import format_user_error_preview
from ..models import IncomingMessage
from ..rendering.discord_markdown import format_status_discord
@@ -97,15 +96,18 @@ class DiscordPlatform(MessagingPlatform):
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
+ messaging_rate_limit: int = 1,
+ messaging_rate_window: float = 1.0,
+ log_raw_messaging_content: bool = False,
+ log_api_error_tracebacks: bool = False,
):
if not DISCORD_AVAILABLE:
raise ImportError(
"discord.py is required. Install with: pip install discord.py"
)
- self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
- raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
- self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
+ self.bot_token = bot_token
+ self.allowed_channel_ids = _parse_allowed_channels(allowed_channel_ids)
if not self.bot_token:
logger.warning("DISCORD_BOT_TOKEN not set")
@@ -130,6 +132,10 @@ class DiscordPlatform(MessagingPlatform):
self._voice_note_enabled = voice_note_enabled
self._whisper_model = whisper_model
self._whisper_device = whisper_device
+ self._messaging_rate_limit = messaging_rate_limit
+ self._messaging_rate_window = messaging_rate_window
+ self._log_raw_messaging_content = log_raw_messaging_content
+ self._log_api_error_tracebacks = log_api_error_tracebacks
async def _handle_client_message(self, message: Any) -> None:
"""Adapter entry point used by the internal discord client."""
@@ -234,23 +240,40 @@ class DiscordPlatform(MessagingPlatform):
status_message_id=status_msg_id,
)
- logger.info(
- "DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}",
- channel_id,
- message_id,
- (transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
- )
+ if self._log_raw_messaging_content:
+ logger.info(
+ "DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}",
+ channel_id,
+ message_id,
+ (
+ transcribed[:80] + "..."
+ if len(transcribed) > 80
+ else transcribed
+ ),
+ )
+ else:
+ logger.info(
+ "DISCORD_VOICE: chat_id={} message_id={} transcribed_len={}",
+ channel_id,
+ message_id,
+ len(transcribed),
+ )
await self._message_handler(incoming)
return True
except ValueError as e:
- await message.reply(get_user_facing_error_message(e)[:200])
+ await message.reply(format_user_error_preview(e))
return True
except ImportError as e:
- await message.reply(get_user_facing_error_message(e)[:200])
+ await message.reply(format_user_error_preview(e))
return True
except Exception as e:
- logger.error(f"Voice transcription failed: {e}")
+ if self._log_api_error_tracebacks:
+ logger.error("Voice transcription failed: {}", e)
+ else:
+ logger.error(
+ "Voice transcription failed: exc_type={}", type(e).__name__
+ )
await message.reply(
"Could not transcribe voice note. Please try again or send text."
)
@@ -285,16 +308,26 @@ class DiscordPlatform(MessagingPlatform):
else None
)
- text_preview = (message.content or "")[:80]
- if len(message.content or "") > 80:
- text_preview += "..."
- logger.info(
- "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
- channel_id,
- message_id,
- reply_to,
- text_preview,
- )
+ raw_content = message.content or ""
+ if self._log_raw_messaging_content:
+ text_preview = raw_content[:80]
+ if len(raw_content) > 80:
+ text_preview += "..."
+ logger.info(
+ "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
+ channel_id,
+ message_id,
+ reply_to,
+ text_preview,
+ )
+ else:
+ logger.info(
+ "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_len={}",
+ channel_id,
+ message_id,
+ reply_to,
+ len(raw_content),
+ )
if not self._message_handler:
return
@@ -313,13 +346,14 @@ class DiscordPlatform(MessagingPlatform):
try:
await self._message_handler(incoming)
except Exception as e:
- logger.error(f"Error handling message: {e}")
+ if self._log_api_error_tracebacks:
+ logger.error("Error handling message: {}", e)
+ else:
+ logger.error("Error handling message: exc_type={}", type(e).__name__)
with contextlib.suppress(Exception):
await self.send_message(
channel_id,
- format_status_discord(
- "Error:", get_user_facing_error_message(e)[:200]
- ),
+ format_status_discord("Error:", format_user_error_preview(e)),
reply_to=message_id,
)
@@ -336,7 +370,10 @@ class DiscordPlatform(MessagingPlatform):
from ..limiter import MessagingRateLimiter
- self._limiter = await MessagingRateLimiter.get_instance()
+ self._limiter = await MessagingRateLimiter.get_instance(
+ rate_limit=self._messaging_rate_limit,
+ rate_window=self._messaging_rate_window,
+ )
self._start_task = asyncio.create_task(
self._client.start(self.bot_token),
diff --git a/messaging/platforms/factory.py b/messaging/platforms/factory.py
index b40fe87..772a31d 100644
--- a/messaging/platforms/factory.py
+++ b/messaging/platforms/factory.py
@@ -28,6 +28,10 @@ class MessagingPlatformOptions:
whisper_device: str = "cpu"
hf_token: str = ""
nvidia_nim_api_key: str = ""
+ messaging_rate_limit: int = 1
+ messaging_rate_window: float = 1.0
+ log_raw_messaging_content: bool = False
+ log_api_error_tracebacks: bool = False
def create_messaging_platform(
@@ -64,6 +68,10 @@ def create_messaging_platform(
whisper_device=opts.whisper_device,
hf_token=opts.hf_token,
nvidia_nim_api_key=opts.nvidia_nim_api_key,
+ messaging_rate_limit=opts.messaging_rate_limit,
+ messaging_rate_window=opts.messaging_rate_window,
+ log_raw_messaging_content=opts.log_raw_messaging_content,
+ log_api_error_tracebacks=opts.log_api_error_tracebacks,
)
if platform_type == "discord":
@@ -82,6 +90,10 @@ def create_messaging_platform(
whisper_device=opts.whisper_device,
hf_token=opts.hf_token,
nvidia_nim_api_key=opts.nvidia_nim_api_key,
+ messaging_rate_limit=opts.messaging_rate_limit,
+ messaging_rate_window=opts.messaging_rate_window,
+ log_raw_messaging_content=opts.log_raw_messaging_content,
+ log_api_error_tracebacks=opts.log_api_error_tracebacks,
)
logger.warning(
diff --git a/messaging/platforms/telegram.py b/messaging/platforms/telegram.py
index d131e44..a8f34df 100644
--- a/messaging/platforms/telegram.py
+++ b/messaging/platforms/telegram.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any
from loguru import logger
-from core.anthropic import get_user_facing_error_message
+from core.anthropic import format_user_error_preview
if TYPE_CHECKING:
from telegram import Update
@@ -68,14 +68,18 @@ class TelegramPlatform(MessagingPlatform):
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
+ messaging_rate_limit: int = 1,
+ messaging_rate_window: float = 1.0,
+ log_raw_messaging_content: bool = False,
+ log_api_error_tracebacks: bool = False,
):
if not TELEGRAM_AVAILABLE:
raise ImportError(
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
)
- self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
- self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
+ self.bot_token = bot_token
+ self.allowed_user_id = allowed_user_id
if not self.bot_token:
# We don't raise here to allow instantiation for testing/conditional logic,
@@ -97,6 +101,10 @@ class TelegramPlatform(MessagingPlatform):
self._voice_note_enabled = voice_note_enabled
self._whisper_model = whisper_model
self._whisper_device = whisper_device
+ self._messaging_rate_limit = messaging_rate_limit
+ self._messaging_rate_window = messaging_rate_window
+ self._log_raw_messaging_content = log_raw_messaging_content
+ self._log_api_error_tracebacks = log_api_error_tracebacks
async def _register_pending_voice(
self, chat_id: str, voice_msg_id: str, status_msg_id: str
@@ -172,7 +180,10 @@ class TelegramPlatform(MessagingPlatform):
# Initialize rate limiter
from ..limiter import MessagingRateLimiter
- self._limiter = await MessagingRateLimiter.get_instance()
+ self._limiter = await MessagingRateLimiter.get_instance(
+ rate_limit=self._messaging_rate_limit,
+ rate_window=self._messaging_rate_window,
+ )
# Send startup notification
try:
@@ -187,7 +198,13 @@ class TelegramPlatform(MessagingPlatform):
startup_text,
)
except Exception as e:
- logger.warning(f"Could not send startup message: {e}")
+ if self._log_api_error_tracebacks:
+ logger.warning("Could not send startup message: {}", e)
+ else:
+ logger.warning(
+ "Could not send startup message: exc_type={}",
+ type(e).__name__,
+ )
logger.info("Telegram platform started (Bot API)")
@@ -506,16 +523,26 @@ class TelegramPlatform(MessagingPlatform):
if getattr(update.message, "message_thread_id", None) is not None
else None
)
- text_preview = (update.message.text or "")[:80]
- if len(update.message.text or "") > 80:
- text_preview += "..."
- logger.info(
- "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
- chat_id,
- message_id,
- reply_to,
- text_preview,
- )
+ raw_text = update.message.text or ""
+ if self._log_raw_messaging_content:
+ text_preview = raw_text[:80]
+ if len(raw_text) > 80:
+ text_preview += "..."
+ logger.info(
+ "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
+ chat_id,
+ message_id,
+ reply_to,
+ text_preview,
+ )
+ else:
+ logger.info(
+ "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_len={}",
+ chat_id,
+ message_id,
+ reply_to,
+ len(raw_text),
+ )
if not self._message_handler:
return
@@ -534,11 +561,14 @@ class TelegramPlatform(MessagingPlatform):
try:
await self._message_handler(incoming)
except Exception as e:
- logger.error(f"Error handling message: {e}")
+ if self._log_api_error_tracebacks:
+ logger.error("Error handling message: {}", e)
+ else:
+ logger.error("Error handling message: exc_type={}", type(e).__name__)
with contextlib.suppress(Exception):
await self.send_message(
chat_id,
- f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}",
+ f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(format_user_error_preview(e))}",
reply_to=incoming.message_id,
message_thread_id=thread_id,
parse_mode="MarkdownV2",
@@ -631,20 +661,37 @@ class TelegramPlatform(MessagingPlatform):
status_message_id=status_msg_id,
)
- logger.info(
- "TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}",
- chat_id,
- message_id,
- (transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
- )
+ if self._log_raw_messaging_content:
+ logger.info(
+ "TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}",
+ chat_id,
+ message_id,
+ (
+ transcribed[:80] + "..."
+ if len(transcribed) > 80
+ else transcribed
+ ),
+ )
+ else:
+ logger.info(
+ "TELEGRAM_VOICE: chat_id={} message_id={} transcribed_len={}",
+ chat_id,
+ message_id,
+ len(transcribed),
+ )
await self._message_handler(incoming)
except ValueError as e:
- await update.message.reply_text(get_user_facing_error_message(e)[:200])
+ await update.message.reply_text(format_user_error_preview(e))
except ImportError as e:
- await update.message.reply_text(get_user_facing_error_message(e)[:200])
+ await update.message.reply_text(format_user_error_preview(e))
except Exception as e:
- logger.error(f"Voice transcription failed: {e}")
+ if self._log_api_error_tracebacks:
+ logger.error("Voice transcription failed: {}", e)
+ else:
+ logger.error(
+ "Voice transcription failed: exc_type={}", type(e).__name__
+ )
await update.message.reply_text(
"Could not transcribe voice note. Please try again or send text."
)
diff --git a/messaging/safe_diagnostics.py b/messaging/safe_diagnostics.py
new file mode 100644
index 0000000..add30c5
--- /dev/null
+++ b/messaging/safe_diagnostics.py
@@ -0,0 +1,17 @@
+"""Helpers for redacting user-derived content from log lines."""
+
+from __future__ import annotations
+
+
+def format_exception_for_log(exc: BaseException, *, log_full_message: bool) -> str:
+ """Return exception type and optionally ``str(exc)`` for operator diagnostics."""
+ if log_full_message:
+ return f"{type(exc).__name__}: {exc}"
+ return type(exc).__name__
+
+
+def text_len_hint(text: str | None) -> int:
+ """Length of text for metadata-only logging (0 when missing)."""
+ if not text:
+ return 0
+ return len(text)
diff --git a/messaging/session.py b/messaging/session.py
index a4e4e9d..bbb0cdc 100644
--- a/messaging/session.py
+++ b/messaging/session.py
@@ -5,8 +5,10 @@ Provides persistent storage for mapping platform messages to Claude CLI session
and message trees for conversation continuation.
"""
+import contextlib
import json
import os
+import tempfile
import threading
from datetime import UTC, datetime
from typing import Any
@@ -22,7 +24,12 @@ class SessionStore:
Platform-agnostic: works with any messaging platform.
"""
- def __init__(self, storage_path: str = "sessions.json"):
+ def __init__(
+ self,
+ storage_path: str = "sessions.json",
+ *,
+ message_log_cap: int | None = None,
+ ):
self.storage_path = storage_path
self._lock = threading.Lock()
self._trees: dict[str, dict] = {} # root_id -> tree data
@@ -34,11 +41,7 @@ class SessionStore:
self._dirty = False
self._save_timer: threading.Timer | None = None
self._save_debounce_secs = 0.5
- cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
- try:
- self._message_log_cap: int | None = int(cap_raw) if cap_raw else None
- except ValueError:
- self._message_log_cap = None
+ self._message_log_cap: int | None = message_log_cap
self._load()
def _make_chat_key(self, platform: str, chat_id: str) -> str:
@@ -104,9 +107,22 @@ class SessionStore:
}
def _write_data(self, data: dict) -> None:
- """Write data dict to disk. Must be called WITHOUT holding self._lock."""
- with open(self.storage_path, "w", encoding="utf-8") as f:
- json.dump(data, f, indent=2)
+ """Atomically write data dict to disk. Must be called WITHOUT holding self._lock."""
+ abs_target = os.path.abspath(self.storage_path)
+ dir_name = os.path.dirname(abs_target) or "."
+ fd, tmp_path = tempfile.mkstemp(
+ dir=dir_name, prefix=".sessions.", suffix=".tmp.json"
+ )
+ try:
+ with os.fdopen(fd, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2)
+ f.flush()
+ os.fsync(f.fileno())
+ os.replace(tmp_path, abs_target)
+ except BaseException:
+ with contextlib.suppress(OSError):
+ os.unlink(tmp_path)
+ raise
def _schedule_save(self) -> None:
"""Schedule a debounced save. Caller must hold self._lock."""
diff --git a/messaging/transcript.py b/messaging/transcript.py
index 073f40b..f1a015d 100644
--- a/messaging/transcript.py
+++ b/messaging/transcript.py
@@ -9,7 +9,6 @@ the transcript grows over time and older content must be truncated.
from __future__ import annotations
import json
-import os
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Callable, Iterable
@@ -210,7 +209,12 @@ class RenderCtx:
class TranscriptBuffer:
"""Maintains an ordered, truncatable transcript of events."""
- def __init__(self, *, show_tool_results: bool = True) -> None:
+ def __init__(
+ self,
+ *,
+ show_tool_results: bool = True,
+ debug_subagent_stack: bool = False,
+ ) -> None:
self._segments: list[Segment] = []
self._open_thinking_by_index: dict[int, ThinkingSegment] = {}
self._open_text_by_index: dict[int, TextSegment] = {}
@@ -227,7 +231,7 @@ class TranscriptBuffer:
self._subagent_stack: list[str] = []
# Parallel stack of segments for rendering nested subagents.
self._subagent_segments: list[SubagentSegment] = []
- self._debug_subagent_stack = os.getenv("DEBUG_SUBAGENT_STACK") == "1"
+ self._debug_subagent_stack = debug_subagent_stack
def _in_subagent(self) -> bool:
return bool(self._subagent_stack)
diff --git a/messaging/transcription.py b/messaging/transcription.py
index 389a9be..9a5f01f 100644
--- a/messaging/transcription.py
+++ b/messaging/transcription.py
@@ -5,32 +5,18 @@ Supports:
- NVIDIA NIM: NVIDIA NIM Whisper/Parakeet
"""
-import os
from pathlib import Path
from typing import Any
from loguru import logger
-from config.settings import get_settings
+from providers.nvidia_nim.voice import (
+ transcribe_audio_file as transcribe_nvidia_nim_audio,
+)
# Max file size in bytes (25 MB)
MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024
-# NVIDIA NIM Whisper model mapping: (function_id, language_code)
-_NIM_MODEL_MAP: dict[str, tuple[str, str]] = {
- "nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"),
- "nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"),
- "nvidia/parakeet-ctc-0.6b-es": ("None", "es-US"),
- "nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"),
- "nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"),
- "nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"),
- "nvidia/parakeet-1.1b-rnnt-multilingual-asr": (
- "71203149-d3b7-4460-8231-1be2543a1fca",
- "",
- ),
- "openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"),
-}
-
# Short model names -> full Hugging Face model IDs (for local Whisper)
_MODEL_MAP: dict[str, str] = {
"tiny": "openai/whisper-tiny",
@@ -51,22 +37,19 @@ def _resolve_model_id(whisper_model: str) -> str:
return _MODEL_MAP.get(whisper_model, whisper_model)
-def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any:
+def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any:
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
global _pipeline_cache
if device not in ("cpu", "cuda"):
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
- resolved_token = (
- hf_token if hf_token is not None else get_settings().hf_token
- ) or ""
+ resolved_token = hf_token or ""
cache_key = (model_id, device, resolved_token)
if cache_key not in _pipeline_cache:
try:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
- if resolved_token:
- os.environ["HF_TOKEN"] = resolved_token
+ hf_auth_token = resolved_token or None
use_cuda = device == "cuda" and torch.cuda.is_available()
pipe_device = "cuda:0" if use_cuda else "cpu"
@@ -77,9 +60,10 @@ def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> An
dtype=model_dtype,
low_cpu_mem_usage=True,
attn_implementation="sdpa",
+ token=hf_auth_token,
)
model = model.to(pipe_device)
- processor = AutoProcessor.from_pretrained(model_id)
+ processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token)
pipe = pipeline(
"automatic-speech-recognition",
@@ -119,7 +103,7 @@ def transcribe_audio(
file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported)
mime_type: MIME type of the audio (e.g. "audio/ogg")
whisper_model: Model ID or short name (local) or NVIDIA NIM model
- whisper_device: "cpu" | "cuda" | "nvidia_nim" (defaults to WHISPER_DEVICE env var)
+ whisper_device: "cpu" | "cuda" | "nvidia_nim"
Returns:
Transcribed text
@@ -140,8 +124,8 @@ def transcribe_audio(
)
if whisper_device == "nvidia_nim":
- return _transcribe_nim(
- file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key
+ return transcribe_nvidia_nim_audio(
+ file_path, whisper_model, api_key=nvidia_nim_api_key
)
return _transcribe_local(
file_path, whisper_model, whisper_device, hf_token=hf_token
@@ -169,8 +153,7 @@ def _transcribe_local(
) -> str:
"""Transcribe using transformers Whisper pipeline."""
model_id = _resolve_model_id(whisper_model)
- token: str | None = hf_token if hf_token else None
- pipe = _get_pipeline(model_id, whisper_device, hf_token=token)
+ pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token)
audio = _load_audio(file_path)
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
text = result.get("text", "") or ""
@@ -179,65 +162,3 @@ def _transcribe_local(
result_text = text.strip()
logger.debug(f"Local transcription: {len(result_text)} chars")
return result_text or "(no speech detected)"
-
-
-def _transcribe_nim(
- file_path: Path, model: str, *, nvidia_nim_api_key: str = ""
-) -> str:
- """Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
- try:
- import riva.client
- except ImportError as e:
- raise ImportError(
- "NVIDIA NIM transcription requires the voice extra. "
- "Install with: uv sync --extra voice"
- ) from e
-
- api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key
-
- # Look up function ID and language code from model mapping
- model_config = _NIM_MODEL_MAP.get(model)
- if not model_config:
- raise ValueError(
- f"No NVIDIA NIM config found for model: {model}. "
- f"Supported models: {', '.join(_NIM_MODEL_MAP.keys())}"
- )
- function_id, language_code = model_config
-
- # Riva server configuration
- server = "grpc.nvcf.nvidia.com:443"
-
- # Auth with SSL and metadata
- auth = riva.client.Auth(
- use_ssl=True,
- uri=server,
- metadata_args=[
- ["function-id", function_id],
- ["authorization", f"Bearer {api_key}"],
- ],
- )
-
- asr_service = riva.client.ASRService(auth)
-
- # Configure recognition - language_code from model config
- config = riva.client.RecognitionConfig(
- language_code=language_code,
- max_alternatives=1,
- verbatim_transcripts=True,
- )
-
- # Read audio file
- with open(file_path, "rb") as f:
- data = f.read()
-
- # Perform offline recognition
- response = asr_service.offline_recognize(data, config)
-
- # Extract text from response - use getattr for safe attribute access
- transcript = ""
- results = getattr(response, "results", None)
- if results and results[0].alternatives:
- transcript = results[0].alternatives[0].transcript
-
- logger.debug(f"NIM transcription: {len(transcript)} chars")
- return transcript or "(no speech detected)"
diff --git a/messaging/trees/processor.py b/messaging/trees/processor.py
deleted file mode 100644
index 5fdcca6..0000000
--- a/messaging/trees/processor.py
+++ /dev/null
@@ -1,165 +0,0 @@
-"""Async queue processor for message trees.
-
-Handles the async processing lifecycle of tree nodes.
-"""
-
-import asyncio
-from collections.abc import Awaitable, Callable
-
-from loguru import logger
-
-from core.anthropic import get_user_facing_error_message
-
-from .data import MessageNode, MessageState, MessageTree
-
-
-class TreeQueueProcessor:
- """
- Handles async queue processing for a single tree.
-
- Separates the async processing logic from the data management.
- """
-
- def __init__(
- self,
- queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
- node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
- | None = None,
- ):
- self._queue_update_callback = queue_update_callback
- self._node_started_callback = node_started_callback
-
- def set_queue_update_callback(
- self,
- queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
- ) -> None:
- """Update the callback used to refresh queue positions."""
- self._queue_update_callback = queue_update_callback
-
- def set_node_started_callback(
- self,
- node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
- ) -> None:
- """Update the callback used when a queued node starts processing."""
- self._node_started_callback = node_started_callback
-
- async def _notify_queue_updated(self, tree: MessageTree) -> None:
- """Invoke queue update callback if set."""
- if not self._queue_update_callback:
- return
- try:
- await self._queue_update_callback(tree)
- except Exception as e:
- logger.warning(f"Queue update callback failed: {e}")
-
- async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
- """Invoke node started callback if set."""
- if not self._node_started_callback:
- return
- try:
- await self._node_started_callback(tree, node_id)
- except Exception as e:
- logger.warning(f"Node started callback failed: {e}")
-
- async def process_node(
- self,
- tree: MessageTree,
- node: MessageNode,
- processor: Callable[[str, MessageNode], Awaitable[None]],
- ) -> None:
- """Process a single node and then check the queue."""
- # Skip if already in terminal state (e.g. from error propagation)
- if node.state == MessageState.ERROR:
- logger.info(
- f"Skipping node {node.node_id} as it is already in state {node.state}"
- )
- # Still need to check for next messages
- await self._process_next(tree, processor)
- return
-
- try:
- await processor(node.node_id, node)
- except asyncio.CancelledError:
- logger.info(f"Task for node {node.node_id} was cancelled")
- raise
- except Exception as e:
- logger.error(f"Error processing node {node.node_id}: {e}")
- await tree.update_state(
- node.node_id,
- MessageState.ERROR,
- error_message=get_user_facing_error_message(e),
- )
- finally:
- async with tree.with_lock():
- tree.clear_current_node()
- # Check if there are more messages in the queue
- await self._process_next(tree, processor)
-
- async def _process_next(
- self,
- tree: MessageTree,
- processor: Callable[[str, MessageNode], Awaitable[None]],
- ) -> None:
- """Process the next message in queue, if any."""
- next_node_id = None
- node = None
- async with tree.with_lock():
- next_node_id = await tree.dequeue()
-
- if not next_node_id:
- tree.set_processing_state(None, False)
- logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
- return
-
- tree.set_processing_state(next_node_id, True)
- logger.info(f"Processing next queued node {next_node_id}")
-
- # Process next node (outside lock)
- node = tree.get_node(next_node_id)
- if node:
- tree.set_current_task(
- asyncio.create_task(self.process_node(tree, node, processor))
- )
-
- # Notify that this node has started processing and refresh queue positions.
- if next_node_id:
- await self._notify_node_started(tree, next_node_id)
- await self._notify_queue_updated(tree)
-
- async def enqueue_and_start(
- self,
- tree: MessageTree,
- node_id: str,
- processor: Callable[[str, MessageNode], Awaitable[None]],
- ) -> bool:
- """
- Enqueue a node or start processing immediately.
-
- Args:
- tree: The message tree
- node_id: Node to process
- processor: Async function to process the node
-
- Returns:
- True if queued, False if processing immediately
- """
- async with tree.with_lock():
- if tree.is_processing:
- tree.put_queue_unlocked(node_id)
- queue_size = tree.get_queue_size()
- logger.info(f"Queued node {node_id}, position {queue_size}")
- return True
- else:
- tree.set_processing_state(node_id, True)
-
- # Process outside the lock
- node = tree.get_node(node_id)
- if node:
- tree.set_current_task(
- asyncio.create_task(self.process_node(tree, node, processor))
- )
- return False
-
- def cancel_current(self, tree: MessageTree) -> bool:
- """Cancel the currently running task in a tree."""
- return tree.cancel_current_task()
diff --git a/messaging/trees/queue_manager.py b/messaging/trees/queue_manager.py
index fbf7029..2c0cecd 100644
--- a/messaging/trees/queue_manager.py
+++ b/messaging/trees/queue_manager.py
@@ -1,30 +1,345 @@
-"""Tree-Based Message Queue Manager - Refactored.
-
-Coordinates data access, async processing, and error handling.
-Uses TreeRepository for data, TreeQueueProcessor for async logic.
-"""
+"""Tree-based message queue: index, async node processor, and public manager API."""
import asyncio
from collections.abc import Awaitable, Callable
from loguru import logger
+from config.settings import get_settings
+from core.anthropic import get_user_facing_error_message
+
from ..models import IncomingMessage
+from ..safe_diagnostics import format_exception_for_log
from .data import MessageNode, MessageState, MessageTree
-from .processor import TreeQueueProcessor
-from .repository import TreeRepository
+
+
+class TreeRepository:
+ """
+ In-memory index of trees and node-to-root mappings.
+
+ Used only by :class:`TreeQueueManager`; kept as a named type for tests.
+ """
+
+ def __init__(self) -> None:
+ self._trees: dict[str, MessageTree] = {} # root_id -> tree
+ self._node_to_tree: dict[str, str] = {} # node_id -> root_id
+
+ def get_tree(self, root_id: str) -> MessageTree | None:
+ """Get a tree by its root ID."""
+ return self._trees.get(root_id)
+
+ def get_tree_for_node(self, node_id: str) -> MessageTree | None:
+ """Get the tree containing a given node."""
+ root_id = self._node_to_tree.get(node_id)
+ if not root_id:
+ return None
+ return self._trees.get(root_id)
+
+ def get_node(self, node_id: str) -> MessageNode | None:
+ """Get a node from any tree."""
+ tree = self.get_tree_for_node(node_id)
+ return tree.get_node(node_id) if tree else None
+
+ def add_tree(self, root_id: str, tree: MessageTree) -> None:
+ """Add a new tree to the repository."""
+ self._trees[root_id] = tree
+ self._node_to_tree[root_id] = root_id
+ logger.debug("TREE_REPO: add_tree root_id={}", root_id)
+
+ def register_node(self, node_id: str, root_id: str) -> None:
+ """Register a node ID to a tree."""
+ self._node_to_tree[node_id] = root_id
+ logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
+
+ def has_node(self, node_id: str) -> bool:
+ """Check if a node is registered in any tree."""
+ return node_id in self._node_to_tree
+
+ def tree_count(self) -> int:
+ """Get the number of trees in the repository."""
+ return len(self._trees)
+
+ def is_tree_busy(self, root_id: str) -> bool:
+ """Check if a tree is currently processing."""
+ tree = self._trees.get(root_id)
+ return tree.is_processing if tree else False
+
+ def is_node_tree_busy(self, node_id: str) -> bool:
+ """Check if the tree containing a node is busy."""
+ tree = self.get_tree_for_node(node_id)
+ return tree.is_processing if tree else False
+
+ def get_queue_size(self, node_id: str) -> int:
+ """Get queue size for the tree containing a node."""
+ tree = self.get_tree_for_node(node_id)
+ return tree.get_queue_size() if tree else 0
+
+ def resolve_parent_node_id(self, msg_id: str) -> str | None:
+ """
+ Resolve a message ID to the actual parent node ID.
+
+ Handles the case where msg_id is a status message ID
+ (which maps to the tree but isn't an actual node).
+
+ Returns:
+ The node_id to use as parent, or None if not found
+ """
+ tree = self.get_tree_for_node(msg_id)
+ if not tree:
+ return None
+
+ if tree.has_node(msg_id):
+ return msg_id
+
+ node = tree.find_node_by_status_message(msg_id)
+ if node:
+ return node.node_id
+
+ return None
+
+ def get_pending_children(self, node_id: str) -> list[MessageNode]:
+ """
+ Get all pending child nodes (recursively) of a given node.
+
+ Used for error propagation - when a node fails, its pending
+ children should also be marked as failed.
+ """
+ tree = self.get_tree_for_node(node_id)
+ if not tree:
+ return []
+
+ pending: list[MessageNode] = []
+ stack = [node_id]
+
+ while stack:
+ current_id = stack.pop()
+ node = tree.get_node(current_id)
+ if not node:
+ continue
+ for child_id in node.children_ids:
+ child = tree.get_node(child_id)
+ if child and child.state == MessageState.PENDING:
+ pending.append(child)
+ stack.append(child_id)
+
+ return pending
+
+ def all_trees(self) -> list[MessageTree]:
+ """Get all trees in the repository."""
+ return list(self._trees.values())
+
+ def tree_ids(self) -> list[str]:
+ """Get all tree root IDs."""
+ return list(self._trees.keys())
+
+ def unregister_nodes(self, node_ids: list[str]) -> None:
+ """Remove node IDs from the node-to-tree mapping."""
+ for nid in node_ids:
+ self._node_to_tree.pop(nid, None)
+
+ def remove_tree(self, root_id: str) -> MessageTree | None:
+ """
+ Remove a tree and all its node mappings from the repository.
+
+ Returns:
+ The removed tree, or None if not found.
+ """
+ tree = self._trees.pop(root_id, None)
+ if not tree:
+ return None
+ for node in tree.all_nodes():
+ self._node_to_tree.pop(node.node_id, None)
+ logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
+ return tree
+
+ def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
+ """Get all message IDs (incoming + status) for a given platform/chat."""
+ msg_ids: set[str] = set()
+ for tree in self._trees.values():
+ for node in tree.all_nodes():
+ if str(node.incoming.platform) == str(platform) and str(
+ node.incoming.chat_id
+ ) == str(chat_id):
+ if node.incoming.message_id is not None:
+ msg_ids.add(str(node.incoming.message_id))
+ if node.status_message_id:
+ msg_ids.add(str(node.status_message_id))
+ return msg_ids
+
+ def to_dict(self) -> dict:
+ """Serialize all trees."""
+ return {
+ "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
+ "node_to_tree": self._node_to_tree.copy(),
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict) -> TreeRepository:
+ """Deserialize from dictionary."""
+ repo = cls()
+ for root_id, tree_data in data.get("trees", {}).items():
+ repo._trees[root_id] = MessageTree.from_dict(tree_data)
+ repo._node_to_tree = data.get("node_to_tree", {})
+ return repo
+
+
+class TreeQueueProcessor:
+ """
+ Per-tree async queue processing (one manager owns one processor instance).
+ """
+
+ def __init__(
+ self,
+ queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
+ node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
+ | None = None,
+ ) -> None:
+ self._queue_update_callback = queue_update_callback
+ self._node_started_callback = node_started_callback
+
+ def set_queue_update_callback(
+ self,
+ queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
+ ) -> None:
+ """Update the callback used to refresh queue positions."""
+ self._queue_update_callback = queue_update_callback
+
+ def set_node_started_callback(
+ self,
+ node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
+ ) -> None:
+ """Update the callback used when a queued node starts processing."""
+ self._node_started_callback = node_started_callback
+
+ async def _notify_queue_updated(self, tree: MessageTree) -> None:
+ """Invoke queue update callback if set."""
+ if not self._queue_update_callback:
+ return
+ try:
+ await self._queue_update_callback(tree)
+ except Exception as e:
+ d = get_settings().log_messaging_error_details
+ logger.warning(
+ "Queue update callback failed: {}",
+ format_exception_for_log(e, log_full_message=d),
+ )
+
+ async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
+ """Invoke node started callback if set."""
+ if not self._node_started_callback:
+ return
+ try:
+ await self._node_started_callback(tree, node_id)
+ except Exception as e:
+ d = get_settings().log_messaging_error_details
+ logger.warning(
+ "Node started callback failed: {}",
+ format_exception_for_log(e, log_full_message=d),
+ )
+
+ async def process_node(
+ self,
+ tree: MessageTree,
+ node: MessageNode,
+ processor: Callable[[str, MessageNode], Awaitable[None]],
+ ) -> None:
+ """Process a single node and then check the queue."""
+ if node.state == MessageState.ERROR:
+ logger.info(
+ f"Skipping node {node.node_id} as it is already in state {node.state}"
+ )
+ await self._process_next(tree, processor)
+ return
+
+ try:
+ await processor(node.node_id, node)
+ except asyncio.CancelledError:
+ logger.info(f"Task for node {node.node_id} was cancelled")
+ raise
+ except Exception as e:
+ d = get_settings().log_messaging_error_details
+ logger.error(
+ "Error processing node {}: {}",
+ node.node_id,
+ format_exception_for_log(e, log_full_message=d),
+ )
+ await tree.update_state(
+ node.node_id,
+ MessageState.ERROR,
+ error_message=get_user_facing_error_message(e),
+ )
+ finally:
+ async with tree.with_lock():
+ tree.clear_current_node()
+ await self._process_next(tree, processor)
+
+ async def _process_next(
+ self,
+ tree: MessageTree,
+ processor: Callable[[str, MessageNode], Awaitable[None]],
+ ) -> None:
+ """Process the next message in queue, if any."""
+ next_node_id = None
+ async with tree.with_lock():
+ next_node_id = await tree.dequeue()
+
+ if not next_node_id:
+ tree.set_processing_state(None, False)
+ logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
+ return
+
+ tree.set_processing_state(next_node_id, True)
+ logger.info(f"Processing next queued node {next_node_id}")
+
+ node = tree.get_node(next_node_id)
+ if node:
+ tree.set_current_task(
+ asyncio.create_task(self.process_node(tree, node, processor))
+ )
+
+ if next_node_id:
+ await self._notify_node_started(tree, next_node_id)
+ await self._notify_queue_updated(tree)
+
+ async def enqueue_and_start(
+ self,
+ tree: MessageTree,
+ node_id: str,
+ processor: Callable[[str, MessageNode], Awaitable[None]],
+ ) -> bool:
+ """
+ Enqueue a node or start processing immediately.
+
+ Returns:
+ True if queued, False if processing immediately
+ """
+ async with tree.with_lock():
+ if tree.is_processing:
+ tree.put_queue_unlocked(node_id)
+ queue_size = tree.get_queue_size()
+ logger.info(f"Queued node {node_id}, position {queue_size}")
+ return True
+ else:
+ tree.set_processing_state(node_id, True)
+
+ node = tree.get_node(node_id)
+ if node:
+ tree.set_current_task(
+ asyncio.create_task(self.process_node(tree, node, processor))
+ )
+ return False
+
+ def cancel_current(self, tree: MessageTree) -> bool:
+ """Cancel the currently running task in a tree."""
+ return tree.cancel_current_task()
class TreeQueueManager:
"""
- Manages multiple message trees. Facade that coordinates components.
+ Manages multiple message trees: index + async processing.
Each new conversation creates a new tree.
Replies to existing messages add nodes to existing trees.
-
- Components:
- - TreeRepository: Data access layer
- - TreeQueueProcessor: Async queue processing
"""
def __init__(
@@ -33,7 +348,7 @@ class TreeQueueManager:
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
| None = None,
_repository: TreeRepository | None = None,
- ):
+ ) -> None:
self._repository = _repository or TreeRepository()
self._processor = TreeQueueProcessor(
queue_update_callback=queue_update_callback,
@@ -101,7 +416,6 @@ class TreeQueueManager:
if not tree:
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
- # Add node (tree has its own lock) - outside manager lock to avoid deadlock
node = await tree.add_node(
node_id=node_id,
incoming=incoming,
@@ -228,7 +542,6 @@ class TreeQueueManager:
cleanup_count = 0
async with tree.with_lock():
- # 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
@@ -240,12 +553,10 @@ class TreeQueueManager:
tree.set_node_error_sync(node, "Cancelled by user")
cancelled_nodes.append(node)
- # 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
- # 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
@@ -269,10 +580,6 @@ class TreeQueueManager:
"""
Cancel a single node (queued or in-progress) without affecting other nodes.
- - If the node is currently running, cancels the current asyncio task.
- - If the node is queued, removes it from the queue.
- - Marks the node as ERROR with "Cancelled by user".
-
Returns:
List containing the cancelled node if it was cancellable, else empty list.
"""
@@ -351,8 +658,6 @@ class TreeQueueManager:
async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]:
"""
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
-
- Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
@@ -435,3 +740,10 @@ class TreeQueueManager:
node_started_callback=node_started_callback,
_repository=TreeRepository.from_dict(data),
)
+
+
+__all__ = [
+ "TreeQueueManager",
+ "TreeQueueProcessor",
+ "TreeRepository",
+]
diff --git a/messaging/trees/repository.py b/messaging/trees/repository.py
deleted file mode 100644
index 2dae1fb..0000000
--- a/messaging/trees/repository.py
+++ /dev/null
@@ -1,186 +0,0 @@
-"""Repository for message tree data access.
-
-Provides data access layer for managing trees and node mappings.
-"""
-
-from loguru import logger
-
-from .data import MessageNode, MessageState, MessageTree
-
-
-class TreeRepository:
- """
- Repository for message tree data access.
-
- Manages the storage and lookup of trees and node-to-tree mappings.
- """
-
- def __init__(self):
- self._trees: dict[str, MessageTree] = {} # root_id -> tree
- self._node_to_tree: dict[str, str] = {} # node_id -> root_id
-
- def get_tree(self, root_id: str) -> MessageTree | None:
- """Get a tree by its root ID."""
- return self._trees.get(root_id)
-
- def get_tree_for_node(self, node_id: str) -> MessageTree | None:
- """Get the tree containing a given node."""
- root_id = self._node_to_tree.get(node_id)
- if not root_id:
- return None
- return self._trees.get(root_id)
-
- def get_node(self, node_id: str) -> MessageNode | None:
- """Get a node from any tree."""
- tree = self.get_tree_for_node(node_id)
- return tree.get_node(node_id) if tree else None
-
- def add_tree(self, root_id: str, tree: MessageTree) -> None:
- """Add a new tree to the repository."""
- self._trees[root_id] = tree
- self._node_to_tree[root_id] = root_id
- logger.debug("TREE_REPO: add_tree root_id={}", root_id)
-
- def register_node(self, node_id: str, root_id: str) -> None:
- """Register a node ID to a tree."""
- self._node_to_tree[node_id] = root_id
- logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
-
- def has_node(self, node_id: str) -> bool:
- """Check if a node is registered in any tree."""
- return node_id in self._node_to_tree
-
- def tree_count(self) -> int:
- """Get the number of trees in the repository."""
- return len(self._trees)
-
- def is_tree_busy(self, root_id: str) -> bool:
- """Check if a tree is currently processing."""
- tree = self._trees.get(root_id)
- return tree.is_processing if tree else False
-
- def is_node_tree_busy(self, node_id: str) -> bool:
- """Check if the tree containing a node is busy."""
- tree = self.get_tree_for_node(node_id)
- return tree.is_processing if tree else False
-
- def get_queue_size(self, node_id: str) -> int:
- """Get queue size for the tree containing a node."""
- tree = self.get_tree_for_node(node_id)
- return tree.get_queue_size() if tree else 0
-
- def resolve_parent_node_id(self, msg_id: str) -> str | None:
- """
- Resolve a message ID to the actual parent node ID.
-
- Handles the case where msg_id is a status message ID
- (which maps to the tree but isn't an actual node).
-
- Returns:
- The node_id to use as parent, or None if not found
- """
- tree = self.get_tree_for_node(msg_id)
- if not tree:
- return None
-
- # Check if msg_id is an actual node
- if tree.has_node(msg_id):
- return msg_id
-
- # Otherwise, it might be a status message - find the owning node
- node = tree.find_node_by_status_message(msg_id)
- if node:
- return node.node_id
-
- return None
-
- def get_pending_children(self, node_id: str) -> list[MessageNode]:
- """
- Get all pending child nodes (recursively) of a given node.
-
- Used for error propagation - when a node fails, its pending
- children should also be marked as failed.
- """
- tree = self.get_tree_for_node(node_id)
- if not tree:
- return []
-
- pending: list[MessageNode] = []
- stack = [node_id]
-
- while stack:
- current_id = stack.pop()
- node = tree.get_node(current_id)
- if not node:
- continue
- for child_id in node.children_ids:
- child = tree.get_node(child_id)
- if child and child.state == MessageState.PENDING:
- pending.append(child)
- stack.append(child_id)
-
- return pending
-
- def all_trees(self) -> list[MessageTree]:
- """Get all trees in the repository."""
- return list(self._trees.values())
-
- def tree_ids(self) -> list[str]:
- """Get all tree root IDs."""
- return list(self._trees.keys())
-
- def unregister_nodes(self, node_ids: list[str]) -> None:
- """Remove node IDs from the node-to-tree mapping."""
- for nid in node_ids:
- self._node_to_tree.pop(nid, None)
-
- def remove_tree(self, root_id: str) -> MessageTree | None:
- """
- Remove a tree and all its node mappings from the repository.
-
- Returns:
- The removed tree, or None if not found.
- """
- tree = self._trees.pop(root_id, None)
- if not tree:
- return None
- for node in tree.all_nodes():
- self._node_to_tree.pop(node.node_id, None)
- logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
- return tree
-
- def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
- """Get all message IDs (incoming + status) for a given platform/chat.
-
- Note: O(total_nodes) scan. Acceptable because this is only called
- from /clear (user-initiated, infrequent).
- """
- msg_ids: set[str] = set()
- for tree in self._trees.values():
- for node in tree.all_nodes():
- if str(node.incoming.platform) == str(platform) and str(
- node.incoming.chat_id
- ) == str(chat_id):
- if node.incoming.message_id is not None:
- msg_ids.add(str(node.incoming.message_id))
- if node.status_message_id:
- msg_ids.add(str(node.status_message_id))
- return msg_ids
-
- def to_dict(self) -> dict:
- """Serialize all trees."""
- return {
- "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
- "node_to_tree": self._node_to_tree.copy(),
- }
-
- @classmethod
- def from_dict(cls, data: dict) -> TreeRepository:
- """Deserialize from dictionary."""
- from .data import MessageTree
-
- repo = cls()
- for root_id, tree_data in data.get("trees", {}).items():
- repo._trees[root_id] = MessageTree.from_dict(tree_data)
- repo._node_to_tree = data.get("node_to_tree", {})
- return repo
diff --git a/messaging/ui_updates.py b/messaging/ui_updates.py
new file mode 100644
index 0000000..929202e
--- /dev/null
+++ b/messaging/ui_updates.py
@@ -0,0 +1,101 @@
+"""Throttled platform UI updates driven by transcript rendering."""
+
+from __future__ import annotations
+
+import time
+from collections.abc import Callable
+
+from loguru import logger
+
+from .platforms.base import MessagingPlatform
+from .safe_diagnostics import format_exception_for_log
+from .transcript import RenderCtx, TranscriptBuffer
+
+
+class ThrottledTranscriptEditor:
+ """Rate-limited status message edits from a growing transcript."""
+
+ def __init__(
+ self,
+ *,
+ platform: MessagingPlatform,
+ parse_mode: str | None,
+ get_limit_chars: Callable[[], int],
+ transcript: TranscriptBuffer,
+ render_ctx: RenderCtx,
+ node_id: str,
+ chat_id: str,
+ status_msg_id: str,
+ debug_platform_edits: bool,
+ log_messaging_error_details: bool = False,
+ ) -> None:
+ self._platform = platform
+ self._parse_mode = parse_mode
+ self._get_limit_chars = get_limit_chars
+ self._transcript = transcript
+ self._render_ctx = render_ctx
+ self._node_id = node_id
+ self._chat_id = chat_id
+ self._status_msg_id = status_msg_id
+ self._debug_platform_edits = debug_platform_edits
+ self._log_messaging_error_details = log_messaging_error_details
+ self._last_ui_update = 0.0
+ self._last_displayed_text: str | None = None
+ self._last_status: str | None = None
+
+ @property
+ def last_status(self) -> str | None:
+ return self._last_status
+
+ async def update(self, status: str | None = None, *, force: bool = False) -> None:
+ """Render transcript + optional status line and edit the platform message."""
+ now = time.time()
+ if not force and now - self._last_ui_update < 1.0:
+ return
+
+ self._last_ui_update = now
+ if status is not None:
+ self._last_status = status
+ try:
+ display = self._transcript.render(
+ self._render_ctx,
+ limit_chars=self._get_limit_chars(),
+ status=status,
+ )
+ except Exception as e:
+ logger.warning(
+ "Transcript render failed for node {}: {}",
+ self._node_id,
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
+ )
+ return
+ if display and display != self._last_displayed_text:
+ logger.debug(
+ "PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
+ self._node_id,
+ self._chat_id,
+ self._status_msg_id,
+ bool(force),
+ status,
+ len(display),
+ )
+ if self._debug_platform_edits:
+ logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
+ self._last_displayed_text = display
+ try:
+ await self._platform.queue_edit_message(
+ self._chat_id,
+ self._status_msg_id,
+ display,
+ parse_mode=self._parse_mode,
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to update platform for node {}: {}",
+ self._node_id,
+ format_exception_for_log(
+ e, log_full_message=self._log_messaging_error_details
+ ),
+ )
diff --git a/providers/anthropic_messages.py b/providers/anthropic_messages.py
index c35c1d5..b501aac 100644
--- a/providers/anthropic_messages.py
+++ b/providers/anthropic_messages.py
@@ -2,19 +2,32 @@
from __future__ import annotations
-import json
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal
import httpx
from loguru import logger
-from core.anthropic import get_user_facing_error_message
+from config.constants import (
+ ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES,
+)
+from core.anthropic import iter_provider_stream_error_sse_events
+from core.anthropic.emitted_sse_tracker import EmittedNativeSseTracker
+from core.anthropic.native_messages_request import (
+ build_base_native_anthropic_request_body,
+)
+from core.anthropic.native_sse_block_policy import (
+ NativeSseBlockPolicyState,
+ transform_native_sse_block_event,
+)
from providers.base import BaseProvider, ProviderConfig
-from providers.error_mapping import map_error
+from providers.error_mapping import (
+ map_error,
+ user_visible_message_for_mapped_provider_error,
+)
from providers.rate_limit import GlobalRateLimiter
-ANTHROPIC_DEFAULT_MAX_TOKENS = 81920
StreamChunkMode = Literal["line", "event"]
@@ -64,25 +77,11 @@ class AnthropicMessagesTransport(BaseProvider):
) -> dict:
"""Build a native Anthropic request body."""
thinking_enabled = self._is_thinking_enabled(request, thinking_enabled)
- body = request.model_dump(exclude_none=True)
-
- body.pop("extra_body", None)
- body.pop("original_model", None)
- body.pop("resolved_provider_model", None)
-
- if "thinking" in body:
- thinking_cfg = body.pop("thinking")
- if thinking_enabled and isinstance(thinking_cfg, dict):
- thinking_payload = {"type": "enabled"}
- budget_tokens = thinking_cfg.get("budget_tokens")
- if isinstance(budget_tokens, int):
- thinking_payload["budget_tokens"] = budget_tokens
- body["thinking"] = thinking_payload
-
- if "max_tokens" not in body:
- body["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS
-
- return body
+ return build_base_native_anthropic_request_body(
+ request,
+ default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ thinking_enabled=thinking_enabled,
+ )
async def _send_stream_request(self, body: dict) -> httpx.Response:
"""Create a streaming messages response."""
@@ -97,30 +96,68 @@ class AnthropicMessagesTransport(BaseProvider):
async def _raise_for_status(
self, response: httpx.Response, *, req_tag: str
) -> None:
- """Raise for non-200 responses after logging the upstream body."""
+ """Raise for non-200 responses after logging safe metadata (or capped body if opted in)."""
try:
response.raise_for_status()
except httpx.HTTPStatusError as error:
- response_text = await self._read_error_body(response)
- if response_text:
+ if self._config.log_api_error_tracebacks:
+ preview, truncated = await self._read_error_body_preview(
+ response, NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES
+ )
+ if preview:
+ text = preview.decode("utf-8", errors="replace")
+ logger.error(
+ "{}_ERROR:{} HTTP {} body_preview_bytes={} truncated={}: {}",
+ self._provider_name,
+ req_tag,
+ response.status_code,
+ len(preview),
+ truncated,
+ text,
+ )
+ else:
+ logger.error(
+ "{}_ERROR:{} HTTP {} (empty error body)",
+ self._provider_name,
+ req_tag,
+ response.status_code,
+ )
+ else:
+ cl = response.headers.get("content-length", "").strip()
+ extra = f" content_length_declared={cl}" if cl.isdigit() else ""
logger.error(
- "{}_ERROR:{} HTTP {}: {}",
+ "{}_ERROR:{} HTTP {}{}",
self._provider_name,
req_tag,
response.status_code,
- response_text,
+ extra,
)
raise error
- async def _read_error_body(self, response: httpx.Response) -> str:
- """Read a response body for diagnostics."""
- aread = getattr(response, "aread", None)
- if aread is None:
- return ""
- body = await aread()
- if isinstance(body, bytes):
- return body.decode("utf-8", errors="replace")
- return str(body)
+ async def _read_error_body_preview(
+ self, response: httpx.Response, max_bytes: int
+ ) -> tuple[bytes, bool]:
+ """Read at most ``max_bytes`` from the error body for logging. Returns (preview, truncated)."""
+ if max_bytes <= 0:
+ return b"", False
+ received = 0
+ parts: list[bytes] = []
+ truncated = False
+ async for chunk in response.aiter_bytes(chunk_size=65_536):
+ if received >= max_bytes:
+ truncated = True
+ break
+ remaining = max_bytes - received
+ take = chunk if len(chunk) <= remaining else chunk[:remaining]
+ if take:
+ parts.append(take)
+ received += len(take)
+ if len(chunk) > len(take):
+ truncated = True
+ break
+ if received >= max_bytes:
+ break
+ return (b"".join(parts), truncated)
async def _iter_sse_lines(self, response: httpx.Response) -> AsyncIterator[str]:
"""Yield raw SSE line chunks preserving local provider behavior."""
@@ -145,6 +182,8 @@ class AnthropicMessagesTransport(BaseProvider):
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
"""Return per-stream provider state for event transformation."""
+ if self.stream_chunk_mode == "line":
+ return NativeSseBlockPolicyState()
return None
def _transform_stream_event(
@@ -155,6 +194,10 @@ class AnthropicMessagesTransport(BaseProvider):
thinking_enabled: bool,
) -> str | None:
"""Transform or drop a grouped SSE event before yielding it downstream."""
+ if isinstance(state, NativeSseBlockPolicyState):
+ return transform_native_sse_block_event(
+ event, state, thinking_enabled=thinking_enabled
+ )
return event
def _format_error_message(self, base_message: str, request_id: str | None) -> str:
@@ -166,15 +209,11 @@ class AnthropicMessagesTransport(BaseProvider):
def _get_error_message(self, error: Exception, request_id: str | None) -> str:
"""Map an exception into a user-facing provider error message."""
mapped_error = map_error(error, rate_limiter=self._global_rate_limiter)
- if getattr(mapped_error, "status_code", None) == 405:
- base_message = (
- f"Upstream provider {self._provider_name} rejected the request method "
- "or endpoint (HTTP 405)."
- )
- else:
- base_message = get_user_facing_error_message(
- mapped_error, read_timeout_s=self._config.http_read_timeout
- )
+ base_message = user_visible_message_for_mapped_provider_error(
+ mapped_error,
+ provider_name=self._provider_name,
+ read_timeout_s=self._config.http_read_timeout,
+ )
return self._format_error_message(base_message, request_id)
def _emit_error_events(
@@ -185,12 +224,14 @@ class AnthropicMessagesTransport(BaseProvider):
error_message: str,
sent_any_event: bool,
) -> Iterator[str]:
- """Emit a native Anthropic error event."""
- error_event = {
- "type": "error",
- "error": {"type": "api_error", "message": error_message},
- }
- yield f"event: error\ndata: {json.dumps(error_event)}\n\n"
+ """Emit the same Anthropic message lifecycle used by OpenAI-compat providers."""
+ yield from iter_provider_stream_error_sse_events(
+ request=request,
+ input_tokens=input_tokens,
+ error_message=error_message,
+ sent_any_event=sent_any_event,
+ log_raw_sse_events=self._config.log_raw_sse_events,
+ )
async def _iter_stream_chunks(
self,
@@ -200,6 +241,21 @@ class AnthropicMessagesTransport(BaseProvider):
thinking_enabled: bool,
) -> AsyncIterator[str]:
"""Yield stream chunks according to the provider's observable chunk shape."""
+ if self.stream_chunk_mode == "line" and isinstance(
+ state, NativeSseBlockPolicyState
+ ):
+ async for event in self._iter_sse_events(response):
+ output_event = self._transform_stream_event(
+ event,
+ state,
+ thinking_enabled=thinking_enabled,
+ )
+ if output_event is None:
+ continue
+ for line in output_event.splitlines(keepends=True):
+ yield line
+ return
+
if self.stream_chunk_mode == "line":
async for chunk in self._iter_sse_lines(response):
yield chunk
@@ -240,15 +296,28 @@ class AnthropicMessagesTransport(BaseProvider):
response: httpx.Response | None = None
sent_any_event = False
state = self._new_stream_state(request, thinking_enabled=thinking_enabled)
+ emitted_tracker = EmittedNativeSseTracker()
async with self._global_rate_limiter.concurrency_slot():
try:
- response = await self._global_rate_limiter.execute_with_retry(
- self._send_stream_request, body
- )
- if response.status_code != 200:
- await self._raise_for_status(response, req_tag=req_tag)
+ async def _validated_stream_send() -> httpx.Response:
+ """Send request; raise inside retry loop on 429 so rate limiter can backoff."""
+ send_response = await self._send_stream_request(body)
+ if send_response.status_code == 429:
+ await send_response.aclose()
+ send_response.raise_for_status()
+ if send_response.status_code != 200:
+ try:
+ await self._raise_for_status(send_response, req_tag=req_tag)
+ finally:
+ if not send_response.is_closed:
+ await send_response.aclose()
+ return send_response
+
+ response = await self._global_rate_limiter.execute_with_retry(
+ _validated_stream_send
+ )
async for chunk in self._iter_stream_chunks(
response,
@@ -256,12 +325,12 @@ class AnthropicMessagesTransport(BaseProvider):
thinking_enabled=thinking_enabled,
):
sent_any_event = True
+ emitted_tracker.feed(chunk)
yield chunk
except Exception as error:
- logger.error(
- "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
- )
+ if not isinstance(error, httpx.HTTPStatusError):
+ self._log_stream_transport_error(tag, req_tag, error)
error_message = self._get_error_message(error, request_id)
if response is not None and not response.is_closed:
@@ -273,13 +342,24 @@ class AnthropicMessagesTransport(BaseProvider):
type(error).__name__,
req_tag,
)
- for event in self._emit_error_events(
- request=request,
- input_tokens=input_tokens,
- error_message=error_message,
- sent_any_event=sent_any_event,
- ):
- yield event
+ if sent_any_event:
+ for event in emitted_tracker.iter_close_unclosed_blocks():
+ yield event
+ for event in emitted_tracker.iter_midstream_error_tail(
+ error_message,
+ request=request,
+ input_tokens=input_tokens,
+ log_raw_sse_events=self._config.log_raw_sse_events,
+ ):
+ yield event
+ else:
+ for event in self._emit_error_events(
+ request=request,
+ input_tokens=input_tokens,
+ error_message=error_message,
+ sent_any_event=False,
+ ):
+ yield event
return
finally:
if response is not None and not response.is_closed:
diff --git a/providers/base.py b/providers/base.py
index f8ff193..4f1a74b 100644
--- a/providers/base.py
+++ b/providers/base.py
@@ -6,6 +6,8 @@ from typing import Any
from pydantic import BaseModel
+from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT
+
class ProviderConfig(BaseModel):
"""Configuration for a provider.
@@ -21,9 +23,11 @@ class ProviderConfig(BaseModel):
max_concurrency: int = 5
http_read_timeout: float = 300.0
http_write_timeout: float = 10.0
- http_connect_timeout: float = 2.0
+ http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT
enable_thinking: bool = True
proxy: str = ""
+ log_raw_sse_events: bool = False
+ log_api_error_tracebacks: bool = False
class BaseProvider(ABC):
@@ -61,6 +65,42 @@ class BaseProvider(ABC):
request_enabled = bool(enabled)
return config_enabled and request_enabled
+ def preflight_stream(
+ self, request: Any, *, thinking_enabled: bool | None = None
+ ) -> None:
+ """Eagerly validate/build the upstream request before opening an SSE stream.
+
+ Subclasses with ``_build_request_body`` (OpenAI and native) raise
+ :class:`providers.exceptions.InvalidRequestError` on conversion failures.
+ """
+ build = getattr(self, "_build_request_body", None)
+ if build is None:
+ return
+ build(request, thinking_enabled=thinking_enabled)
+
+ def _log_stream_transport_error(
+ self, tag: str, req_tag: str, error: Exception
+ ) -> None:
+ """Log streaming transport failures (metadata-only unless verbose is enabled)."""
+ from loguru import logger
+
+ if self._config.log_api_error_tracebacks:
+ logger.error(
+ "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
+ )
+ return
+ response = getattr(error, "response", None)
+ status_code = (
+ getattr(response, "status_code", None) if response is not None else None
+ )
+ logger.error(
+ "{}_ERROR:{} exc_type={} http_status={}",
+ tag,
+ req_tag,
+ type(error).__name__,
+ status_code,
+ )
+
@abstractmethod
async def cleanup(self) -> None:
"""Release any resources held by this provider."""
@@ -75,5 +115,7 @@ class BaseProvider(ABC):
thinking_enabled: bool | None = None,
) -> AsyncIterator[str]:
"""Stream response in Anthropic SSE format."""
+ # Typing: abstract async generators need a yield for AsyncIterator[str]
+ # inference; this branch is never executed.
if False:
- yield "" # Required for ty/mypy to accept abstract async generator
+ yield ""
diff --git a/providers/deepseek/__init__.py b/providers/deepseek/__init__.py
index 0ecf7d0..5ab4c79 100644
--- a/providers/deepseek/__init__.py
+++ b/providers/deepseek/__init__.py
@@ -1,5 +1,7 @@
"""DeepSeek provider exports."""
-from .client import DEEPSEEK_BASE_URL, DeepSeekProvider
+from providers.defaults import DEEPSEEK_DEFAULT_BASE
-__all__ = ["DEEPSEEK_BASE_URL", "DeepSeekProvider"]
+from .client import DeepSeekProvider
+
+__all__ = ["DEEPSEEK_DEFAULT_BASE", "DeepSeekProvider"]
diff --git a/providers/deepseek/client.py b/providers/deepseek/client.py
index 210f223..3996b61 100644
--- a/providers/deepseek/client.py
+++ b/providers/deepseek/client.py
@@ -3,7 +3,7 @@
from typing import Any
from providers.base import ProviderConfig
-from providers.defaults import DEEPSEEK_BASE_URL
+from providers.defaults import DEEPSEEK_DEFAULT_BASE
from providers.openai_compat import OpenAIChatTransport
from .request import build_request_body
@@ -16,7 +16,7 @@ class DeepSeekProvider(OpenAIChatTransport):
super().__init__(
config,
provider_name="DEEPSEEK",
- base_url=config.base_url or DEEPSEEK_BASE_URL,
+ base_url=config.base_url or DEEPSEEK_DEFAULT_BASE,
api_key=config.api_key,
)
diff --git a/providers/deepseek/request.py b/providers/deepseek/request.py
index 67cabf9..f22a9f8 100644
--- a/providers/deepseek/request.py
+++ b/providers/deepseek/request.py
@@ -5,6 +5,8 @@ from typing import Any
from loguru import logger
from core.anthropic import build_base_request_body
+from core.anthropic.conversion import OpenAIConversionError
+from providers.exceptions import InvalidRequestError
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
@@ -14,10 +16,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
- body = build_base_request_body(
- request_data,
- include_reasoning_content=True,
- )
+ try:
+ body = build_base_request_body(
+ request_data,
+ include_thinking=thinking_enabled,
+ include_reasoning_content=thinking_enabled,
+ )
+ except OpenAIConversionError as exc:
+ raise InvalidRequestError(str(exc)) from exc
extra_body: dict[str, Any] = {}
request_extra = getattr(request_data, "extra_body", None)
diff --git a/providers/defaults.py b/providers/defaults.py
index 5f7fcfe..80109dc 100644
--- a/providers/defaults.py
+++ b/providers/defaults.py
@@ -1,21 +1,19 @@
-"""Default upstream base URLs and shared provider constants.
+"""Re-exports default upstream base URLs from the config provider catalog."""
-Adapters and :mod:`providers.registry` import from here to avoid duplicating
-literals and to keep ``providers.registry`` free of per-adapter eager imports.
-"""
+from config.provider_catalog import (
+ DEEPSEEK_DEFAULT_BASE,
+ LLAMACPP_DEFAULT_BASE,
+ LMSTUDIO_DEFAULT_BASE,
+ NVIDIA_NIM_DEFAULT_BASE,
+ OLLAMA_DEFAULT_BASE,
+ OPENROUTER_DEFAULT_BASE,
+)
-# OpenAI-compatible chat (NIM, DeepSeek) and local/native provider endpoints
-NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
-DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
-OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
-LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
-LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
-OLLAMA_DEFAULT_BASE = "http://localhost:11434"
-
-# Backward-compatible names used by existing adapter modules
-NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE
-DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE
-OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE
-LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE
-LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE
-OLLAMA_DEFAULT_BASE_URL = OLLAMA_DEFAULT_BASE
+__all__ = (
+ "DEEPSEEK_DEFAULT_BASE",
+ "LLAMACPP_DEFAULT_BASE",
+ "LMSTUDIO_DEFAULT_BASE",
+ "NVIDIA_NIM_DEFAULT_BASE",
+ "OLLAMA_DEFAULT_BASE",
+ "OPENROUTER_DEFAULT_BASE",
+)
diff --git a/providers/error_mapping.py b/providers/error_mapping.py
index 3826a84..44c33a5 100644
--- a/providers/error_mapping.py
+++ b/providers/error_mapping.py
@@ -14,10 +14,30 @@ from providers.exceptions import (
from providers.rate_limit import GlobalRateLimiter
+def user_visible_message_for_mapped_provider_error(
+ mapped: Exception,
+ *,
+ provider_name: str,
+ read_timeout_s: float | None,
+) -> str:
+ """Return the user-visible string after :func:`map_error` (405 + mapped types)."""
+ if getattr(mapped, "status_code", None) == 405:
+ return (
+ f"Upstream provider {provider_name} rejected the request method "
+ "or endpoint (HTTP 405)."
+ )
+ return get_user_facing_error_message(mapped, read_timeout_s=read_timeout_s)
+
+
def map_error(
e: Exception, *, rate_limiter: GlobalRateLimiter | None = None
) -> Exception:
- """Map OpenAI or HTTPX exception to specific ProviderError."""
+ """Map OpenAI or HTTPX exception to specific ProviderError.
+
+ Streaming transports should pass their scoped limiter (``self._global_rate_limiter``)
+ so reactive 429 handling applies to the correct provider. Tests may omit
+ ``rate_limiter`` to use the process-wide singleton.
+ """
message = get_user_facing_error_message(e)
limiter = rate_limiter or GlobalRateLimiter.get_instance()
diff --git a/providers/exceptions.py b/providers/exceptions.py
index 31c6781..6901b9c 100644
--- a/providers/exceptions.py
+++ b/providers/exceptions.py
@@ -90,8 +90,20 @@ class APIError(ProviderError):
)
-class UnknownProviderTypeError(ValueError):
+class UnknownProviderTypeError(InvalidRequestError):
"""Raised when ``provider_id`` is not registered in the provider map."""
def __init__(self, message: str) -> None:
super().__init__(message)
+
+
+class ServiceUnavailableError(ProviderError):
+ """Raised when the server is not ready (e.g. app lifespan did not wire state)."""
+
+ def __init__(self, message: str, raw_error: Any = None):
+ super().__init__(
+ message,
+ status_code=503,
+ error_type="api_error",
+ raw_error=raw_error,
+ )
diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py
index de0c268..891022d 100644
--- a/providers/llamacpp/client.py
+++ b/providers/llamacpp/client.py
@@ -2,7 +2,7 @@
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
-from providers.defaults import LLAMACPP_DEFAULT_BASE_URL
+from providers.defaults import LLAMACPP_DEFAULT_BASE
class LlamaCppProvider(AnthropicMessagesTransport):
@@ -12,5 +12,5 @@ class LlamaCppProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="LLAMACPP",
- default_base_url=LLAMACPP_DEFAULT_BASE_URL,
+ default_base_url=LLAMACPP_DEFAULT_BASE,
)
diff --git a/providers/lmstudio/client.py b/providers/lmstudio/client.py
index 2961993..e32b835 100644
--- a/providers/lmstudio/client.py
+++ b/providers/lmstudio/client.py
@@ -2,7 +2,7 @@
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
-from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL
+from providers.defaults import LMSTUDIO_DEFAULT_BASE
class LMStudioProvider(AnthropicMessagesTransport):
@@ -12,5 +12,5 @@ class LMStudioProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="LMSTUDIO",
- default_base_url=LMSTUDIO_DEFAULT_BASE_URL,
+ default_base_url=LMSTUDIO_DEFAULT_BASE,
)
diff --git a/providers/nvidia_nim/__init__.py b/providers/nvidia_nim/__init__.py
index b0a410f..253acd1 100644
--- a/providers/nvidia_nim/__init__.py
+++ b/providers/nvidia_nim/__init__.py
@@ -1,5 +1,7 @@
"""NVIDIA NIM provider package."""
-from .client import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
+from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
-__all__ = ["NVIDIA_NIM_BASE_URL", "NvidiaNimProvider"]
+from .client import NvidiaNimProvider
+
+__all__ = ["NVIDIA_NIM_DEFAULT_BASE", "NvidiaNimProvider"]
diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py
index 5cc2f31..2661203 100644
--- a/providers/nvidia_nim/client.py
+++ b/providers/nvidia_nim/client.py
@@ -8,7 +8,7 @@ from loguru import logger
from config.nim import NimSettings
from providers.base import ProviderConfig
-from providers.defaults import NVIDIA_NIM_BASE_URL
+from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
from providers.openai_compat import OpenAIChatTransport
from .request import (
@@ -25,7 +25,7 @@ class NvidiaNimProvider(OpenAIChatTransport):
super().__init__(
config,
provider_name="NIM",
- base_url=config.base_url or NVIDIA_NIM_BASE_URL,
+ base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE,
api_key=config.api_key,
)
self._nim_settings = nim_settings
diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py
index c0a83f3..a4e96de 100644
--- a/providers/nvidia_nim/request.py
+++ b/providers/nvidia_nim/request.py
@@ -1,5 +1,6 @@
"""Request builder for NVIDIA NIM provider."""
+from collections.abc import Callable
from copy import deepcopy
from typing import Any
@@ -7,6 +8,42 @@ from loguru import logger
from config.nim import NimSettings
from core.anthropic import build_base_request_body, set_if_not_none
+from core.anthropic.conversion import OpenAIConversionError
+from providers.exceptions import InvalidRequestError
+
+
+def _clone_strip_extra_body(
+ body: dict[str, Any],
+ strip: Callable[[dict[str, Any]], bool],
+) -> dict[str, Any] | None:
+ """Deep-clone ``body`` and remove fields via ``strip`` on ``extra_body`` only.
+
+ Returns ``None`` when there is no ``extra_body`` dict or ``strip`` reports no change.
+ """
+ cloned_body = deepcopy(body)
+ extra_body = cloned_body.get("extra_body")
+ if not isinstance(extra_body, dict):
+ return None
+ if not strip(extra_body):
+ return None
+ if not extra_body:
+ cloned_body.pop("extra_body", None)
+ return cloned_body
+
+
+def _strip_reasoning_budget_fields(extra_body: dict[str, Any]) -> bool:
+ removed = extra_body.pop("reasoning_budget", None) is not None
+ chat_template_kwargs = extra_body.get("chat_template_kwargs")
+ if (
+ isinstance(chat_template_kwargs, dict)
+ and chat_template_kwargs.pop("reasoning_budget", None) is not None
+ ):
+ removed = True
+ return removed
+
+
+def _strip_chat_template_field(extra_body: dict[str, Any]) -> bool:
+ return extra_body.pop("chat_template", None) is not None
def _set_extra(
@@ -23,43 +60,12 @@ def _set_extra(
def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None:
"""Clone a request body and strip only reasoning_budget fields."""
- cloned_body = deepcopy(body)
- extra_body = cloned_body.get("extra_body")
- if not isinstance(extra_body, dict):
- return None
-
- removed = extra_body.pop("reasoning_budget", None) is not None
-
- chat_template_kwargs = extra_body.get("chat_template_kwargs")
- if (
- isinstance(chat_template_kwargs, dict)
- and chat_template_kwargs.pop("reasoning_budget", None) is not None
- ):
- removed = True
-
- if not extra_body:
- cloned_body.pop("extra_body", None)
-
- if not removed:
- return None
-
- return cloned_body
+ return _clone_strip_extra_body(body, _strip_reasoning_budget_fields)
def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None:
"""Clone a request body and strip only chat_template."""
- cloned_body = deepcopy(body)
- extra_body = cloned_body.get("extra_body")
- if not isinstance(extra_body, dict):
- return None
-
- if extra_body.pop("chat_template", None) is None:
- return None
-
- if not extra_body:
- cloned_body.pop("extra_body", None)
-
- return cloned_body
+ return _clone_strip_extra_body(body, _strip_chat_template_field)
def build_request_body(
@@ -71,10 +77,13 @@ def build_request_body(
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
- body = build_base_request_body(
- request_data,
- include_thinking=thinking_enabled,
- )
+ try:
+ body = build_base_request_body(
+ request_data,
+ include_thinking=thinking_enabled,
+ )
+ except OpenAIConversionError as exc:
+ raise InvalidRequestError(str(exc)) from exc
# NIM-specific max_tokens: cap against nim.max_tokens
max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None)
diff --git a/providers/nvidia_nim/voice.py b/providers/nvidia_nim/voice.py
new file mode 100644
index 0000000..f26255a
--- /dev/null
+++ b/providers/nvidia_nim/voice.py
@@ -0,0 +1,95 @@
+"""NVIDIA NIM / Riva offline ASR for voice notes (provider-owned transport)."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from loguru import logger
+
+# NVIDIA NIM Whisper model mapping: (function_id, language_code)
+_NIM_ASR_MODEL_MAP: dict[str, tuple[str, str]] = {
+ "nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"),
+ "nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"),
+ # function-id from NVIDIA NIM API docs (parakeet-ctc-0.6b-es).
+ "nvidia/parakeet-ctc-0.6b-es": ("a9eeee8f-b509-4712-b19d-194361fa5f31", "es-US"),
+ "nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"),
+ "nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"),
+ "nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"),
+ "nvidia/parakeet-1.1b-rnnt-multilingual-asr": (
+ "71203149-d3b7-4460-8231-1be2543a1fca",
+ "",
+ ),
+ "openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"),
+}
+
+_RIVA_SERVER = "grpc.nvcf.nvidia.com:443"
+
+
+def transcribe_audio_file(
+ file_path: Path,
+ model: str,
+ *,
+ api_key: str,
+) -> str:
+ """Transcribe audio using NVIDIA NIM / Riva gRPC (offline recognition).
+
+ Args:
+ file_path: Path to encoded audio bytes readable by Riva.
+ model: Hugging Face-style NIM model id (see ``_NIM_ASR_MODEL_MAP``).
+ api_key: NVIDIA API key (Bearer token); must be non-empty.
+
+ Returns:
+ Transcript text, or ``(no speech detected)`` when empty.
+ """
+ key = (api_key or "").strip()
+ if not key:
+ raise ValueError(
+ "NVIDIA NIM transcription requires a non-empty nvidia_nim_api_key "
+ "(configure NVIDIA_NIM_API_KEY or pass api_key explicitly)."
+ )
+
+ try:
+ import riva.client
+ except ImportError as e:
+ raise ImportError(
+ "NVIDIA NIM transcription requires the voice extra. "
+ "Install with: uv sync --extra voice"
+ ) from e
+
+ model_config = _NIM_ASR_MODEL_MAP.get(model)
+ if not model_config:
+ raise ValueError(
+ f"No NVIDIA NIM config found for model: {model}. "
+ f"Supported models: {', '.join(_NIM_ASR_MODEL_MAP.keys())}"
+ )
+ function_id, language_code = model_config
+
+ auth = riva.client.Auth(
+ use_ssl=True,
+ uri=_RIVA_SERVER,
+ metadata_args=[
+ ["function-id", function_id],
+ ["authorization", f"Bearer {key}"],
+ ],
+ )
+
+ asr_service = riva.client.ASRService(auth)
+
+ config = riva.client.RecognitionConfig(
+ language_code=language_code,
+ max_alternatives=1,
+ verbatim_transcripts=True,
+ )
+
+ with open(file_path, "rb") as f:
+ data = f.read()
+
+ response = asr_service.offline_recognize(data, config)
+
+ transcript = ""
+ results = getattr(response, "results", None)
+ if results and results[0].alternatives:
+ transcript = results[0].alternatives[0].transcript
+
+ logger.debug(f"NIM transcription: {len(transcript)} chars")
+ return transcript or "(no speech detected)"
diff --git a/providers/ollama/__init__.py b/providers/ollama/__init__.py
index 365677c..8934ab9 100644
--- a/providers/ollama/__init__.py
+++ b/providers/ollama/__init__.py
@@ -1,5 +1,7 @@
"""Ollama provider package."""
-from .client import OLLAMA_BASE_URL, OllamaProvider
+from providers.defaults import OLLAMA_DEFAULT_BASE
-__all__ = ["OLLAMA_BASE_URL", "OllamaProvider"]
+from .client import OllamaProvider
+
+__all__ = ["OLLAMA_DEFAULT_BASE", "OllamaProvider"]
diff --git a/providers/ollama/client.py b/providers/ollama/client.py
index ebff031..6696bd8 100644
--- a/providers/ollama/client.py
+++ b/providers/ollama/client.py
@@ -6,8 +6,6 @@ from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
from providers.defaults import OLLAMA_DEFAULT_BASE
-OLLAMA_BASE_URL = OLLAMA_DEFAULT_BASE
-
class OllamaProvider(AnthropicMessagesTransport):
"""Ollama provider using native Anthropic Messages API."""
@@ -16,7 +14,7 @@ class OllamaProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="OLLAMA",
- default_base_url=OLLAMA_BASE_URL,
+ default_base_url=OLLAMA_DEFAULT_BASE,
)
self._api_key = config.api_key or "ollama"
diff --git a/providers/open_router/__init__.py b/providers/open_router/__init__.py
index 7a0cdad..df12496 100644
--- a/providers/open_router/__init__.py
+++ b/providers/open_router/__init__.py
@@ -1,5 +1,7 @@
"""OpenRouter provider - Anthropic-compatible native transport."""
-from .client import OPENROUTER_BASE_URL, OpenRouterProvider
+from providers.defaults import OPENROUTER_DEFAULT_BASE
-__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"]
+from .client import OpenRouterProvider
+
+__all__ = ["OPENROUTER_DEFAULT_BASE", "OpenRouterProvider"]
diff --git a/providers/open_router/client.py b/providers/open_router/client.py
index d5c2081..032df60 100644
--- a/providers/open_router/client.py
+++ b/providers/open_router/client.py
@@ -2,34 +2,25 @@
from __future__ import annotations
-import json
-import uuid
from collections.abc import Iterator
-from dataclasses import dataclass, field
from typing import Any
-from core.anthropic import SSEBuilder, append_request_id
+from core.anthropic import append_request_id, iter_provider_stream_error_sse_events
+from core.anthropic.native_sse_block_policy import (
+ NativeSseBlockPolicyState,
+ is_terminal_openrouter_done_event,
+ parse_native_sse_event,
+ transform_native_sse_block_event,
+)
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
from providers.base import ProviderConfig
-from providers.defaults import OPENROUTER_BASE_URL
+from providers.defaults import OPENROUTER_DEFAULT_BASE
from .request import build_request_body
_ANTHROPIC_VERSION = "2023-06-01"
-@dataclass
-class _SSEFilterState:
- """Track Anthropic content block index remapping while filtering thinking."""
-
- next_index: int = 0
- index_map: dict[int, int] = field(default_factory=dict)
- dropped_indexes: set[int] = field(default_factory=set)
- open_block_types: dict[int, str] = field(default_factory=dict)
- closed_indexes: set[int] = field(default_factory=set)
- message_stopped: bool = False
-
-
class OpenRouterProvider(AnthropicMessagesTransport):
"""OpenRouter provider using the native Anthropic-compatible messages API."""
@@ -39,7 +30,7 @@ class OpenRouterProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="OPENROUTER",
- default_base_url=OPENROUTER_BASE_URL,
+ default_base_url=OPENROUTER_DEFAULT_BASE,
)
def _build_request_body(
@@ -60,163 +51,9 @@ class OpenRouterProvider(AnthropicMessagesTransport):
"anthropic-version": _ANTHROPIC_VERSION,
}
- @staticmethod
- def _format_sse_event(event_name: str | None, data_text: str) -> str:
- """Format an SSE event from its event name and data payload."""
- lines: list[str] = []
- if event_name:
- lines.append(f"event: {event_name}")
- lines.extend(f"data: {line}" for line in data_text.splitlines())
- return "\n".join(lines) + "\n\n"
-
- @staticmethod
- def _parse_sse_event(event: str) -> tuple[str | None, str]:
- """Extract the event name and raw data payload from an SSE event."""
- event_name = None
- data_lines: list[str] = []
- for line in event.strip().splitlines():
- if line.startswith("event:"):
- event_name = line[6:].strip()
- elif line.startswith("data:"):
- data_lines.append(line[5:].lstrip())
- return event_name, "\n".join(data_lines)
-
- @staticmethod
- def _is_terminal_done_event(event_name: str | None, data_text: str) -> bool:
- """Return whether an event is OpenAI-style terminal noise."""
- return (event_name is None or event_name in {"data", "done"}) and (
- data_text.strip().upper() == "[DONE]"
- )
-
- @staticmethod
- def _remap_index(
- payload: dict[str, Any], state: _SSEFilterState, *, create: bool
- ) -> int | None:
- """Return the downstream index for a content block event."""
- upstream_index = payload.get("index")
- if not isinstance(upstream_index, int):
- return None
- if upstream_index in state.dropped_indexes:
- return None
- mapped_index = state.index_map.get(upstream_index)
- if mapped_index is None and create:
- mapped_index = state.next_index
- state.index_map[upstream_index] = mapped_index
- state.next_index += 1
- return mapped_index
-
- def _close_open_blocks_before(
- self, state: _SSEFilterState, upstream_index: int
- ) -> str:
- """Close overlapping upstream blocks before starting a new block."""
- events: list[str] = []
- for open_upstream_index in list(state.open_block_types):
- if open_upstream_index == upstream_index:
- continue
- mapped_index = state.index_map.get(open_upstream_index)
- if mapped_index is None:
- continue
- payload = {"type": "content_block_stop", "index": mapped_index}
- events.append(
- self._format_sse_event("content_block_stop", json.dumps(payload))
- )
- state.closed_indexes.add(open_upstream_index)
- state.open_block_types.pop(open_upstream_index, None)
- return "".join(events)
-
- @staticmethod
- def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool:
- if not isinstance(block_type, str):
- return False
- if block_type.startswith("redacted_thinking"):
- return True
- return not thinking_enabled and "thinking" in block_type
-
- def _transform_sse_payload(
- self,
- event: str,
- state: _SSEFilterState,
- *,
- thinking_enabled: bool,
- ) -> str | None:
- """Normalize OpenRouter SSE events and enforce local thinking policy."""
- event_name, data_text = self._parse_sse_event(event)
- if not event_name or not data_text:
- return event
-
- try:
- payload = json.loads(data_text)
- except json.JSONDecodeError:
- return event
-
- if event_name == "content_block_start":
- block = payload.get("content_block")
- if not isinstance(block, dict):
- return event
- block_type = block.get("type")
- upstream_index = payload.get("index")
- if self._should_drop_block_type(
- block_type, thinking_enabled=thinking_enabled
- ):
- if isinstance(upstream_index, int):
- state.dropped_indexes.add(upstream_index)
- return None
-
- mapped_index = self._remap_index(payload, state, create=True)
- if mapped_index is not None:
- payload["index"] = mapped_index
- if isinstance(upstream_index, int) and isinstance(block_type, str):
- prefix = self._close_open_blocks_before(state, upstream_index)
- state.open_block_types[upstream_index] = block_type
- return prefix + self._format_sse_event(
- event_name, json.dumps(payload)
- )
- return self._format_sse_event(event_name, json.dumps(payload))
- return None if not thinking_enabled else event
-
- if event_name == "content_block_delta":
- delta = payload.get("delta")
- if not isinstance(delta, dict):
- return event
- delta_type = delta.get("type")
- if self._should_drop_block_type(
- delta_type, thinking_enabled=thinking_enabled
- ):
- return None
-
- mapped_index = self._remap_index(payload, state, create=False)
- if mapped_index is not None:
- payload["index"] = mapped_index
- return self._format_sse_event(event_name, json.dumps(payload))
- if payload.get("index") in state.dropped_indexes:
- return None
- if not thinking_enabled:
- return None
-
- if event_name == "content_block_stop":
- upstream_index = payload.get("index")
- if (
- isinstance(upstream_index, int)
- and upstream_index in state.closed_indexes
- ):
- state.closed_indexes.discard(upstream_index)
- return None
- mapped_index = self._remap_index(payload, state, create=False)
- if mapped_index is not None:
- payload["index"] = mapped_index
- if isinstance(upstream_index, int):
- state.open_block_types.pop(upstream_index, None)
- return self._format_sse_event(event_name, json.dumps(payload))
- if payload.get("index") in state.dropped_indexes:
- return None
- if not thinking_enabled:
- return None
-
- return event
-
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
"""Create per-stream state for thinking block filtering."""
- return _SSEFilterState()
+ return NativeSseBlockPolicyState()
def _transform_stream_event(
self,
@@ -226,23 +63,17 @@ class OpenRouterProvider(AnthropicMessagesTransport):
thinking_enabled: bool,
) -> str | None:
"""Drop provider-specific terminal noise and hidden thinking events."""
- if isinstance(state, _SSEFilterState):
- event_name, data_text = self._parse_sse_event(event)
- if state.message_stopped or self._is_terminal_done_event(
+ if isinstance(state, NativeSseBlockPolicyState):
+ event_name, data_text = parse_native_sse_event(event)
+ if state.message_stopped or is_terminal_openrouter_done_event(
event_name, data_text
):
return None
if event_name == "message_stop":
state.message_stopped = True
- if thinking_enabled:
- if isinstance(state, _SSEFilterState):
- return self._transform_sse_payload(
- event, state, thinking_enabled=thinking_enabled
- )
- return event
- if isinstance(state, _SSEFilterState):
- return self._transform_sse_payload(
+ if isinstance(state, NativeSseBlockPolicyState):
+ return transform_native_sse_block_event(
event, state, thinking_enabled=thinking_enabled
)
return event
@@ -260,9 +91,10 @@ class OpenRouterProvider(AnthropicMessagesTransport):
sent_any_event: bool,
) -> Iterator[str]:
"""Emit the Anthropic SSE error shape expected by Claude clients."""
- sse = SSEBuilder(f"msg_{uuid.uuid4()}", request.model, input_tokens)
- if not sent_any_event:
- yield sse.message_start()
- yield from sse.emit_error(error_message)
- yield sse.message_delta("end_turn", 1)
- yield sse.message_stop()
+ yield from iter_provider_stream_error_sse_events(
+ request=request,
+ input_tokens=input_tokens,
+ error_message=error_message,
+ sent_any_event=sent_any_event,
+ log_raw_sse_events=self._config.log_raw_sse_events,
+ )
diff --git a/providers/open_router/request.py b/providers/open_router/request.py
index 89759ab..abd93df 100644
--- a/providers/open_router/request.py
+++ b/providers/open_router/request.py
@@ -2,136 +2,18 @@
from __future__ import annotations
-from collections.abc import Sequence
from typing import Any
from loguru import logger
-from pydantic import BaseModel
-OPENROUTER_DEFAULT_MAX_TOKENS = 81920
-
-_REQUEST_FIELDS = (
- "model",
- "messages",
- "system",
- "max_tokens",
- "stop_sequences",
- "stream",
- "temperature",
- "top_p",
- "top_k",
- "metadata",
- "tools",
- "tool_choice",
- "thinking",
- "extra_body",
- "original_model",
- "resolved_provider_model",
+from config.constants import (
+ ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS as OPENROUTER_DEFAULT_MAX_TOKENS,
)
-
-_INTERNAL_FIELDS = {
- "thinking",
- "extra_body",
- "original_model",
- "resolved_provider_model",
-}
-_THINKING_HISTORY_BLOCK_TYPES = {"thinking", "redacted_thinking"}
-
-
-def _serialize_value(value: Any) -> Any:
- """Convert Pydantic models and lightweight objects into JSON-ready values."""
- if isinstance(value, BaseModel):
- return value.model_dump(exclude_none=True)
- if isinstance(value, dict):
- return {
- key: _serialize_value(item)
- for key, item in value.items()
- if item is not None
- }
- if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
- return [_serialize_value(item) for item in value]
- if value is None or isinstance(value, str | int | float | bool):
- return value
- if hasattr(value, "__dict__"):
- return {
- key: _serialize_value(item)
- for key, item in vars(value).items()
- if not key.startswith("_") and item is not None
- }
- return value
-
-
-def _dump_request_fields(request_data: Any) -> dict[str, Any]:
- """Extract the public request fields forwarded to OpenRouter."""
- if isinstance(request_data, BaseModel):
- return request_data.model_dump(exclude_none=True)
-
- dumped: dict[str, Any] = {}
- for field in _REQUEST_FIELDS:
- value = getattr(request_data, field, None)
- if value is not None:
- dumped[field] = _serialize_value(value)
- return dumped
-
-
-def _strip_unsigned_thinking_history(messages: Any) -> Any:
- """Remove assistant thinking history blocks that OpenRouter cannot replay."""
- if not isinstance(messages, list):
- return messages
-
- sanitized_messages: list[Any] = []
- for message in messages:
- if not isinstance(message, dict):
- sanitized_messages.append(message)
- continue
-
- content = message.get("content")
- if not isinstance(content, list):
- sanitized_messages.append(message)
- continue
-
- sanitized_content = [
- block
- for block in content
- if not (
- isinstance(block, dict)
- and block.get("type") in _THINKING_HISTORY_BLOCK_TYPES
- and not isinstance(block.get("signature"), str)
- )
- ]
-
- sanitized_message = dict(message)
- sanitized_message["content"] = sanitized_content or ""
- sanitized_messages.append(sanitized_message)
-
- return sanitized_messages
-
-
-def _normalize_system_prompt(system: Any) -> Any:
- """Flatten Claude SDK system blocks for OpenRouter's native endpoint."""
- if not isinstance(system, list):
- return system
-
- text_parts: list[str] = []
- for block in system:
- if not isinstance(block, dict):
- continue
- if block.get("type") == "text" and isinstance(block.get("text"), str):
- text_parts.append(block["text"])
- return "\n\n".join(text_parts).strip() if text_parts else system
-
-
-def _apply_reasoning(body: dict[str, Any], thinking_cfg: Any) -> None:
- """Map Anthropic thinking controls onto OpenRouter reasoning controls."""
- reasoning = body.setdefault("reasoning", {"enabled": True})
- if not isinstance(reasoning, dict):
- return
- reasoning.setdefault("enabled", True)
- if not isinstance(thinking_cfg, dict):
- return
- budget_tokens = thinking_cfg.get("budget_tokens")
- if isinstance(budget_tokens, int):
- reasoning.setdefault("max_tokens", budget_tokens)
+from core.anthropic.native_messages_request import (
+ OpenRouterExtraBodyError,
+ build_openrouter_native_request_body,
+)
+from providers.exceptions import InvalidRequestError
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
@@ -142,27 +24,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
len(getattr(request_data, "messages", [])),
)
- dumped_request = _dump_request_fields(request_data)
- request_extra = dumped_request.pop("extra_body", None)
- thinking_cfg = dumped_request.get("thinking")
- body = {
- key: value
- for key, value in dumped_request.items()
- if key not in _INTERNAL_FIELDS
- }
-
- if isinstance(request_extra, dict):
- body.update(request_extra)
-
- body["messages"] = _strip_unsigned_thinking_history(body.get("messages"))
- if "system" in body:
- body["system"] = _normalize_system_prompt(body["system"])
- body["stream"] = True
- if body.get("max_tokens") is None:
- body["max_tokens"] = OPENROUTER_DEFAULT_MAX_TOKENS
-
- if thinking_enabled:
- _apply_reasoning(body, thinking_cfg)
+ try:
+ body = build_openrouter_native_request_body(
+ request_data,
+ thinking_enabled=thinking_enabled,
+ default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS,
+ )
+ except OpenRouterExtraBodyError as exc:
+ raise InvalidRequestError(str(exc)) from exc
logger.debug(
"OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}",
diff --git a/providers/openai_compat.py b/providers/openai_compat.py
index d729896..31c1d7e 100644
--- a/providers/openai_compat.py
+++ b/providers/openai_compat.py
@@ -21,14 +21,40 @@ from core.anthropic import (
SSEBuilder,
ThinkTagParser,
append_request_id,
- get_user_facing_error_message,
map_stop_reason,
)
from providers.base import BaseProvider, ProviderConfig
-from providers.error_mapping import map_error
+from providers.error_mapping import (
+ map_error,
+ user_visible_message_for_mapped_provider_error,
+)
from providers.rate_limit import GlobalRateLimiter
+def _iter_heuristic_tool_use_sse(
+ sse: SSEBuilder, tool_use: dict[str, Any]
+) -> Iterator[str]:
+ """Emit SSE for one heuristic tool_use block (closes open text/thinking first)."""
+ if tool_use.get("name") == "Task" and isinstance(tool_use.get("input"), dict):
+ task_input = tool_use["input"]
+ if task_input.get("run_in_background") is not False:
+ task_input["run_in_background"] = False
+ yield from sse.close_content_blocks()
+ block_idx = sse.blocks.allocate_index()
+ yield sse.content_block_start(
+ block_idx,
+ "tool_use",
+ id=tool_use["id"],
+ name=tool_use["name"],
+ )
+ yield sse.content_block_delta(
+ block_idx,
+ "input_json_delta",
+ json.dumps(tool_use["input"]),
+ )
+ yield sse.content_block_stop(block_idx)
+
+
class OpenAIChatTransport(BaseProvider):
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …)."""
@@ -113,6 +139,22 @@ class OpenAIChatTransport(BaseProvider):
)
return stream, retry_body
+ def _emit_tool_arg_delta(
+ self, sse: SSEBuilder, tc_index: int, args: str
+ ) -> Iterator[str]:
+ """Emit one argument fragment for a started tool block (Task buffer or raw JSON)."""
+ if not args:
+ return
+ state = sse.blocks.tool_states.get(tc_index)
+ if state is None:
+ return
+ if state.name == "Task":
+ parsed = sse.blocks.buffer_task_args(tc_index, args)
+ if parsed is not None:
+ yield sse.emit_tool_delta(tc_index, json.dumps(parsed))
+ return
+ yield sse.emit_tool_delta(tc_index, args)
+
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
"""Process a single tool call delta and yield SSE events."""
tc_index = tc.get("index", 0)
@@ -121,34 +163,42 @@ class OpenAIChatTransport(BaseProvider):
fn_delta = tc.get("function", {})
incoming_name = fn_delta.get("name")
- arguments = fn_delta.get("arguments", "")
+ arguments = fn_delta.get("arguments", "") or ""
+
+ if tc.get("id") is not None:
+ sse.blocks.set_stream_tool_id(tc_index, tc.get("id"))
+
if incoming_name is not None:
sse.blocks.register_tool_name(tc_index, incoming_name)
state = sse.blocks.tool_states.get(tc_index)
+ resolved_id = (state.tool_id if state and state.tool_id else None) or tc.get(
+ "id"
+ )
+ resolved_name = (state.name if state else "") or ""
+
+ if not state or not state.started:
+ name_ok = bool((resolved_name or "").strip())
+ if name_ok:
+ tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}"
+ display_name = (resolved_name or "").strip() or "tool_call"
+ yield sse.start_tool_block(tc_index, tool_id, display_name)
+ state = sse.blocks.tool_states[tc_index]
+ if state.pre_start_args:
+ pre = state.pre_start_args
+ state.pre_start_args = ""
+ yield from self._emit_tool_arg_delta(sse, tc_index, pre)
+
+ state = sse.blocks.tool_states.get(tc_index)
+ if not arguments:
+ return
if state is None or not state.started:
- name = state.name if state else ""
- if name or tc.get("id"):
- tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
- yield sse.start_tool_block(tc_index, tool_id, name)
-
- args = arguments
- if args:
- state = sse.blocks.tool_states.get(tc_index)
- if state is None or not state.started:
- tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
- name = (state.name if state else None) or "tool_call"
- yield sse.start_tool_block(tc_index, tool_id, name)
- state = sse.blocks.tool_states.get(tc_index)
-
- current_name = state.name if state else ""
- if current_name == "Task":
- parsed = sse.blocks.buffer_task_args(tc_index, args)
- if parsed is not None:
- yield sse.emit_tool_delta(tc_index, json.dumps(parsed))
+ state = sse.blocks.ensure_tool_state(tc_index)
+ if not (resolved_name or "").strip():
+ state.pre_start_args += arguments
return
- yield sse.emit_tool_delta(tc_index, args)
+ yield from self._emit_tool_arg_delta(sse, tc_index, arguments)
def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]:
"""Emit buffered Task args as a single JSON delta (best-effort)."""
@@ -181,7 +231,12 @@ class OpenAIChatTransport(BaseProvider):
"""Shared streaming implementation."""
tag = self._provider_name
message_id = f"msg_{uuid.uuid4()}"
- sse = SSEBuilder(message_id, request.model, input_tokens)
+ sse = SSEBuilder(
+ message_id,
+ request.model,
+ input_tokens,
+ log_raw_events=self._config.log_raw_sse_events,
+ )
body = self._build_request_body(request, thinking_enabled=thinking_enabled)
thinking_enabled = self._is_thinking_enabled(request, thinking_enabled)
@@ -201,8 +256,6 @@ class OpenAIChatTransport(BaseProvider):
heuristic_parser = HeuristicToolParser()
finish_reason = None
usage_info = None
- error_occurred = False
- error_message = ""
async with self._global_rate_limiter.concurrency_slot():
try:
@@ -258,26 +311,10 @@ class OpenAIChatTransport(BaseProvider):
yield sse.emit_text_delta(filtered_text)
for tool_use in detected_tools:
- for event in sse.close_content_blocks():
- yield event
-
- block_idx = sse.blocks.allocate_index()
- if tool_use.get("name") == "Task" and isinstance(
- tool_use.get("input"), dict
+ for event in _iter_heuristic_tool_use_sse(
+ sse, tool_use
):
- tool_use["input"]["run_in_background"] = False
- yield sse.content_block_start(
- block_idx,
- "tool_use",
- id=tool_use["id"],
- name=tool_use["name"],
- )
- yield sse.content_block_delta(
- block_idx,
- "input_json_delta",
- json.dumps(tool_use["input"]),
- )
- yield sse.content_block_stop(block_idx)
+ yield event
# Handle native tool calls
if delta.tool_calls:
@@ -298,18 +335,13 @@ class OpenAIChatTransport(BaseProvider):
except asyncio.CancelledError, GeneratorExit:
raise
except Exception as e:
- logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
+ self._log_stream_transport_error(tag, req_tag, e)
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)
- error_occurred = True
- if getattr(mapped_e, "status_code", None) == 405:
- base_message = (
- f"Upstream provider {tag} rejected the request method "
- "or endpoint (HTTP 405)."
- )
- else:
- base_message = get_user_facing_error_message(
- mapped_e, read_timeout_s=self._config.http_read_timeout
- )
+ base_message = user_visible_message_for_mapped_provider_error(
+ mapped_e,
+ provider_name=tag,
+ read_timeout_s=self._config.http_read_timeout,
+ )
error_message = append_request_id(base_message, request_id)
logger.info(
"{}_STREAM: Emitting SSE error event for {}{}",
@@ -317,10 +349,13 @@ class OpenAIChatTransport(BaseProvider):
type(e).__name__,
req_tag,
)
- for event in sse.close_content_blocks():
+ for event in sse.close_all_blocks():
yield event
for event in sse.emit_error(error_message):
yield event
+ yield sse.message_delta("end_turn", 1)
+ yield sse.message_stop()
+ return
# Flush remaining content
remaining = think_parser.flush()
@@ -338,32 +373,16 @@ class OpenAIChatTransport(BaseProvider):
yield sse.emit_text_delta(remaining.content)
for tool_use in heuristic_parser.flush():
- for event in sse.close_content_blocks():
+ for event in _iter_heuristic_tool_use_sse(sse, tool_use):
yield event
- block_idx = sse.blocks.allocate_index()
- yield sse.content_block_start(
- block_idx,
- "tool_use",
- id=tool_use["id"],
- name=tool_use["name"],
- )
- if tool_use.get("name") == "Task" and isinstance(
- tool_use.get("input"), dict
- ):
- tool_use["input"]["run_in_background"] = False
- yield sse.content_block_delta(
- block_idx,
- "input_json_delta",
- json.dumps(tool_use["input"]),
- )
- yield sse.content_block_stop(block_idx)
-
- if (
- not error_occurred
- and sse.blocks.text_index == -1
- and not sse.blocks.tool_states
- ):
+ has_started_tool = any(s.started for s in sse.blocks.tool_states.values())
+ has_content_blocks = (
+ sse.blocks.text_index != -1
+ or sse.blocks.thinking_index != -1
+ or has_started_tool
+ )
+ if not has_content_blocks:
for event in sse.ensure_text_block():
yield event
yield sse.emit_text_delta(" ")
diff --git a/providers/rate_limit.py b/providers/rate_limit.py
index 0e427e9..59e4291 100644
--- a/providers/rate_limit.py
+++ b/providers/rate_limit.py
@@ -3,14 +3,16 @@
import asyncio
import random
import time
-from collections import deque
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from typing import Any, ClassVar, TypeVar
+import httpx
import openai
from loguru import logger
+from core.rate_limit import StrictSlidingWindowLimiter
+
T = TypeVar("T")
@@ -51,10 +53,10 @@ class GlobalRateLimiter:
self._rate_limit = rate_limit
self._rate_window = float(rate_window)
self._max_concurrency = max_concurrency
- # Monotonic timestamps of the last granted slots.
- self._request_times: deque[float] = deque()
+ self._proactive_limiter = StrictSlidingWindowLimiter(
+ self._rate_limit, self._rate_window
+ )
self._blocked_until: float = 0
- self._lock = asyncio.Lock()
self._concurrency_sem = asyncio.Semaphore(max_concurrency)
self._initialized = True
@@ -107,12 +109,11 @@ class GlobalRateLimiter:
logger.info(
"Rebuilding provider rate limiter for updated scope '{}'", scope
)
- if scope not in cls._scoped_instances or existing:
- cls._scoped_instances[scope] = cls(
- rate_limit=desired_rate_limit,
- rate_window=desired_rate_window,
- max_concurrency=max_concurrency,
- )
+ cls._scoped_instances[scope] = cls(
+ rate_limit=desired_rate_limit,
+ rate_window=desired_rate_window,
+ max_concurrency=max_concurrency,
+ )
return cls._scoped_instances[scope]
@classmethod
@@ -150,27 +151,7 @@ class GlobalRateLimiter:
Guarantees: at most `self._rate_limit` acquisitions in any interval of length
`self._rate_window` (seconds).
"""
- while True:
- wait_time = 0.0
- async with self._lock:
- now = time.monotonic()
- cutoff = now - self._rate_window
-
- while self._request_times and self._request_times[0] <= cutoff:
- self._request_times.popleft()
-
- if len(self._request_times) < self._rate_limit:
- self._request_times.append(now)
- return
-
- oldest = self._request_times[0]
- wait_time = max(0.0, (oldest + self._rate_window) - now)
-
- # Sleep outside the lock so other tasks can continue to queue.
- if wait_time > 0:
- await asyncio.sleep(wait_time)
- else:
- await asyncio.sleep(0)
+ await self._proactive_limiter.acquire()
def set_blocked(self, seconds: float = 60) -> None:
"""
@@ -263,6 +244,24 @@ class GlobalRateLimiter:
)
self.set_blocked(delay)
await asyncio.sleep(delay)
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code != 429:
+ raise
+ last_exc = e
+ if attempt >= max_retries:
+ logger.warning(
+ f"HTTP 429 retry exhausted after {max_retries} retries"
+ )
+ break
+
+ delay = min(base_delay * (2**attempt), max_delay)
+ delay += random.uniform(0, jitter)
+ logger.warning(
+ f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. "
+ f"Retrying in {delay:.1f}s..."
+ )
+ self.set_blocked(delay)
+ await asyncio.sleep(delay)
assert last_exc is not None
raise last_exc
diff --git a/providers/registry.py b/providers/registry.py
index 75a6d8c..20ef4a0 100644
--- a/providers/registry.py
+++ b/providers/registry.py
@@ -3,108 +3,20 @@
from __future__ import annotations
from collections.abc import Callable, MutableMapping
-from dataclasses import dataclass
-from typing import Literal
-from config.provider_ids import SUPPORTED_PROVIDER_IDS
+from config.provider_catalog import (
+ PROVIDER_CATALOG,
+ SUPPORTED_PROVIDER_IDS,
+ ProviderDescriptor,
+)
from config.settings import Settings
from providers.base import BaseProvider, ProviderConfig
-from providers.defaults import (
- DEEPSEEK_DEFAULT_BASE,
- LLAMACPP_DEFAULT_BASE,
- LMSTUDIO_DEFAULT_BASE,
- NVIDIA_NIM_DEFAULT_BASE,
- OLLAMA_DEFAULT_BASE,
- OPENROUTER_DEFAULT_BASE,
-)
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
-TransportType = Literal["openai_chat", "anthropic_messages"]
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
-
-@dataclass(frozen=True, slots=True)
-class ProviderDescriptor:
- """Metadata for building :class:`ProviderConfig` and factory wiring."""
-
- provider_id: str
- transport_type: TransportType
- capabilities: tuple[str, ...]
- credential_env: str | None = None
- credential_url: str | None = None
- # If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
- credential_attr: str | None = None
- # If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
- static_credential: str | None = None
- default_base_url: str | None = None
- base_url_attr: str | None = None
- proxy_attr: str | None = None
-
-
-PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
- "nvidia_nim": ProviderDescriptor(
- provider_id="nvidia_nim",
- transport_type="openai_chat",
- credential_env="NVIDIA_NIM_API_KEY",
- credential_url="https://build.nvidia.com/settings/api-keys",
- credential_attr="nvidia_nim_api_key",
- default_base_url=NVIDIA_NIM_DEFAULT_BASE,
- proxy_attr="nvidia_nim_proxy",
- capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
- ),
- "open_router": ProviderDescriptor(
- provider_id="open_router",
- transport_type="anthropic_messages",
- credential_env="OPENROUTER_API_KEY",
- credential_url="https://openrouter.ai/keys",
- credential_attr="open_router_api_key",
- default_base_url=OPENROUTER_DEFAULT_BASE,
- proxy_attr="open_router_proxy",
- capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
- ),
- "deepseek": ProviderDescriptor(
- provider_id="deepseek",
- transport_type="openai_chat",
- credential_env="DEEPSEEK_API_KEY",
- credential_url="https://platform.deepseek.com/api_keys",
- credential_attr="deepseek_api_key",
- default_base_url=DEEPSEEK_DEFAULT_BASE,
- capabilities=("chat", "streaming", "thinking"),
- ),
- "lmstudio": ProviderDescriptor(
- provider_id="lmstudio",
- transport_type="anthropic_messages",
- static_credential="lm-studio",
- default_base_url=LMSTUDIO_DEFAULT_BASE,
- base_url_attr="lm_studio_base_url",
- proxy_attr="lmstudio_proxy",
- capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
- ),
- "llamacpp": ProviderDescriptor(
- provider_id="llamacpp",
- transport_type="anthropic_messages",
- static_credential="llamacpp",
- default_base_url=LLAMACPP_DEFAULT_BASE,
- base_url_attr="llamacpp_base_url",
- proxy_attr="llamacpp_proxy",
- capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
- ),
- "ollama": ProviderDescriptor(
- provider_id="ollama",
- transport_type="anthropic_messages",
- static_credential="ollama",
- default_base_url=OLLAMA_DEFAULT_BASE,
- base_url_attr="ollama_base_url",
- capabilities=(
- "chat",
- "streaming",
- "tools",
- "thinking",
- "native_anthropic",
- "local",
- ),
- ),
-}
+# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
+PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
@@ -119,25 +31,25 @@ def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProv
return OpenRouterProvider(config)
-def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
+def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.deepseek import DeepSeekProvider
return DeepSeekProvider(config)
-def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
+def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.lmstudio import LMStudioProvider
return LMStudioProvider(config)
-def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
+def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.llamacpp import LlamaCppProvider
return LlamaCppProvider(config)
-def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider:
+def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.ollama import OllamaProvider
return OllamaProvider(config)
@@ -208,6 +120,8 @@ def build_provider_config(
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_model_thinking,
proxy=proxy,
+ log_raw_sse_events=settings.log_raw_sse_events,
+ log_api_error_tracebacks=settings.log_api_error_tracebacks,
)
diff --git a/server.py b/server.py
index b025be0..958e730 100644
--- a/server.py
+++ b/server.py
@@ -1,11 +1,13 @@
"""
Claude Code Proxy - Entry Point
-Minimal entry point that imports the app from the api module.
+Minimal entry point that builds the ASGI app via :func:`api.app.create_app`.
Run with: uv run uvicorn server:app --host 0.0.0.0 --port 8082 --timeout-graceful-shutdown 5
"""
-from api.app import app, create_app
+from api.app import create_app
+
+app = create_app()
__all__ = ["app", "create_app"]
diff --git a/smoke/capabilities.py b/smoke/capabilities.py
index 5f29fff..eafae44 100644
--- a/smoke/capabilities.py
+++ b/smoke/capabilities.py
@@ -98,7 +98,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
"providers.registry.ProviderRegistry",
"provider id and Settings",
"configured BaseProvider instance",
- "503 for missing credentials; ValueError for unknown provider",
+ "503 for missing credentials; invalid_request_error for unknown provider",
("tests/api/test_dependencies.py", "tests/providers/test_registry.py"),
("test_configured_provider_models_stream_successfully",),
),
diff --git a/smoke/lib/e2e.py b/smoke/lib/e2e.py
index e669b51..8b20e80 100644
--- a/smoke/lib/e2e.py
+++ b/smoke/lib/e2e.py
@@ -19,6 +19,14 @@ import httpx
import pytest
from config.provider_ids import SUPPORTED_PROVIDER_IDS
+from core.anthropic.stream_contracts import (
+ SSEEvent,
+ assert_anthropic_stream_contract,
+ event_index,
+ has_tool_use,
+ parse_sse_lines,
+ text_content,
+)
from messaging.handler import ClaudeMessageHandler
from messaging.models import IncomingMessage
from messaging.platforms.base import MessagingPlatform
@@ -26,13 +34,6 @@ from messaging.session import SessionStore
from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers
from smoke.lib.server import RunningServer, start_server
from smoke.lib.skips import fail_missing_env
-from smoke.lib.sse import (
- SSEEvent,
- assert_anthropic_stream_contract,
- has_tool_use,
- parse_sse_lines,
- text_content,
-)
@dataclass(slots=True)
@@ -590,14 +591,14 @@ def assistant_content_from_events(events: list[SSEEvent]) -> list[dict[str, Any]
block_order: list[int] = []
for event in events:
if event.event == "content_block_start":
- index = _event_index(event)
+ index = event_index(event)
block = event.data.get("content_block", {})
if isinstance(block, dict):
blocks[index] = dict(block)
block_order.append(index)
continue
if event.event == "content_block_delta":
- index = _event_index(event)
+ index = event_index(event)
block = blocks.get(index)
delta = event.data.get("delta", {})
if not isinstance(block, dict) or not isinstance(delta, dict):
@@ -663,12 +664,6 @@ def default_cli_events(session_id: str) -> list[dict[str, Any]]:
]
-def _event_index(event: SSEEvent) -> int:
- value = event.data.get("index")
- assert isinstance(value, int), event.data
- return value
-
-
def assert_product_stream(events: list[SSEEvent]) -> None:
assert_anthropic_stream_contract(events)
assert text_content(events).strip() or has_tool_use(events), (
diff --git a/smoke/lib/http.py b/smoke/lib/http.py
index f019473..71e2cb0 100644
--- a/smoke/lib/http.py
+++ b/smoke/lib/http.py
@@ -6,9 +6,10 @@ from typing import Any
import httpx
+from core.anthropic.stream_contracts import SSEEvent, parse_sse_lines
+
from .config import SmokeConfig, auth_headers, redacted
from .server import RunningServer
-from .sse import SSEEvent, parse_sse_lines
def message_payload(
diff --git a/smoke/lib/skips.py b/smoke/lib/skips.py
index ffadb85..1cb1e29 100644
--- a/smoke/lib/skips.py
+++ b/smoke/lib/skips.py
@@ -5,7 +5,7 @@ from __future__ import annotations
import httpx
import pytest
-from smoke.lib.sse import SSEEvent
+from core.anthropic.stream_contracts import SSEEvent
UPSTREAM_UNAVAILABLE_MARKERS = (
"connection refused",
diff --git a/smoke/lib/sse.py b/smoke/lib/sse.py
deleted file mode 100644
index a294ce5..0000000
--- a/smoke/lib/sse.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""SSE parsing and Anthropic stream assertions for smoke tests.
-
-Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this
-module re-exports it for smoke and historical import paths.
-"""
-
-from __future__ import annotations
-
-from core.anthropic.stream_contracts import (
- SSEEvent,
- assert_anthropic_stream_contract,
- event_names,
- has_tool_use,
- parse_sse_lines,
- parse_sse_text,
- text_content,
- thinking_content,
-)
-
-__all__ = [
- "SSEEvent",
- "assert_anthropic_stream_contract",
- "event_names",
- "has_tool_use",
- "parse_sse_lines",
- "parse_sse_text",
- "text_content",
- "thinking_content",
-]
diff --git a/smoke/prereq/test_provider_prereq_live.py b/smoke/prereq/test_provider_prereq_live.py
index 8731278..0d85356 100644
--- a/smoke/prereq/test_provider_prereq_live.py
+++ b/smoke/prereq/test_provider_prereq_live.py
@@ -5,6 +5,10 @@ import time
import httpx
import pytest
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ text_content,
+)
from smoke.lib.config import SmokeConfig, auth_headers
from smoke.lib.http import collect_message_stream, message_payload
from smoke.lib.server import start_server
@@ -12,7 +16,6 @@ from smoke.lib.skips import (
skip_if_upstream_unavailable_events,
skip_if_upstream_unavailable_exception,
)
-from smoke.lib.sse import assert_anthropic_stream_contract, text_content
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")]
diff --git a/smoke/prereq/test_tools_prereq_live.py b/smoke/prereq/test_tools_prereq_live.py
index 824197b..716678c 100644
--- a/smoke/prereq/test_tools_prereq_live.py
+++ b/smoke/prereq/test_tools_prereq_live.py
@@ -2,11 +2,14 @@ from __future__ import annotations
import pytest
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ has_tool_use,
+)
from smoke.lib.config import SmokeConfig
from smoke.lib.http import collect_message_stream, message_payload
from smoke.lib.server import start_server
from smoke.lib.skips import skip_if_upstream_unavailable_events
-from smoke.lib.sse import assert_anthropic_stream_contract, has_tool_use
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("tools")]
diff --git a/smoke/prereq/test_voice_prereq_live.py b/smoke/prereq/test_voice_prereq_live.py
index 21e4dcb..50ce9e5 100644
--- a/smoke/prereq/test_voice_prereq_live.py
+++ b/smoke/prereq/test_voice_prereq_live.py
@@ -24,12 +24,13 @@ def test_voice_transcription_backend_when_explicitly_enabled(
wav_path = tmp_path / "smoke-tone.wav"
_write_tone_wav(wav_path)
try:
- text = transcribe_audio(
- wav_path,
- "audio/wav",
- whisper_model=smoke_config.settings.whisper_model,
- whisper_device=smoke_config.settings.whisper_device,
- )
+ t_kw: dict[str, str] = {
+ "whisper_model": smoke_config.settings.whisper_model,
+ "whisper_device": smoke_config.settings.whisper_device,
+ }
+ if smoke_config.settings.whisper_device == "nvidia_nim":
+ t_kw["nvidia_nim_api_key"] = smoke_config.settings.nvidia_nim_api_key
+ text = transcribe_audio(wav_path, "audio/wav", **t_kw)
except ImportError as exc:
pytest.skip(str(exc))
assert isinstance(text, str)
diff --git a/smoke/product/test_provider_product_live.py b/smoke/product/test_provider_product_live.py
index ac61839..be5d45e 100644
--- a/smoke/product/test_provider_product_live.py
+++ b/smoke/product/test_provider_product_live.py
@@ -3,6 +3,11 @@ from __future__ import annotations
import httpx
import pytest
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ parse_sse_lines,
+ text_content,
+)
from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers
from smoke.lib.e2e import (
ConversationDriver,
@@ -16,11 +21,6 @@ from smoke.lib.skips import (
skip_if_upstream_unavailable_events,
skip_if_upstream_unavailable_exception,
)
-from smoke.lib.sse import (
- assert_anthropic_stream_contract,
- parse_sse_lines,
- text_content,
-)
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")]
diff --git a/smoke/product/test_voice_product_live.py b/smoke/product/test_voice_product_live.py
index d1c27f5..0327168 100644
--- a/smoke/product/test_voice_product_live.py
+++ b/smoke/product/test_voice_product_live.py
@@ -55,6 +55,7 @@ def test_voice_nim_backend_e2e(smoke_config: SmokeConfig, tmp_path: Path) -> Non
"audio/wav",
whisper_model=smoke_config.settings.whisper_model,
whisper_device="nvidia_nim",
+ nvidia_nim_api_key=smoke_config.settings.nvidia_nim_api_key,
)
assert isinstance(text, str)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..10436b4
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+"""Test suite package."""
diff --git a/tests/api/test_anthropic_request_passthrough.py b/tests/api/test_anthropic_request_passthrough.py
new file mode 100644
index 0000000..15ee3db
--- /dev/null
+++ b/tests/api/test_anthropic_request_passthrough.py
@@ -0,0 +1,184 @@
+"""Pydantic passthrough of Anthropic protocol fields (e.g. ``cache_control``)."""
+
+from __future__ import annotations
+
+from api.models.anthropic import (
+ ContentBlockServerToolUse,
+ ContentBlockText,
+ ContentBlockWebSearchToolResult,
+ Message,
+ MessagesRequest,
+ SystemContent,
+ Tool,
+)
+from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
+from core.anthropic.native_messages_request import (
+ build_base_native_anthropic_request_body,
+)
+
+
+def test_cache_control_preserved_on_parsed_user_text_system_and_tool() -> None:
+ raw = {
+ "model": "m",
+ "max_tokens": 20,
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "hi",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ }
+ ],
+ "system": [
+ {
+ "type": "text",
+ "text": "be brief",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ "tools": [
+ {
+ "name": "alpha",
+ "input_schema": {"type": "object"},
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ }
+ req = MessagesRequest.model_validate(raw)
+ block = req.messages[0].content[0]
+ assert isinstance(block, ContentBlockText)
+ assert block.model_dump()["cache_control"] == {"type": "ephemeral"}
+ s0 = req.system[0] if isinstance(req.system, list) else None
+ assert isinstance(s0, SystemContent)
+ assert s0.model_dump()["cache_control"] == {"type": "ephemeral"}
+ t0 = req.tools[0] if req.tools else None
+ assert isinstance(t0, Tool)
+ assert t0.model_dump()["cache_control"] == {"type": "ephemeral"}
+
+
+def test_build_base_native_body_includes_cache_control() -> None:
+ req = MessagesRequest(
+ model="m",
+ max_tokens=20,
+ messages=[
+ Message(
+ role="user",
+ content=[
+ ContentBlockText.model_validate(
+ {
+ "type": "text",
+ "text": "x",
+ "cache_control": {"type": "ephemeral"},
+ }
+ )
+ ],
+ )
+ ],
+ system=[
+ SystemContent.model_validate(
+ {
+ "type": "text",
+ "text": "s",
+ "cache_control": {"type": "ephemeral"},
+ }
+ )
+ ],
+ tools=[
+ Tool.model_validate(
+ {
+ "name": "n",
+ "input_schema": {"type": "object"},
+ "cache_control": {"type": "ephemeral"},
+ }
+ )
+ ],
+ )
+ body = build_base_native_anthropic_request_body(
+ req,
+ default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ thinking_enabled=False,
+ )
+ assert body["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"}
+ assert body["system"][0]["cache_control"] == {"type": "ephemeral"}
+ assert body["tools"][0]["cache_control"] == {"type": "ephemeral"}
+
+
+def test_pydantic_discriminator_still_distinguishes_blocks() -> None:
+ m = Message.model_validate(
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "a", "z": 1}],
+ }
+ )
+ b = m.content[0]
+ assert isinstance(b, ContentBlockText)
+ assert b.model_dump()["z"] == 1
+
+
+def test_server_tool_assistant_blocks_round_trip_in_native_body() -> None:
+ """Local server-tool responses must parse as valid history for a follow-up request."""
+ raw = {
+ "model": "m",
+ "max_tokens": 20,
+ "messages": [
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "server_tool_use",
+ "id": "srvtoolu_1",
+ "name": "web_search",
+ "input": {"query": "q"},
+ },
+ {
+ "type": "web_search_tool_result",
+ "tool_use_id": "srvtoolu_1",
+ "content": [
+ {
+ "type": "web_search_result",
+ "title": "T",
+ "url": "https://example.com",
+ }
+ ],
+ },
+ ],
+ }
+ ],
+ "mcp_servers": [{"type": "url", "url": "https://example.com/mcp"}],
+ }
+ req = MessagesRequest.model_validate(raw)
+ assert len(req.messages) == 1
+ blocks = req.messages[0].content
+ assert isinstance(blocks, list)
+ assert isinstance(blocks[0], ContentBlockServerToolUse)
+ assert isinstance(blocks[1], ContentBlockWebSearchToolResult)
+ body = build_base_native_anthropic_request_body(
+ req,
+ default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ thinking_enabled=False,
+ )
+ assert body["mcp_servers"][0]["type"] == "url"
+ assert body["messages"][0]["content"][0]["type"] == "server_tool_use"
+ assert body["messages"][0]["content"][1]["type"] == "web_search_tool_result"
+
+
+def test_native_body_preserves_context_and_output_config() -> None:
+ raw = {
+ "model": "m",
+ "max_tokens": 20,
+ "messages": [{"role": "user", "content": "x"}],
+ "context_management": {"edits": [{"type": "clear"}]},
+ "output_config": {"some": "hint"},
+ }
+ req = MessagesRequest.model_validate(raw)
+ body = build_base_native_anthropic_request_body(
+ req,
+ default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ thinking_enabled=False,
+ )
+ assert body["context_management"] == raw["context_management"]
+ assert body["output_config"] == raw["output_config"]
diff --git a/tests/api/test_api.py b/tests/api/test_api.py
index e4912e7..b387048 100644
--- a/tests/api/test_api.py
+++ b/tests/api/test_api.py
@@ -3,9 +3,11 @@ from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
-from api.app import app
+from api.app import create_app
from providers.nvidia_nim import NvidiaNimProvider
+app = create_app()
+
# Mock provider
mock_provider = MagicMock(spec=NvidiaNimProvider)
@@ -171,7 +173,7 @@ def test_generic_exception_returns_500(client: TestClient):
def test_generic_exception_with_status_code(client: TestClient):
- """Generic exception with status_code attribute uses that status (getattr fallback)."""
+ """Unexpected errors always map to HTTP 500 (ignore ad-hoc status_code attrs)."""
class ExceptionWithStatus(RuntimeError):
def __init__(self, msg: str, status_code: int = 500):
@@ -191,7 +193,7 @@ def test_generic_exception_with_status_code(client: TestClient):
"stream": True,
},
)
- assert response.status_code == 502
+ assert response.status_code == 500
mock_provider.stream_response = _mock_stream_response
diff --git a/tests/api/test_app_lifespan_and_errors.py b/tests/api/test_app_lifespan_and_errors.py
index 126a869..adc81c5 100644
--- a/tests/api/test_app_lifespan_and_errors.py
+++ b/tests/api/test_app_lifespan_and_errors.py
@@ -17,6 +17,15 @@ _RUNTIME_EXTRAS = {
"nvidia_nim_api_key": "",
"claude_cli_bin": "claude",
"uses_process_anthropic_auth_token": lambda: False,
+ "messaging_rate_limit": 1,
+ "messaging_rate_window": 1.0,
+ "max_message_log_entries_per_chat": None,
+ "debug_platform_edits": False,
+ "debug_subagent_stack": False,
+ "log_api_error_tracebacks": False,
+ "log_raw_messaging_content": False,
+ "log_raw_cli_diagnostics": False,
+ "log_messaging_error_details": False,
}
@@ -81,9 +90,50 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
with TestClient(app) as client:
resp = client.get("/raise_provider")
assert resp.status_code == 401
- body = resp.json()
- assert body["type"] == "error"
- assert body["error"]["type"] == "authentication_error"
+ body = resp.json()
+ assert body["type"] == "error"
+ assert body["error"]["type"] == "authentication_error"
+
+
+def test_create_app_provider_error_default_logs_exclude_provider_message():
+ """Provider errors must not log exc.message by default."""
+ from api.app import create_app
+ from providers.exceptions import AuthenticationError
+
+ app = create_app()
+ secret = "provider-upstream-secret-detail"
+
+ @app.get("/raise_provider_secret")
+ async def _raise():
+ raise AuthenticationError(secret)
+
+ api_app_mod = importlib.import_module("api.app")
+ settings = _app_settings(
+ messaging_platform="telegram",
+ telegram_bot_token=None,
+ allowed_telegram_user_id=None,
+ discord_bot_token=None,
+ allowed_discord_channels=None,
+ allowed_dir="",
+ claude_workspace="./agent_workspace",
+ host="127.0.0.1",
+ port=8082,
+ log_file="server.log",
+ log_api_error_tracebacks=False,
+ )
+ with (
+ patch.object(api_app_mod, "get_settings", return_value=settings),
+ patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
+ patch.object(api_app_mod.logger, "error") as log_err,
+ ):
+ with TestClient(app) as client:
+ resp = client.get("/raise_provider_secret")
+ assert resp.status_code == 401
+
+ blob = " ".join(str(a) for c in log_err.call_args_list for a in c.args)
+ blob += repr([c.kwargs for c in log_err.call_args_list])
+ assert secret not in blob
+ assert "authentication_error" in blob
def test_create_app_general_exception_handler_returns_500():
@@ -120,6 +170,50 @@ def test_create_app_general_exception_handler_returns_500():
assert body["error"]["type"] == "api_error"
+def test_create_app_general_exception_default_logs_exclude_exception_message():
+ """Unhandled errors must not log exception text by default (may echo user content)."""
+ from api.app import create_app
+
+ app = create_app()
+
+ secret = "user-provided-secret-token-xyzzy"
+
+ @app.get("/raise_secret")
+ async def _raise_secret():
+ raise ValueError(secret)
+
+ api_app_mod = importlib.import_module("api.app")
+ settings = _app_settings(
+ messaging_platform="telegram",
+ telegram_bot_token=None,
+ allowed_telegram_user_id=None,
+ discord_bot_token=None,
+ allowed_discord_channels=None,
+ allowed_dir="",
+ claude_workspace="./agent_workspace",
+ host="127.0.0.1",
+ port=8082,
+ log_file="server.log",
+ log_api_error_tracebacks=False,
+ )
+ with (
+ patch.object(api_app_mod, "get_settings", return_value=settings),
+ patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
+ patch.object(api_app_mod.logger, "error") as log_err,
+ ):
+ with TestClient(app, raise_server_exceptions=False) as client:
+ resp = client.get("/raise_secret")
+ assert resp.status_code == 500
+
+ flattened: list[str] = []
+ for call in log_err.call_args_list:
+ flattened.extend(str(arg) for arg in call.args)
+ flattened.append(repr(call.kwargs))
+ blob = " ".join(flattened)
+ assert secret not in blob
+ assert "ValueError" in blob
+
+
@pytest.mark.parametrize(
"messaging_enabled", [True, False], ids=["with_platform", "no_platform"]
)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index e661873..99230a3 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -2,10 +2,12 @@ from unittest.mock import patch
from fastapi.testclient import TestClient
-from api.app import app
+from api.app import create_app
from api.dependencies import get_settings
from config.settings import Settings
+app = create_app()
+
def test_anthropic_auth_token_required_and_accepts_x_api_key():
client = TestClient(app)
diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py
index 2849667..5be4c26 100644
--- a/tests/api/test_dependencies.py
+++ b/tests/api/test_dependencies.py
@@ -16,7 +16,7 @@ from api.dependencies import (
)
from config.nim import NimSettings
from providers.deepseek import DeepSeekProvider
-from providers.exceptions import UnknownProviderTypeError
+from providers.exceptions import ServiceUnavailableError, UnknownProviderTypeError
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NvidiaNimProvider
from providers.ollama import OllamaProvider
@@ -40,7 +40,7 @@ def _make_mock_settings(**overrides):
mock.nim = NimSettings()
mock.http_read_timeout = 300.0
mock.http_write_timeout = 10.0
- mock.http_connect_timeout = 2.0
+ mock.http_connect_timeout = 10.0
mock.enable_model_thinking = True
for key, value in overrides.items():
setattr(mock, key, value)
@@ -434,18 +434,17 @@ def test_resolve_provider_per_app_uses_separate_registries() -> None:
assert p1 is not p2
-def test_resolve_provider_lazily_installs_registry() -> None:
- """If app has no provider_registry, one is created on app.state."""
+def test_resolve_provider_missing_registry_raises_service_unavailable() -> None:
+ """HTTP apps must install app.state.provider_registry (e.g. via AppRuntime)."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
settings = _make_mock_settings()
app = SimpleNamespace(state=State())
assert getattr(app.state, "provider_registry", None) is None
- resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
- reg = app.state.provider_registry
- assert reg is not None
- p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
- assert p2 is reg.get("nvidia_nim", settings) # same registry instance
+ with pytest.raises(
+ ServiceUnavailableError, match="Provider registry is not configured"
+ ):
+ resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None:
diff --git a/tests/api/test_routes_optimizations.py b/tests/api/test_routes_optimizations.py
index e0c058c..eb8d2dd 100644
--- a/tests/api/test_routes_optimizations.py
+++ b/tests/api/test_routes_optimizations.py
@@ -3,10 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
-from api.app import app
+from api.app import create_app
from api.dependencies import get_settings
from config.settings import Settings
+app = create_app()
+
@pytest.fixture
def client():
@@ -103,6 +105,17 @@ def test_create_message_empty_messages_returns_400(client):
assert "cannot be empty" in data.get("error", {}).get("message", "")
+def test_count_tokens_empty_messages_returns_400(client):
+ """POST /v1/messages/count_tokens with messages: [] matches messages validation."""
+ payload = {"model": "claude-3-sonnet", "messages": []}
+ response = client.post("/v1/messages/count_tokens", json=payload)
+ assert response.status_code == 400
+ data = response.json()
+ assert data.get("type") == "error"
+ assert data.get("error", {}).get("type") == "invalid_request_error"
+ assert "cannot be empty" in data.get("error", {}).get("message", "")
+
+
def test_count_tokens_endpoint(client):
payload = {
"model": "claude-3-sonnet",
diff --git a/tests/api/test_runtime_safe_logging.py b/tests/api/test_runtime_safe_logging.py
new file mode 100644
index 0000000..5c86c15
--- /dev/null
+++ b/tests/api/test_runtime_safe_logging.py
@@ -0,0 +1,70 @@
+"""Tests for safe default logging in :mod:`api.runtime`."""
+
+import importlib
+import logging
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from tests.api.test_app_lifespan_and_errors import _app_settings
+
+
+@pytest.mark.asyncio
+async def test_messaging_start_failure_default_logs_exclude_traceback(caplog):
+ api_runtime_mod = importlib.import_module("api.runtime")
+ settings = _app_settings(
+ messaging_platform="telegram",
+ telegram_bot_token="t",
+ allowed_telegram_user_id="1",
+ discord_bot_token=None,
+ allowed_discord_channels=None,
+ allowed_dir="",
+ claude_workspace="./agent_workspace",
+ host="127.0.0.1",
+ port=8082,
+ log_file="server.log",
+ log_api_error_tracebacks=False,
+ )
+ runtime = api_runtime_mod.AppRuntime(app=MagicMock(), settings=settings)
+
+ with (
+ patch(
+ "messaging.platforms.factory.create_messaging_platform",
+ side_effect=RuntimeError("SECRET_RUNTIME_DETAIL"),
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ await runtime._start_messaging_if_configured()
+
+ blob = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_RUNTIME_DETAIL" not in blob
+ assert "exc_type=RuntimeError" in blob
+
+
+@pytest.mark.asyncio
+async def test_best_effort_default_logs_exclude_exception_text(caplog):
+ api_runtime_mod = importlib.import_module("api.runtime")
+
+ async def boom():
+ raise ValueError("SECRET_SHUTDOWN")
+
+ with caplog.at_level(logging.WARNING):
+ await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=False)
+
+ blob = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_SHUTDOWN" not in blob
+ assert "exc_type=ValueError" in blob
+
+
+@pytest.mark.asyncio
+async def test_best_effort_verbose_includes_exception_text(caplog):
+ api_runtime_mod = importlib.import_module("api.runtime")
+
+ async def boom():
+ raise ValueError("VISIBLE_SHUTDOWN")
+
+ with caplog.at_level(logging.WARNING):
+ await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=True)
+
+ blob = " | ".join(r.getMessage() for r in caplog.records)
+ assert "VISIBLE_SHUTDOWN" in blob
diff --git a/tests/api/test_safe_logging.py b/tests/api/test_safe_logging.py
new file mode 100644
index 0000000..1096516
--- /dev/null
+++ b/tests/api/test_safe_logging.py
@@ -0,0 +1,208 @@
+"""Tests that API and SSE logging avoid raw sensitive payloads by default."""
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from fastapi import HTTPException
+
+from api import services as services_mod
+from api.models.anthropic import Message, MessagesRequest
+from api.services import ClaudeProxyService
+from config.settings import Settings
+from core.anthropic.sse import SSEBuilder
+
+
+def test_create_message_skips_full_payload_debug_log_by_default():
+ settings = Settings()
+ assert settings.log_raw_api_payloads is False
+ mock_provider = MagicMock()
+
+ async def fake_stream(*_a, **_kw):
+ yield "event: ping\ndata: {}\n\n"
+
+ mock_provider.stream_response = fake_stream
+ service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
+
+ request = MessagesRequest(
+ model="claude-3-haiku-20240307",
+ max_tokens=10,
+ messages=[Message(role="user", content="secret-user-text")],
+ )
+
+ with patch.object(services_mod.logger, "debug") as mock_debug:
+ service.create_message(request)
+
+ full_payload_calls = [
+ c
+ for c in mock_debug.call_args_list
+ if c.args and str(c.args[0]) == "FULL_PAYLOAD [{}]: {}"
+ ]
+ assert not full_payload_calls
+
+
+def test_create_message_logs_full_payload_when_opt_in():
+ settings = Settings()
+ settings.log_raw_api_payloads = True
+ mock_provider = MagicMock()
+
+ async def fake_stream(*_a, **_kw):
+ yield "event: ping\ndata: {}\n\n"
+
+ mock_provider.stream_response = fake_stream
+ service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
+ request = MessagesRequest(
+ model="claude-3-haiku-20240307",
+ max_tokens=10,
+ messages=[Message(role="user", content="visible")],
+ )
+
+ with patch.object(services_mod.logger, "debug") as mock_debug:
+ service.create_message(request)
+
+ keys = [c.args[0] for c in mock_debug.call_args_list if c.args]
+ assert any(k == "FULL_PAYLOAD [{}]: {}" for k in keys)
+
+
+def test_sse_builder_default_debug_has_no_serialized_json_content():
+ with patch("core.anthropic.sse.logger.debug") as mock_debug:
+ sse = SSEBuilder("msg_x", "m", 1, log_raw_events=False)
+ sse.message_start()
+
+ assert mock_debug.call_count == 1
+ message = str(mock_debug.call_args)
+ assert "serialized_bytes=" in message
+ assert "role" not in message
+ assert "assistant" not in message
+
+
+def test_sse_builder_raw_logging_includes_event_body_when_enabled():
+ with patch("core.anthropic.sse.logger.debug") as mock_debug:
+ sse = SSEBuilder("msg_x", "m", 1, log_raw_events=True)
+ sse.message_start()
+
+ assert mock_debug.call_count == 1
+ message = str(mock_debug.call_args)
+ assert "message_start" in message
+ assert "role" in message
+
+
+def _flatten_log_calls(mock_log) -> str:
+ parts: list[str] = []
+ for call in mock_log.call_args_list:
+ parts.extend(str(arg) for arg in call.args)
+ parts.append(repr(call.kwargs))
+ return " ".join(parts)
+
+
+def test_create_message_unexpected_error_default_logs_exclude_exception_text():
+ settings = Settings()
+ assert settings.log_api_error_tracebacks is False
+ secret = "upstream-secret-token-abc"
+
+ mock_provider = MagicMock()
+
+ def stream_boom(*_a, **_kw):
+ raise RuntimeError(secret)
+
+ mock_provider.stream_response = stream_boom
+ service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
+ request = MessagesRequest(
+ model="claude-3-haiku-20240307",
+ max_tokens=10,
+ messages=[Message(role="user", content="hi")],
+ )
+
+ with (
+ patch.object(services_mod.logger, "error") as log_err,
+ pytest.raises(HTTPException),
+ ):
+ service.create_message(request)
+
+ blob = _flatten_log_calls(log_err)
+ assert secret not in blob
+ assert "RuntimeError" in blob
+
+
+def test_create_message_unexpected_error_always_returns_500():
+ """Non-provider failures must not leak arbitrary status_code attributes."""
+
+ class WeirdError(Exception):
+ status_code = 418
+
+ settings = Settings()
+ mock_provider = MagicMock()
+
+ def stream_boom(*_a, **_kw):
+ raise WeirdError("no")
+
+ mock_provider.stream_response = stream_boom
+ service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
+ request = MessagesRequest(
+ model="claude-3-haiku-20240307",
+ max_tokens=10,
+ messages=[Message(role="user", content="hi")],
+ )
+
+ with pytest.raises(HTTPException) as excinfo:
+ service.create_message(request)
+
+ assert excinfo.value.status_code == 500
+
+
+def test_parse_cli_event_error_logs_metadata_by_default():
+ """CLI parser must not log raw error text unless LOG_RAW_CLI_DIAGNOSTICS is on."""
+ from messaging.event_parser import parse_cli_event
+
+ secret = "user-secret-parser-leak-xyz"
+ with patch("messaging.event_parser.logger.info") as log_info:
+ parse_cli_event(
+ {"type": "error", "error": {"message": secret}}, log_raw_cli=False
+ )
+ flat = " ".join(str(c) for c in log_info.call_args_list)
+ assert secret not in flat
+ assert "message_chars" in flat
+
+
+def test_parse_cli_event_error_logs_text_when_log_raw_cli_enabled():
+ from messaging.event_parser import parse_cli_event
+
+ secret = "visible-cli-parser-msg"
+ with patch("messaging.event_parser.logger.info") as log_info:
+ parse_cli_event(
+ {"type": "error", "error": {"message": secret}}, log_raw_cli=True
+ )
+ flat = " ".join(str(c) for c in log_info.call_args_list)
+ assert secret in flat
+
+
+def test_count_tokens_unexpected_error_default_logs_exclude_exception_text():
+ settings = Settings()
+ assert settings.log_api_error_tracebacks is False
+ secret = "count-tokens-leak-xyz"
+
+ def boom(*_a, **_kw):
+ raise ValueError(secret)
+
+ service = ClaudeProxyService(
+ settings,
+ provider_getter=lambda _: MagicMock(),
+ token_counter=boom,
+ )
+ from api.models.anthropic import TokenCountRequest
+
+ req = TokenCountRequest(
+ model="claude-3-haiku-20240307",
+ messages=[Message(role="user", content="x")],
+ )
+
+ with (
+ patch.object(services_mod.logger, "error") as log_err,
+ pytest.raises(HTTPException),
+ ):
+ service.count_tokens(req)
+
+ blob = _flatten_log_calls(log_err)
+ assert secret not in blob
+ assert "ValueError" in blob
diff --git a/tests/api/test_validation_log.py b/tests/api/test_validation_log.py
new file mode 100644
index 0000000..9408745
--- /dev/null
+++ b/tests/api/test_validation_log.py
@@ -0,0 +1,33 @@
+"""Tests for validation log summaries (metadata only)."""
+
+from api.validation_log import summarize_request_validation_body
+
+
+def test_summarize_lists_block_metadata_without_echoing_string_content():
+ body = {
+ "messages": [
+ {
+ "role": "user",
+ "content": "secret user phrase",
+ }
+ ],
+ "tools": [{"name": "web_search", "type": "web_search_20250305"}],
+ }
+ summary, tool_names = summarize_request_validation_body(body)
+ assert summary == [
+ {
+ "role": "user",
+ "content_kind": "str",
+ "content_length": 18,
+ }
+ ]
+ assert tool_names == ["web_search"]
+ blob = repr(summary) + repr(tool_names)
+ assert "secret" not in blob
+
+
+def test_summarize_handles_non_dict_messages_and_missing_tools():
+ body = {"messages": ["not_a_dict"]}
+ summary, tool_names = summarize_request_validation_body(body)
+ assert summary == [{"message_kind": "str"}]
+ assert tool_names == []
diff --git a/tests/api/test_web_server_tools.py b/tests/api/test_web_server_tools.py
index c00d538..c46ddf2 100644
--- a/tests/api/test_web_server_tools.py
+++ b/tests/api/test_web_server_tools.py
@@ -1,20 +1,65 @@
-import json
+from typing import Any
+from unittest.mock import AsyncMock, MagicMock, patch
+import httpx
import pytest
+import api.web_tools.constants as web_tool_constants
+from api.model_router import ModelRouter, ResolvedModel, RoutedMessagesRequest
from api.models.anthropic import Message, MessagesRequest, Tool
-from api.web_server_tools import (
- is_web_server_tool_request,
- stream_web_server_tool_response,
+from api.services import ClaudeProxyService
+from api.web_tools import egress as web_egress
+from api.web_tools.egress import (
+ WebFetchEgressPolicy,
+ WebFetchEgressViolation,
+ enforce_web_fetch_egress,
+)
+from api.web_tools.outbound import (
+ _drain_response_body_capped,
+ _read_response_body_capped,
+ _run_web_fetch,
+)
+from api.web_tools.request import is_web_server_tool_request
+from api.web_tools.streaming import stream_web_server_tool_response
+from config.settings import Settings
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ parse_sse_text,
+ text_content,
+)
+from messaging.event_parser import parse_cli_event
+from providers.exceptions import InvalidRequestError
+
+_STRICT_EGRESS = WebFetchEgressPolicy(
+ allow_private_network_targets=False,
+ allowed_schemes=frozenset({"http", "https"}),
)
-def _event_data(event: str) -> dict:
- data_line = next(line for line in event.splitlines() if line.startswith("data: "))
- return json.loads(data_line.removeprefix("data: "))
+class FixedProviderModelRouter(ModelRouter):
+ """Test double: pin ``provider_id`` for OpenAI vs native routing assertions."""
+
+ def __init__(self, settings: Settings, provider_id: str) -> None:
+ super().__init__(settings)
+ self._fixed_provider_id = provider_id
+
+ def resolve_messages_request(
+ self, request: MessagesRequest
+ ) -> RoutedMessagesRequest:
+ resolved = ResolvedModel(
+ original_model=request.model,
+ provider_id=self._fixed_provider_id,
+ provider_model=request.model,
+ provider_model_ref=f"{self._fixed_provider_id}/{request.model}",
+ thinking_enabled=False,
+ )
+ routed = request.model_copy(deep=True)
+ routed.model = resolved.provider_model
+ return RoutedMessagesRequest(request=routed, resolved=resolved)
-def test_detects_web_search_server_tool_request():
+def test_web_server_tool_not_detected_when_tool_only_listed():
+ """Listing web_search without forcing it must not skip the upstream provider."""
request = MessagesRequest(
model="claude-haiku-4-5-20251001",
max_tokens=100,
@@ -22,16 +67,243 @@ def test_detects_web_search_server_tool_request():
tools=[Tool(name="web_search", type="web_search_20250305")],
)
+ assert not is_web_server_tool_request(request)
+
+
+def test_web_server_tool_detected_when_tool_choice_forces_it():
+ request = MessagesRequest(
+ model="claude-haiku-4-5-20251001",
+ max_tokens=100,
+ messages=[Message(role="user", content="search")],
+ tools=[Tool(name="web_search", type="web_search_20250305")],
+ tool_choice={"type": "tool", "name": "web_search"},
+ )
+
assert is_web_server_tool_request(request)
+def test_web_server_tool_not_detected_when_forced_name_missing_from_tools():
+ request = MessagesRequest(
+ model="claude-haiku-4-5-20251001",
+ max_tokens=100,
+ messages=[Message(role="user", content="hi")],
+ tools=[Tool(name="other", type="function")],
+ tool_choice={"type": "tool", "name": "web_search"},
+ )
+
+ assert not is_web_server_tool_request(request)
+
+
+def test_service_rejects_forced_server_tool_on_openai_when_disabled():
+ """OpenAI Chat upstreams cannot run forced server tools without the local handler."""
+ settings = Settings()
+ assert settings.enable_web_server_tools is False
+ service = ClaudeProxyService(
+ settings,
+ provider_getter=lambda _: MagicMock(),
+ model_router=FixedProviderModelRouter(settings, "nvidia_nim"),
+ )
+ request = MessagesRequest(
+ model="claude-haiku-4-5-20251001",
+ max_tokens=100,
+ messages=[
+ Message(
+ role="user",
+ content="Perform a web search for the query: DeepSeek V4 model release 2026",
+ )
+ ],
+ tools=[Tool(name="web_search", type="web_search_20250305")],
+ tool_choice={"type": "tool", "name": "web_search"},
+ )
+ with pytest.raises(InvalidRequestError, match="ENABLE_WEB_SERVER_TOOLS"):
+ service.create_message(request)
+
+
+@pytest.mark.parametrize(
+ "url",
+ [
+ "http://127.0.0.1/",
+ "http://192.168.1.1/",
+ "http://10.0.0.1/",
+ "http://[::1]/",
+ "http://localhost/foo",
+ "http://mybox.local/",
+ "file:///etc/passwd",
+ "http://169.254.169.254/latest/meta-data/",
+ ],
+)
+def test_enforce_web_fetch_egress_blocks_internal_or_disallowed(url: str):
+ with pytest.raises(WebFetchEgressViolation):
+ enforce_web_fetch_egress(url, _STRICT_EGRESS)
+
+
+def test_enforce_web_fetch_egress_allows_global_literal_ip():
+ enforce_web_fetch_egress("http://8.8.8.8/", _STRICT_EGRESS)
+
+
+def test_enforce_web_fetch_egress_skips_private_checks_when_opted_in():
+ enforce_web_fetch_egress(
+ "http://127.0.0.1/",
+ WebFetchEgressPolicy(
+ allow_private_network_targets=True,
+ allowed_schemes=frozenset({"http", "https"}),
+ ),
+ )
+
+
+def _cm(mock_client: MagicMock) -> MagicMock:
+ cm = MagicMock()
+ cm.__aenter__ = AsyncMock(return_value=mock_client)
+ cm.__aexit__ = AsyncMock(return_value=None)
+ return cm
+
+
+def _stream_cm(response: httpx.Response) -> MagicMock:
+ cm = MagicMock()
+ cm.__aenter__ = AsyncMock(return_value=response)
+ cm.__aexit__ = AsyncMock(return_value=None)
+ return cm
+
+
+def _aiohttp_response(
+ status: int,
+ *,
+ url: str = "http://8.8.8.8/",
+ location: str | None = None,
+ body: bytes = b"hello world",
+) -> MagicMock:
+ r = MagicMock()
+ r.status = status
+ r.url = url
+ hdrs: dict[str, str] = {}
+ if location is not None:
+ hdrs["location"] = location
+ r.headers = hdrs
+ r.get_encoding = MagicMock(return_value="utf-8")
+ r.raise_for_status = MagicMock()
+ r.request_info = MagicMock()
+ r.history = ()
+
+ async def iter_chunked(_n: int) -> Any:
+ yield body
+
+ r.content.iter_chunked = MagicMock(side_effect=iter_chunked)
+ return r
+
+
+def _aiohttp_client_session_patch(
+ *responses: MagicMock,
+) -> tuple[MagicMock, MagicMock]:
+ """Build ``ClientSession`` mock that serves ``responses`` to successive ``get`` calls."""
+ queue = list(responses)
+ n = 0
+
+ def get_side(*_a: Any, **_k: Any) -> Any:
+ nonlocal n
+ resp = queue[n] if n < len(queue) else queue[-1]
+ n += 1
+ cm = MagicMock()
+ cm.__aenter__ = AsyncMock(return_value=resp)
+ cm.__aexit__ = AsyncMock(return_value=None)
+ return cm
+
+ session = MagicMock()
+ session.get = MagicMock(side_effect=get_side)
+
+ client_cm = MagicMock()
+ client_cm.__aenter__ = AsyncMock(return_value=session)
+ client_cm.__aexit__ = AsyncMock(return_value=None)
+ return client_cm, session
+
+
+def test_enforce_web_fetch_egress_documents_connect_time_pinning():
+ assert enforce_web_fetch_egress.__doc__ and "resolved addresses" in (
+ enforce_web_fetch_egress.__doc__ or ""
+ )
+ assert (
+ web_egress.get_validated_stream_addrinfos_for_egress.__doc__
+ and "pinning"
+ in (web_egress.get_validated_stream_addrinfos_for_egress.__doc__ or "")
+ )
+ assert "DNS-pinned" in (_run_web_fetch.__doc__ or "")
+
+
+@pytest.mark.asyncio
+async def test_run_web_fetch_follows_redirect_when_each_hop_is_allowed():
+ res_redirect = _aiohttp_response(
+ 302, url="http://8.8.8.8/start", location="/final", body=b""
+ )
+ res_ok = _aiohttp_response(200, url="http://8.8.8.8/final", body=b"hello world")
+ client_cm, session = _aiohttp_client_session_patch(res_redirect, res_ok)
+ with patch("api.web_tools.outbound.ClientSession", return_value=client_cm):
+ out = await _run_web_fetch("http://8.8.8.8/start", _STRICT_EGRESS)
+
+ assert out["data"] == "hello world"
+ assert session.get.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_run_web_fetch_truncates_large_body_to_byte_cap(monkeypatch):
+ huge = b"x" * 5000
+ res_ok = _aiohttp_response(200, url="http://8.8.8.8/big", body=huge)
+ client_cm, _ = _aiohttp_client_session_patch(res_ok)
+ monkeypatch.setattr(web_tool_constants, "_MAX_WEB_FETCH_RESPONSE_BYTES", 100)
+ with patch("api.web_tools.outbound.ClientSession", return_value=client_cm):
+ out = await _run_web_fetch("http://8.8.8.8/big", _STRICT_EGRESS)
+
+ assert len(out["data"]) <= 100
+ assert out["data"] == "x" * 100
+
+
+@pytest.mark.asyncio
+async def test_run_web_fetch_redirect_to_blocked_host_raises():
+ res_redirect = _aiohttp_response(
+ 302,
+ url="http://8.8.8.8/start",
+ location="http://127.0.0.1/secret",
+ body=b"",
+ )
+ client_cm, session = _aiohttp_client_session_patch(res_redirect)
+ with (
+ patch("api.web_tools.outbound.ClientSession", return_value=client_cm),
+ pytest.raises(WebFetchEgressViolation),
+ ):
+ await _run_web_fetch("http://8.8.8.8/start", _STRICT_EGRESS)
+
+ session.get.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_run_web_fetch_redirect_without_location_raises():
+ res_bad = _aiohttp_response(302, url="http://8.8.8.8/here", body=b"")
+ client_cm, _ = _aiohttp_client_session_patch(res_bad)
+ with (
+ patch("api.web_tools.outbound.ClientSession", return_value=client_cm),
+ pytest.raises(WebFetchEgressViolation, match="missing Location"),
+ ):
+ await _run_web_fetch("http://8.8.8.8/here", _STRICT_EGRESS)
+
+
+@pytest.mark.asyncio
+async def test_run_web_fetch_excess_redirects_raises():
+ res1 = _aiohttp_response(302, url="http://8.8.8.8/a", location="/b", body=b"")
+ res2 = _aiohttp_response(302, url="http://8.8.8.8/b", location="/c", body=b"")
+ client_cm, _ = _aiohttp_client_session_patch(res1, res2)
+ with (
+ patch("api.web_tools.constants._MAX_WEB_FETCH_REDIRECTS", 1),
+ patch("api.web_tools.outbound.ClientSession", return_value=client_cm),
+ pytest.raises(WebFetchEgressViolation, match="exceeded maximum redirects"),
+ ):
+ await _run_web_fetch("http://8.8.8.8/a", _STRICT_EGRESS)
+
+
@pytest.mark.asyncio
async def test_streams_web_search_server_tool_result(monkeypatch):
async def fake_search(query: str) -> list[dict[str, str]]:
assert query == "DeepSeek V4 model release 2026"
return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}]
- monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search)
+ monkeypatch.setattr("api.web_tools.outbound._run_web_search", fake_search)
request = MessagesRequest(
model="claude-haiku-4-5-20251001",
max_tokens=100,
@@ -47,22 +319,92 @@ async def test_streams_web_search_server_tool_result(monkeypatch):
tool_choice={"type": "tool", "name": "web_search"},
)
- events = [
- event
- async for event in stream_web_server_tool_response(request, input_tokens=42)
+ raw = "".join(
+ [
+ event
+ async for event in stream_web_server_tool_response(
+ request, input_tokens=42, web_fetch_egress=_STRICT_EGRESS
+ )
+ ]
+ )
+ events = parse_sse_text(raw)
+ assert_anthropic_stream_contract(events)
+ starts = [e for e in events if e.event == "content_block_start"]
+ assert starts[0].data["content_block"]["type"] == "server_tool_use"
+ assert starts[0].data["content_block"]["name"] == "web_search"
+ tool_use_id = starts[0].data["content_block"]["id"]
+ assert starts[1].data["content_block"]["type"] == "web_search_tool_result"
+ assert starts[1].data["content_block"]["tool_use_id"] == tool_use_id
+ assert starts[1].data["content_block"]["content"][0]["url"] == (
+ "https://example.com/v4"
+ )
+ text_deltas = [
+ e
+ for e in events
+ if e.event == "content_block_delta"
+ and e.data.get("delta", {}).get("type") == "text_delta"
]
- payloads = [_event_data(event) for event in events]
+ assert text_deltas, "summary must be streamed as text_delta"
+ assert "example.com" in text_content(events)
+ cli_text: list[str] = []
+ for ev in events:
+ cli_text.extend(
+ str(p.get("text", ""))
+ for p in parse_cli_event(ev.data)
+ if p.get("type") == "text_delta"
+ )
+ assert "example.com" in "".join(cli_text)
+ deltas = [e for e in events if e.event == "message_delta"]
+ assert deltas[-1].data["usage"]["server_tool_use"] == {"web_search_requests": 1}
- assert payloads[1]["content_block"]["type"] == "server_tool_use"
- assert payloads[1]["content_block"]["name"] == "web_search"
- assert payloads[3]["content_block"]["type"] == "web_search_tool_result"
- assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4"
- assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1}
+
+@pytest.mark.asyncio
+async def test_forced_web_fetch_ignores_stale_url_from_prior_user_turns(monkeypatch):
+ """Only the latest user message supplies the URL (not earlier transcript text)."""
+ target = "https://new-only.example.com/page"
+
+ async def fake_fetch(url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]:
+ assert url == target
+ return {
+ "url": url,
+ "title": "T",
+ "media_type": "text/plain",
+ "data": "x",
+ }
+
+ monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", fake_fetch)
+ request = MessagesRequest(
+ model="claude-haiku-4-5-20251001",
+ max_tokens=100,
+ messages=[
+ Message(
+ role="user",
+ content="Earlier turn https://stale.com/old-article ignore this",
+ ),
+ Message(role="assistant", content="ok"),
+ Message(
+ role="user",
+ content=f"Please fetch {target} for the summary",
+ ),
+ ],
+ tools=[Tool(name="web_fetch", type="web_fetch_20250910")],
+ tool_choice={"type": "tool", "name": "web_fetch"},
+ )
+
+ raw = "".join(
+ [
+ event
+ async for event in stream_web_server_tool_response(
+ request, input_tokens=1, web_fetch_egress=_STRICT_EGRESS
+ )
+ ]
+ )
+ assert target in raw
@pytest.mark.asyncio
async def test_streams_web_fetch_server_tool_result(monkeypatch):
- async def fake_fetch(url: str) -> dict[str, str]:
+ async def fake_fetch(url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]:
assert url == "https://example.com/article"
return {
"url": url,
@@ -71,7 +413,7 @@ async def test_streams_web_fetch_server_tool_result(monkeypatch):
"data": "Article body",
}
- monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch)
+ monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", fake_fetch)
request = MessagesRequest(
model="claude-haiku-4-5-20251001",
max_tokens=100,
@@ -82,15 +424,198 @@ async def test_streams_web_fetch_server_tool_result(monkeypatch):
tool_choice={"type": "tool", "name": "web_fetch"},
)
- events = [
- event
- async for event in stream_web_server_tool_response(request, input_tokens=42)
- ]
- payloads = [_event_data(event) for event in events]
-
- assert payloads[1]["content_block"]["type"] == "server_tool_use"
- assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result"
- assert payloads[3]["content_block"]["content"]["content"]["title"] == (
+ raw = "".join(
+ [
+ event
+ async for event in stream_web_server_tool_response(
+ request, input_tokens=42, web_fetch_egress=_STRICT_EGRESS
+ )
+ ]
+ )
+ events = parse_sse_text(raw)
+ assert_anthropic_stream_contract(events)
+ starts = [e for e in events if e.event == "content_block_start"]
+ assert starts[0].data["content_block"]["type"] == "server_tool_use"
+ tool_use_id = starts[0].data["content_block"]["id"]
+ assert starts[1].data["content_block"]["type"] == "web_fetch_tool_result"
+ assert starts[1].data["content_block"]["tool_use_id"] == tool_use_id
+ assert starts[1].data["content_block"]["content"]["content"]["title"] == (
"Example Article"
)
- assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1}
+ assert any(
+ e.event == "content_block_delta"
+ and e.data.get("delta", {}).get("type") == "text_delta"
+ for e in events
+ )
+ assert "Article body" in text_content(events)
+ cli_text: list[str] = []
+ for ev in events:
+ cli_text.extend(
+ str(p.get("text", ""))
+ for p in parse_cli_event(ev.data)
+ if p.get("type") == "text_delta"
+ )
+ assert "Article body" in "".join(cli_text)
+ deltas = [e for e in events if e.event == "message_delta"]
+ assert deltas[-1].data["usage"]["server_tool_use"] == {"web_fetch_requests": 1}
+
+
+@pytest.mark.asyncio
+async def test_streams_web_fetch_error_summary_generic_by_default(monkeypatch):
+ secret = "sensitive-upstream-token"
+
+ async def boom(_url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]:
+ raise ValueError(secret)
+
+ monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", boom)
+ request = MessagesRequest(
+ model="claude-haiku-4-5-20251001",
+ max_tokens=100,
+ messages=[
+ Message(
+ role="user",
+ content="Fetch https://example.com/sensitive-path?x=1 please",
+ )
+ ],
+ tools=[Tool(name="web_fetch", type="web_fetch_20250910")],
+ tool_choice={"type": "tool", "name": "web_fetch"},
+ )
+
+ with patch("api.web_tools.outbound.logger.warning") as log_warn:
+ raw = "".join(
+ [
+ event
+ async for event in stream_web_server_tool_response(
+ request,
+ input_tokens=1,
+ web_fetch_egress=_STRICT_EGRESS,
+ verbose_client_errors=False,
+ )
+ ]
+ )
+
+ assert secret not in raw
+ assert "ValueError" not in raw
+ assert "Web tool request failed." in raw
+ err_events = parse_sse_text(raw)
+ assert_anthropic_stream_contract(err_events)
+ assert any(
+ e.event == "content_block_delta"
+ and e.data.get("delta", {}).get("type") == "text_delta"
+ for e in err_events
+ )
+ cli_err_text: list[str] = []
+ for ev in err_events:
+ cli_err_text.extend(
+ str(p.get("text", ""))
+ for p in parse_cli_event(ev.data)
+ if p.get("type") == "text_delta"
+ )
+ assert "Web tool request failed." in "".join(cli_err_text)
+ log_blob = " ".join(str(a) for c in log_warn.call_args_list for a in c.args)
+ assert secret not in log_blob
+ assert "example.com" in log_blob
+ assert "/sensitive-path" not in log_blob
+
+
+@pytest.mark.asyncio
+async def test_streams_web_fetch_error_summary_verbose_includes_exception_class(
+ monkeypatch,
+):
+ async def boom(_url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]:
+ raise OSError(5, "oops")
+
+ monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", boom)
+ request = MessagesRequest(
+ model="claude-haiku-4-5-20251001",
+ max_tokens=100,
+ messages=[Message(role="user", content="Fetch https://example.com/x")],
+ tools=[Tool(name="web_fetch", type="web_fetch_20250910")],
+ tool_choice={"type": "tool", "name": "web_fetch"},
+ )
+
+ raw = "".join(
+ [
+ event
+ async for event in stream_web_server_tool_response(
+ request,
+ input_tokens=1,
+ web_fetch_egress=_STRICT_EGRESS,
+ verbose_client_errors=True,
+ )
+ ]
+ )
+ assert "OSError" in raw
+
+
+@pytest.mark.asyncio
+async def test_read_response_body_capped_truncates_single_oversized_chunk():
+ cap = 500
+
+ async def aiter_bytes(chunk_size=None):
+ yield b"z" * (cap * 20)
+
+ response = MagicMock()
+ response.aiter_bytes = aiter_bytes
+
+ out = await _read_response_body_capped(response, cap)
+ assert len(out) == cap
+ assert out == b"z" * cap
+
+
+@pytest.mark.asyncio
+async def test_drain_response_body_capped_stops_after_first_chunk_when_oversized():
+ cap = 300
+ chunk_calls = {"n": 0}
+
+ async def aiter_bytes(chunk_size=None):
+ chunk_calls["n"] += 1
+ yield b"y" * (cap * 10)
+
+ response = MagicMock()
+ response.aiter_bytes = aiter_bytes
+
+ await _drain_response_body_capped(response, cap)
+ assert chunk_calls["n"] == 1
+
+
+def test_service_rejects_listed_server_tools_on_openai_chat() -> None:
+ settings = Settings()
+ service = ClaudeProxyService(
+ settings,
+ provider_getter=lambda _: MagicMock(),
+ model_router=FixedProviderModelRouter(settings, "deepseek"),
+ )
+ request = MessagesRequest(
+ model="m",
+ max_tokens=20,
+ messages=[Message(role="user", content="q")],
+ tools=[Tool(name="web_search", type="web_search_20250305")],
+ )
+ with pytest.raises(InvalidRequestError, match="OpenAI Chat upstreams"):
+ service.create_message(request)
+
+
+def test_listed_server_tools_routed_on_open_router() -> None:
+ """Native Anthropic transport may receive listed server tool definitions."""
+ settings = Settings()
+
+ async def fake_stream(*_a, **_k):
+ yield 'event: message_start\ndata: {"type":"message_start"}\n\n'
+ yield 'event: message_stop\ndata: {"type":"message_stop"}\n\n'
+
+ mock_provider = MagicMock()
+ mock_provider.stream_response = fake_stream
+ service = ClaudeProxyService(
+ settings,
+ provider_getter=lambda _: mock_provider,
+ model_router=FixedProviderModelRouter(settings, "open_router"),
+ )
+ request = MessagesRequest(
+ model="m",
+ max_tokens=20,
+ messages=[Message(role="user", content="q")],
+ tools=[Tool(name="web_search", type="web_search_20250305")],
+ )
+ service.create_message(request)
+ mock_provider.preflight_stream.assert_called()
diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py
index a22f9dd..590f51c 100644
--- a/tests/cli/test_cli.py
+++ b/tests/cli/test_cli.py
@@ -3,6 +3,7 @@
import asyncio
import json
import os
+from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -300,7 +301,7 @@ class TestCLISession:
mock_process = AsyncMock()
mock_process.stdout.read.side_effect = [b""] # No stdout
- mock_process.stderr.read.return_value = b"Fatal error"
+ mock_process.stderr.read.side_effect = [b"Fatal error", b""]
mock_process.wait.return_value = 1
with patch(
@@ -319,6 +320,62 @@ class TestCLISession:
assert events[1]["code"] == 1
assert events[1]["stderr"] == "Fatal error"
+ @pytest.mark.asyncio
+ async def test_start_task_stderr_while_stdout_streams(self):
+ """Stderr is drained concurrently so stdout streaming is not blocked."""
+ from cli.session import CLISession
+
+ session = CLISession("/tmp", "http://localhost:8082/v1")
+
+ mock_process = AsyncMock()
+ mock_process.stdout.read.side_effect = [
+ b'{"type": "message", "content": "Hi"}\n',
+ b"",
+ ]
+ mock_process.stderr.read.side_effect = [b"warning on stderr\n", b""]
+ mock_process.wait.return_value = 0
+
+ with patch(
+ "asyncio.create_subprocess_exec", new_callable=AsyncMock
+ ) as mock_exec:
+ mock_exec.return_value = mock_process
+
+ events = [e async for e in session.start_task("Hello")]
+
+ assert mock_process.stderr.read.await_count >= 2
+ err_events = [e for e in events if e.get("type") == "error"]
+ assert len(err_events) == 1
+ assert "warning on stderr" in err_events[0]["error"]["message"]
+ assert events[-1]["type"] == "exit"
+ assert events[-1]["code"] == 0
+
+ @pytest.mark.asyncio
+ async def test_drain_stderr_bounded_retains_cap_but_drains_to_eof(self):
+ """Oversized stderr is fully drained so the pipe cannot deadlock; capture is bounded."""
+ from cli.session import _MAX_STDERR_CAPTURE_BYTES, CLISession
+
+ total_len = _MAX_STDERR_CAPTURE_BYTES + 100_000
+ remaining: dict[str, int] = {"n": total_len}
+
+ class _FakeStderr:
+ async def read(self, n: int = 65536) -> bytes:
+ left = remaining["n"]
+ if left <= 0:
+ return b""
+ take = min(n, left)
+ remaining["n"] = left - take
+ return b"y" * take
+
+ class _FakeProcess:
+ stderr = _FakeStderr()
+
+ out = await CLISession._drain_stderr_bounded(
+ cast(asyncio.subprocess.Process, _FakeProcess())
+ )
+ assert len(out) == _MAX_STDERR_CAPTURE_BYTES
+ assert out == b"y" * _MAX_STDERR_CAPTURE_BYTES
+ assert remaining["n"] == 0
+
@pytest.mark.asyncio
async def test_stop_session(self):
"""Test stopping the session process."""
diff --git a/tests/config/test_config.py b/tests/config/test_config.py
index 583dc75..e62f583 100644
--- a/tests/config/test_config.py
+++ b/tests/config/test_config.py
@@ -3,6 +3,10 @@
import pytest
from pydantic import ValidationError
+from config.constants import (
+ ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
+ HTTP_CONNECT_TIMEOUT_DEFAULT,
+)
from config.nim import NimSettings
@@ -22,6 +26,7 @@ class TestSettings:
monkeypatch.delenv("MODEL", raising=False)
monkeypatch.delenv("HTTP_READ_TIMEOUT", raising=False)
+ monkeypatch.delenv("HTTP_CONNECT_TIMEOUT", raising=False)
monkeypatch.setitem(Settings.model_config, "env_file", ())
settings = Settings()
assert settings.model == "nvidia_nim/z-ai/glm4.7"
@@ -31,6 +36,12 @@ class TestSettings:
assert isinstance(settings.fast_prefix_detection, bool)
assert isinstance(settings.enable_model_thinking, bool)
assert settings.http_read_timeout == 120.0
+ assert settings.http_connect_timeout == HTTP_CONNECT_TIMEOUT_DEFAULT
+ assert settings.enable_web_server_tools is False
+ assert settings.log_raw_api_payloads is False
+ assert settings.log_raw_sse_events is False
+ assert settings.debug_platform_edits is False
+ assert settings.debug_subagent_stack is False
def test_get_settings_cached(self):
"""Test get_settings returns cached instance."""
@@ -57,10 +68,10 @@ class TestSettings:
assert len(settings.model) > 0
def test_base_url_constant(self):
- """Test NVIDIA_NIM_BASE_URL is a constant."""
- from providers.nvidia_nim import NVIDIA_NIM_BASE_URL
+ """Test NVIDIA_NIM_DEFAULT_BASE is a constant."""
+ from providers.nvidia_nim import NVIDIA_NIM_DEFAULT_BASE
- assert NVIDIA_NIM_BASE_URL == "https://integrate.api.nvidia.com/v1"
+ assert NVIDIA_NIM_DEFAULT_BASE == "https://integrate.api.nvidia.com/v1"
def test_lm_studio_base_url_from_env(self, monkeypatch):
"""LM_STUDIO_BASE_URL env var is loaded into settings."""
@@ -127,6 +138,18 @@ class TestSettings:
settings = Settings()
assert settings.http_connect_timeout == 5.0
+ def test_http_connect_timeout_default_matches_shared_constant(
+ self, monkeypatch
+ ) -> None:
+ """Default must match config.constants (and README / .env.example)."""
+ from config.settings import Settings
+
+ monkeypatch.delenv("HTTP_CONNECT_TIMEOUT", raising=False)
+ monkeypatch.setitem(Settings.model_config, "env_file", ())
+ settings = Settings()
+ assert settings.http_connect_timeout == HTTP_CONNECT_TIMEOUT_DEFAULT
+ assert HTTP_CONNECT_TIMEOUT_DEFAULT == 10.0
+
def test_enable_model_thinking_from_env(self, monkeypatch):
"""ENABLE_MODEL_THINKING env var is loaded into settings."""
from config.settings import Settings
@@ -319,6 +342,9 @@ class TestNimSettingsInvalidBounds:
class TestNimSettingsValidators:
"""Test custom field validators in NimSettings."""
+ def test_default_max_tokens_matches_shared_constant(self):
+ assert NimSettings().max_tokens == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
+
@pytest.mark.parametrize(
"seed_val,expected",
[("", None), (None, None), ("42", 42), (42, 42)],
diff --git a/tests/config/test_logging_config.py b/tests/config/test_logging_config.py
index 11a818c..9f54df0 100644
--- a/tests/config/test_logging_config.py
+++ b/tests/config/test_logging_config.py
@@ -4,6 +4,8 @@ import json
import logging
from pathlib import Path
+from loguru import logger
+
from config.logging_config import configure_logging
@@ -57,3 +59,38 @@ def test_configure_logging_skips_when_already_configured(tmp_path):
assert "Still goes to first file" in (tmp_path / "test.log").read_text(
encoding="utf-8"
)
+
+
+def test_telegram_bot_token_redacted_in_message_field(tmp_path) -> None:
+ log_file = str(tmp_path / "redact.log")
+ configure_logging(log_file, force=True, verbose_third_party=False)
+ token = "123456:ABCDEF-ghij-klm"
+ logger.info("Calling {}", f"https://api.telegram.org/bot{token}/getMe")
+ logger.complete()
+ text = Path(log_file).read_text(encoding="utf-8")
+ assert token not in text
+ assert "bot/" in text or "redacted" in text
+
+
+def test_bearer_substring_redacted_in_log_file(tmp_path) -> None:
+ log_file = str(tmp_path / "bearer.log")
+ configure_logging(log_file, force=True, verbose_third_party=False)
+ secret = "ya29.secret-token-abc"
+ logger.info("Request headers: Authorization: Bearer {}", secret)
+ logger.complete()
+ text = Path(log_file).read_text(encoding="utf-8")
+ assert secret not in text
+ assert "Bearer" in text
+
+
+def test_httpx_logger_quieted_when_not_verbose_third_party(tmp_path) -> None:
+ log_file = str(tmp_path / "quiet.log")
+ configure_logging(log_file, force=True, verbose_third_party=False)
+ assert logging.getLogger("httpx").level >= logging.WARNING
+ assert logging.getLogger("httpcore").level >= logging.WARNING
+
+
+def test_httpx_resets_to_notset_when_verbose_third_party(tmp_path) -> None:
+ log_file = str(tmp_path / "verbose.log")
+ configure_logging(log_file, force=True, verbose_third_party=True)
+ assert logging.getLogger("httpx").level == logging.NOTSET
diff --git a/tests/contracts/test_architecture_contracts.py b/tests/contracts/test_architecture_contracts.py
index 344122e..25be850 100644
--- a/tests/contracts/test_architecture_contracts.py
+++ b/tests/contracts/test_architecture_contracts.py
@@ -13,6 +13,25 @@ def test_architecture_plan_exists() -> None:
text = plan.read_text(encoding="utf-8")
assert "Intended Dependency Direction" in text
assert "Smoke Coverage Policy" in text
+ assert "providers.nvidia_nim.voice" in text
+ assert "no dedicated smoke SSE shim" in text
+
+
+def test_smoke_lib_has_no_sse_shim_module() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ assert not (repo_root / "smoke" / "lib" / "sse.py").exists()
+
+
+def test_api_package_exports_match_plan() -> None:
+ import api
+
+ assert set(api.__all__) == {
+ "MessagesRequest",
+ "MessagesResponse",
+ "TokenCountRequest",
+ "TokenCountResponse",
+ "create_app",
+ }
def test_root_env_example_is_the_single_template_source() -> None:
diff --git a/tests/contracts/test_import_boundaries.py b/tests/contracts/test_import_boundaries.py
index 918d7d1..3bde11d 100644
--- a/tests/contracts/test_import_boundaries.py
+++ b/tests/contracts/test_import_boundaries.py
@@ -54,6 +54,15 @@ def test_core_does_not_import_product_packages() -> None:
assert offenders == []
+def test_provider_catalog_is_single_source_for_supported_ids() -> None:
+ from config.provider_catalog import PROVIDER_CATALOG, SUPPORTED_PROVIDER_IDS
+ from providers.registry import PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES
+
+ assert tuple(PROVIDER_CATALOG.keys()) == SUPPORTED_PROVIDER_IDS
+ assert PROVIDER_DESCRIPTORS is PROVIDER_CATALOG
+ assert set(SUPPORTED_PROVIDER_IDS) == set(PROVIDER_FACTORIES)
+
+
def test_config_does_not_import_non_config_packages() -> None:
"""Settings and env handling must not depend on transport or protocol layers."""
repo_root = Path(__file__).resolve().parents[2]
@@ -71,14 +80,34 @@ def test_config_does_not_import_non_config_packages() -> None:
assert offenders == []
-def test_messaging_does_not_import_api_or_cli_or_providers() -> None:
- """Messaging is wired by ``api.runtime``; must not import server or provider adapters."""
+_MESSAGING_ALLOWED_PROVIDER_MODULES = frozenset({"providers.nvidia_nim.voice"})
+
+
+def test_messaging_does_not_import_disallowed_modules() -> None:
+ """Messaging is wired by ``api.runtime``; narrow provider imports only for NIM voice ASR."""
repo_root = Path(__file__).resolve().parents[2]
- offenders = _imports_matching(
- [repo_root / "messaging"],
- forbidden_prefixes=("api.", "cli.", "providers.", "smoke."),
- )
- assert offenders == []
+ offenders: list[str] = []
+ for path in (repo_root / "messaging").rglob("*.py"):
+ for imported in _imports_from(path, repo_root):
+ if imported is None:
+ continue
+ if (
+ imported == "api"
+ or imported.startswith("api.")
+ or imported == "cli"
+ or imported.startswith("cli.")
+ or imported == "smoke"
+ or imported.startswith("smoke.")
+ ):
+ rel = path.relative_to(repo_root)
+ offenders.append(f"{rel}: {imported}")
+ elif imported.startswith("providers."):
+ if imported in _MESSAGING_ALLOWED_PROVIDER_MODULES:
+ continue
+ rel = path.relative_to(repo_root)
+ offenders.append(f"{rel}: {imported}")
+
+ assert sorted(offenders) == []
def test_api_may_only_import_narrow_provider_facade() -> None:
diff --git a/tests/contracts/test_smoke_sse_reexport.py b/tests/contracts/test_smoke_sse_reexport.py
deleted file mode 100644
index 4190312..0000000
--- a/tests/contracts/test_smoke_sse_reexport.py
+++ /dev/null
@@ -1,11 +0,0 @@
-"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`."""
-
-from __future__ import annotations
-
-import core.anthropic.stream_contracts as core_sc
-import smoke.lib.sse as smoke_sse
-
-
-def test_smoke_lib_sse_reexports_core_stream_contracts() -> None:
- for name in smoke_sse.__all__:
- assert getattr(smoke_sse, name) is getattr(core_sc, name)
diff --git a/tests/contracts/test_stream_contracts.py b/tests/contracts/test_stream_contracts.py
index 0125459..6898cc3 100644
--- a/tests/contracts/test_stream_contracts.py
+++ b/tests/contracts/test_stream_contracts.py
@@ -8,16 +8,14 @@ from __future__ import annotations
from collections.abc import Iterable
from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser
+from core.anthropic.sse import format_sse_event
from core.anthropic.stream_contracts import (
assert_anthropic_stream_contract,
event_names,
- has_tool_use,
parse_sse_text,
text_content,
thinking_content,
)
-from messaging.event_parser import parse_cli_event
-from messaging.transcript import RenderCtx, TranscriptBuffer
def test_interleaved_thinking_text_blocks_are_valid() -> None:
@@ -59,44 +57,49 @@ def test_mixed_reasoning_content_and_think_tags_keep_order() -> None:
assert text_content(events) == " visible done"
-def test_thinking_tool_text_and_transcript_order_contract() -> None:
- builder = SSEBuilder("msg_contract", "contract-model")
- chunks = [builder.message_start()]
- chunks.extend(builder.ensure_thinking_block())
- chunks.append(builder.emit_thinking_delta("inspect first"))
- chunks.extend(builder.close_content_blocks())
- tool_block_index = builder.blocks.allocate_index()
- chunks.append(
- builder.content_block_start(
- tool_block_index, "tool_use", id="toolu_1", name="Read"
- )
- )
- chunks.append(
- builder.content_block_delta(
- tool_block_index, "input_json_delta", '{"file":"README.md"}'
- )
- )
- chunks.append(builder.content_block_stop(tool_block_index))
- chunks.extend(builder.ensure_text_block())
- chunks.append(builder.emit_text_delta("done"))
- chunks.extend(builder.close_all_blocks())
- chunks.append(builder.message_delta("end_turn", 20))
- chunks.append(builder.message_stop())
-
+def test_redacted_thinking_block_start_stop_is_valid() -> None:
+ """Native redacted_thinking uses start/stop only (no deltas)."""
+ chunks = [
+ format_sse_event(
+ "message_start",
+ {
+ "type": "message_start",
+ "message": {
+ "id": "msg_r",
+ "type": "message",
+ "role": "assistant",
+ "content": [],
+ "model": "m",
+ "stop_reason": None,
+ "stop_sequence": None,
+ "usage": {"input_tokens": 1, "output_tokens": 1},
+ },
+ },
+ ),
+ format_sse_event(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {"type": "redacted_thinking", "data": "opaque"},
+ },
+ ),
+ format_sse_event(
+ "content_block_stop",
+ {"type": "content_block_stop", "index": 0},
+ ),
+ format_sse_event(
+ "message_delta",
+ {
+ "type": "message_delta",
+ "delta": {"stop_reason": "end_turn", "stop_sequence": None},
+ "usage": {"input_tokens": 1, "output_tokens": 2},
+ },
+ ),
+ format_sse_event("message_stop", {"type": "message_stop"}),
+ ]
events = parse_sse_text("".join(chunks))
assert_anthropic_stream_contract(events)
- assert has_tool_use(events)
-
- transcript = TranscriptBuffer()
- for event in events:
- for parsed in parse_cli_event(event.data):
- transcript.apply(parsed)
- rendered = transcript.render(_render_ctx(), limit_chars=3900, status=None)
- assert (
- rendered.find("inspect first")
- < rendered.find("Tool call:")
- < rendered.find("done")
- )
def test_enable_thinking_false_suppresses_reasoning_only() -> None:
@@ -186,13 +189,3 @@ def _emit_parser_parts(
def _parse_builder_events(chunks: Iterable[str]):
return parse_sse_text("".join(chunks))
-
-
-def _render_ctx() -> RenderCtx:
- return RenderCtx(
- bold=lambda text: f"*{text}*",
- code_inline=lambda text: f"`{text}`",
- escape_code=lambda text: text,
- escape_text=lambda text: text,
- render_markdown=lambda text: text,
- )
diff --git a/tests/core/anthropic/test_native_sse_block_policy.py b/tests/core/anthropic/test_native_sse_block_policy.py
new file mode 100644
index 0000000..f1b6658
--- /dev/null
+++ b/tests/core/anthropic/test_native_sse_block_policy.py
@@ -0,0 +1,133 @@
+"""Unit tests for shared native Anthropic SSE thinking policy / block remapping."""
+
+from __future__ import annotations
+
+import json
+
+from core.anthropic.native_sse_block_policy import (
+ NativeSseBlockPolicyState,
+ format_native_sse_event,
+ transform_native_sse_block_event,
+)
+
+
+def test_thinking_start_dropped_when_disabled() -> None:
+ st = NativeSseBlockPolicyState()
+ payload = {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {"type": "thinking", "thinking": ""},
+ }
+ ev = format_native_sse_event(
+ "content_block_start",
+ json.dumps(payload),
+ )
+ assert transform_native_sse_block_event(ev, st, thinking_enabled=False) is None
+
+
+def test_thinking_delta_dropped_when_disabled() -> None:
+ st = NativeSseBlockPolicyState()
+ # No prior start in stream (OpenRouter-style: returns None when thinking off)
+ payload = {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "thinking_delta", "thinking": "secret"},
+ }
+ ev = format_native_sse_event("content_block_delta", json.dumps(payload))
+ assert transform_native_sse_block_event(ev, st, thinking_enabled=False) is None
+
+
+def test_text_block_passthrough_when_thinking_disabled() -> None:
+ st = NativeSseBlockPolicyState()
+ payload = {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {"type": "text", "text": ""},
+ }
+ ev = format_native_sse_event("content_block_start", json.dumps(payload))
+ out = transform_native_sse_block_event(ev, st, thinking_enabled=False)
+ assert out is not None
+ assert '"index": 0' in (out or "")
+
+
+def test_interleaved_thinking_signature_delta_remaps_to_reopened_block_index() -> None:
+ """After text interrupts thinking, signature_delta must follow the reopened segment index."""
+ st = NativeSseBlockPolicyState()
+
+ def run(ev: str) -> str | None:
+ return transform_native_sse_block_event(ev, st, thinking_enabled=True)
+
+ out1 = run(
+ format_native_sse_event(
+ "content_block_start",
+ json.dumps(
+ {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {"type": "thinking", "thinking": ""},
+ }
+ ),
+ )
+ )
+ assert out1 is not None and '"index": 0' in out1
+
+ out2 = run(
+ format_native_sse_event(
+ "content_block_start",
+ json.dumps(
+ {
+ "type": "content_block_start",
+ "index": 1,
+ "content_block": {"type": "text", "text": ""},
+ }
+ ),
+ )
+ )
+ assert out2 is not None
+
+ out3 = run(
+ format_native_sse_event(
+ "content_block_delta",
+ json.dumps(
+ {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "thinking_delta", "thinking": "plan"},
+ }
+ ),
+ )
+ )
+ assert out3 is not None
+ assert "content_block_start" in out3
+
+ out4 = run(
+ format_native_sse_event(
+ "content_block_delta",
+ json.dumps(
+ {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "signature_delta", "signature": "sig"},
+ }
+ ),
+ )
+ )
+ assert out4 is not None
+ assert '"index": 2' in out4
+ assert "signature_delta" in out4
+
+
+def test_startless_text_delta_synthesizes_start_when_thinking_disabled() -> None:
+ """Startless text deltas must not be dropped when thinking is disabled (OpenRouter)."""
+ st = NativeSseBlockPolicyState()
+ payload = {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "text_delta", "text": "Hello"},
+ }
+ ev = format_native_sse_event("content_block_delta", json.dumps(payload))
+ out = transform_native_sse_block_event(ev, st, thinking_enabled=False)
+ assert out is not None
+ assert "content_block_start" in (out or "")
+ assert "Hello" in (out or "")
+ assert "text_delta" in (out or "")
diff --git a/tests/core/test_strict_sliding_window.py b/tests/core/test_strict_sliding_window.py
new file mode 100644
index 0000000..1607521
--- /dev/null
+++ b/tests/core/test_strict_sliding_window.py
@@ -0,0 +1,38 @@
+"""Direct tests for :class:`core.rate_limit.StrictSlidingWindowLimiter`."""
+
+import time
+
+import pytest
+
+from core.rate_limit import StrictSlidingWindowLimiter
+
+
+@pytest.mark.asyncio
+async def test_strict_window_allows_burst_then_blocks():
+ lim = StrictSlidingWindowLimiter(rate_limit=2, rate_window=0.2)
+ await lim.acquire()
+ await lim.acquire()
+ start = time.monotonic()
+ await lim.acquire()
+ assert time.monotonic() - start >= 0.15
+
+
+@pytest.mark.asyncio
+async def test_strict_window_async_context_manager():
+ lim = StrictSlidingWindowLimiter(rate_limit=1, rate_window=0.15)
+
+ async def run():
+ async with lim:
+ pass
+
+ await run()
+ start = time.monotonic()
+ await run()
+ assert time.monotonic() - start >= 0.1
+
+
+def test_strict_window_rejects_invalid_config():
+ with pytest.raises(ValueError):
+ StrictSlidingWindowLimiter(rate_limit=0, rate_window=1.0)
+ with pytest.raises(ValueError):
+ StrictSlidingWindowLimiter(rate_limit=1, rate_window=0.0)
diff --git a/tests/messaging/test_handler.py b/tests/messaging/test_handler.py
index 7695dd0..013b096 100644
--- a/tests/messaging/test_handler.py
+++ b/tests/messaging/test_handler.py
@@ -13,6 +13,49 @@ def handler(mock_platform, mock_cli_manager, mock_session_store):
return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store)
+@pytest.mark.asyncio
+async def test_handle_message_default_logs_text_len_not_content(
+ mock_platform, mock_cli_manager, mock_session_store, incoming_message_factory
+):
+ secret = "user-secret-content-never-log-default"
+ handler = ClaudeMessageHandler(
+ mock_platform,
+ mock_cli_manager,
+ mock_session_store,
+ log_raw_messaging_content=False,
+ )
+ incoming = incoming_message_factory(text=secret)
+ with (
+ patch.object(handler, "_handle_message_impl", new_callable=AsyncMock),
+ patch("messaging.handler.logger.info") as log_info,
+ ):
+ await handler.handle_message(incoming)
+ blob = " ".join(str(c) for c in log_info.call_args_list)
+ assert secret not in blob
+ assert "text_len=" in blob
+
+
+@pytest.mark.asyncio
+async def test_handle_message_raw_content_logging_includes_preview(
+ mock_platform, mock_cli_manager, mock_session_store, incoming_message_factory
+):
+ secret = "visible-preview-xyz"
+ handler = ClaudeMessageHandler(
+ mock_platform,
+ mock_cli_manager,
+ mock_session_store,
+ log_raw_messaging_content=True,
+ )
+ incoming = incoming_message_factory(text=secret)
+ with (
+ patch.object(handler, "_handle_message_impl", new_callable=AsyncMock),
+ patch("messaging.handler.logger.info") as log_info,
+ ):
+ await handler.handle_message(incoming)
+ blob = " ".join(str(c) for c in log_info.call_args_list)
+ assert secret in blob
+
+
def test_get_initial_status_new_conversation(handler):
"""New conversation always returns launching message."""
result = handler._get_initial_status(None, None)
diff --git a/tests/messaging/test_handler_format.py b/tests/messaging/test_handler_format.py
index 2cf60ee..cede9fb 100644
--- a/tests/messaging/test_handler_format.py
+++ b/tests/messaging/test_handler_format.py
@@ -17,7 +17,6 @@ def handler():
platform = MagicMock()
cli = MagicMock()
store = MagicMock()
- # Kept for backwards test structure; transcript rendering is now separate.
return (platform, cli, store)
diff --git a/tests/messaging/test_handler_markdown_and_status_edges.py b/tests/messaging/test_handler_markdown_and_status_edges.py
index 185e8b1..662c6ff 100644
--- a/tests/messaging/test_handler_markdown_and_status_edges.py
+++ b/tests/messaging/test_handler_markdown_and_status_edges.py
@@ -4,6 +4,7 @@ import pytest
from messaging.handler import ClaudeMessageHandler
from messaging.models import IncomingMessage
+from messaging.node_event_pipeline import process_parsed_cli_event
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
from messaging.trees.data import MessageNode, MessageState
@@ -333,7 +334,7 @@ async def test_handle_message_incoming_text_none_safe():
@pytest.mark.asyncio
async def test_process_parsed_event_malformed_content_continues():
- """Malformed/unknown parsed event does not crash _process_parsed_event."""
+ """Malformed/unknown parsed event does not crash process_parsed_cli_event."""
platform = MagicMock()
platform.queue_edit_message = AsyncMock()
@@ -344,7 +345,7 @@ async def test_process_parsed_event_malformed_content_continues():
transcript = MagicMock()
update_ui = AsyncMock()
- last_status, had = await handler._process_parsed_event(
+ last_status, had = await process_parsed_cli_event(
parsed={"type": "unknown_type"},
transcript=transcript,
update_ui=update_ui,
@@ -353,6 +354,9 @@ async def test_process_parsed_event_malformed_content_continues():
tree=None,
node_id="n1",
captured_session_id=None,
+ session_store=session_store,
+ format_status=handler.format_status,
+ propagate_error_to_children=AsyncMock(),
)
assert last_status is None
assert had is False
diff --git a/tests/messaging/test_limiter.py b/tests/messaging/test_limiter.py
index 565e80f..0882c03 100644
--- a/tests/messaging/test_limiter.py
+++ b/tests/messaging/test_limiter.py
@@ -1,16 +1,10 @@
import asyncio
-import os
+import contextlib
import time
import pytest
import pytest_asyncio
-# Set environment variables relative to test execution
-os.environ["MESSAGING_RATE_LIMIT"] = "1"
-os.environ["MESSAGING_RATE_WINDOW"] = "0.5"
-
-import contextlib
-
from messaging.limiter import MessagingRateLimiter
@@ -19,11 +13,8 @@ class TestMessagingRateLimiter:
@pytest_asyncio.fixture(autouse=True)
async def reset_limiter(self):
- """Reset singleton and environment before each test."""
- # Ensure the singleton worker is stopped between tests to avoid dangling tasks.
+ """Reset singleton before each test."""
await MessagingRateLimiter.shutdown_instance(timeout=0.1)
- os.environ["MESSAGING_RATE_LIMIT"] = "1"
- os.environ["MESSAGING_RATE_WINDOW"] = "0.5"
yield
@@ -32,9 +23,16 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_singleton_pattern(self):
"""Test that get_instance returns the same object."""
- limiter1 = await MessagingRateLimiter.get_instance()
- limiter2 = await MessagingRateLimiter.get_instance()
+ limiter1 = await MessagingRateLimiter.get_instance(
+ rate_limit=1, rate_window=0.5
+ )
+ limiter2 = await MessagingRateLimiter.get_instance(
+ rate_limit=99, rate_window=99.0
+ )
assert limiter1 is limiter2
+ # First-construction wins for rate parameters
+ assert limiter1.limiter._rate_limit == 1
+ assert limiter1.limiter._rate_window == 0.5
@pytest.mark.asyncio
async def test_compaction(self):
@@ -42,13 +40,8 @@ class TestMessagingRateLimiter:
Verify multiple rapid requests with same dedup_key are compacted.
Logic ported from verify_limiter.py
"""
- # Set slow rate for testing compaction
- os.environ["MESSAGING_RATE_LIMIT"] = "1"
- os.environ["MESSAGING_RATE_WINDOW"] = "1.0"
-
- # Must reset instance to pick up new env vars
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
call_counts = {}
@@ -78,10 +71,8 @@ class TestMessagingRateLimiter:
Verify that even when compacted, all futures resolve to the result of the LAST execution.
Logic ported from verify_limiter_v2.py
"""
- os.environ["MESSAGING_RATE_LIMIT"] = "1"
- os.environ["MESSAGING_RATE_WINDOW"] = "0.5"
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=0.5)
call_counts = {}
msg_id = "test_msg_hang"
@@ -116,8 +107,8 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_flood_wait_handling(self):
"""Test that FloodWait exceptions pause the worker."""
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
# Mock exception with .seconds attribute
class FloodWait(Exception):
@@ -157,8 +148,8 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_flood_wait_retry_after_parsing(self):
"""Error message with 'retry after N' parses the wait seconds."""
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
async def mock_flood():
raise Exception("Flood wait: retry after 2 seconds")
@@ -172,8 +163,8 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_non_flood_exception_no_pause(self):
"""Non-flood exception doesn't trigger pause."""
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
async def mock_error():
raise ValueError("some regular error")
@@ -187,8 +178,8 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_flood_with_seconds_attribute(self):
"""Exception with .seconds attribute uses that value for pause."""
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
class FloodWaitCustom(Exception):
def __init__(self):
@@ -209,10 +200,8 @@ class TestMessagingRateLimiter:
Proactive limiter should enforce a strict sliding window:
for any i, t[i+rate_limit] - t[i] >= rate_window (within tolerance).
"""
- os.environ["MESSAGING_RATE_LIMIT"] = "2"
- os.environ["MESSAGING_RATE_WINDOW"] = "0.5"
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=2, rate_window=0.5)
async def acquire(i: int) -> float:
async def _do() -> float:
@@ -235,8 +224,8 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_compaction_last_task_fails_all_futures_get_exception(self):
"""When compacted task's last func fails, all futures get the exception."""
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
async def ok_task():
return "ok"
@@ -255,8 +244,8 @@ class TestMessagingRateLimiter:
@pytest.mark.asyncio
async def test_fire_and_forget_failure_logged(self, caplog):
"""fire_and_forget with failing task logs error and does not re-raise."""
- MessagingRateLimiter._instance = None
- limiter = await MessagingRateLimiter.get_instance()
+ await MessagingRateLimiter.shutdown_instance(timeout=0.1)
+ limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0)
async def fail_task():
raise ValueError("fire_and_forget failed")
@@ -264,4 +253,6 @@ class TestMessagingRateLimiter:
limiter.fire_and_forget(fail_task, dedup_key="fire_fail")
await asyncio.sleep(1.5)
- assert any("fire_and_forget failed" in str(r) for r in caplog.records)
+ joined = " ".join(str(r.message) for r in caplog.records)
+ assert "ValueError" in joined
+ assert "fire_and_forget failed" not in joined
diff --git a/tests/messaging/test_messaging_factory.py b/tests/messaging/test_messaging_factory.py
index 5d3c0a3..a2ed0cd 100644
--- a/tests/messaging/test_messaging_factory.py
+++ b/tests/messaging/test_messaging_factory.py
@@ -41,6 +41,10 @@ class TestCreateMessagingPlatform:
whisper_device="cuda",
hf_token="",
nvidia_nim_api_key="",
+ messaging_rate_limit=1,
+ messaging_rate_window=1.0,
+ log_raw_messaging_content=False,
+ log_api_error_tracebacks=False,
)
def test_telegram_without_token(self):
@@ -85,6 +89,10 @@ class TestCreateMessagingPlatform:
whisper_device="nvidia_nim",
hf_token="",
nvidia_nim_api_key="",
+ messaging_rate_limit=1,
+ messaging_rate_window=1.0,
+ log_raw_messaging_content=False,
+ log_api_error_tracebacks=False,
)
def test_discord_without_token(self):
diff --git a/tests/messaging/test_session_store_edge_cases.py b/tests/messaging/test_session_store_edge_cases.py
index 31683d4..8656c36 100644
--- a/tests/messaging/test_session_store_edge_cases.py
+++ b/tests/messaging/test_session_store_edge_cases.py
@@ -81,15 +81,40 @@ class TestSessionStoreSaveEdgeCases:
"""Tests for save failure handling."""
def test_save_io_error_handled(self, tmp_store):
- """Write failure in _write_data() raises (callers handle the error)."""
+ """Write failure during atomic replace is surfaced to callers."""
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
with (
- patch("builtins.open", side_effect=OSError("disk full")),
+ patch("messaging.session.os.replace", side_effect=OSError("disk full")),
pytest.raises(OSError),
):
tmp_store._write_data(tmp_store._snapshot())
+class TestSessionStoreAtomicWrites:
+ """Atomic persistence: failed replace must not truncate the prior file."""
+
+ def test_failed_replace_keeps_prior_bytes_and_marks_dirty(self, tmp_path):
+ path = str(tmp_path / "sessions.json")
+ store = SessionStore(storage_path=path)
+ store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
+ store.flush_pending_save()
+ with open(path, encoding="utf-8") as f:
+ disk_after_first = f.read()
+
+ store.save_tree("r2", {"root_id": "r2", "nodes": {"r2": {}}})
+
+ with patch(
+ "messaging.session.os.replace", side_effect=OSError("replace failed")
+ ):
+ store.flush_pending_save()
+
+ with open(path, encoding="utf-8") as f:
+ disk_after_failed = f.read()
+ assert disk_after_failed == disk_after_first
+ assert store._dirty is True
+ assert store.get_tree("r2") is not None
+
+
class TestSessionStoreClearAll:
def test_clear_all_wipes_state_and_persists(self, tmp_path):
path = str(tmp_path / "sessions.json")
diff --git a/tests/messaging/test_stream_transcript_contract.py b/tests/messaging/test_stream_transcript_contract.py
new file mode 100644
index 0000000..2ef2b60
--- /dev/null
+++ b/tests/messaging/test_stream_transcript_contract.py
@@ -0,0 +1,62 @@
+"""Messaging-specific assertions built on neutral Anthropic stream contracts."""
+
+from __future__ import annotations
+
+from core.anthropic import SSEBuilder
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ has_tool_use,
+ parse_sse_text,
+)
+from messaging.event_parser import parse_cli_event
+from messaging.transcript import RenderCtx, TranscriptBuffer
+
+
+def test_thinking_tool_text_and_transcript_order_contract() -> None:
+ builder = SSEBuilder("msg_contract", "contract-model")
+ chunks = [builder.message_start()]
+ chunks.extend(builder.ensure_thinking_block())
+ chunks.append(builder.emit_thinking_delta("inspect first"))
+ chunks.extend(builder.close_content_blocks())
+ tool_block_index = builder.blocks.allocate_index()
+ chunks.append(
+ builder.content_block_start(
+ tool_block_index, "tool_use", id="toolu_1", name="Read"
+ )
+ )
+ chunks.append(
+ builder.content_block_delta(
+ tool_block_index, "input_json_delta", '{"file":"README.md"}'
+ )
+ )
+ chunks.append(builder.content_block_stop(tool_block_index))
+ chunks.extend(builder.ensure_text_block())
+ chunks.append(builder.emit_text_delta("done"))
+ chunks.extend(builder.close_all_blocks())
+ chunks.append(builder.message_delta("end_turn", 20))
+ chunks.append(builder.message_stop())
+
+ events = parse_sse_text("".join(chunks))
+ assert_anthropic_stream_contract(events)
+ assert has_tool_use(events)
+
+ transcript = TranscriptBuffer()
+ for event in events:
+ for parsed in parse_cli_event(event.data):
+ transcript.apply(parsed)
+ rendered = transcript.render(_render_ctx(), limit_chars=3900, status=None)
+ assert (
+ rendered.find("inspect first")
+ < rendered.find("Tool call:")
+ < rendered.find("done")
+ )
+
+
+def _render_ctx() -> RenderCtx:
+ return RenderCtx(
+ bold=lambda s: f"*{s}*",
+ code_inline=lambda s: f"`{s}`",
+ escape_code=lambda s: s,
+ escape_text=lambda s: s,
+ render_markdown=lambda s: s,
+ )
diff --git a/tests/messaging/test_transcription.py b/tests/messaging/test_transcription.py
index 6aaaca4..023f723 100644
--- a/tests/messaging/test_transcription.py
+++ b/tests/messaging/test_transcription.py
@@ -25,7 +25,7 @@ def test_transcribe_file_too_large_raises():
path = Path(f.name)
try:
with pytest.raises(ValueError, match="too large"):
- transcribe_audio(path, "audio/ogg", whisper_device="auto")
+ transcribe_audio(path, "audio/ogg", whisper_device="cpu")
finally:
path.unlink(missing_ok=True)
@@ -87,15 +87,9 @@ def test_transcribe_invalid_device_raises():
f.write(b"fake ogg")
path = Path(f.name)
try:
- # Mock settings to return invalid device "auto"
- mock_settings = MagicMock()
- mock_settings.whisper_device = "auto"
- mock_settings.whisper_model = "base"
-
# Patch _load_audio to avoid ImportError from missing librosa
# Device validation happens in _get_pipeline before torch import
with (
- patch("messaging.transcription.get_settings", return_value=mock_settings),
patch("messaging.transcription._load_audio"),
pytest.raises(ValueError, match="whisper_device must be 'cpu' or 'cuda'"),
):
@@ -104,6 +98,24 @@ def test_transcribe_invalid_device_raises():
path.unlink(missing_ok=True)
+def test_transcribe_nim_requires_api_key():
+ """NIM path rejects empty API key without reading global settings."""
+ with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f:
+ f.write(b"fake ogg")
+ path = Path(f.name)
+ try:
+ with pytest.raises(ValueError, match="non-empty"):
+ transcribe_audio(
+ path,
+ "audio/ogg",
+ whisper_device="nvidia_nim",
+ whisper_model="openai/whisper-large-v3",
+ nvidia_nim_api_key="",
+ )
+ finally:
+ path.unlink(missing_ok=True)
+
+
def test_transcribe_local_import_error_raises():
"""Local backend when voice_local extra not installed raises ImportError."""
with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f:
@@ -120,6 +132,6 @@ def test_transcribe_local_import_error_raises():
),
pytest.raises(ImportError, match="voice_local extra"),
):
- transcribe_audio(path, "audio/ogg", whisper_device="auto")
+ transcribe_audio(path, "audio/ogg", whisper_device="cpu")
finally:
path.unlink(missing_ok=True)
diff --git a/tests/messaging/test_transcription_nim.py b/tests/messaging/test_transcription_nim.py
new file mode 100644
index 0000000..b15f89b
--- /dev/null
+++ b/tests/messaging/test_transcription_nim.py
@@ -0,0 +1,39 @@
+"""Tests for NVIDIA NIM voice transcription wiring."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from unittest.mock import patch
+
+from messaging.transcription import transcribe_audio
+
+
+def test_transcribe_audio_nvidia_nim_forwards_api_key(tmp_path: Path) -> None:
+ wav = tmp_path / "stub.wav"
+ wav.write_bytes(b"\x00" * 128)
+ with patch("messaging.transcription.transcribe_nvidia_nim_audio") as nim_fn:
+ nim_fn.return_value = "ok"
+ out = transcribe_audio(
+ wav,
+ "audio/wav",
+ whisper_model="openai/whisper-large-v3",
+ whisper_device="nvidia_nim",
+ nvidia_nim_api_key="test-nim-key",
+ )
+ nim_fn.assert_called_once_with(
+ wav, "openai/whisper-large-v3", api_key="test-nim-key"
+ )
+ assert out == "ok"
+
+
+def test_nim_asr_model_map_entries_are_real_function_ids() -> None:
+ from providers.nvidia_nim.voice import _NIM_ASR_MODEL_MAP
+
+ for function_id, language_code in _NIM_ASR_MODEL_MAP.values():
+ assert function_id
+ assert function_id.strip().lower() != "none"
+ # Hosted NIM function-id is a lowercase UUID string.
+ parts = function_id.split("-")
+ assert len(parts) == 5
+ assert all(p for p in parts)
+ assert language_code is not None
diff --git a/tests/messaging/test_tree_processor.py b/tests/messaging/test_tree_processor.py
index e6a3a03..664748d 100644
--- a/tests/messaging/test_tree_processor.py
+++ b/tests/messaging/test_tree_processor.py
@@ -6,7 +6,7 @@ import pytest
from messaging.models import IncomingMessage
from messaging.trees.data import MessageNode, MessageState, MessageTree
-from messaging.trees.processor import TreeQueueProcessor
+from messaging.trees.queue_manager import TreeQueueProcessor
@pytest.fixture
diff --git a/tests/messaging/test_tree_repository.py b/tests/messaging/test_tree_repository.py
index 2860ae6..e5bf4ef 100644
--- a/tests/messaging/test_tree_repository.py
+++ b/tests/messaging/test_tree_repository.py
@@ -4,7 +4,7 @@ import pytest
from messaging.models import IncomingMessage
from messaging.trees.data import MessageNode, MessageState, MessageTree
-from messaging.trees.repository import TreeRepository
+from messaging.trees.queue_manager import TreeRepository
@pytest.fixture
diff --git a/tests/provider_request_mocks.py b/tests/provider_request_mocks.py
new file mode 100644
index 0000000..ef2a40d
--- /dev/null
+++ b/tests/provider_request_mocks.py
@@ -0,0 +1,25 @@
+"""Shared MagicMock request objects for OpenAI-compatible provider tests."""
+
+from unittest.mock import MagicMock
+
+
+def make_openai_compat_stream_request(
+ *, model: str = "test-model", stream: bool = True
+) -> MagicMock:
+ """Minimal request stub matching :meth:`OpenAIChatTransport._build_request_body` needs."""
+ req = MagicMock()
+ req.model = model
+ req.stream = stream
+ req.messages = []
+ req.system = None
+ req.tools = None
+ req.tool_choice = None
+ req.metadata = None
+ req.max_tokens = 4096
+ req.temperature = None
+ req.top_p = None
+ req.top_k = None
+ req.stop_sequences = None
+ req.extra_body = None
+ req.thinking = None
+ return req
diff --git a/tests/providers/test_anthropic_messages.py b/tests/providers/test_anthropic_messages.py
index 2d9498d..1cb57cb 100644
--- a/tests/providers/test_anthropic_messages.py
+++ b/tests/providers/test_anthropic_messages.py
@@ -6,8 +6,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
+from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
+from core.anthropic.sse import format_sse_event
+from core.anthropic.stream_contracts import event_index, parse_sse_text
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
+from tests.stream_contract import assert_canonical_stream_error_envelope
class NativeProvider(AnthropicMessagesTransport):
@@ -32,8 +36,6 @@ class MockRequest:
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"extra_body": {"ignored": True},
- "original_model": "claude",
- "resolved_provider_model": "native/test-model",
"thinking": {"enabled": thinking_enabled},
}
@@ -42,16 +44,30 @@ class MockRequest:
class FakeResponse:
- def __init__(self, *, status_code=200, lines=None, text=""):
+ def __init__(
+ self,
+ *,
+ status_code=200,
+ lines=None,
+ text="",
+ raise_after_line_index: int | None = None,
+ ):
self.status_code = status_code
self._lines = lines or []
self._text = text
+ self._raise_after_line_index = raise_after_line_index
self.is_closed = False
self.request = httpx.Request("POST", "https://example.test/v1/messages")
+ self.headers = httpx.Headers()
async def aiter_lines(self):
- for line in self._lines:
+ for i, line in enumerate(self._lines):
yield line
+ if (
+ self._raise_after_line_index is not None
+ and i >= self._raise_after_line_index
+ ):
+ raise RuntimeError("mid-stream failure")
async def aread(self):
return self._text.encode()
@@ -67,6 +83,11 @@ class FakeResponse:
async def aclose(self):
self.is_closed = True
+ async def aiter_bytes(self, chunk_size: int = 65_536):
+ data = self._text.encode("utf-8")
+ for offset in range(0, len(data), chunk_size):
+ yield data[offset : offset + chunk_size]
+
@pytest.fixture
def provider_config():
@@ -122,10 +143,8 @@ def test_default_request_body_strips_internal_fields(provider_config):
assert body["model"] == "test-model"
assert body["thinking"] == {"type": "enabled"}
- assert body["max_tokens"] == 81920
+ assert body["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
assert "extra_body" not in body
- assert "original_model" not in body
- assert "resolved_provider_model" not in body
def test_default_request_body_preserves_thinking_budget(provider_config):
@@ -210,7 +229,66 @@ async def test_stream_maps_non_200_to_error_event_and_closes_response(
]
assert response.is_closed
- assert len(events) == 1
- assert events[0].startswith("event: error\ndata: {")
- assert "Internal Server Error" in events[0]
- assert "REQ_123" in events[0]
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Provider API request failed"
+ )
+ blob = "".join(events)
+ assert "REQ_123" in blob
+
+
+@pytest.mark.asyncio
+async def test_midstream_error_closes_open_block_and_uses_fresh_content_index(
+ provider_config,
+):
+ """After upstream message_start + content_block_start, synthetic errors must not reuse index 0."""
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ mid = "msg_midstream_err"
+ msg_start = format_sse_event(
+ "message_start",
+ {
+ "type": "message_start",
+ "message": {
+ "id": mid,
+ "type": "message",
+ "role": "assistant",
+ "content": [],
+ "model": "test-model",
+ "stop_reason": None,
+ "stop_sequence": None,
+ "usage": {"input_tokens": 1, "output_tokens": 1},
+ },
+ },
+ )
+ block_start = format_sse_event(
+ "content_block_start",
+ {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {"type": "text", "text": ""},
+ },
+ )
+ lines: list[str] = []
+ for blob in (msg_start, block_start):
+ lines.extend(blob.splitlines())
+ response = FakeResponse(lines=lines, raise_after_line_index=len(lines) - 1)
+
+ with (
+ patch.object(provider._client, "build_request", return_value=MagicMock()),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ ):
+ events = [e async for e in provider.stream_response(req)]
+
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="mid-stream failure"
+ )
+ parsed = parse_sse_text("".join(events))
+ starts = [e for e in parsed if e.event == "content_block_start"]
+ assert event_index(starts[0]) == 0
+ assert event_index(starts[-1]) == 1
+ assert {event_index(e) for e in parsed if e.event == "content_block_stop"} == {0, 1}
diff --git a/tests/providers/test_anthropic_messages_429_retry.py b/tests/providers/test_anthropic_messages_429_retry.py
new file mode 100644
index 0000000..e64bbb7
--- /dev/null
+++ b/tests/providers/test_anthropic_messages_429_retry.py
@@ -0,0 +1,125 @@
+"""Native Anthropic transport: HTTP 429 is retried inside execute_with_retry."""
+
+from contextlib import asynccontextmanager
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import httpx
+import pytest
+
+from providers.base import ProviderConfig
+from providers.rate_limit import GlobalRateLimiter
+from tests.providers.test_anthropic_messages import (
+ FakeResponse,
+ MockRequest,
+ NativeProvider,
+)
+from tests.stream_contract import assert_canonical_stream_error_envelope
+
+
+@pytest.fixture
+def provider_config():
+ return ProviderConfig(
+ api_key="test-key",
+ base_url="https://custom.test/v1/",
+ rate_limit=100,
+ rate_window=60,
+ http_read_timeout=600.0,
+ http_write_timeout=15.0,
+ http_connect_timeout=5.0,
+ )
+
+
+@pytest.mark.asyncio
+async def test_native_stream_retries_on_http_429_then_streams(provider_config):
+ """First response 429 (closed), second 200 streams; send is called twice."""
+ GlobalRateLimiter.reset_instance()
+ try:
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ request_obj = httpx.Request("POST", "https://custom.test/v1/messages")
+ ok_lines = [
+ "event: message_start",
+ 'data: {"type":"message_start"}',
+ "",
+ ]
+ ok_response = FakeResponse(lines=ok_lines)
+ too_many = FakeResponse(status_code=429, text="rate limited")
+
+ send_calls = {"n": 0}
+
+ async def send_side_effect(*_a, **_kw):
+ send_calls["n"] += 1
+ if send_calls["n"] == 1:
+ return too_many
+ return ok_response
+
+ with (
+ patch.object(provider._client, "build_request", return_value=request_obj),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ side_effect=send_side_effect,
+ ),
+ patch(
+ "asyncio.sleep",
+ new_callable=AsyncMock,
+ ),
+ ):
+ events = [e async for e in provider.stream_response(req)]
+
+ assert send_calls["n"] == 2
+ assert too_many.is_closed
+ assert ok_response.is_closed
+ assert events == [
+ "event: message_start\n",
+ 'data: {"type":"message_start"}\n',
+ "\n",
+ ]
+ finally:
+ GlobalRateLimiter.reset_instance()
+
+
+@pytest.mark.asyncio
+async def test_non_429_http_error_not_retried(provider_config):
+ """HTTP 500 from upstream is not retried; single send."""
+ GlobalRateLimiter.reset_instance()
+ try:
+
+ @asynccontextmanager
+ async def _slot():
+ yield
+
+ with patch("providers.anthropic_messages.GlobalRateLimiter") as mock_gl:
+ instance = mock_gl.get_scoped_instance.return_value
+
+ async def _passthrough(fn, *args, **kwargs):
+ return await fn(*args, **kwargs)
+
+ instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
+ instance.concurrency_slot.side_effect = _slot
+
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ err = FakeResponse(status_code=500, text="Internal Server Error")
+
+ with (
+ patch.object(
+ provider._client, "build_request", return_value=MagicMock()
+ ),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=err,
+ ) as mock_send,
+ ):
+ events = [e async for e in provider.stream_response(req)]
+
+ mock_send.assert_awaited_once()
+ assert err.is_closed
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Provider API request failed"
+ )
+ finally:
+ GlobalRateLimiter.reset_instance()
diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py
index 7312fbd..99aeba6 100644
--- a/tests/providers/test_converter.py
+++ b/tests/providers/test_converter.py
@@ -2,7 +2,12 @@ import json
import pytest
-from core.anthropic import AnthropicToOpenAIConverter
+from api.models.anthropic import MessagesRequest
+from core.anthropic import (
+ AnthropicToOpenAIConverter,
+ OpenAIConversionError,
+ build_base_request_body,
+)
# --- Mock Classes ---
@@ -468,3 +473,102 @@ def test_convert_multiple_tool_results():
assert len(result) == 2
assert result[0]["tool_call_id"] == "t1"
assert result[1]["tool_call_id"] == "t2"
+
+
+def test_convert_user_message_tool_result_dict_as_json():
+ content = [
+ MockBlock(
+ type="tool_result",
+ tool_use_id="t_dict",
+ content={"ok": True, "count": 2},
+ ),
+ ]
+ messages = [MockMessage("user", content)]
+ result = AnthropicToOpenAIConverter.convert_messages(messages)
+ assert result[0]["role"] == "tool"
+ assert result[0]["content"] == '{"ok": true, "count": 2}'
+
+
+def test_assistant_redacted_thinking_omitted_from_openai_chat():
+ """Opaque redacted_thinking is not materialized as content or reasoning_content."""
+ content = [
+ MockBlock(type="redacted_thinking", data="secret-opaque"),
+ MockBlock(type="text", text="Visible."),
+ ]
+ messages = [MockMessage("assistant", content)]
+ result = AnthropicToOpenAIConverter.convert_messages(
+ messages, include_thinking=True, include_reasoning_content=True
+ )
+ assert result[0]["content"] == "Visible."
+ assert "secret-opaque" not in result[0]["content"]
+ assert "reasoning_content" not in result[0]
+
+
+def test_convert_user_message_image_raises():
+ content = [
+ MockBlock(type="image", source={"type": "url", "url": "https://example.com/x"})
+ ]
+ messages = [MockMessage("user", content)]
+ with pytest.raises(OpenAIConversionError):
+ AnthropicToOpenAIConverter.convert_messages(messages)
+
+
+def test_convert_assistant_text_after_tool_use_raises():
+ content = [
+ MockBlock(type="tool_use", id="call_z", name="Read", input={}),
+ MockBlock(type="text", text="Illegal after tool"),
+ ]
+ messages = [MockMessage("assistant", content)]
+ with pytest.raises(OpenAIConversionError):
+ AnthropicToOpenAIConverter.convert_messages(messages)
+
+
+def test_openai_build_accepts_declared_native_top_level_hints() -> None:
+ """OpenAI conversion ignores known non-OpenAI hints (e.g. context_management) without 400."""
+ req = MessagesRequest.model_validate(
+ {
+ "model": "m",
+ "messages": [{"role": "user", "content": "h"}],
+ "context_management": {"edits": []},
+ "output_config": {"foo": 1},
+ "mcp_servers": [{"type": "url", "url": "https://x.com"}],
+ }
+ )
+ body = build_base_request_body(req, default_max_tokens=100)
+ assert "context_management" not in body
+ assert "output_config" not in body
+ assert "mcp_servers" not in body
+ assert body["model"] == "m"
+
+
+def test_openai_build_rejects_unknown_top_level_extras() -> None:
+ """Truly unknown keys must still be rejected (not dropped silently)."""
+ req = MessagesRequest.model_validate(
+ {
+ "model": "m",
+ "messages": [{"role": "user", "content": "h"}],
+ "experimental_client_only_passthrough": True,
+ }
+ )
+ with pytest.raises(OpenAIConversionError, match="top-level request fields"):
+ build_base_request_body(req)
+
+
+@pytest.mark.parametrize(
+ "content",
+ [
+ [MockBlock(type="server_tool_use", id="1", name="web_search", input={})],
+ [MockBlock(type="web_search_tool_result", tool_use_id="1", content=[])],
+ [
+ MockBlock(
+ type="web_fetch_tool_result",
+ tool_use_id="1",
+ content={"type": "web_fetch_result", "url": "https://a.com/x"},
+ )
+ ],
+ ],
+)
+def test_convert_assistant_server_tool_blocks_raise(content) -> None:
+ messages = [MockMessage("assistant", content)]
+ with pytest.raises(OpenAIConversionError, match="server tool"):
+ AnthropicToOpenAIConverter.convert_messages(messages)
diff --git a/tests/providers/test_deepseek.py b/tests/providers/test_deepseek.py
index bb00e0a..7a4d997 100644
--- a/tests/providers/test_deepseek.py
+++ b/tests/providers/test_deepseek.py
@@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
+from api.models.anthropic import ContentBlockImage, Message, MessagesRequest
from providers.base import ProviderConfig
-from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider
+from providers.deepseek import DEEPSEEK_DEFAULT_BASE, DeepSeekProvider
+from providers.exceptions import InvalidRequestError
class MockMessage:
@@ -41,7 +43,7 @@ class MockRequest:
def deepseek_config():
return ProviderConfig(
api_key="test_deepseek_key",
- base_url=DEEPSEEK_BASE_URL,
+ base_url=DEEPSEEK_DEFAULT_BASE,
rate_limit=10,
rate_window=60,
enable_thinking=True,
@@ -72,7 +74,7 @@ def test_init(deepseek_config):
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
provider = DeepSeekProvider(deepseek_config)
assert provider._api_key == "test_deepseek_key"
- assert provider._base_url == DEEPSEEK_BASE_URL
+ assert provider._base_url == DEEPSEEK_DEFAULT_BASE
mock_openai.assert_called_once()
@@ -91,7 +93,7 @@ def test_build_request_body_global_disable_blocks_request_thinking():
provider = DeepSeekProvider(
ProviderConfig(
api_key="test_deepseek_key",
- base_url=DEEPSEEK_BASE_URL,
+ base_url=DEEPSEEK_DEFAULT_BASE,
rate_limit=10,
rate_window=60,
enable_thinking=False,
@@ -152,6 +154,67 @@ def test_build_request_body_preserves_reasoning_content(deepseek_provider):
assert body["messages"][0]["reasoning_content"] == "First think"
+def test_build_request_body_disabled_thinking_omits_reasoning_and_thinking_tags():
+ """Resolved thinking policy must strip assistant thinking from OpenAI history."""
+ provider = DeepSeekProvider(
+ ProviderConfig(
+ api_key="test_deepseek_key",
+ base_url=DEEPSEEK_DEFAULT_BASE,
+ rate_limit=10,
+ rate_window=60,
+ enable_thinking=False,
+ )
+ )
+ req = MockRequest(
+ system=None,
+ model="deepseek-chat",
+ messages=[
+ MockMessage(
+ "assistant",
+ [
+ MockBlock(type="thinking", thinking="secret"),
+ MockBlock(type="text", text="hi"),
+ ],
+ )
+ ],
+ )
+
+ body = provider._build_request_body(req)
+
+ assistant = body["messages"][0]
+ assert "reasoning_content" not in assistant
+ assert "secret" not in assistant["content"]
+ assert assistant["content"] == "hi"
+
+
+def test_build_request_body_disabled_thinking_omits_redacted_blocks():
+ """redacted_thinking is not sent as OpenAI text when thinking is disabled."""
+ provider = DeepSeekProvider(
+ ProviderConfig(
+ api_key="test_deepseek_key",
+ base_url=DEEPSEEK_DEFAULT_BASE,
+ rate_limit=10,
+ rate_window=60,
+ enable_thinking=False,
+ )
+ )
+ req = MockRequest(
+ system=None,
+ model="deepseek-chat",
+ messages=[
+ MockMessage(
+ "assistant",
+ [
+ MockBlock(type="redacted_thinking", data="opaque-xyz"),
+ MockBlock(type="text", text="hi"),
+ ],
+ )
+ ],
+ )
+ body = provider._build_request_body(req)
+ assert "opaque-xyz" not in body["messages"][0]["content"]
+
+
@pytest.mark.asyncio
async def test_stream_response_reasoning_content(deepseek_provider):
"""reasoning_content deltas are emitted as thinking blocks."""
@@ -179,3 +242,37 @@ async def test_stream_response_reasoning_content(deepseek_provider):
assert any(
'"thinking_delta"' in event and "Thinking..." in event for event in events
)
+
+
+def test_preflight_stream_rejects_unsupported_user_image_for_openai_conversion():
+ """Eager preflight: image block fails before a stream would be opened."""
+ request = MessagesRequest(
+ model="deepseek/deepseek-chat",
+ max_tokens=100,
+ messages=[
+ Message(
+ role="user",
+ content=[
+ ContentBlockImage(
+ type="image",
+ source={
+ "type": "base64",
+ "media_type": "image/png",
+ "data": "YQ==",
+ },
+ )
+ ],
+ )
+ ],
+ )
+ provider = DeepSeekProvider(
+ ProviderConfig(
+ api_key="k",
+ base_url=DEEPSEEK_DEFAULT_BASE,
+ rate_limit=10,
+ rate_window=60,
+ )
+ )
+ with pytest.raises(InvalidRequestError) as exc:
+ provider.preflight_stream(request, thinking_enabled=True)
+ assert "image" in str(exc.value).lower()
diff --git a/tests/providers/test_error_mapping.py b/tests/providers/test_error_mapping.py
index 162b05d..2a180e7 100644
--- a/tests/providers/test_error_mapping.py
+++ b/tests/providers/test_error_mapping.py
@@ -1,13 +1,21 @@
"""Tests for provider error mapping and core error formatting."""
+from pathlib import Path
from unittest.mock import MagicMock, patch
import openai
import pytest
from httpx import ReadTimeout, Request, Response
-from core.anthropic import append_request_id, get_user_facing_error_message
-from providers.error_mapping import map_error
+from core.anthropic import (
+ append_request_id,
+ format_user_error_preview,
+ get_user_facing_error_message,
+)
+from providers.error_mapping import (
+ map_error,
+ user_visible_message_for_mapped_provider_error,
+)
from providers.exceptions import (
APIError,
AuthenticationError,
@@ -101,27 +109,6 @@ class TestMapError:
result = map_error(exc)
assert result is exc
- @pytest.mark.parametrize(
- "exc_cls,expected_cls",
- [
- (openai.AuthenticationError, AuthenticationError),
- (openai.RateLimitError, RateLimitError),
- (openai.BadRequestError, InvalidRequestError),
- ],
- ids=["auth", "rate_limit", "bad_request"],
- )
- def test_mapping_parametrized(self, exc_cls, expected_cls):
- """Parametrized check of openai -> provider error mapping."""
- status_map = {
- openai.AuthenticationError: 401,
- openai.RateLimitError: 429,
- openai.BadRequestError: 400,
- }
- exc = _make_openai_error(exc_cls, status_code=status_map[exc_cls])
- with patch("providers.error_mapping.GlobalRateLimiter"):
- result = map_error(exc)
- assert isinstance(result, expected_cls)
-
def test_user_facing_message_read_timeout_empty_string():
"""ReadTimeout wrapping TimeoutError should still produce readable text."""
@@ -134,3 +121,35 @@ def test_append_request_id_suffix():
"""Request id suffix should be appended deterministically."""
message = append_request_id("Provider request failed.", "req_abc123")
assert message == "Provider request failed. (request_id=req_abc123)"
+
+
+def test_user_facing_message_bad_request_prefers_mapped_text_over_sdk_string():
+ """BadRequestError should map to stable wording even when str(exc) is non-empty."""
+ exc = _make_openai_error(
+ openai.BadRequestError, message="leaky-upstream-detail", status_code=400
+ )
+ assert get_user_facing_error_message(exc) == "Invalid request sent to provider."
+
+
+def test_format_user_error_preview_truncates():
+ exc = ValueError("x" * 500)
+ preview = format_user_error_preview(exc, max_len=20)
+ assert len(preview) == 20
+ assert preview == "x" * 20
+
+
+def test_user_visible_message_for_mapped_provider_error_405():
+ mapped = APIError("ignored", status_code=405, raw_error="")
+ msg = user_visible_message_for_mapped_provider_error(
+ mapped, provider_name="ACME", read_timeout_s=30.0
+ )
+ assert "ACME" in msg and "405" in msg
+
+
+def test_streaming_transports_pass_scoped_rate_limiter_to_map_error():
+ """Guardrail: streaming adapters must scope reactive 429 handling per provider."""
+ root = Path(__file__).resolve().parents[2]
+ for name in ("anthropic_messages.py", "openai_compat.py"):
+ text = (root / "providers" / name).read_text(encoding="utf-8")
+ assert "map_error(" in text, name
+ assert "rate_limiter=self._global_rate_limiter" in text, name
diff --git a/tests/providers/test_llamacpp.py b/tests/providers/test_llamacpp.py
index 04e0874..df87069 100644
--- a/tests/providers/test_llamacpp.py
+++ b/tests/providers/test_llamacpp.py
@@ -5,8 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
+from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
from providers.base import ProviderConfig
from providers.llamacpp import LlamaCppProvider
+from tests.stream_contract import assert_canonical_stream_error_envelope
class MockMessage:
@@ -221,7 +223,7 @@ async def test_stream_response_adds_max_tokens_if_missing(llamacpp_provider):
[e async for e in llamacpp_provider.stream_response(req)]
_, kwargs = mock_build.call_args
- assert kwargs["json"]["max_tokens"] == 81920
+ assert kwargs["json"]["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
@pytest.mark.asyncio
@@ -254,10 +256,10 @@ async def test_stream_error_status_code(llamacpp_provider):
async for e in llamacpp_provider.stream_response(req, request_id="TEST_ID")
]
- assert len(events) == 1
- assert events[0].startswith("event: error\ndata: {")
- assert "Internal Server Error" in events[0]
- assert "TEST_ID" in events[0]
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Provider API request failed"
+ )
+ assert "TEST_ID" in "".join(events)
@pytest.mark.asyncio
@@ -281,10 +283,11 @@ async def test_stream_network_error(llamacpp_provider):
async for e in llamacpp_provider.stream_response(req, request_id="TEST_ID2")
]
- assert len(events) == 1
- assert events[0].startswith("event: error\ndata: {")
- assert "Connection refused" in events[0]
- assert "TEST_ID2" in events[0]
+ blob = "".join(events)
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Connection refused"
+ )
+ assert "TEST_ID2" in blob
@pytest.mark.asyncio
@@ -315,8 +318,36 @@ async def test_stream_error_405_mentions_upstream_provider(llamacpp_provider):
e async for e in llamacpp_provider.stream_response(req, request_id="REQ405")
]
+ blob = "".join(events)
assert (
"Upstream provider LLAMACPP rejected the request method or endpoint (HTTP 405)."
- in events[0]
+ in blob
)
- assert "REQ405" in events[0]
+ assert "REQ405" in blob
+
+
+def test_build_request_body_disabled_thinking_strips_native_thinking_history(
+ llamacpp_config,
+):
+ """With thinking disabled, prior assistant thinking/redacted blocks are omitted."""
+ config = llamacpp_config.model_copy(update={"enable_thinking": False})
+ provider = LlamaCppProvider(config)
+ messages = [
+ MockMessage("user", "Hi"),
+ MockMessage(
+ "assistant",
+ [
+ {"type": "thinking", "thinking": "p"},
+ {"type": "redacted_thinking", "data": "ZGF0YQ=="},
+ ],
+ ),
+ ]
+ req = MockRequest(
+ system=None,
+ messages=messages,
+ )
+ body = provider._build_request_body(req, thinking_enabled=False)
+ asst = body["messages"][1]
+ assert asst["content"] == ""
+ assert "thinking" not in str(body)
+ assert "redacted_thinking" not in str(body)
diff --git a/tests/providers/test_lmstudio.py b/tests/providers/test_lmstudio.py
index f8c4dee..06aca7a 100644
--- a/tests/providers/test_lmstudio.py
+++ b/tests/providers/test_lmstudio.py
@@ -5,8 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
+from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
from providers.base import ProviderConfig
from providers.lmstudio import LMStudioProvider
+from tests.stream_contract import assert_canonical_stream_error_envelope
class MockMessage:
@@ -252,7 +254,7 @@ async def test_stream_response_adds_max_tokens_if_missing(lmstudio_provider):
[e async for e in lmstudio_provider.stream_response(req)]
_, kwargs = mock_build.call_args
- assert kwargs["json"]["max_tokens"] == 81920
+ assert kwargs["json"]["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
@pytest.mark.asyncio
@@ -285,10 +287,10 @@ async def test_stream_error_status_code(lmstudio_provider):
async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID")
]
- assert len(events) == 1
- assert events[0].startswith("event: error\ndata: {")
- assert "Internal Server Error" in events[0]
- assert "TEST_ID" in events[0]
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Provider API request failed"
+ )
+ assert "TEST_ID" in "".join(events)
@pytest.mark.asyncio
@@ -312,10 +314,11 @@ async def test_stream_network_error(lmstudio_provider):
async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID2")
]
- assert len(events) == 1
- assert events[0].startswith("event: error\ndata: {")
- assert "Connection refused" in events[0]
- assert "TEST_ID2" in events[0]
+ blob = "".join(events)
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Connection refused"
+ )
+ assert "TEST_ID2" in blob
@pytest.mark.asyncio
@@ -346,8 +349,61 @@ async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider):
e async for e in lmstudio_provider.stream_response(req, request_id="REQ405")
]
+ blob = "".join(events)
assert (
"Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)."
- in events[0]
+ in blob
)
- assert "REQ405" in events[0]
+ assert "REQ405" in blob
+
+
+def test_build_request_body_disabled_thinking_strips_native_thinking_history(
+ lmstudio_config,
+):
+ """Disabled thinking must omit prior assistant thinking/redacted blocks in JSON."""
+ provider = LMStudioProvider(
+ lmstudio_config.model_copy(update={"enable_thinking": False})
+ )
+ req = MockRequest(
+ system=None,
+ messages=[
+ MockMessage("user", "hi"),
+ MockMessage(
+ "assistant",
+ [
+ {"type": "thinking", "thinking": "t"},
+ {"type": "redacted_thinking", "data": "opaque"},
+ ],
+ ),
+ ],
+ )
+ body = provider._build_request_body(req, thinking_enabled=False)
+ assert body["messages"][1]["content"] == ""
+ assert "redacted_thinking" not in str(body)
+
+
+def test_build_request_body_preserves_signed_thinking_native_history(lmstudio_config):
+ """When thinking is enabled, signed thinking blocks are kept; unsigned stripped."""
+ provider = LMStudioProvider(lmstudio_config)
+ req = MockRequest(
+ system=None,
+ messages=[
+ MockMessage("user", "hi"),
+ MockMessage(
+ "assistant",
+ [
+ {
+ "type": "thinking",
+ "thinking": "signed",
+ "signature": "sig",
+ },
+ {"type": "redacted_thinking", "data": "opaque"},
+ ],
+ ),
+ ],
+ )
+ body = provider._build_request_body(req, thinking_enabled=True)
+ c = body["messages"][1]["content"]
+ assert isinstance(c, list)
+ assert any(isinstance(b, dict) and b.get("type") == "thinking" for b in c)
+ assert any(isinstance(b, dict) and b.get("type") == "redacted_thinking" for b in c)
diff --git a/tests/providers/test_nim_request_clone.py b/tests/providers/test_nim_request_clone.py
new file mode 100644
index 0000000..30260c1
--- /dev/null
+++ b/tests/providers/test_nim_request_clone.py
@@ -0,0 +1,40 @@
+"""Tests for NVIDIA NIM request body cloning helpers."""
+
+from copy import deepcopy
+
+from providers.nvidia_nim.request import clone_body_without_reasoning_budget
+
+
+def test_clone_body_without_reasoning_budget_strips_top_level_and_nested():
+ body: dict = {
+ "model": "x",
+ "extra_body": {
+ "reasoning_budget": 99,
+ "chat_template_kwargs": {"reasoning_budget": 42, "thinking": True},
+ "top_k": 1,
+ },
+ }
+ original_extra = deepcopy(body["extra_body"])
+ out = clone_body_without_reasoning_budget(body)
+
+ assert out is not None
+ assert out["extra_body"]["chat_template_kwargs"] == {"thinking": True}
+ assert "reasoning_budget" not in out["extra_body"]
+ assert body["extra_body"] == original_extra
+
+
+def test_clone_body_without_reasoning_budget_returns_none_when_unchanged():
+ body = {"model": "x", "extra_body": {"top_k": 3}}
+ assert clone_body_without_reasoning_budget(body) is None
+
+
+def test_clone_body_without_reasoning_budget_returns_none_without_extra_body():
+ assert clone_body_without_reasoning_budget({"model": "y"}) is None
+
+
+def test_clone_body_drops_empty_extra_body_after_strip():
+ body = {"model": "z", "extra_body": {"reasoning_budget": 7}}
+ out = clone_body_without_reasoning_budget(body)
+ assert out is not None
+ assert "extra_body" not in out
+ assert "extra_body" in body
diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py
index dcacbfd..4c190ed 100644
--- a/tests/providers/test_nvidia_nim.py
+++ b/tests/providers/test_nvidia_nim.py
@@ -5,6 +5,8 @@ import openai
import pytest
from httpx import Request, Response
+from config.nim import NimSettings
+from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
from providers.nvidia_nim import NvidiaNimProvider
@@ -42,7 +44,7 @@ class MockRequest:
def _make_bad_request_error(message: str) -> openai.BadRequestError:
response = Response(
status_code=400,
- request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"),
+ request=Request("POST", f"{NVIDIA_NIM_DEFAULT_BASE}/chat/completions"),
)
body = {"error": {"message": message, "type": "BadRequestError", "code": 400}}
return openai.BadRequestError(message, response=response, body=body)
@@ -67,8 +69,6 @@ def mock_rate_limiter():
async def test_init(provider_config):
"""Test provider initialization."""
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
assert provider._api_key == "test_key"
assert provider._base_url == "https://test.api.nvidia.com/v1"
@@ -78,7 +78,6 @@ async def test_init(provider_config):
@pytest.mark.asyncio
async def test_init_uses_configurable_timeouts():
"""Test that provider passes configurable read/write/connect timeouts to client."""
- from config.nim import NimSettings
from providers.base import ProviderConfig
config = ProviderConfig(
@@ -100,8 +99,6 @@ async def test_init_uses_configurable_timeouts():
@pytest.mark.asyncio
async def test_build_request_body(provider_config):
"""Test request body construction."""
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
req = MockRequest()
body = provider._build_request_body(req)
@@ -124,8 +121,6 @@ async def test_build_request_body(provider_config):
async def test_build_request_body_omits_reasoning_when_globally_disabled(
provider_config,
):
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(
provider_config.model_copy(update={"enable_thinking": False}),
nim_settings=NimSettings(),
@@ -142,8 +137,6 @@ async def test_build_request_body_omits_reasoning_when_globally_disabled(
async def test_build_request_body_omits_reasoning_when_request_disables_thinking(
provider_config,
):
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
req = MockRequest()
req.thinking.enabled = False
@@ -241,8 +234,6 @@ async def test_stream_response_thinking_reasoning_content(nim_provider):
@pytest.mark.asyncio
async def test_stream_response_suppresses_thinking_when_disabled(provider_config):
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(
provider_config.model_copy(update={"enable_thinking": False}),
nim_settings=NimSettings(),
@@ -285,8 +276,6 @@ def _make_bad_request_error(message: str) -> openai.BadRequestError:
@pytest.mark.asyncio
async def test_stream_response_retries_without_chat_template(provider_config):
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(
provider_config,
nim_settings=NimSettings(chat_template="custom_template"),
@@ -344,8 +333,6 @@ async def test_stream_response_retries_without_chat_template(provider_config):
@pytest.mark.asyncio
async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config):
- from config.nim import NimSettings
-
provider = NvidiaNimProvider(
provider_config,
nim_settings=NimSettings(chat_template="custom_template"),
@@ -361,7 +348,7 @@ async def test_stream_response_does_not_retry_unrelated_bad_request(provider_con
assert mock_create.await_count == 1
event_text = "".join(events)
- assert "unrelated bad request" in event_text
+ assert "Invalid request sent to provider" in event_text
assert "event: message_stop" in event_text
@@ -457,5 +444,5 @@ async def test_stream_response_bad_request_without_reasoning_budget_does_not_ret
events = [e async for e in nim_provider.stream_response(req)]
assert mock_create.await_count == 1
- assert any("Unsupported field: top_k" in event for event in events)
+ assert any("Invalid request sent to provider" in event for event in events)
assert any("message_stop" in event for event in events)
diff --git a/tests/providers/test_ollama.py b/tests/providers/test_ollama.py
index 1edeb3b..ffb90a7 100644
--- a/tests/providers/test_ollama.py
+++ b/tests/providers/test_ollama.py
@@ -6,7 +6,8 @@ import httpx
import pytest
from providers.base import ProviderConfig
-from providers.ollama import OLLAMA_BASE_URL, OllamaProvider
+from providers.ollama import OLLAMA_DEFAULT_BASE, OllamaProvider
+from tests.stream_contract import assert_canonical_stream_error_envelope
class MockMessage:
@@ -92,7 +93,7 @@ def test_init_uses_default_base_url():
config = ProviderConfig(api_key="ollama", base_url=None)
with patch("httpx.AsyncClient"):
provider = OllamaProvider(config)
- assert provider._base_url == OLLAMA_BASE_URL
+ assert provider._base_url == OLLAMA_DEFAULT_BASE
def test_init_uses_configurable_timeouts():
@@ -199,6 +200,30 @@ async def test_build_request_body_omits_thinking_when_disabled(ollama_config):
assert body["model"] == "llama3.1:8b"
+def test_build_request_body_disabled_thinking_strips_assistant_thinking_blocks(
+ ollama_config,
+):
+ """Prior assistant thinking/redacted blocks are removed when policy is off."""
+ provider = OllamaProvider(
+ ollama_config.model_copy(update={"enable_thinking": False})
+ )
+ req = MockRequest(
+ system=None,
+ messages=[
+ MockMessage("user", "hi"),
+ MockMessage(
+ "assistant",
+ [
+ {"type": "thinking", "thinking": "t"},
+ {"type": "redacted_thinking", "data": "opaque"},
+ ],
+ ),
+ ],
+ )
+ body = provider._build_request_body(req, thinking_enabled=False)
+ assert body["messages"][1]["content"] == ""
+
+
@pytest.mark.asyncio
async def test_stream_error_status_code(ollama_provider):
"""Non-200 status code is yielded as an SSE API error."""
@@ -228,10 +253,10 @@ async def test_stream_error_status_code(ollama_provider):
async for event in ollama_provider.stream_response(req, request_id="REQ")
]
- assert len(events) == 1
- assert events[0].startswith("event: error\ndata: {")
- assert "Internal Server Error" in events[0]
- assert "REQ" in events[0]
+ assert_canonical_stream_error_envelope(
+ events, user_message_substr="Provider API request failed"
+ )
+ assert "REQ" in "".join(events)
@pytest.mark.asyncio
diff --git a/tests/providers/test_open_router.py b/tests/providers/test_open_router.py
index d88b2c6..7a6da14 100644
--- a/tests/providers/test_open_router.py
+++ b/tests/providers/test_open_router.py
@@ -6,6 +6,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ parse_sse_text,
+ text_content,
+ thinking_content,
+)
from providers.base import ProviderConfig
from providers.open_router import OpenRouterProvider
from providers.open_router.request import OPENROUTER_DEFAULT_MAX_TOKENS
@@ -30,8 +36,6 @@ class MockRequest:
self.tool_choice = None
self.metadata = None
self.extra_body = {}
- self.original_model = "claude-3-sonnet"
- self.resolved_provider_model = "open_router/stepfun/step-3.5-flash:free"
self.thinking = MagicMock()
self.thinking.enabled = True
for k, v in kwargs.items():
@@ -135,8 +139,26 @@ def test_build_request_body_is_native_anthropic(open_router_provider):
assert body["system"] == "System prompt"
assert body["reasoning"] == {"enabled": True}
assert "extra_body" not in body
- assert "original_model" not in body
- assert "resolved_provider_model" not in body
+
+
+def test_openrouter_extra_body_rejects_overriding_reserved_fields() -> None:
+ from providers.exceptions import InvalidRequestError
+ from providers.open_router.request import build_request_body
+
+ r = MockRequest()
+ r.extra_body = {"model": "hijack"}
+ with pytest.raises(InvalidRequestError, match="model"):
+ build_request_body(r, thinking_enabled=True)
+
+
+def test_openrouter_extra_body_allows_openrouter_only_keys() -> None:
+ from providers.open_router.request import build_request_body
+
+ r = MockRequest()
+ r.extra_body = {"transforms": ["no-web"], "plugins": []}
+ body = build_request_body(r, thinking_enabled=False)
+ assert body["transforms"] == ["no-web"]
+ assert body["plugins"] == []
def test_build_request_body_omits_reasoning_when_globally_disabled(
@@ -188,7 +210,6 @@ def test_build_request_body_default_max_tokens(open_router_provider):
body = open_router_provider._build_request_body(req)
assert body["max_tokens"] == OPENROUTER_DEFAULT_MAX_TOKENS
- assert body["max_tokens"] == 81920
def test_build_request_body_strips_unsigned_thinking_history(open_router_provider):
@@ -209,7 +230,32 @@ def test_build_request_body_strips_unsigned_thinking_history(open_router_provide
body = open_router_provider._build_request_body(req)
- assert body["messages"][1]["content"] == [{"type": "text", "text": "Hello"}]
+ assert body["messages"][1]["content"] == [
+ {"type": "redacted_thinking", "data": "opaque"},
+ {"type": "text", "text": "Hello"},
+ ]
+
+
+def test_build_request_body_strips_redacted_when_thinking_disabled(
+ open_router_config,
+):
+ """Disabled thinking must remove all assistant thinking history including redacted."""
+ provider = OpenRouterProvider(
+ open_router_config.model_copy(update={"enable_thinking": False})
+ )
+ req = MockRequest(
+ messages=[
+ MockMessage(
+ "assistant",
+ [
+ {"type": "redacted_thinking", "data": "opaque"},
+ {"type": "text", "text": "Hi"},
+ ],
+ )
+ ]
+ )
+ body = provider._build_request_body(req)
+ assert body["messages"][0]["content"] == [{"type": "text", "text": "Hi"}]
def test_build_request_body_preserves_signed_thinking_history(open_router_provider):
@@ -336,12 +382,12 @@ async def test_stream_response_suppresses_native_thinking_when_disabled(
assert "Answer" in event_text
text_start = next(event for event in events if "content_block_start" in event)
- payload = json.loads(text_start.split("data: ", 1)[1])
+ payload = parse_sse_text(text_start)[0].data
assert payload["index"] == 0
@pytest.mark.asyncio
-async def test_stream_response_drops_redacted_thinking_when_enabled(
+async def test_stream_response_preserves_redacted_thinking_when_enabled(
open_router_provider,
):
response = FakeResponse(
@@ -378,16 +424,216 @@ async def test_stream_response_drops_redacted_thinking_when_enabled(
):
events = [e async for e in open_router_provider.stream_response(MockRequest())]
+ event_text = "".join(events)
+ assert "redacted_thinking" in event_text
+ assert "opaque" in event_text
+ assert "Answer" in event_text
+
+ parsed = parse_sse_text(event_text)
+ first_start = next(
+ p
+ for p in parsed
+ if p.event == "content_block_start"
+ and p.data.get("content_block", {}).get("type") == "redacted_thinking"
+ )
+ assert first_start.data["index"] == 0
+
+
+@pytest.mark.asyncio
+async def test_stream_response_drops_redacted_thinking_when_disabled(
+ open_router_config,
+):
+ provider = OpenRouterProvider(
+ open_router_config.model_copy(update={"enable_thinking": False})
+ )
+ response = FakeResponse(
+ lines=[
+ "event: content_block_start",
+ 'data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"opaque"}}',
+ "",
+ "event: content_block_stop",
+ 'data: {"type":"content_block_stop","index":0}',
+ "",
+ "event: content_block_start",
+ 'data: {"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}',
+ "",
+ "event: content_block_delta",
+ 'data: {"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"Answer"}}',
+ "",
+ "event: content_block_stop",
+ 'data: {"type":"content_block_stop","index":1}',
+ "",
+ ]
+ )
+
+ with (
+ patch.object(provider._client, "build_request"),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ ):
+ events = [e async for e in provider.stream_response(MockRequest())]
+
event_text = "".join(events)
assert "redacted_thinking" not in event_text
+ assert "opaque" not in event_text
assert "Answer" in event_text
start_event = next(event for event in events if "content_block_start" in event)
- payload = json.loads(start_event.split("data: ", 1)[1])
+ payload = parse_sse_text(start_event)[0].data
assert payload["index"] == 0
assert payload["content_block"]["type"] == "text"
+@pytest.mark.asyncio
+async def test_stream_response_reopens_interleaved_thinking_after_text(
+ open_router_provider,
+):
+ """Overthinking+text+more thinking: downstream indices must not reuse closed blocks."""
+ response = FakeResponse(
+ lines=[
+ "event: message_start",
+ 'data: {"type":"message_start","message":{}}',
+ "",
+ "event: content_block_start",
+ 'data: {"type":"content_block_start","index":0,'
+ '"content_block":{"type":"thinking","thinking":"","signature":""}}',
+ "",
+ "event: content_block_delta",
+ 'data: {"type":"content_block_delta","index":0,'
+ '"delta":{"type":"thinking_delta","thinking":"first"}}',
+ "",
+ "event: content_block_start",
+ 'data: {"type":"content_block_start","index":1,'
+ '"content_block":{"type":"text","text":""}}',
+ "",
+ "event: content_block_delta",
+ 'data: {"type":"content_block_delta","index":0,'
+ '"delta":{"type":"thinking_delta","thinking":" second"}}',
+ "",
+ "event: content_block_delta",
+ 'data: {"type":"content_block_delta","index":1,'
+ '"delta":{"type":"text_delta","text":"Answer"}}',
+ "",
+ "event: content_block_stop",
+ 'data: {"type":"content_block_stop","index":1}',
+ "",
+ "event: content_block_stop",
+ 'data: {"type":"content_block_stop","index":0}',
+ "",
+ "event: message_stop",
+ 'data: {"type":"message_stop"}',
+ "",
+ ]
+ )
+
+ with (
+ patch.object(open_router_provider._client, "build_request"),
+ patch.object(
+ open_router_provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ ):
+ events = [e async for e in open_router_provider.stream_response(MockRequest())]
+
+ parsed = parse_sse_text("".join(events))
+ assert_anthropic_stream_contract(parsed)
+ assert thinking_content(parsed) == "first second"
+ assert "Answer" in text_content(parsed)
+ stop_payloads = [
+ p.data
+ for p in parsed
+ if p.event == "content_block_stop"
+ and p.data.get("type") == "content_block_stop"
+ ]
+ seen_stop_indices: set[int] = set()
+ for s in stop_payloads:
+ idx = s.get("index")
+ assert isinstance(idx, int)
+ assert idx not in seen_stop_indices, "stop reused or duplicated index"
+ seen_stop_indices.add(idx)
+ # Two distinct thinking block indices: initial + reopened segment
+ think_starts = [
+ p
+ for p in parsed
+ if p.event == "content_block_start"
+ and p.data.get("content_block", {}).get("type") == "thinking"
+ ]
+ assert len(think_starts) == 2, (
+ "reopened thinking must have its own `content_block_start`"
+ )
+
+
+@pytest.mark.asyncio
+async def test_stream_response_reopened_tool_use_preserves_tool_identity(
+ open_router_provider,
+):
+ """After overlapping close, resumed input_json_delta must keep original tool id/name."""
+ lines: list[str] = []
+ for payload in (
+ {
+ "type": "content_block_start",
+ "index": 0,
+ "content_block": {
+ "type": "tool_use",
+ "id": "toolu_real_1",
+ "name": "Read",
+ "input": {},
+ },
+ },
+ {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "input_json_delta", "partial_json": '{"path'},
+ },
+ {
+ "type": "content_block_start",
+ "index": 1,
+ "content_block": {"type": "text", "text": ""},
+ },
+ {
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": {"type": "input_json_delta", "partial_json": '":"/tmp"}'},
+ },
+ {"type": "content_block_stop", "index": 1},
+ {"type": "content_block_stop", "index": 0},
+ ):
+ event_name = payload["type"]
+ lines.extend((f"event: {event_name}", f"data: {json.dumps(payload)}", ""))
+
+ response = FakeResponse(lines=lines)
+
+ with (
+ patch.object(open_router_provider._client, "build_request"),
+ patch.object(
+ open_router_provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ ):
+ events = [e async for e in open_router_provider.stream_response(MockRequest())]
+
+ parsed = parse_sse_text("".join(events))
+ tool_starts = [
+ p
+ for p in parsed
+ if p.event == "content_block_start"
+ and p.data.get("content_block", {}).get("type") == "tool_use"
+ ]
+ assert len(tool_starts) == 2
+ for start in tool_starts:
+ block = start.data["content_block"]
+ assert block["id"] == "toolu_real_1"
+ assert block["name"] == "Read"
+
+
@pytest.mark.asyncio
async def test_stream_response_closes_overlapping_thinking_before_text(
open_router_provider,
diff --git a/tests/providers/test_provider_rate_limit.py b/tests/providers/test_provider_rate_limit.py
index 7cd00d9..7d59280 100644
--- a/tests/providers/test_provider_rate_limit.py
+++ b/tests/providers/test_provider_rate_limit.py
@@ -27,11 +27,11 @@ class TestProviderRateLimiter:
GlobalRateLimiter.reset_instance()
limiter = GlobalRateLimiter.get_instance(rate_limit=1, rate_window=0.25)
- start_time = time.time()
+ start_time = time.monotonic()
async def call_limiter():
await limiter.wait_if_blocked()
- return time.time()
+ return time.monotonic()
# 5 requests.
# R0 -> 0s
@@ -41,7 +41,7 @@ class TestProviderRateLimiter:
# R4 -> 1.00s
results = [await call_limiter() for _ in range(5)]
- total_time = time.time() - start_time
+ total_time = time.monotonic() - start_time
assert len(results) == 5
# Should take at least ~1.0s
@@ -56,7 +56,7 @@ class TestProviderRateLimiter:
GlobalRateLimiter.reset_instance()
limiter = GlobalRateLimiter.get_instance()
- start_time = time.time()
+ start_time = time.monotonic()
# Manually block for 1.5s
block_time = 1.5
@@ -71,7 +71,7 @@ class TestProviderRateLimiter:
# They should both wait for the block time
results = await asyncio.gather(call_limiter(), call_limiter())
- total_time = time.time() - start_time
+ total_time = time.monotonic() - start_time
# Both should report having waited reactively
assert all(results) is True
@@ -121,10 +121,10 @@ class TestProviderRateLimiter:
GlobalRateLimiter.reset_instance()
limiter = GlobalRateLimiter.get_instance(rate_limit=10000, rate_window=60)
- start = time.time()
+ start = time.monotonic()
for _ in range(20):
await limiter.wait_if_blocked()
- duration = time.time() - start
+ duration = time.monotonic() - start
# 20 requests with 10000 limit should be near-instant
assert duration < 1.0, f"High rate limit caused throttling: {duration:.2f}s"
@@ -252,6 +252,32 @@ class TestProviderRateLimiter:
assert result == "ok"
assert call_count == 2
+ @pytest.mark.asyncio
+ async def test_execute_with_retry_succeeds_on_httpx_429(self):
+ """HTTP 429 as httpx.HTTPStatusError then success returns result."""
+ import httpx
+ from httpx import Request, Response
+
+ limiter = GlobalRateLimiter.get_instance(rate_limit=100, rate_window=60)
+
+ call_count = 0
+
+ async def fail_then_ok():
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ r = Response(429, request=Request("POST", "http://x"), text="slow")
+ raise httpx.HTTPStatusError(
+ "Too Many Requests", request=r.request, response=r
+ )
+ return "ok"
+
+ result = await limiter.execute_with_retry(
+ fail_then_ok, max_retries=2, base_delay=0.01, max_delay=0.1, jitter=0
+ )
+ assert result == "ok"
+ assert call_count == 2
+
@pytest.mark.asyncio
async def test_max_concurrency_zero_raises(self):
"""max_concurrency <= 0 raises ValueError."""
diff --git a/tests/providers/test_provider_transport_logging.py b/tests/providers/test_provider_transport_logging.py
new file mode 100644
index 0000000..5076930
--- /dev/null
+++ b/tests/providers/test_provider_transport_logging.py
@@ -0,0 +1,263 @@
+"""Tests for metadata-only provider transport logging by default."""
+
+import logging
+from contextlib import asynccontextmanager
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from config.constants import NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES
+from config.nim import NimSettings
+from providers.anthropic_messages import AnthropicMessagesTransport
+from providers.base import ProviderConfig
+from providers.nvidia_nim import NvidiaNimProvider
+from tests.provider_request_mocks import make_openai_compat_stream_request
+from tests.providers.test_anthropic_messages import (
+ FakeResponse,
+ MockRequest,
+ NativeProvider,
+)
+
+
+@pytest.fixture
+def provider_config():
+ return ProviderConfig(
+ api_key="test-key",
+ base_url="https://custom.test/v1/",
+ proxy="socks5://127.0.0.1:9999",
+ rate_limit=10,
+ rate_window=60,
+ http_read_timeout=600.0,
+ http_write_timeout=15.0,
+ http_connect_timeout=5.0,
+ )
+
+
+@pytest.fixture(autouse=True)
+def mock_rate_limiter():
+ @asynccontextmanager
+ async def _slot():
+ yield
+
+ with patch("providers.anthropic_messages.GlobalRateLimiter") as mock:
+ instance = mock.get_scoped_instance.return_value
+
+ async def _passthrough(fn, *args, **kwargs):
+ return await fn(*args, **kwargs)
+
+ instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
+ instance.concurrency_slot.side_effect = _slot
+ yield instance
+
+
+@pytest.mark.asyncio
+async def test_native_non_200_logs_exclude_body_text_by_default(
+ caplog, provider_config
+):
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ response = FakeResponse(status_code=500, text="SECRET_UPSTREAM_BODY")
+
+ with (
+ patch.object(provider._client, "build_request", return_value=MagicMock()),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_UPSTREAM_BODY" not in messages
+ assert "HTTP 500" in messages
+ assert "body_preview_bytes=" not in messages
+
+
+@pytest.mark.asyncio
+async def test_native_non_200_logs_body_when_verbose(caplog, provider_config):
+ provider_config.log_api_error_tracebacks = True
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ response = FakeResponse(status_code=500, text="SECRET_UPSTREAM_BODY")
+
+ with (
+ patch.object(provider._client, "build_request", return_value=MagicMock()),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_UPSTREAM_BODY" in messages
+ assert "truncated=False" in messages
+
+
+@pytest.mark.asyncio
+async def test_native_non_200_verbose_logs_only_capped_error_body(
+ caplog, provider_config
+):
+ provider_config.log_api_error_tracebacks = True
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ tail = "SECRET_TAIL_NOT_LOGGED"
+ huge = f"{'A' * (NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES + 50)}{tail}"
+ response = FakeResponse(status_code=500, text=huge)
+
+ with (
+ patch.object(provider._client, "build_request", return_value=MagicMock()),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_TAIL_NOT_LOGGED" not in messages
+ assert "truncated=True" in messages
+ assert f"body_preview_bytes={NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES}" in messages
+
+
+@pytest.mark.asyncio
+async def test_native_non_200_default_does_not_read_oversized_body(
+ caplog, provider_config
+):
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ huge = f"{'Z' * 500_000}LEAK_MARKER"
+ response = FakeResponse(status_code=500, text=huge)
+
+ with (
+ patch.object(provider._client, "build_request", return_value=MagicMock()),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "LEAK_MARKER" not in messages
+ assert "ZZZ" not in messages
+ assert "HTTP 500" in messages
+
+
+@pytest.mark.asyncio
+async def test_native_stream_failure_logs_exclude_exception_str_by_default(
+ caplog, provider_config
+):
+ provider = NativeProvider(provider_config)
+ req = MockRequest()
+ response = FakeResponse(
+ lines=[
+ "event: ping",
+ 'data: {"type":"ping"}',
+ "",
+ ]
+ )
+
+ async def boom(_self, _response):
+ raise RuntimeError("SECRET_DETAIL")
+ if False:
+ yield ""
+
+ with (
+ patch.object(provider._client, "build_request", return_value=MagicMock()),
+ patch.object(
+ provider._client,
+ "send",
+ new_callable=AsyncMock,
+ return_value=response,
+ ),
+ patch.object(AnthropicMessagesTransport, "_iter_sse_events", boom),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_DETAIL" not in messages
+ assert "exc_type=RuntimeError" in messages
+ assert "http_status=None" in messages
+
+
+@pytest.mark.asyncio
+async def test_openai_compat_stream_failure_default_logs_exclude_exception_str(caplog):
+ config = ProviderConfig(
+ api_key="k",
+ base_url="http://localhost:1/v1",
+ log_api_error_tracebacks=False,
+ )
+ provider = NvidiaNimProvider(config, nim_settings=NimSettings())
+ req = make_openai_compat_stream_request()
+
+ @asynccontextmanager
+ async def _noop_slot():
+ yield
+
+ with (
+ patch.object(
+ provider,
+ "_create_stream",
+ new_callable=AsyncMock,
+ side_effect=RuntimeError("SECRET_OPENAI_COMPAT"),
+ ),
+ patch.object(
+ provider._global_rate_limiter,
+ "concurrency_slot",
+ _noop_slot,
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_OPENAI_COMPAT" not in messages
+ assert "exc_type=RuntimeError" in messages
+
+
+@pytest.mark.asyncio
+async def test_openai_compat_stream_failure_respects_verbose_flag(caplog):
+ config = ProviderConfig(
+ api_key="k",
+ base_url="http://localhost:1/v1",
+ log_api_error_tracebacks=True,
+ )
+ provider = NvidiaNimProvider(config, nim_settings=NimSettings())
+ req = make_openai_compat_stream_request()
+
+ @asynccontextmanager
+ async def _noop_slot():
+ yield
+
+ with (
+ patch.object(
+ provider,
+ "_create_stream",
+ new_callable=AsyncMock,
+ side_effect=RuntimeError("SECRET_OPENAI_COMPAT"),
+ ),
+ patch.object(
+ provider._global_rate_limiter,
+ "concurrency_slot",
+ _noop_slot,
+ ),
+ caplog.at_level(logging.ERROR),
+ ):
+ _ = [e async for e in provider.stream_response(req)]
+
+ messages = " | ".join(r.getMessage() for r in caplog.records)
+ assert "SECRET_OPENAI_COMPAT" in messages
diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py
index b00d780..ac8cbba 100644
--- a/tests/providers/test_registry.py
+++ b/tests/providers/test_registry.py
@@ -39,7 +39,7 @@ def _make_settings(**overrides):
mock.provider_max_concurrency = 5
mock.http_read_timeout = 300.0
mock.http_write_timeout = 10.0
- mock.http_connect_timeout = 2.0
+ mock.http_connect_timeout = 10.0
mock.enable_model_thinking = True
mock.nim = NimSettings()
for key, value in overrides.items():
diff --git a/tests/providers/test_sse_builder.py b/tests/providers/test_sse_builder.py
index 07df2f3..01426e5 100644
--- a/tests/providers/test_sse_builder.py
+++ b/tests/providers/test_sse_builder.py
@@ -1,19 +1,20 @@
"""Tests for core.anthropic.sse."""
-import json
from unittest.mock import patch
import pytest
from core.anthropic import ContentBlockManager, SSEBuilder, map_stop_reason
+from core.anthropic.sse import ToolCallState
+from core.anthropic.stream_contracts import parse_sse_text
def _parse_sse(sse_str: str) -> dict:
"""Parse an SSE event string into its data payload."""
- for line in sse_str.strip().split("\n"):
- if line.startswith("data: "):
- return json.loads(line[len("data: ") :])
- raise ValueError(f"No data line found in SSE: {sse_str}")
+ events = parse_sse_text(sse_str)
+ if len(events) != 1:
+ raise ValueError(f"expected 1 SSE event, got {len(events)} in {sse_str!r}")
+ return events[0].data
class TestMapStopReason:
@@ -61,6 +62,24 @@ class TestContentBlockManager:
assert mgr.text_started is False
assert mgr.tool_states == {}
+ def test_flush_task_arg_buffers_logs_digest_not_secret(self, caplog):
+ """Invalid Task JSON warnings must not echo argument prefixes (secrets)."""
+ mgr = ContentBlockManager()
+ mgr.tool_states[0] = ToolCallState(
+ block_index=0, tool_id="call_x", name="Task", started=True
+ )
+ mgr.tool_states[
+ 0
+ ].task_arg_buffer = (
+ '{"api_key": "sk-live-super-secret-do-not-log"}not_valid_json'
+ )
+ with caplog.at_level("WARNING"):
+ out = mgr.flush_task_arg_buffers()
+ assert out == [(0, "{}")]
+ text = " | ".join(r.message for r in caplog.records)
+ assert "sk-live-super-secret" not in text
+ assert "buffer_sha256_prefix=" in text
+
class TestSSEBuilderMessageLifecycle:
"""Tests for message_start, message_delta, message_stop."""
diff --git a/tests/providers/test_streaming_errors.py b/tests/providers/test_streaming_errors.py
index 0b5f614..085212b 100644
--- a/tests/providers/test_streaming_errors.py
+++ b/tests/providers/test_streaming_errors.py
@@ -1,14 +1,20 @@
"""Tests for streaming error handling in providers/nvidia_nim/client.py."""
import json
+from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from config.nim import NimSettings
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ parse_sse_text,
+)
from providers.base import ProviderConfig
from providers.nvidia_nim import NvidiaNimProvider
+from tests.provider_request_mocks import make_openai_compat_stream_request
class AsyncStreamMock:
@@ -53,22 +59,7 @@ def _make_provider_with_thinking_enabled(enabled: bool):
def _make_request(model="test-model", stream=True):
"""Create a mock request with all fields build_request_body needs."""
- req = MagicMock()
- req.model = model
- req.stream = stream
- req.messages = []
- req.system = None
- req.tools = None
- req.tool_choice = None
- req.metadata = None
- req.max_tokens = 4096
- req.temperature = None
- req.top_p = None
- req.top_k = None
- req.stop_sequences = None
- req.extra_body = None
- req.thinking = None
- return req
+ return make_openai_compat_stream_request(model=model, stream=stream)
def _make_chunk(
@@ -95,6 +86,29 @@ async def _collect_stream(provider, request):
return [e async for e in provider.stream_response(request)]
+def _assert_no_content_deltas_after_error_text(
+ events: list[str], error_substr: str
+) -> None:
+ """After the error text delta, only block close + message tail events may follow."""
+ parsed = parse_sse_text("".join(events))
+ first_error_idx = None
+ for i, ev in enumerate(parsed):
+ if ev.event != "content_block_delta":
+ continue
+ delta = ev.data.get("delta", {})
+ if delta.get("type") == "text_delta" and error_substr in str(
+ delta.get("text", "")
+ ):
+ first_error_idx = i
+ break
+ assert first_error_idx is not None, (error_substr, "".join(events))
+ for ev in parsed[first_error_idx + 1 :]:
+ assert ev.event in ("content_block_stop", "message_delta", "message_stop"), (
+ ev.event,
+ ev.data,
+ )
+
+
class TestStreamingExceptionHandling:
"""Tests for error paths during stream_response."""
@@ -128,6 +142,7 @@ class TestStreamingExceptionHandling:
assert "message_start" in event_text
assert "API failed" in event_text
assert "message_stop" in event_text
+ _assert_no_content_deltas_after_error_text(events, "API failed")
@pytest.mark.asyncio
async def test_read_timeout_with_empty_message_emits_fallback(self):
@@ -161,6 +176,7 @@ class TestStreamingExceptionHandling:
assert "timed out after" in event_text
assert "request_id=req_timeout123" in event_text
assert "message_stop" in event_text
+ _assert_no_content_deltas_after_error_text(events, "timed out after")
@pytest.mark.asyncio
async def test_error_after_partial_content(self):
@@ -191,6 +207,7 @@ class TestStreamingExceptionHandling:
assert "Hello" in event_text
assert "Connection lost" in event_text
assert "message_stop" in event_text
+ _assert_no_content_deltas_after_error_text(events, "Connection lost")
@pytest.mark.asyncio
async def test_empty_response_gets_space(self):
@@ -353,6 +370,10 @@ class TestStreamingExceptionHandling:
in event_text
)
assert "request_id=REQ405" in event_text
+ _assert_no_content_deltas_after_error_text(
+ events,
+ "Upstream provider NIM rejected the request method or endpoint (HTTP 405).",
+ )
@pytest.mark.asyncio
async def test_stream_rate_limited_retries_via_execute_with_retry(self):
@@ -406,6 +427,59 @@ class TestProcessToolCall:
assert "search" in event_text
assert "call_123" in event_text
+ def test_tool_call_id_arrives_before_name_still_emits_id_and_name(self):
+ """Split-stream tool: id (no name) then name then args; id preserved on start."""
+ provider = _make_provider()
+ from core.anthropic import SSEBuilder
+
+ sse = SSEBuilder("msg_test", "test-model")
+ t1 = {
+ "index": 0,
+ "id": "call_split",
+ "function": {"name": None, "arguments": ""},
+ }
+ t2 = {
+ "index": 0,
+ "id": "call_split",
+ "function": {"name": "Grep", "arguments": ""},
+ }
+ t3 = {
+ "index": 0,
+ "id": "call_split",
+ "function": {"name": None, "arguments": "{}"},
+ }
+ b1 = "".join(provider._process_tool_call(t1, sse))
+ b2 = "".join(provider._process_tool_call(t2, sse))
+ b3 = "".join(provider._process_tool_call(t3, sse))
+ combined = b1 + b2 + b3
+ assert "call_split" in combined
+ assert "Grep" in combined
+ assert b1 == ""
+
+ def test_tool_call_arguments_buffered_until_name(self):
+ """Argument deltas before tool name are emitted after the block starts."""
+ provider = _make_provider()
+ from core.anthropic import SSEBuilder
+
+ sse = SSEBuilder("msg_test", "test-model")
+ t1 = {
+ "index": 0,
+ "id": "call_buf",
+ "function": {"name": None, "arguments": '{"x":'},
+ }
+ t2 = {
+ "index": 0,
+ "id": "call_buf",
+ "function": {"name": "Read", "arguments": "1}"},
+ }
+ b1 = "".join(provider._process_tool_call(t1, sse))
+ b2 = "".join(provider._process_tool_call(t2, sse))
+ assert b1 == ""
+ combined = b2
+ assert "Read" in combined
+ assert "call_buf" in combined
+ assert '{"x":' in combined or "partial_json" in combined
+
def test_tool_call_without_id_generates_uuid(self):
"""Tool call without id generates a uuid-based id."""
provider = _make_provider()
@@ -617,6 +691,7 @@ class TestStreamChunkEdgeCases:
assert "Partial" in event_text
assert "Connection reset" in event_text
assert "message_stop" in event_text
+ _assert_no_content_deltas_after_error_text(events, "Connection reset")
def test_stream_malformed_tool_args_chunked(self):
"""Chunked tool args that never form valid JSON are flushed with {}."""
@@ -642,3 +717,36 @@ class TestStreamChunkEdgeCases:
event_text = "".join(events1 + events2 + flushed)
assert "tool_use" in event_text
assert "{}" in event_text
+
+
+@pytest.mark.asyncio
+async def test_openai_compat_stream_ends_with_contract_when_tool_name_never_arrives() -> (
+ None
+):
+ """Nameless / incomplete tool-call buffer must not break Anthropic stream contract."""
+ provider = _make_provider()
+ request = _make_request()
+ tc0 = SimpleNamespace(
+ index=0,
+ id="call_inc",
+ function=SimpleNamespace(name=None, arguments="{}"),
+ )
+ stream_mock = AsyncStreamMock([_make_chunk(tool_calls=[tc0])])
+ with (
+ patch.object(
+ provider._client.chat.completions,
+ "create",
+ new_callable=AsyncMock,
+ return_value=stream_mock,
+ ),
+ patch.object(
+ provider._global_rate_limiter,
+ "wait_if_blocked",
+ new_callable=AsyncMock,
+ return_value=False,
+ ),
+ ):
+ events = await _collect_stream(provider, request)
+ text = "".join(events)
+ assert_anthropic_stream_contract(parse_sse_text(text))
+ assert "text_delta" in text
diff --git a/tests/stream_contract.py b/tests/stream_contract.py
new file mode 100644
index 0000000..9fdc658
--- /dev/null
+++ b/tests/stream_contract.py
@@ -0,0 +1,18 @@
+"""Shared assertions for canonical provider streaming error envelopes."""
+
+from core.anthropic.stream_contracts import (
+ assert_anthropic_stream_contract,
+ parse_sse_text,
+ text_content,
+)
+
+
+def assert_canonical_stream_error_envelope(
+ events: list[str], *, user_message_substr: str
+) -> None:
+ """Native transports emit message_start → text error → message_stop."""
+ blob = "".join(events)
+ assert "event: error\ndata:" not in blob
+ parsed = parse_sse_text(blob)
+ assert_anthropic_stream_contract(parsed, allow_error=False)
+ assert user_message_substr in text_content(parsed)