Major refactor: API, providers, messaging, and Anthropic protocol
Some checks are pending
CI / checks (push) Waiting to run

Consolidates the incremental refactor work into a single change set: modular web tools (api/web_tools), native Anthropic request building and SSE block policy, OpenAI conversion and error handling, provider transports and rate limiting, messaging handler and tree queue, safe logging, smoke tests, and broad test coverage.
This commit is contained in:
Alishahryar1 2026-04-26 02:55:10 -07:00
parent b9ed704095
commit f3a7528d49
139 changed files with 7460 additions and 2422 deletions

View file

@ -47,8 +47,8 @@ OPENROUTER_PROXY=""
LMSTUDIO_PROXY=""
LLAMACPP_PROXY=""
PROVIDER_RATE_LIMIT=40
PROVIDER_RATE_WINDOW=60
PROVIDER_RATE_LIMIT=1
PROVIDER_RATE_WINDOW=3
PROVIDER_MAX_CONCURRENCY=5
@ -102,3 +102,25 @@ ENABLE_NETWORK_PROBE_MOCK=true
ENABLE_TITLE_GENERATION_SKIP=true
ENABLE_SUGGESTION_MODE_SKIP=true
ENABLE_FILEPATH_EXTRACTION_MOCK=true
# Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default)
ENABLE_WEB_SERVER_TOOLS=true
WEB_FETCH_ALLOWED_SCHEMES=http,https
WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false
# Verbose diagnostics (avoid logging raw prompts / SSE bodies in production)
DEBUG_PLATFORM_EDITS=false
DEBUG_SUBAGENT_STACK=false
# When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging).
LOG_RAW_API_PAYLOADS=false
LOG_RAW_SSE_EVENTS=false
# When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data).
LOG_API_ERROR_TRACEBACKS=false
# When true, log message/transcription text previews in messaging adapters (may leak user content).
LOG_RAW_MESSAGING_CONTENT=false
# When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text.
LOG_RAW_CLI_DIAGNOSTICS=false
# When true, log full exception and CLI error message strings in messaging (may leak user content).
LOG_MESSAGING_ERROR_DETAILS=false

View file

@ -32,13 +32,14 @@
- **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code.
- **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue.
- **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required.
- **Maximum Test Coverage**: There should be maximum test coverage for everything, preferably live smoke test coverage to catch bugs early
## COGNITIVE WORKFLOW
1. **ANALYZE**: Read relevant files. Do not guess.
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
4. **VERIFY**: Run ci checks. Confirm the fix via logs or output.
4. **VERIFY**: Run ci checks and relevant smoke tests. Confirm the fix via logs or output.
5. **SPECIFICITY**: Do exactly as much as asked; nothing more, nothing less.
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.

19
PLAN.md
View file

@ -56,13 +56,20 @@ with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging
**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`,
and `providers.registry` from the providers package (not per-adapter modules).
`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`.
`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract
assertions for default CI live under `core/anthropic/stream_contracts.py`;
`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers
(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and
unit tests; production HTTP handlers must use `resolve_provider` with
`messaging/` does not import `api`, `cli`, or `smoke`, and may import `providers`
only via `providers.nvidia_nim.voice` (NVIDIA/Riva offline ASR). Stream contract
helpers live in `core/anthropic/stream_contracts.py`; live smoke imports that
module directly (no dedicated smoke SSE shim). NVIDIA NIM chat tuning uses the
canonical `config.nim.NimSettings` model on `Settings`; `providers.registry`
passes `settings.nim` into `NvidiaNimProvider` without a duplicate schema.
Default upstream base URLs use a single constant per endpoint in
`providers/defaults.py` (e.g. `NVIDIA_NIM_DEFAULT_BASE`). Process-cached provider
helpers (`api.dependencies.get_provider` / `get_provider_for_type`) exist for
scripts and unit tests; production HTTP handlers must use `resolve_provider` with
`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package
`__all__` exposes HTTP models and `create_app` only (not those helpers).
`__all__` exposes HTTP models and `create_app` only (not `app`, not those helpers).
`api.app:create_app` is the ASGI factory (e.g. `uvicorn api.app:create_app --factory`);
`server.py` still exposes `server:app` as a module-level instance for convenience.
## Target Boundaries

View file

@ -1,6 +1,6 @@
"""API layer for Claude Code Proxy."""
from .app import app, create_app
from .app import create_app
from .models import (
MessagesRequest,
MessagesResponse,
@ -13,6 +13,5 @@ __all__ = [
"MessagesResponse",
"TokenCountRequest",
"TokenCountResponse",
"app",
"create_app",
]

View file

@ -1,6 +1,6 @@
"""FastAPI application factory and configuration."""
import os
import traceback
from contextlib import asynccontextmanager
from typing import Any
@ -16,13 +16,7 @@ from providers.exceptions import ProviderError
from .routes import router
from .runtime import AppRuntime
# Opt-in to future behavior for python-telegram-bot
os.environ["PTB_TIMEDELTA"] = "1"
# Configure logging first (before any module logs)
_settings = get_settings()
configure_logging(_settings.log_file)
from .validation_log import summarize_request_validation_body
@asynccontextmanager
@ -38,6 +32,11 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
settings = get_settings()
configure_logging(
settings.log_file, verbose_third_party=settings.log_raw_api_payloads
)
app = FastAPI(
title="Claude Code Proxy",
version="2.0.0",
@ -57,33 +56,7 @@ def create_app() -> FastAPI:
except Exception as e:
body = {"_json_error": type(e).__name__}
messages = body.get("messages") if isinstance(body, dict) else None
message_summary: list[dict[str, Any]] = []
if isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
message_summary.append({"message_kind": type(msg).__name__})
continue
content = msg.get("content")
item: dict[str, Any] = {
"role": msg.get("role"),
"content_kind": type(content).__name__,
}
if isinstance(content, list):
item["block_types"] = [
block.get("type", "dict")
if isinstance(block, dict)
else type(block).__name__
for block in content[:12]
]
item["block_keys"] = [
sorted(str(key) for key in block)[:12]
for block in content[:5]
if isinstance(block, dict)
]
elif isinstance(content, str):
item["content_length"] = len(content)
message_summary.append(item)
message_summary, tool_names = summarize_request_validation_body(body)
logger.debug(
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
@ -92,20 +65,27 @@ def create_app() -> FastAPI:
[list(error.get("loc", ())) for error in exc.errors()],
[str(error.get("type", "")) for error in exc.errors()],
message_summary,
[
str(tool.get("name", ""))
for tool in body.get("tools", [])
if isinstance(body, dict)
and isinstance(body.get("tools"), list)
and isinstance(tool, dict)
],
tool_names,
)
return await request_validation_exception_handler(request, exc)
@app.exception_handler(ProviderError)
async def provider_error_handler(request: Request, exc: ProviderError):
"""Handle provider-specific errors and return Anthropic format."""
logger.error(f"Provider Error: {exc.error_type} - {exc.message}")
err_settings = get_settings()
if err_settings.log_api_error_tracebacks:
logger.error(
"Provider Error: error_type={} status_code={} message={}",
exc.error_type,
exc.status_code,
exc.message,
)
else:
logger.error(
"Provider Error: error_type={} status_code={}",
exc.error_type,
exc.status_code,
)
return JSONResponse(
status_code=exc.status_code,
content=exc.to_anthropic_format(),
@ -114,10 +94,17 @@ def create_app() -> FastAPI:
@app.exception_handler(Exception)
async def general_error_handler(request: Request, exc: Exception):
"""Handle general errors and return Anthropic format."""
logger.error(f"General Error: {exc!s}")
import traceback
settings = get_settings()
if settings.log_api_error_tracebacks:
logger.error("General Error: {}", exc)
logger.error(traceback.format_exc())
else:
logger.error(
"General Error: path={} method={} exc_type={}",
request.url.path,
request.method,
type(exc).__name__,
)
return JSONResponse(
status_code=500,
content={
@ -130,7 +117,3 @@ def create_app() -> FastAPI:
)
return app
# Default app instance for uvicorn
app = create_app()

View file

@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str:
return "<filepaths>\n</filepaths>"
except Exception:
except ValueError:
return "<filepaths>\n</filepaths>"

View file

@ -8,7 +8,11 @@ from config.settings import Settings
from config.settings import get_settings as _get_settings
from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
from providers.exceptions import (
AuthenticationError,
ServiceUnavailableError,
UnknownProviderTypeError,
)
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
"""Get application settings via dependency injection."""
"""Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
return _get_settings()
@ -31,10 +35,9 @@ def resolve_provider(
"""Resolve a provider using the app-scoped registry when ``app`` is set.
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
is always used. If the registry is missing (e.g. a test app without
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
is installed on ``app.state`` so the process cache is never mixed with
per-request app identity.
must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
Callers that construct a bare ``FastAPI`` without lifespan must set
``app.state.provider_registry`` explicitly.
When ``app`` is ``None`` (no HTTP context), uses the process-level
:data:`_providers` cache only.
@ -42,8 +45,10 @@ def resolve_provider(
if app is not None:
reg = getattr(app.state, "provider_registry", None)
if reg is None:
reg = ProviderRegistry()
app.state.provider_registry = reg
raise ServiceUnavailableError(
"Provider registry is not configured. Ensure AppRuntime startup ran "
"or assign app.state.provider_registry for test apps."
)
return _resolve_with_registry(reg, provider_type, settings)
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
@ -55,9 +60,10 @@ def _resolve_with_registry(
try:
provider = registry.get(provider_type, settings)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
# Provider :class:`~providers.exceptions.AuthenticationError` messages are
# curated configuration hints (env var names, docs links), not upstream noise.
detail = str(e).strip() or get_user_facing_error_message(e)
raise HTTPException(status_code=503, detail=detail) from e
except UnknownProviderTypeError:
logger.error(
"Unknown provider_type: '{}'. Supported: {}",
@ -73,8 +79,9 @@ def _resolve_with_registry(
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider in the process-level cache (no ``app``/Request).
For server requests, use :func:`resolve_provider` with the active
:attr:`request.app` so the app-scoped provider registry is used.
HTTP route handlers should call :func:`resolve_provider` with the active
:attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
process-wide cache.
"""
return resolve_provider(provider_type, app=None, settings=get_settings())

View file

@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st
try:
cmd_start = content.rfind("Command:") + len("Command:")
return True, content[cmd_start:].strip()
except Exception:
pass
except TypeError:
return False, ""
return False, ""
@ -121,7 +121,6 @@ def is_filepath_extraction_request(
if not user_has_filepaths and not system_has_extract:
return False, "", ""
try:
cmd_start = content.find("Command:") + len("Command:")
output_marker = content.find("Output:", cmd_start)
if output_marker == -1:
@ -135,5 +134,3 @@ def is_filepath_extraction_request(
output = output.split(marker)[0].strip()
return True, command, output
except Exception:
return False, "", ""

View file

@ -3,7 +3,7 @@
from enum import StrEnum
from typing import Any, Literal
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field
# =============================================================================
@ -15,41 +15,68 @@ class Role(StrEnum):
system = "system"
class ContentBlockText(BaseModel):
class _AnthropicBlockBase(BaseModel):
"""Pass through provider fields (e.g. ``cache_control``) for native transports."""
model_config = ConfigDict(extra="allow")
class ContentBlockText(_AnthropicBlockBase):
type: Literal["text"]
text: str
class ContentBlockImage(BaseModel):
class ContentBlockImage(_AnthropicBlockBase):
type: Literal["image"]
source: dict[str, Any]
class ContentBlockToolUse(BaseModel):
class ContentBlockToolUse(_AnthropicBlockBase):
type: Literal["tool_use"]
id: str
name: str
input: dict[str, Any]
class ContentBlockToolResult(BaseModel):
class ContentBlockToolResult(_AnthropicBlockBase):
type: Literal["tool_result"]
tool_use_id: str
content: str | list[Any] | dict[str, Any]
class ContentBlockThinking(BaseModel):
class ContentBlockThinking(_AnthropicBlockBase):
type: Literal["thinking"]
thinking: str
signature: str | None = None
class ContentBlockRedactedThinking(BaseModel):
class ContentBlockRedactedThinking(_AnthropicBlockBase):
type: Literal["redacted_thinking"]
data: str
class SystemContent(BaseModel):
class ContentBlockServerToolUse(_AnthropicBlockBase):
"""Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
type: Literal["server_tool_use"]
id: str
name: str
input: dict[str, Any]
class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
type: Literal["web_search_tool_result"]
tool_use_id: str
content: Any
class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
type: Literal["web_fetch_tool_result"]
tool_use_id: str
content: Any
class SystemContent(_AnthropicBlockBase):
type: Literal["text"]
text: str
@ -68,12 +95,15 @@ class Message(BaseModel):
| ContentBlockToolResult
| ContentBlockThinking
| ContentBlockRedactedThinking
| ContentBlockServerToolUse
| ContentBlockWebSearchToolResult
| ContentBlockWebFetchToolResult
]
)
reasoning_content: str | None = None
class Tool(BaseModel):
class Tool(_AnthropicBlockBase):
name: str
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
# may omit ``input_schema`` because the provider owns the schema.
@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel):
# Request Models
# =============================================================================
class MessagesRequest(BaseModel):
model_config = ConfigDict(extra="allow")
model: str
# Internal routing / debug: accepted on parse but not serialized to providers.
original_model: str | None = Field(default=None, exclude=True)
resolved_provider_model: str | None = Field(default=None, exclude=True)
max_tokens: int | None = None
messages: list[Message]
system: str | list[SystemContent] | None = None
@ -105,13 +140,24 @@ class MessagesRequest(BaseModel):
tools: list[Tool] | None = None
tool_choice: dict[str, Any] | None = None
thinking: ThinkingConfig | None = None
# Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
context_management: dict[str, Any] | None = None
output_config: dict[str, Any] | None = None
mcp_servers: list[dict[str, Any]] | None = None
extra_body: dict[str, Any] | None = None
class TokenCountRequest(BaseModel):
model_config = ConfigDict(extra="allow")
model: str
original_model: str | None = Field(default=None, exclude=True)
resolved_provider_model: str | None = Field(default=None, exclude=True)
messages: list[Message]
system: str | list[SystemContent] | None = None
tools: list[Tool] | None = None
thinking: ThinkingConfig | None = None
tool_choice: dict[str, Any] | None = None
context_management: dict[str, Any] | None = None
output_config: dict[str, Any] | None = None
mcp_servers: list[dict[str, Any]] | None = None

View file

@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0
async def best_effort(
name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S
name: str,
awaitable: Any,
timeout_s: float = _SHUTDOWN_TIMEOUT_S,
*,
log_verbose_errors: bool = False,
) -> None:
"""Run a shutdown step with timeout; never raise to callers."""
try:
await asyncio.wait_for(awaitable, timeout=timeout_s)
except TimeoutError:
logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
except Exception as e:
logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
if log_verbose_errors:
logger.warning(
"Shutdown step failed: {}: {}: {}",
name,
type(e).__name__,
e,
)
else:
logger.warning(
"Shutdown step failed: {}: exc_type={}",
name,
type(e).__name__,
)
def warn_if_process_auth_token(settings: Settings) -> None:
@ -73,20 +89,37 @@ class AppRuntime:
self._publish_state()
async def shutdown(self) -> None:
verbose = self.settings.log_api_error_tracebacks
if self.message_handler is not None:
try:
self.message_handler.session_store.flush_pending_save()
except Exception as e:
logger.warning(f"Session store flush on shutdown: {e}")
if verbose:
logger.warning("Session store flush on shutdown: {}", e)
else:
logger.warning(
"Session store flush on shutdown: exc_type={}",
type(e).__name__,
)
logger.info("Shutdown requested, cleaning up...")
if self.messaging_platform:
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
await best_effort(
"messaging_platform.stop",
self.messaging_platform.stop(),
log_verbose_errors=verbose,
)
if self.cli_manager:
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
await best_effort(
"cli_manager.stop_all",
self.cli_manager.stop_all(),
log_verbose_errors=verbose,
)
if self._provider_registry is not None:
await best_effort(
"provider_registry.cleanup", self._provider_registry.cleanup()
"provider_registry.cleanup",
self._provider_registry.cleanup(),
log_verbose_errors=verbose,
)
await self._shutdown_limiter()
logger.info("Server shut down cleanly")
@ -110,6 +143,10 @@ class AppRuntime:
whisper_device=self.settings.whisper_device,
hf_token=self.settings.hf_token,
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
messaging_rate_limit=self.settings.messaging_rate_limit,
messaging_rate_window=self.settings.messaging_rate_window,
log_raw_messaging_content=self.settings.log_raw_messaging_content,
log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
),
)
@ -117,12 +154,24 @@ class AppRuntime:
await self._start_message_handler()
except ImportError as e:
logger.warning(f"Messaging module import error: {e}")
if self.settings.log_api_error_tracebacks:
logger.warning("Messaging module import error: {}", e)
else:
logger.warning(
"Messaging module import error: exc_type={}",
type(e).__name__,
)
except Exception as e:
logger.error(f"Failed to start messaging platform: {e}")
if self.settings.log_api_error_tracebacks:
logger.error("Failed to start messaging platform: {}", e)
import traceback
logger.error(traceback.format_exc())
else:
logger.error(
"Failed to start messaging platform: exc_type={}",
type(e).__name__,
)
async def _start_message_handler(self) -> None:
from cli.manager import CLISessionManager
@ -151,10 +200,13 @@ class AppRuntime:
allowed_dirs=allowed_dirs,
plans_directory=plans_directory,
claude_bin=self.settings.claude_cli_bin,
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
log_messaging_error_details=self.settings.log_messaging_error_details,
)
session_store = SessionStore(
storage_path=os.path.join(data_path, "sessions.json")
storage_path=os.path.join(data_path, "sessions.json"),
message_log_cap=self.settings.max_message_log_entries_per_chat,
)
platform = self.messaging_platform
assert platform is not None
@ -162,6 +214,11 @@ class AppRuntime:
platform=platform,
cli_manager=self.cli_manager,
session_store=session_store,
debug_platform_edits=self.settings.debug_platform_edits,
debug_subagent_stack=self.settings.debug_subagent_stack,
log_raw_messaging_content=self.settings.log_raw_messaging_content,
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
log_messaging_error_details=self.settings.log_messaging_error_details,
)
self._restore_tree_state(session_store)
@ -201,18 +258,26 @@ class AppRuntime:
self.app.state.cli_manager = self.cli_manager
async def _shutdown_limiter(self) -> None:
verbose = self.settings.log_api_error_tracebacks
try:
from messaging.limiter import MessagingRateLimiter
except Exception as e:
if verbose:
logger.debug(
"Rate limiter shutdown skipped (import failed): {}: {}",
type(e).__name__,
e,
)
else:
logger.debug(
"Rate limiter shutdown skipped (import failed): exc_type={}",
type(e).__name__,
)
return
await best_effort(
"MessagingRateLimiter.shutdown_instance",
MessagingRateLimiter.shutdown_instance(),
timeout_s=2.0,
log_verbose_errors=verbose,
)

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import traceback
import uuid
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from typing import Any
from fastapi import HTTPException
@ -13,6 +13,7 @@ from loguru import logger
from config.settings import Settings
from core.anthropic import get_token_count, get_user_facing_error_message
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS
from providers.base import BaseProvider
from providers.exceptions import InvalidRequestError, ProviderError
@ -20,15 +21,67 @@ from .model_router import ModelRouter
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
from .web_server_tools import (
from .web_tools.egress import WebFetchEgressPolicy
from .web_tools.request import (
is_web_server_tool_request,
stream_web_server_tool_response,
openai_chat_upstream_server_tool_error,
)
from .web_tools.streaming import stream_web_server_tool_response
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
ProviderGetter = Callable[[str], BaseProvider]
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"})
def anthropic_sse_streaming_response(
body: AsyncIterator[str],
) -> StreamingResponse:
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
return StreamingResponse(
body,
media_type="text/event-stream",
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
)
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
"""HTTP status for uncaught non-provider failures (stable client contract)."""
return 500
def _log_unexpected_service_exception(
settings: Settings,
exc: BaseException,
*,
context: str,
request_id: str | None = None,
) -> None:
"""Log service-layer failures without echoing exception text unless opted in."""
if settings.log_api_error_tracebacks:
if request_id is not None:
logger.error("{} request_id={}: {}", context, request_id, exc)
else:
logger.error("{}: {}", context, exc)
logger.error(traceback.format_exc())
return
if request_id is not None:
logger.error(
"{} request_id={} exc_type={}",
context,
request_id,
type(exc).__name__,
)
else:
logger.error("{} exc_type={}", context, type(exc).__name__)
def _require_non_empty_messages(messages: list[Any]) -> None:
if not messages:
raise InvalidRequestError("messages cannot be empty")
class ClaudeProxyService:
"""Coordinate request optimization, model routing, token count, and providers."""
@ -48,25 +101,35 @@ class ClaudeProxyService:
def create_message(self, request_data: MessagesRequest) -> object:
"""Create a message response or streaming response."""
try:
if not request_data.messages:
raise InvalidRequestError("messages cannot be empty")
_require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_messages_request(request_data)
if is_web_server_tool_request(routed.request):
if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
tool_err = openai_chat_upstream_server_tool_error(
routed.request,
web_tools_enabled=self._settings.enable_web_server_tools,
)
if tool_err is not None:
raise InvalidRequestError(tool_err)
if self._settings.enable_web_server_tools and is_web_server_tool_request(
routed.request
):
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
logger.info("Optimization: Handling Anthropic web server tool")
return StreamingResponse(
egress = WebFetchEgressPolicy(
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
)
return anthropic_sse_streaming_response(
stream_web_server_tool_response(
routed.request, input_tokens=input_tokens
routed.request,
input_tokens=input_tokens,
web_fetch_egress=egress,
verbose_client_errors=self._settings.log_api_error_tracebacks,
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
optimized = try_optimizations(routed.request, self._settings)
@ -75,6 +138,10 @@ class ClaudeProxyService:
logger.debug("No optimization matched, routing to provider")
provider = self._provider_getter(routed.resolved.provider_id)
provider.preflight_stream(
routed.request,
thinking_enabled=routed.resolved.thinking_enabled,
)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
@ -83,6 +150,7 @@ class ClaudeProxyService:
routed.request.model,
len(routed.request.messages),
)
if self._settings.log_raw_api_payloads:
logger.debug(
"FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
)
@ -90,27 +158,23 @@ class ClaudeProxyService:
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
return StreamingResponse(
return anthropic_sse_streaming_response(
provider.stream_response(
routed.request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=routed.resolved.thinking_enabled,
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except ProviderError:
raise
except Exception as e:
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
_log_unexpected_service_exception(
self._settings, e, context="CREATE_MESSAGE_ERROR"
)
raise HTTPException(
status_code=getattr(e, "status_code", 500),
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e
@ -119,6 +183,7 @@ class ClaudeProxyService:
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
_require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_token_count_request(request_data)
tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
@ -131,13 +196,16 @@ class ClaudeProxyService:
tokens,
)
return TokenCountResponse(input_tokens=tokens)
except ProviderError:
raise
except Exception as e:
logger.error(
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
request_id,
get_user_facing_error_message(e),
traceback.format_exc(),
_log_unexpected_service_exception(
self._settings,
e,
context="COUNT_TOKENS_ERROR",
request_id=request_id,
)
raise HTTPException(
status_code=500, detail=get_user_facing_error_message(e)
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e

48
api/validation_log.py Normal file
View file

@ -0,0 +1,48 @@
"""Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
from __future__ import annotations
from typing import Any
def summarize_request_validation_body(
body: Any,
) -> tuple[list[dict[str, Any]], list[str]]:
"""Return message shape summary and tool name list for debug logs."""
messages = body.get("messages") if isinstance(body, dict) else None
message_summary: list[dict[str, Any]] = []
if isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
message_summary.append({"message_kind": type(msg).__name__})
continue
content = msg.get("content")
item: dict[str, Any] = {
"role": msg.get("role"),
"content_kind": type(content).__name__,
}
if isinstance(content, list):
item["block_types"] = [
block.get("type", "dict")
if isinstance(block, dict)
else type(block).__name__
for block in content[:12]
]
item["block_keys"] = [
sorted(str(key) for key in block)[:12]
for block in content[:5]
if isinstance(block, dict)
]
elif isinstance(content, str):
item["content_length"] = len(content)
message_summary.append(item)
tool_names: list[str] = []
if isinstance(body, dict) and isinstance(body.get("tools"), list):
tool_names = [
str(tool.get("name", ""))
for tool in body["tools"]
if isinstance(tool, dict)
]
return message_summary, tool_names

View file

@ -1,331 +1,22 @@
"""Local handlers for Anthropic web server tools.
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
web tools are server-side: the API response itself must include the tool result.
"""
"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
from __future__ import annotations
import html
import json
import re
import uuid
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from html.parser import HTMLParser
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse
import httpx
from .models.anthropic import MessagesRequest
_REQUEST_TIMEOUT_S = 20.0
_MAX_SEARCH_RESULTS = 10
_MAX_FETCH_CHARS = 24_000
class _SearchResultParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.results: list[dict[str, str]] = []
self._href: str | None = None
self._title_parts: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag != "a":
return
href = dict(attrs).get("href")
if not href or "uddg=" not in href:
return
parsed = urlparse(href)
query = parse_qs(parsed.query)
uddg = query.get("uddg", [""])[0]
if not uddg:
return
self._href = unquote(uddg)
self._title_parts = []
def handle_data(self, data: str) -> None:
if self._href is not None:
self._title_parts.append(data)
def handle_endtag(self, tag: str) -> None:
if tag != "a" or self._href is None:
return
title = " ".join("".join(self._title_parts).split())
if title and not any(result["url"] == self._href for result in self.results):
self.results.append({"title": html.unescape(title), "url": self._href})
self._href = None
self._title_parts = []
class _HTMLTextParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.title = ""
self.text_parts: list[str] = []
self._in_title = False
self._skip_depth = 0
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag in {"script", "style", "noscript"}:
self._skip_depth += 1
elif tag == "title":
self._in_title = True
def handle_endtag(self, tag: str) -> None:
if tag in {"script", "style", "noscript"} and self._skip_depth:
self._skip_depth -= 1
elif tag == "title":
self._in_title = False
def handle_data(self, data: str) -> None:
text = " ".join(data.split())
if not text:
return
if self._in_title:
self.title = f"{self.title} {text}".strip()
elif not self._skip_depth:
self.text_parts.append(text)
def _format_event(event_type: str, data: dict[str, Any]) -> str:
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
def _content_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(str(item.get("text", "")))
else:
parts.append(str(getattr(item, "text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def _request_text(request: MessagesRequest) -> str:
return "\n".join(_content_text(message.content) for message in request.messages)
def _web_tool_name(request: MessagesRequest) -> str | None:
for tool in request.tools or []:
name = tool.name
tool_type = tool.type or ""
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
return name
return None
def is_web_server_tool_request(request: MessagesRequest) -> bool:
return _web_tool_name(request) in {"web_search", "web_fetch"}
def _extract_query(text: str) -> str:
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip().strip("\"'")
return text.strip()
def _extract_url(text: str) -> str:
match = re.search(r"https?://\S+", text)
return match.group(0).rstrip(").,]") if match else text.strip()
async def _run_web_search(query: str) -> list[dict[str, str]]:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
) as client:
response = await client.get(
"https://lite.duckduckgo.com/lite/",
params={"q": query},
from api.web_tools.egress import (
WebFetchEgressPolicy,
WebFetchEgressViolation,
enforce_web_fetch_egress,
)
response.raise_for_status()
from api.web_tools.request import is_web_server_tool_request
from api.web_tools.streaming import stream_web_server_tool_response
parser = _SearchResultParser()
parser.feed(response.text)
return parser.results[:_MAX_SEARCH_RESULTS]
async def _run_web_fetch(url: str) -> dict[str, str]:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
) as client:
response = await client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "text/plain")
title = url
data = response.text
if "html" in content_type.lower():
parser = _HTMLTextParser()
parser.feed(response.text)
title = parser.title or url
data = "\n".join(parser.text_parts)
return {
"url": str(response.url),
"title": title,
"media_type": "text/plain",
"data": data[:_MAX_FETCH_CHARS],
}
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
if not results:
return f"No web search results found for: {query}"
lines = [f"Search results for: {query}"]
for index, result in enumerate(results, start=1):
lines.append(f"{index}. {result['title']}\n{result['url']}")
return "\n\n".join(lines)
async def stream_web_server_tool_response(
request: MessagesRequest, input_tokens: int
) -> AsyncIterator[str]:
tool_name = _web_tool_name(request)
if tool_name is None:
return
text = _request_text(request)
message_id = f"msg_{uuid.uuid4()}"
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
output_tokens = 1
usage_key = (
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
)
tool_input = (
{"query": _extract_query(text)}
if tool_name == "web_search"
else {"url": _extract_url(text)}
)
yield _format_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": request.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
},
},
)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": "server_tool_use",
"id": tool_id,
"name": tool_name,
"input": tool_input,
},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 0}
)
try:
if tool_name == "web_search":
query = str(tool_input["query"])
results = await _run_web_search(query)
result_content: Any = [
{
"type": "web_search_result",
"title": result["title"],
"url": result["url"],
}
for result in results
__all__ = [
"WebFetchEgressPolicy",
"WebFetchEgressViolation",
"enforce_web_fetch_egress",
"httpx",
"is_web_server_tool_request",
"stream_web_server_tool_response",
]
summary = _search_summary(query, results)
result_block_type = "web_search_tool_result"
else:
fetched = await _run_web_fetch(str(tool_input["url"]))
result_content = {
"type": "web_fetch_result",
"url": fetched["url"],
"content": {
"type": "document",
"source": {
"type": "text",
"media_type": fetched["media_type"],
"data": fetched["data"],
},
"title": fetched["title"],
"citations": {"enabled": True},
},
"retrieved_at": datetime.now(UTC).isoformat(),
}
summary = fetched["data"][:_MAX_FETCH_CHARS]
result_block_type = "web_fetch_tool_result"
except Exception as error:
result_block_type = (
"web_search_tool_result"
if tool_name == "web_search"
else "web_fetch_tool_result"
)
error_type = (
"web_search_tool_result_error"
if tool_name == "web_search"
else "web_fetch_tool_error"
)
result_content = {"type": error_type, "error_code": "unavailable"}
summary = f"{tool_name} failed: {type(error).__name__}"
output_tokens = max(1, len(summary) // 4)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 1,
"content_block": {
"type": result_block_type,
"tool_use_id": tool_id,
"content": result_content,
},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 1}
)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 2,
"content_block": {"type": "text", "text": summary},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 2}
)
yield _format_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"server_tool_use": {usage_key: 1},
},
},
)
yield _format_event("message_stop", {"type": "message_stop"})

17
api/web_tools/__init__.py Normal file
View file

@ -0,0 +1,17 @@
"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
from .egress import (
WebFetchEgressPolicy,
WebFetchEgressViolation,
enforce_web_fetch_egress,
)
from .request import is_web_server_tool_request
from .streaming import stream_web_server_tool_response
__all__ = [
"WebFetchEgressPolicy",
"WebFetchEgressViolation",
"enforce_web_fetch_egress",
"is_web_server_tool_request",
"stream_web_server_tool_response",
]

View file

@ -0,0 +1,15 @@
"""Limits and defaults for outbound web server tool HTTP."""
_REQUEST_TIMEOUT_S = 20.0
_MAX_SEARCH_RESULTS = 10
_MAX_FETCH_CHARS = 24_000
# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
# Drain at most this many bytes from redirect responses before following Location.
_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
_MAX_WEB_FETCH_REDIRECTS = 10
_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
_WEB_TOOL_HTTP_HEADERS = {
"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
}

99
api/web_tools/egress.py Normal file
View file

@ -0,0 +1,99 @@
"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
@dataclass(frozen=True, slots=True)
class WebFetchEgressPolicy:
"""Egress rules for user-influenced web_fetch URLs."""
allow_private_network_targets: bool
allowed_schemes: frozenset[str]
class WebFetchEgressViolation(ValueError):
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
def _port_for_url(parsed) -> int:
if parsed.port is not None:
return parsed.port
return 443 if (parsed.scheme or "").lower() == "https" else 80
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
try:
return socket.getaddrinfo(
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
)
except OSError as exc:
raise WebFetchEgressViolation(
f"Could not resolve host {host!r}: {exc}"
) from exc
def get_validated_stream_addrinfos_for_egress(
url: str, policy: WebFetchEgressPolicy
) -> list[tuple]:
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
server cannot rebind to a disallowed address between resolution and the TCP
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
"""
parsed = urlparse(url)
scheme = (parsed.scheme or "").lower()
if scheme not in policy.allowed_schemes:
raise WebFetchEgressViolation(
f"URL scheme {scheme!r} is not allowed for web_fetch"
)
host = parsed.hostname
if host is None or host == "":
raise WebFetchEgressViolation("web_fetch URL must include a host")
port = _port_for_url(parsed)
if policy.allow_private_network_targets:
return _stream_getaddrinfo_or_raise(host, port)
host_lower = host.lower()
if host_lower == "localhost" or host_lower.endswith(".localhost"):
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
if host_lower.endswith(".local"):
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
try:
parsed_ip = ipaddress.ip_address(host)
except ValueError:
parsed_ip = None
if parsed_ip is not None:
if not parsed_ip.is_global:
raise WebFetchEgressViolation(
f"Non-public IP host {host!r} is not allowed for web_fetch"
)
return _stream_getaddrinfo_or_raise(host, port)
infos = _stream_getaddrinfo_or_raise(host, port)
for *_, sockaddr in infos:
addr = sockaddr[0]
try:
resolved = ipaddress.ip_address(addr)
except ValueError:
continue
if not resolved.is_global:
raise WebFetchEgressViolation(
f"Host {host!r} resolves to a non-public address ({resolved})"
)
return infos
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
get_validated_stream_addrinfos_for_egress(url, policy)

278
api/web_tools/outbound.py Normal file
View file

@ -0,0 +1,278 @@
"""Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
from __future__ import annotations
import asyncio
import socket
from collections.abc import AsyncIterator
from urllib.parse import urljoin, urlparse
import aiohttp
import httpx
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.abc import AbstractResolver, ResolveResult
from loguru import logger
from . import constants
from .constants import (
_MAX_FETCH_CHARS,
_MAX_SEARCH_RESULTS,
_REDIRECT_RESPONSE_BODY_CAP_BYTES,
_REQUEST_TIMEOUT_S,
_WEB_FETCH_REDIRECT_STATUSES,
_WEB_TOOL_HTTP_HEADERS,
)
from .egress import (
WebFetchEgressPolicy,
WebFetchEgressViolation,
get_validated_stream_addrinfos_for_egress,
)
from .parsers import HTMLTextParser, SearchResultParser
def _safe_public_host_for_logs(url: str) -> str:
host = urlparse(url).hostname or ""
return host[:253]
def _log_web_tool_failure(
tool_name: str,
error: BaseException,
*,
fetch_url: str | None = None,
) -> None:
exc_type = type(error).__name__
if isinstance(error, WebFetchEgressViolation):
host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
logger.warning(
"web_tool_egress_rejected tool={} exc_type={} host={!r}",
tool_name,
exc_type,
host,
)
return
if tool_name == "web_fetch" and fetch_url:
logger.warning(
"web_tool_failure tool={} exc_type={} host={!r}",
tool_name,
exc_type,
_safe_public_host_for_logs(fetch_url),
)
else:
logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
def _web_tool_client_error_summary(
tool_name: str,
error: BaseException,
*,
verbose: bool,
) -> str:
if verbose:
return f"{tool_name} failed: {type(error).__name__}"
return "Web tool request failed."
async def _iter_response_body_under_cap(
response: httpx.Response, max_bytes: int
) -> AsyncIterator[bytes]:
if max_bytes <= 0:
return
received = 0
async for chunk in response.aiter_bytes(chunk_size=65_536):
if received >= max_bytes:
break
remaining = max_bytes - received
if len(chunk) <= remaining:
received += len(chunk)
yield chunk
if received >= max_bytes:
break
else:
yield chunk[:remaining]
break
async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
async for _ in _iter_response_body_under_cap(response, max_bytes):
pass
async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
return b"".join(
[piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
)
_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
def getaddrinfo_rows_to_resolve_results(
host: str, addrinfos: list[tuple]
) -> list[ResolveResult]:
"""Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
out: list[ResolveResult] = []
for family, _type, proto, _canon, sockaddr in addrinfos:
if family == socket.AF_INET6:
if len(sockaddr) < 3:
continue
if sockaddr[3]:
resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
else:
resolved_host, port = sockaddr[:2]
else:
assert family == socket.AF_INET, family
resolved_host, port = sockaddr[0], sockaddr[1]
resolved_host = str(resolved_host)
port = int(port)
out.append(
ResolveResult(
hostname=host,
host=resolved_host,
port=int(port),
family=family,
proto=proto,
flags=_NUMERIC_RESOLVE_FLAGS,
)
)
return out
class _PinnedEgressStaticResolver(AbstractResolver):
"""Return only pre-validated :class:`ResolveResult` for the outbound request."""
def __init__(self, results: list[ResolveResult]) -> None:
self._results = results
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> list[ResolveResult]:
return self._results
async def close(self) -> None: # pragma: no cover - aiohttp contract
return
async def _read_aiohttp_body_capped(
response: aiohttp.ClientResponse, max_bytes: int
) -> bytes:
received = 0
parts: list[bytes] = []
async for chunk in response.content.iter_chunked(65_536):
if received >= max_bytes:
break
remaining = max_bytes - received
if len(chunk) <= remaining:
received += len(chunk)
parts.append(chunk)
else:
parts.append(chunk[:remaining])
break
return b"".join(parts)
async def _drain_aiohttp_body_capped(
response: aiohttp.ClientResponse, max_bytes: int
) -> None:
if max_bytes <= 0:
return
received = 0
async for chunk in response.content.iter_chunked(65_536):
received += len(chunk)
if received >= max_bytes:
break
async def _run_web_search(query: str) -> list[dict[str, str]]:
async with (
httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers=_WEB_TOOL_HTTP_HEADERS,
) as client,
client.stream(
"GET",
"https://lite.duckduckgo.com/lite/",
params={"q": query},
) as response,
):
response.raise_for_status()
body_bytes = await _read_response_body_capped(
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
)
text = body_bytes.decode("utf-8", errors="replace")
parser = SearchResultParser()
parser.feed(text)
return parser.results[:_MAX_SEARCH_RESULTS]
async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
"""Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
current_url = url
redirect_hops = 0
timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
while True:
addr_infos = await asyncio.to_thread(
get_validated_stream_addrinfos_for_egress, current_url, egress
)
host = urlparse(current_url).hostname or ""
results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
resolver = _PinnedEgressStaticResolver(results)
connector = TCPConnector(
resolver=resolver,
force_close=True,
)
try:
async with (
ClientSession(
timeout=timeout,
headers=_WEB_TOOL_HTTP_HEADERS,
connector=connector,
) as session,
session.get(current_url, allow_redirects=False) as response,
):
if response.status in _WEB_FETCH_REDIRECT_STATUSES:
await _drain_aiohttp_body_capped(
response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
)
if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
raise WebFetchEgressViolation(
"web_fetch exceeded maximum redirects "
f"({constants._MAX_WEB_FETCH_REDIRECTS})"
)
location = response.headers.get("location")
if not location or not location.strip():
raise WebFetchEgressViolation(
"web_fetch redirect response missing Location header"
)
current_url = urljoin(str(response.url), location.strip())
redirect_hops += 1
continue
response.raise_for_status()
content_type = response.headers.get("content-type", "text/plain")
final_url = str(response.url)
encoding = response.get_encoding() or "utf-8"
body_bytes = await _read_aiohttp_body_capped(
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
)
finally:
await connector.close()
break
text = body_bytes.decode(encoding, errors="replace")
title = final_url
data = text
if "html" in content_type.lower():
parser = HTMLTextParser()
parser.feed(text)
title = parser.title or final_url
data = "\n".join(parser.text_parts)
return {
"url": final_url,
"title": title,
"media_type": "text/plain",
"data": data[:_MAX_FETCH_CHARS],
}

104
api/web_tools/parsers.py Normal file
View file

@ -0,0 +1,104 @@
"""HTML parsing for web_search / web_fetch."""
from __future__ import annotations
import html
import re
from html.parser import HTMLParser
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse
class SearchResultParser(HTMLParser):
"""DuckDuckGo lite HTML: extract result links and titles."""
def __init__(self) -> None:
super().__init__()
self.results: list[dict[str, str]] = []
self._href: str | None = None
self._title_parts: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag != "a":
return
href = dict(attrs).get("href")
if not href or "uddg=" not in href:
return
parsed = urlparse(href)
query = parse_qs(parsed.query)
uddg = query.get("uddg", [""])[0]
if not uddg:
return
self._href = unquote(uddg)
self._title_parts = []
def handle_data(self, data: str) -> None:
if self._href is not None:
self._title_parts.append(data)
def handle_endtag(self, tag: str) -> None:
if tag != "a" or self._href is None:
return
title = " ".join("".join(self._title_parts).split())
if title and not any(result["url"] == self._href for result in self.results):
self.results.append({"title": html.unescape(title), "url": self._href})
self._href = None
self._title_parts = []
class HTMLTextParser(HTMLParser):
"""Strip scripts/styles and collect visible text + title for fetch previews."""
def __init__(self) -> None:
super().__init__()
self.title = ""
self.text_parts: list[str] = []
self._in_title = False
self._skip_depth = 0
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag in {"script", "style", "noscript"}:
self._skip_depth += 1
elif tag == "title":
self._in_title = True
def handle_endtag(self, tag: str) -> None:
if tag in {"script", "style", "noscript"} and self._skip_depth:
self._skip_depth -= 1
elif tag == "title":
self._in_title = False
def handle_data(self, data: str) -> None:
text = " ".join(data.split())
if not text:
return
if self._in_title:
self.title = f"{self.title} {text}".strip()
elif not self._skip_depth:
self.text_parts.append(text)
def content_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(str(item.get("text", "")))
else:
parts.append(str(getattr(item, "text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def extract_query(text: str) -> str:
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip().strip("\"'")
return text.strip()
def extract_url(text: str) -> str:
match = re.search(r"https?://\S+", text)
return match.group(0).rstrip(").,]") if match else text.strip()

87
api/web_tools/request.py Normal file
View file

@ -0,0 +1,87 @@
"""Detect forced Anthropic web server tool requests."""
from __future__ import annotations
from api.models.anthropic import MessagesRequest, Tool
def request_text(request: MessagesRequest) -> str:
"""Join all user/assistant message content into one string for tool input parsing."""
from .parsers import content_text
return "\n".join(content_text(message.content) for message in request.messages)
def forced_tool_turn_text(request: MessagesRequest) -> str:
"""Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
if not request.messages:
return ""
from .parsers import content_text
for message in reversed(request.messages):
if message.role == "user":
return content_text(message.content)
return ""
def forced_server_tool_name(request: MessagesRequest) -> str | None:
"""Return web_search or web_fetch only when tool_choice forces that server tool."""
tc = request.tool_choice
if not isinstance(tc, dict):
return None
if tc.get("type") != "tool":
return None
name = tc.get("name")
if name in {"web_search", "web_fetch"}:
return str(name)
return None
def has_tool_named(request: MessagesRequest, name: str) -> bool:
return any(tool.name == name for tool in request.tools or [])
def is_web_server_tool_request(request: MessagesRequest) -> bool:
"""True when the client forces a web server tool via tool_choice (not merely listed)."""
forced = forced_server_tool_name(request)
if forced is None:
return False
return has_tool_named(request, forced)
def is_anthropic_server_tool_definition(tool: Tool) -> bool:
"""Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
name = (tool.name or "").strip()
if name in ("web_search", "web_fetch"):
return True
typ = tool.type
if isinstance(typ, str):
return typ.startswith("web_search") or typ.startswith("web_fetch")
return False
def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
"""True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
def openai_chat_upstream_server_tool_error(
request: MessagesRequest, *, web_tools_enabled: bool
) -> str | None:
"""Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
forced = forced_server_tool_name(request)
if forced and not web_tools_enabled:
return (
f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
"disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic "
"Messages transport (e.g. open_router, ollama, lmstudio)."
)
if not forced and has_listed_anthropic_server_tools(request):
return (
"OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools "
"(web_search / web_fetch) without the local web server tool handler. Use a native "
"Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
"tool_choice, or remove these tools from the request."
)
return None

206
api/web_tools/streaming.py Normal file
View file

@ -0,0 +1,206 @@
"""SSE streaming for local web_search / web_fetch server tool results."""
from __future__ import annotations
import uuid
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Any
from api.models.anthropic import MessagesRequest
from core.anthropic.server_tool_sse import (
SERVER_TOOL_USE,
WEB_FETCH_TOOL_ERROR,
WEB_FETCH_TOOL_RESULT,
WEB_SEARCH_TOOL_RESULT,
WEB_SEARCH_TOOL_RESULT_ERROR,
)
from core.anthropic.sse import format_sse_event
from . import outbound
from .constants import _MAX_FETCH_CHARS
from .egress import WebFetchEgressPolicy
from .parsers import extract_query, extract_url
from .request import (
forced_server_tool_name,
forced_tool_turn_text,
has_tool_named,
)
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
if not results:
return f"No web search results found for: {query}"
lines = [f"Search results for: {query}"]
for index, result in enumerate(results, start=1):
lines.append(f"{index}. {result['title']}\n{result['url']}")
return "\n\n".join(lines)
async def stream_web_server_tool_response(
request: MessagesRequest,
input_tokens: int,
*,
web_fetch_egress: WebFetchEgressPolicy,
verbose_client_errors: bool = False,
) -> AsyncIterator[str]:
"""Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path not a full
hosted Anthropic citation or encrypted-content pipeline.
"""
tool_name = forced_server_tool_name(request)
if tool_name is None or not has_tool_named(request, tool_name):
return
text = forced_tool_turn_text(request)
message_id = f"msg_{uuid.uuid4()}"
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
usage_key = (
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
)
tool_input = (
{"query": extract_query(text)}
if tool_name == "web_search"
else {"url": extract_url(text)}
)
_result_block_for_tool = {
"web_search": WEB_SEARCH_TOOL_RESULT,
"web_fetch": WEB_FETCH_TOOL_RESULT,
}
_error_payload_type_for_tool = {
"web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
"web_fetch": WEB_FETCH_TOOL_ERROR,
}
yield format_sse_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": request.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
},
},
)
yield format_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": SERVER_TOOL_USE,
"id": tool_id,
"name": tool_name,
"input": tool_input,
},
},
)
yield format_sse_event(
"content_block_stop", {"type": "content_block_stop", "index": 0}
)
try:
if tool_name == "web_search":
query = str(tool_input["query"])
results = await outbound._run_web_search(query)
result_content: Any = [
{
"type": "web_search_result",
"title": result["title"],
"url": result["url"],
}
for result in results
]
summary = _search_summary(query, results)
result_block_type = WEB_SEARCH_TOOL_RESULT
else:
fetched = await outbound._run_web_fetch(
str(tool_input["url"]), web_fetch_egress
)
result_content = {
"type": "web_fetch_result",
"url": fetched["url"],
"content": {
"type": "document",
"source": {
"type": "text",
"media_type": fetched["media_type"],
"data": fetched["data"],
},
"title": fetched["title"],
"citations": {"enabled": True},
},
"retrieved_at": datetime.now(UTC).isoformat(),
}
summary = fetched["data"][:_MAX_FETCH_CHARS]
result_block_type = WEB_FETCH_TOOL_RESULT
except Exception as error:
fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
result_block_type = _result_block_for_tool[tool_name]
result_content = {
"type": _error_payload_type_for_tool[tool_name],
"error_code": "unavailable",
}
summary = outbound._web_tool_client_error_summary(
tool_name, error, verbose=verbose_client_errors
)
output_tokens = max(1, len(summary) // 4)
yield format_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": 1,
"content_block": {
"type": result_block_type,
"tool_use_id": tool_id,
"content": result_content,
},
},
)
yield format_sse_event(
"content_block_stop", {"type": "content_block_stop", "index": 1}
)
# Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
# not eager `text` on `content_block_start`).
yield format_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": 2,
"content_block": {"type": "text", "text": ""},
},
)
yield format_sse_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": 2,
"delta": {"type": "text_delta", "text": summary},
},
)
yield format_sse_event(
"content_block_stop", {"type": "content_block_stop", "index": 2}
)
yield format_sse_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"server_tool_use": {usage_key: 1},
},
},
)
yield format_sse_event("message_stop", {"type": "message_stop"})

View file

@ -30,7 +30,8 @@ def serve() -> None:
settings = get_settings()
try:
uvicorn.run(
"api.app:app",
"api.app:create_app",
factory=True,
host=settings.host,
port=settings.port,
log_level="debug",

View file

@ -29,6 +29,9 @@ class CLISessionManager:
allowed_dirs: list[str] | None = None,
plans_directory: str | None = None,
claude_bin: str = "claude",
*,
log_raw_cli_diagnostics: bool = False,
log_messaging_error_details: bool = False,
):
"""
Initialize the session manager.
@ -44,6 +47,8 @@ class CLISessionManager:
self.allowed_dirs = allowed_dirs or []
self.plans_directory = plans_directory
self.claude_bin = claude_bin
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
self._log_messaging_error_details = log_messaging_error_details
self._sessions: dict[str, CLISession] = {}
self._pending_sessions: dict[str, CLISession] = {}
@ -79,6 +84,7 @@ class CLISessionManager:
allowed_dirs=self.allowed_dirs,
plans_directory=self.plans_directory,
claude_bin=self.claude_bin,
log_raw_cli_diagnostics=self._log_raw_cli_diagnostics,
)
self._pending_sessions[temp_id] = new_session
logger.info(f"Created new session: {temp_id}")
@ -130,7 +136,17 @@ class CLISessionManager:
try:
await session.stop()
except Exception as e:
logger.error(f"Error stopping session: {e}")
if self._log_messaging_error_details:
logger.error(
"Error stopping session: {}: {}",
type(e).__name__,
e,
)
else:
logger.error(
"Error stopping session: exc_type={}",
type(e).__name__,
)
self._sessions.clear()
self._pending_sessions.clear()

View file

@ -11,6 +11,9 @@ from loguru import logger
from .process_registry import register_pid, unregister_pid
# Cap stderr capture so a runaway child cannot exhaust memory; pipe is still drained.
_MAX_STDERR_CAPTURE_BYTES = 256 * 1024
@dataclass(frozen=True, slots=True)
class ClaudeCliConfig:
@ -33,6 +36,8 @@ class CLISession:
allowed_dirs: list[str] | None = None,
plans_directory: str | None = None,
claude_bin: str = "claude",
*,
log_raw_cli_diagnostics: bool = False,
):
self.config = ClaudeCliConfig(
workspace_path=os.path.normpath(os.path.abspath(workspace_path)),
@ -46,11 +51,40 @@ class CLISession:
self.allowed_dirs = self.config.allowed_dirs
self.plans_directory = self.config.plans_directory
self.claude_bin = self.config.claude_bin
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
self.process: asyncio.subprocess.Process | None = None
self.current_session_id: str | None = None
self._is_busy = False
self._cli_lock = asyncio.Lock()
@staticmethod
async def _drain_stderr_bounded(
process: asyncio.subprocess.Process,
*,
max_bytes: int = _MAX_STDERR_CAPTURE_BYTES,
) -> bytes:
"""Read stderr concurrently with stdout to avoid subprocess pipe deadlocks.
Retains at most ``max_bytes`` for logging; any excess is discarded, but
the pipe is read until EOF so a noisy child cannot fill the buffer and
block forever.
"""
if not process.stderr:
return b""
parts: list[bytes] = []
received = 0
while True:
chunk = await process.stderr.read(65_536)
if not chunk:
break
if received < max_bytes:
take = min(len(chunk), max_bytes - received)
if take:
parts.append(chunk[:take])
received += take
# If already at cap, keep reading and discarding until EOF.
return b"".join(parts)
@property
def is_busy(self) -> bool:
"""Check if a task is currently running."""
@ -140,6 +174,11 @@ class CLISession:
session_id_extracted = False
buffer = bytearray()
stderr_task: asyncio.Task[bytes] | None = None
if self.process.stderr:
stderr_task = asyncio.create_task(
self._drain_stderr_bounded(self.process)
)
try:
while True:
@ -179,21 +218,25 @@ class CLISession:
except asyncio.CancelledError:
# Cancelling the handler task should not leave a Claude CLI
# subprocess running in the background.
try:
await asyncio.shield(self.stop())
finally:
raise
finally:
stderr_bytes = b""
if stderr_task is not None:
stderr_bytes = await stderr_task
stderr_text = None
if self.process.stderr:
stderr_output = await self.process.stderr.read()
if stderr_output:
stderr_text = stderr_output.decode(
"utf-8", errors="replace"
).strip()
logger.error(f"Claude CLI Stderr: {stderr_text}")
# Yield stderr as error event so it shows in UI
if stderr_bytes:
stderr_text = stderr_bytes.decode("utf-8", errors="replace").strip()
if stderr_text:
if self._log_raw_cli_diagnostics:
logger.error("Claude CLI stderr: {}", stderr_text)
else:
logger.error(
"Claude CLI stderr: bytes={} text_chars={}",
len(stderr_bytes),
len(stderr_text),
)
logger.info("CLI_SESSION: Yielding error event from stderr")
yield {"type": "error", "error": {"message": stderr_text}}
@ -230,7 +273,10 @@ class CLISession:
yield event
except json.JSONDecodeError:
logger.debug(f"Non-JSON output: {line_str}")
if self._log_raw_cli_diagnostics:
logger.debug("Non-JSON output: {}", line_str)
else:
logger.debug("Non-JSON CLI line: char_len={}", len(line_str))
yield {"type": "raw", "content": line_str}
def _extract_session_id(self, event: Any) -> str | None:
@ -273,6 +319,16 @@ class CLISession:
unregister_pid(self.process.pid)
return True
except Exception as e:
logger.error(f"Error stopping process: {e}")
if self._log_raw_cli_diagnostics:
logger.error(
"Error stopping process: {}: {}",
type(e).__name__,
e,
)
else:
logger.error(
"Error stopping process: exc_type={}",
type(e).__name__,
)
return False
return False

10
config/constants.py Normal file
View file

@ -0,0 +1,10 @@
"""Shared defaults used by config models and provider adapters."""
# HTTP client connect timeout (seconds). Keep aligned with README.md and .env.example.
HTTP_CONNECT_TIMEOUT_DEFAULT = 10.0
# Anthropic Messages API default when the client omits max_tokens.
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS = 81920
# Max bytes read from a non-200 native messages response when verbose error logging is on.
NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES = 4096

View file

@ -8,6 +8,7 @@ included at top level for easy grep/filter.
import json
import logging
import re
from pathlib import Path
from loguru import logger
@ -17,6 +18,22 @@ _configured = False
# Context keys we promote to top-level JSON for traceability
_CONTEXT_KEYS = ("request_id", "node_id", "chat_id")
_TELEGRAM_BOT_RE = re.compile(
r"(https?://api\.telegram\.org/)bot([0-9]+:[A-Za-z0-9_-]+)(/?)",
re.IGNORECASE,
)
# Authorization: Bearer <token> (HTTP client / proxy debug lines)
_AUTH_BEARER_RE = re.compile(
r"(\bAuthorization\s*:\s*Bearer\s+)([^\s'\"]+)",
re.IGNORECASE,
)
def _redact_sensitive_substrings(message: str) -> str:
"""Remove obvious API tokens and secrets before JSON log line emission."""
text = _TELEGRAM_BOT_RE.sub(r"\1bot<redacted>\3", message)
return _AUTH_BEARER_RE.sub(r"\1<redacted>", text)
def _serialize_with_context(record) -> str:
"""Format record as JSON with context vars at top level.
@ -26,7 +43,7 @@ def _serialize_with_context(record) -> str:
out = {
"time": str(record["time"]),
"level": record["level"].name,
"message": record["message"],
"message": _redact_sensitive_substrings(str(record["message"])),
"module": record["name"],
"function": record["function"],
"line": record["line"],
@ -57,11 +74,16 @@ class InterceptHandler(logging.Handler):
)
def configure_logging(log_file: str, *, force: bool = False) -> None:
def configure_logging(
log_file: str, *, force: bool = False, verbose_third_party: bool = False
) -> None:
"""Configure loguru with JSON output to log_file and intercept stdlib logging.
Idempotent: skips if already configured (e.g. hot reload).
Use force=True to reconfigure (e.g. in tests with a different log path).
When ``verbose_third_party`` is false, noisy HTTP and Telegram loggers are capped
at WARNING unless explicitly configured otherwise.
"""
global _configured
if _configured and not force:
@ -88,3 +110,16 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
intercept = InterceptHandler()
logging.root.handlers = [intercept]
logging.root.setLevel(logging.DEBUG)
third_party = (
"httpx",
"httpcore",
"httpcore.http11",
"httpcore.connection",
"telegram",
"telegram.ext",
)
for name in third_party:
logging.getLogger(name).setLevel(
logging.WARNING if not verbose_third_party else logging.NOTSET
)

View file

@ -2,6 +2,8 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
class NimSettings(BaseModel):
"""Fixed NVIDIA NIM settings (not configurable via env)."""
@ -14,7 +16,9 @@ class NimSettings(BaseModel):
)
top_k: int = -1
max_tokens: int = Field(
81920, ge=1, description="Maximum number of tokens in output."
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
ge=1,
description="Maximum number of tokens in output.",
)
presence_penalty: float = Field(0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
@ -68,7 +72,7 @@ class NimSettings(BaseModel):
return field_defaults.get(key, 1.0)
try:
val = float(v)
except Exception as err:
except (TypeError, ValueError) as err:
raise ValueError(
f"{info.field_name} must be a float. Got {type(v).__name__}."
) from err
@ -78,15 +82,15 @@ class NimSettings(BaseModel):
@classmethod
def validate_int_fields(cls, v, info: ValidationInfo):
field_defaults = {
"max_tokens": 81920,
"max_tokens": ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
"min_tokens": 0,
}
if v is None or v == "":
key = info.field_name or "max_tokens"
return field_defaults.get(key, 81920)
return field_defaults.get(key, ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS)
try:
val = int(v)
except Exception as err:
except (TypeError, ValueError) as err:
raise ValueError(
f"{info.field_name} must be an int. Got {type(v).__name__}."
) from err
@ -99,7 +103,7 @@ class NimSettings(BaseModel):
return None
try:
return int(v)
except Exception as err:
except (TypeError, ValueError) as err:
raise ValueError(
f"{info.field_name} must be an int or empty/None."
) from err

108
config/provider_catalog.py Normal file
View file

@ -0,0 +1,108 @@
"""Neutral provider catalog: IDs, credentials, defaults, proxy and capability metadata.
Adapter factories live in :mod:`providers.registry`; this module stays free of
provider implementation imports (see contract tests).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
TransportType = Literal["openai_chat", "anthropic_messages"]
# Default upstream base URLs (also re-exported via :mod:`providers.defaults`)
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
OLLAMA_DEFAULT_BASE = "http://localhost:11434"
@dataclass(frozen=True, slots=True)
class ProviderDescriptor:
"""Metadata for building :class:`~providers.base.ProviderConfig` and factory wiring."""
provider_id: str
transport_type: TransportType
capabilities: tuple[str, ...]
credential_env: str | None = None
credential_url: str | None = None
credential_attr: str | None = None
static_credential: str | None = None
default_base_url: str | None = None
base_url_attr: str | None = None
proxy_attr: str | None = None
PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
"nvidia_nim": ProviderDescriptor(
provider_id="nvidia_nim",
transport_type="openai_chat",
credential_env="NVIDIA_NIM_API_KEY",
credential_url="https://build.nvidia.com/settings/api-keys",
credential_attr="nvidia_nim_api_key",
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
proxy_attr="nvidia_nim_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
"open_router": ProviderDescriptor(
provider_id="open_router",
transport_type="anthropic_messages",
credential_env="OPENROUTER_API_KEY",
credential_url="https://openrouter.ai/keys",
credential_attr="open_router_api_key",
default_base_url=OPENROUTER_DEFAULT_BASE,
proxy_attr="open_router_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
),
"deepseek": ProviderDescriptor(
provider_id="deepseek",
transport_type="openai_chat",
credential_env="DEEPSEEK_API_KEY",
credential_url="https://platform.deepseek.com/api_keys",
credential_attr="deepseek_api_key",
default_base_url=DEEPSEEK_DEFAULT_BASE,
capabilities=("chat", "streaming", "thinking"),
),
"lmstudio": ProviderDescriptor(
provider_id="lmstudio",
transport_type="anthropic_messages",
static_credential="lm-studio",
default_base_url=LMSTUDIO_DEFAULT_BASE,
base_url_attr="lm_studio_base_url",
proxy_attr="lmstudio_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
),
"llamacpp": ProviderDescriptor(
provider_id="llamacpp",
transport_type="anthropic_messages",
static_credential="llamacpp",
default_base_url=LLAMACPP_DEFAULT_BASE,
base_url_attr="llamacpp_base_url",
proxy_attr="llamacpp_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
),
"ollama": ProviderDescriptor(
provider_id="ollama",
transport_type="anthropic_messages",
static_credential="ollama",
default_base_url=OLLAMA_DEFAULT_BASE,
base_url_attr="ollama_base_url",
capabilities=(
"chat",
"streaming",
"tools",
"thinking",
"native_anthropic",
"local",
),
),
}
# Order matches docs / historical error text; must match PROVIDER_CATALOG keys.
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = tuple(PROVIDER_CATALOG.keys())
if len(set(SUPPORTED_PROVIDER_IDS)) != len(SUPPORTED_PROVIDER_IDS):
raise AssertionError("Duplicate provider ids in PROVIDER_CATALOG key order")

View file

@ -1,18 +1,7 @@
"""Canonical list of model provider type prefixes (provider_id values).
`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this
module holds the id set for config validation and must stay in sync
(registries assert in `providers.registry`).
"""
"""Canonical provider id tuple (re-exported from the provider catalog)."""
from __future__ import annotations
# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys.
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = (
"nvidia_nim",
"open_router",
"deepseek",
"lmstudio",
"llamacpp",
"ollama",
)
from .provider_catalog import SUPPORTED_PROVIDER_IDS
__all__ = ("SUPPORTED_PROVIDER_IDS",)

View file

@ -10,6 +10,7 @@ from dotenv import dotenv_values
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from .constants import HTTP_CONNECT_TIMEOUT_DEFAULT
from .nim import NimSettings
from .provider_ids import SUPPORTED_PROVIDER_IDS
@ -105,6 +106,12 @@ class Settings(BaseSettings):
messaging_platform: str = Field(
default="discord", validation_alias="MESSAGING_PLATFORM"
)
messaging_rate_limit: int = Field(
default=1, validation_alias="MESSAGING_RATE_LIMIT"
)
messaging_rate_window: float = Field(
default=1.0, validation_alias="MESSAGING_RATE_WINDOW"
)
# ==================== NVIDIA NIM Config ====================
nvidia_nim_api_key: str = ""
@ -173,7 +180,8 @@ class Settings(BaseSettings):
default=10.0, validation_alias="HTTP_WRITE_TIMEOUT"
)
http_connect_timeout: float = Field(
default=2.0, validation_alias="HTTP_CONNECT_TIMEOUT"
default=HTTP_CONNECT_TIMEOUT_DEFAULT,
validation_alias="HTTP_CONNECT_TIMEOUT",
)
# ==================== Fast Prefix Detection ====================
@ -185,6 +193,51 @@ class Settings(BaseSettings):
enable_suggestion_mode_skip: bool = True
enable_filepath_extraction_mock: bool = True
# ==================== Local web server tools (web_search / web_fetch) ====================
# Off by default: these tools perform outbound HTTP from the proxy (SSRF risk).
enable_web_server_tools: bool = Field(
default=False, validation_alias="ENABLE_WEB_SERVER_TOOLS"
)
# Comma-separated URL schemes allowed for web_fetch (default: http,https).
web_fetch_allowed_schemes: str = Field(
default="http,https", validation_alias="WEB_FETCH_ALLOWED_SCHEMES"
)
# When true, skip private/loopback/link-local IP blocking for web_fetch (lab only).
web_fetch_allow_private_networks: bool = Field(
default=False, validation_alias="WEB_FETCH_ALLOW_PRIVATE_NETWORKS"
)
# ==================== Debug / diagnostic logging (avoid sensitive content) ====================
# When false (default), API and SSE helpers log only metadata (counts, lengths, ids).
log_raw_api_payloads: bool = Field(
default=False, validation_alias="LOG_RAW_API_PAYLOADS"
)
log_raw_sse_events: bool = Field(
default=False, validation_alias="LOG_RAW_SSE_EVENTS"
)
# When false (default), unhandled exceptions log only type + route metadata (no message/traceback).
log_api_error_tracebacks: bool = Field(
default=False, validation_alias="LOG_API_ERROR_TRACEBACKS"
)
# When false (default), messaging logs omit text/transcription previews (metadata only).
log_raw_messaging_content: bool = Field(
default=False, validation_alias="LOG_RAW_MESSAGING_CONTENT"
)
# When true, log full Claude CLI stderr, non-JSON lines, and parser error text.
log_raw_cli_diagnostics: bool = Field(
default=False, validation_alias="LOG_RAW_CLI_DIAGNOSTICS"
)
# When true, log exception text / CLI error strings in messaging (may leak user content).
log_messaging_error_details: bool = Field(
default=False, validation_alias="LOG_MESSAGING_ERROR_DETAILS"
)
debug_platform_edits: bool = Field(
default=False, validation_alias="DEBUG_PLATFORM_EDITS"
)
debug_subagent_stack: bool = Field(
default=False, validation_alias="DEBUG_SUBAGENT_STACK"
)
# ==================== NIM Settings ====================
nim: NimSettings = Field(default_factory=NimSettings)
@ -215,6 +268,9 @@ class Settings(BaseSettings):
claude_workspace: str = "./agent_workspace"
allowed_dir: str = ""
claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN")
max_message_log_entries_per_chat: int | None = Field(
default=None, validation_alias="MAX_MESSAGE_LOG_ENTRIES_PER_CHAT"
)
# ==================== Server ====================
host: str = "0.0.0.0"
@ -254,6 +310,13 @@ class Settings(BaseSettings):
return None
return v
@field_validator("max_message_log_entries_per_chat", mode="before")
@classmethod
def parse_optional_log_cap(cls, v: Any) -> Any:
if v == "" or v is None:
return None
return v
@field_validator("whisper_device")
@classmethod
def validate_whisper_device(cls, v: str) -> str:
@ -272,6 +335,33 @@ class Settings(BaseSettings):
)
return v
@field_validator("messaging_rate_limit")
@classmethod
def validate_messaging_rate_limit(cls, v: int) -> int:
if v <= 0:
raise ValueError("messaging_rate_limit must be > 0")
return v
@field_validator("messaging_rate_window")
@classmethod
def validate_messaging_rate_window(cls, v: float) -> float:
if v <= 0:
raise ValueError("messaging_rate_window must be > 0")
return float(v)
@field_validator("web_fetch_allowed_schemes")
@classmethod
def validate_web_fetch_allowed_schemes(cls, v: str) -> str:
schemes = [part.strip().lower() for part in v.split(",") if part.strip()]
if not schemes:
raise ValueError("web_fetch_allowed_schemes must list at least one scheme")
for scheme in schemes:
if not scheme.isascii() or not scheme.isalpha():
raise ValueError(
f"Invalid URL scheme in web_fetch_allowed_schemes: {scheme!r}"
)
return ",".join(schemes)
@field_validator("ollama_base_url")
@classmethod
def validate_ollama_base_url(cls, v: str) -> str:
@ -329,12 +419,12 @@ class Settings(BaseSettings):
@property
def provider_type(self) -> str:
"""Extract provider type from the default model string."""
return self.model.split("/", 1)[0]
return Settings.parse_provider_type(self.model)
@property
def model_name(self) -> str:
"""Extract the actual model name from the default model string."""
return self.model.split("/", 1)[1]
return Settings.parse_model_name(self.model)
def resolve_model(self, claude_model_name: str) -> str:
"""Resolve a Claude model name to the configured provider/model string.
@ -362,6 +452,14 @@ class Settings(BaseSettings):
return self.enable_sonnet_thinking
return self.enable_model_thinking
def web_fetch_allowed_scheme_set(self) -> frozenset[str]:
"""Return normalized schemes allowed for web_fetch."""
return frozenset(
part.strip().lower()
for part in self.web_fetch_allowed_schemes.split(",")
if part.strip()
)
@staticmethod
def parse_provider_type(model_string: str) -> str:
"""Extract provider type from any 'provider/model' string."""

View file

@ -1,9 +1,19 @@
"""Anthropic protocol helpers shared across API, providers, and integrations."""
from .content import extract_text_from_content, get_block_attr, get_block_type
from .conversion import AnthropicToOpenAIConverter, build_base_request_body
from .errors import append_request_id, get_user_facing_error_message
from .sse import ContentBlockManager, SSEBuilder, map_stop_reason
from .conversion import (
AnthropicToOpenAIConverter,
OpenAIConversionError,
build_base_request_body,
)
from .errors import (
append_request_id,
format_user_error_preview,
get_user_facing_error_message,
)
from .native_messages_request import sanitize_native_messages_thinking_policy
from .provider_stream_error import iter_provider_stream_error_sse_events
from .sse import ContentBlockManager, SSEBuilder, format_sse_event, map_stop_reason
from .thinking import ContentChunk, ContentType, ThinkTagParser
from .tokens import get_token_count
from .tools import HeuristicToolParser
@ -15,15 +25,20 @@ __all__ = [
"ContentChunk",
"ContentType",
"HeuristicToolParser",
"OpenAIConversionError",
"SSEBuilder",
"ThinkTagParser",
"append_request_id",
"build_base_request_body",
"extract_text_from_content",
"format_sse_event",
"format_user_error_preview",
"get_block_attr",
"get_block_type",
"get_token_count",
"get_user_facing_error_message",
"iter_provider_stream_error_sse_events",
"map_stop_reason",
"sanitize_native_messages_thinking_policy",
"set_if_not_none",
]

View file

@ -3,10 +3,34 @@
import json
from typing import Any
from pydantic import BaseModel
from .content import get_block_attr, get_block_type
from .utils import set_if_not_none
class OpenAIConversionError(Exception):
"""Raised when Anthropic content cannot be converted to OpenAI chat without data loss."""
def _openai_reject_native_only_top_level_fields(request_data: Any) -> None:
"""OpenAI chat providers may only convert known top-level request fields.
First-class model fields (e.g. ``context_management``) are not forwarded to
the OpenAI API but are allowed so clients do not hit spurious 400s.
Unknown extra keys (``__pydantic_extra__``) are still rejected.
"""
if not isinstance(request_data, BaseModel):
return
extra = getattr(request_data, "__pydantic_extra__", None)
if not extra:
return
raise OpenAIConversionError(
"OpenAI chat conversion does not support these top-level request fields: "
f"{sorted(str(k) for k in extra)}. Use a native Anthropic transport provider."
)
def _tool_name(tool: Any) -> str:
return str(getattr(tool, "name", "") or "")
@ -18,6 +42,27 @@ def _tool_input_schema(tool: Any) -> dict[str, Any]:
return {"type": "object", "properties": {}}
def _serialize_tool_result_content(tool_content: Any) -> str:
"""Serialize tool_result content for OpenAI ``role: tool`` messages (stable JSON for structured values)."""
if tool_content is None:
return ""
if isinstance(tool_content, str):
return tool_content
if isinstance(tool_content, dict):
return json.dumps(tool_content, ensure_ascii=False)
if isinstance(tool_content, list):
parts: list[str] = []
for item in tool_content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
elif isinstance(item, dict):
parts.append(json.dumps(item, ensure_ascii=False))
else:
parts.append(str(item))
return "\n".join(parts)
return str(tool_content)
class AnthropicToOpenAIConverter:
"""Convert Anthropic message format to OpenAI-compatible format."""
@ -64,20 +109,38 @@ class AnthropicToOpenAIConverter:
content_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
seen_tool_use = False
for block in content:
block_type = get_block_type(block)
if block_type == "text":
if seen_tool_use:
raise OpenAIConversionError(
"OpenAI chat conversion does not support assistant text after "
"tool_use in the same message; split the transcript or use a "
"native Anthropic provider."
)
content_parts.append(get_block_attr(block, "text", ""))
elif block_type == "thinking":
if not include_thinking:
continue
if seen_tool_use:
raise OpenAIConversionError(
"OpenAI chat conversion does not support assistant thinking after "
"tool_use in the same message; split the transcript or use a "
"native Anthropic provider."
)
thinking = get_block_attr(block, "thinking", "")
content_parts.append(f"<think>\n{thinking}\n</think>")
if include_reasoning_content:
thinking_parts.append(thinking)
elif block_type == "redacted_thinking":
# Opaque provider continuation data; do not materialize as model-visible text
# or reasoning_content for OpenAI chat upstreams.
continue
elif block_type == "tool_use":
seen_tool_use = True
tool_input = get_block_attr(block, "input", {})
tool_calls.append(
{
@ -91,6 +154,19 @@ class AnthropicToOpenAIConverter:
},
}
)
elif block_type == "image":
raise OpenAIConversionError(
"Assistant image blocks are not supported for OpenAI chat conversion."
)
elif block_type in (
"server_tool_use",
"web_search_tool_result",
"web_fetch_tool_result",
):
raise OpenAIConversionError(
"OpenAI chat conversion does not support Anthropic server tool blocks "
f"({block_type!r} in an assistant message). Use a native Anthropic transport provider."
)
content_str = "\n\n".join(content_parts)
if not content_str and not tool_calls:
@ -122,21 +198,21 @@ class AnthropicToOpenAIConverter:
if block_type == "text":
text_parts.append(get_block_attr(block, "text", ""))
elif block_type == "image":
raise OpenAIConversionError(
"User message image blocks are not supported for OpenAI chat "
"conversion; use a vision-capable native Anthropic provider or "
"extend the converter."
)
elif block_type == "tool_result":
flush_text()
tool_content = get_block_attr(block, "content", "")
if isinstance(tool_content, list):
tool_content = "\n".join(
item.get("text", str(item))
if isinstance(item, dict)
else str(item)
for item in tool_content
)
serialized = _serialize_tool_result_content(tool_content)
result.append(
{
"role": "tool",
"tool_call_id": get_block_attr(block, "tool_use_id"),
"content": str(tool_content) if tool_content else "",
"content": serialized if serialized else "",
}
)
@ -199,6 +275,7 @@ def build_base_request_body(
include_reasoning_content: bool = False,
) -> dict[str, Any]:
"""Build the common parts of an OpenAI-format request body."""
_openai_reject_native_only_top_level_fields(request_data)
messages = AnthropicToOpenAIConverter.convert_messages(
request_data.messages,
include_thinking=include_thinking,

View file

@ -0,0 +1,97 @@
"""Track content-block state for native Anthropic SSE strings we emit to clients."""
from __future__ import annotations
import uuid
from collections.abc import Iterator
from contextlib import suppress
from typing import Any
from core.anthropic.sse import SSEBuilder, format_sse_event
from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines
class EmittedNativeSseTracker:
"""Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index."""
def __init__(self) -> None:
self._buf = ""
self._open_stack: list[int] = []
self._max_index = -1
self.message_id: str | None = None
self.model: str = ""
def feed(self, chunk: str) -> None:
"""Record SSE frames completed by ``chunk`` (handles splitting across reads)."""
self._buf += chunk
while True:
sep = self._buf.find("\n\n")
if sep < 0:
break
frame = self._buf[:sep]
self._buf = self._buf[sep + 2 :]
if not frame.strip():
continue
for event in parse_sse_lines(frame.splitlines()):
self._observe(event)
def _observe(self, event: SSEEvent) -> None:
if event.event == "message_start":
message = event.data.get("message")
if isinstance(message, dict):
mid = message.get("id")
if isinstance(mid, str) and mid:
self.message_id = mid
model = message.get("model")
if isinstance(model, str) and model:
self.model = model
return
if event.event == "content_block_start":
idx = event_index(event)
self._max_index = max(self._max_index, idx)
self._open_stack.append(idx)
return
if event.event == "content_block_stop":
idx = event_index(event)
if self._open_stack and self._open_stack[-1] == idx:
self._open_stack.pop()
else:
with suppress(ValueError):
self._open_stack.remove(idx)
def next_content_index(self) -> int:
"""Next unused content block index based on emitted starts."""
return self._max_index + 1
def iter_close_unclosed_blocks(self) -> Iterator[str]:
"""Yield ``content_block_stop`` events for blocks that were started but not stopped."""
while self._open_stack:
idx = self._open_stack.pop()
yield format_sse_event(
"content_block_stop",
{"type": "content_block_stop", "index": idx},
)
def iter_midstream_error_tail(
self,
error_message: str,
*,
request: Any,
input_tokens: int,
log_raw_sse_events: bool,
) -> Iterator[str]:
"""Close dangling blocks, emit a text error block at a fresh index, then message tail."""
mid = self.message_id or f"msg_{uuid.uuid4()}"
model = self.model or (getattr(request, "model", "") or "")
sse = SSEBuilder(
mid,
model,
input_tokens,
log_raw_events=log_raw_sse_events,
)
sse.blocks.next_index = self.next_content_index()
yield from sse.emit_error(error_message)
yield sse.message_delta("end_turn", 1)
yield sse.message_stop()

View file

@ -9,11 +9,12 @@ def get_user_facing_error_message(
*,
read_timeout_s: float | None = None,
) -> str:
"""Return a readable, non-empty error message for users."""
message = str(e).strip()
if message:
return message
"""Return a readable, non-empty error message for users.
Known transport and OpenAI SDK exception types are mapped to stable wording
before falling back to ``str(e)``, so empty or noisy SDK messages do not skip
the mapped path.
"""
if isinstance(e, httpx.ReadTimeout):
if read_timeout_s is not None:
return f"Provider request timed out after {read_timeout_s:g}s."
@ -25,13 +26,20 @@ def get_user_facing_error_message(
return f"Provider request timed out after {read_timeout_s:g}s."
return "Request timed out."
if isinstance(e, openai.RateLimitError):
return "Provider rate limit reached. Please retry shortly."
if isinstance(e, openai.AuthenticationError):
return "Provider authentication failed. Check API key."
if isinstance(e, openai.BadRequestError):
return "Invalid request sent to provider."
name = type(e).__name__
status_code = getattr(e, "status_code", None)
if isinstance(e, openai.RateLimitError) or name == "RateLimitError":
if name == "RateLimitError":
return "Provider rate limit reached. Please retry shortly."
if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError":
if name == "AuthenticationError":
return "Provider authentication failed. Check API key."
if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError":
if name == "InvalidRequestError":
return "Invalid request sent to provider."
if name == "OverloadedError":
return "Provider is currently overloaded. Please retry."
@ -42,9 +50,18 @@ def get_user_facing_error_message(
if name.endswith("ProviderError") or name == "ProviderError":
return "Provider request failed."
message = str(e).strip()
if message:
return message
return "Provider request failed unexpectedly."
def format_user_error_preview(exc: Exception, *, max_len: int = 200) -> str:
"""Truncate a user-facing error string for short chat replies."""
return get_user_facing_error_message(exc)[:max_len]
def append_request_id(message: str, request_id: str | None) -> str:
"""Append request_id suffix when available."""
base = message.strip() or "Provider request failed unexpectedly."

View file

@ -0,0 +1,260 @@
"""Native Anthropic Messages request body construction (JSON-ready dicts).
Provider adapters supply policy via parameters (defaults, OpenRouter post-steps).
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
_REQUEST_FIELDS = (
"model",
"messages",
"system",
"max_tokens",
"stop_sequences",
"stream",
"temperature",
"top_p",
"top_k",
"metadata",
"tools",
"tool_choice",
"thinking",
"context_management",
"output_config",
"mcp_servers",
"extra_body",
)
# Keys that would override routed canonical request fields if merged from ``extra_body``.
_OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS = frozenset(
{
"model",
"messages",
"system",
"tools",
"tool_choice",
"stream",
"max_tokens",
"temperature",
"top_p",
"top_k",
"metadata",
"stop_sequences",
"context_management",
"output_config",
"mcp_servers",
}
)
class OpenRouterExtraBodyError(ValueError):
"""``extra_body`` contained reserved keys that would override canonical fields."""
def validate_openrouter_extra_body(extra: Any) -> None:
"""Reject ``extra_body`` keys that must not override routed request fields."""
if not isinstance(extra, dict) or not extra:
return
bad = _OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS & extra.keys()
if bad:
raise OpenRouterExtraBodyError(
f"extra_body must not override canonical request fields: {sorted(bad)}"
)
_INTERNAL_FIELDS = {
"thinking",
"extra_body",
}
def _serialize_value(value: Any) -> Any:
"""Convert Pydantic models and lightweight objects into JSON-ready values."""
if isinstance(value, BaseModel):
return value.model_dump(exclude_none=True)
if isinstance(value, dict):
return {
key: _serialize_value(item)
for key, item in value.items()
if item is not None
}
if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
return [_serialize_value(item) for item in value]
if value is None or isinstance(value, str | int | float | bool):
return value
if hasattr(value, "__dict__"):
return {
key: _serialize_value(item)
for key, item in vars(value).items()
if not key.startswith("_") and item is not None
}
return value
def _dump_request_fields(request_data: Any) -> dict[str, Any]:
"""Extract the public request fields (OpenRouter-style explicit field list)."""
if isinstance(request_data, BaseModel):
return request_data.model_dump(exclude_none=True)
dumped: dict[str, Any] = {}
for field in _REQUEST_FIELDS:
value = getattr(request_data, field, None)
if value is not None:
dumped[field] = _serialize_value(value)
return dumped
def sanitize_native_messages_thinking_policy(
messages: Any, *, thinking_enabled: bool
) -> Any:
"""Filter assistant message thinking blocks for upstream native Anthropic JSON.
When ``thinking_enabled`` is false, remove ``thinking`` and ``redacted_thinking``
history so disabled policy is not undermined by prior turns.
When true, keep ``redacted_thinking`` and signed ``thinking``; remove only
unsigned plain ``thinking`` blocks (not replayable).
"""
if not isinstance(messages, list):
return messages
sanitized_messages: list[Any] = []
for message in messages:
if not isinstance(message, dict):
sanitized_messages.append(message)
continue
if message.get("role") != "assistant":
sanitized_messages.append(message)
continue
content = message.get("content")
if not isinstance(content, list):
sanitized_messages.append(message)
continue
if not thinking_enabled:
sanitized_content = [
block
for block in content
if not (
isinstance(block, dict)
and block.get("type") in ("thinking", "redacted_thinking")
)
]
else:
sanitized_content = [
block
for block in content
if not (
isinstance(block, dict)
and block.get("type") == "thinking"
and not isinstance(block.get("signature"), str)
)
]
sanitized_message = dict(message)
sanitized_message["content"] = sanitized_content or ""
sanitized_messages.append(sanitized_message)
return sanitized_messages
def _normalize_system_prompt_for_openrouter(system: Any) -> Any:
"""Flatten Claude SDK system blocks for OpenRouter's native endpoint."""
if not isinstance(system, list):
return system
text_parts: list[str] = []
for block in system:
if not isinstance(block, dict):
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text_parts.append(block["text"])
return "\n\n".join(text_parts).strip() if text_parts else system
def _apply_openrouter_reasoning_policy(body: dict[str, Any], thinking_cfg: Any) -> None:
"""Map Anthropic thinking controls onto OpenRouter reasoning controls."""
reasoning = body.setdefault("reasoning", {"enabled": True})
if not isinstance(reasoning, dict):
return
reasoning.setdefault("enabled", True)
if not isinstance(thinking_cfg, dict):
return
budget_tokens = thinking_cfg.get("budget_tokens")
if isinstance(budget_tokens, int):
reasoning.setdefault("max_tokens", budget_tokens)
def build_base_native_anthropic_request_body(
request: Any,
*,
default_max_tokens: int,
thinking_enabled: bool,
) -> dict[str, Any]:
"""Serialize a Pydantic messages request to a generic native Anthropic body."""
body = request.model_dump(exclude_none=True)
body.pop("extra_body", None)
if "thinking" in body:
thinking_cfg = body.pop("thinking")
if thinking_enabled and isinstance(thinking_cfg, dict):
thinking_payload: dict[str, Any] = {"type": "enabled"}
budget_tokens = thinking_cfg.get("budget_tokens")
if isinstance(budget_tokens, int):
thinking_payload["budget_tokens"] = budget_tokens
body["thinking"] = thinking_payload
if "max_tokens" not in body:
body["max_tokens"] = default_max_tokens
if "messages" in body:
body["messages"] = sanitize_native_messages_thinking_policy(
body["messages"],
thinking_enabled=thinking_enabled,
)
return body
def build_openrouter_native_request_body(
request_data: Any,
*,
thinking_enabled: bool,
default_max_tokens: int,
) -> dict[str, Any]:
"""Build an Anthropic-format request body for OpenRouter (policy hooks built-in)."""
dumped_request = _dump_request_fields(request_data)
request_extra = dumped_request.pop("extra_body", None)
thinking_cfg = dumped_request.get("thinking")
body: dict[str, Any] = {
key: value
for key, value in dumped_request.items()
if key not in _INTERNAL_FIELDS
}
if isinstance(request_extra, dict):
validate_openrouter_extra_body(request_extra)
body.update(request_extra)
body["messages"] = sanitize_native_messages_thinking_policy(
body.get("messages"),
thinking_enabled=thinking_enabled,
)
if "system" in body:
body["system"] = _normalize_system_prompt_for_openrouter(body["system"])
body["stream"] = True
if body.get("max_tokens") is None:
body["max_tokens"] = default_max_tokens
if thinking_enabled:
_apply_openrouter_reasoning_policy(body, thinking_cfg)
return body

View file

@ -0,0 +1,313 @@
"""Shared native Anthropic SSE thinking policy, block remapping, and overlap repair.
Used by :class:`OpenRouterProvider` and line-mode
:class:`providers.anthropic_messages.AnthropicMessagesTransport` providers.
"""
from __future__ import annotations
import copy
import json
from dataclasses import dataclass, field
from typing import Any
__all__ = [
"NativeSseBlockPolicyState",
"format_native_sse_event",
"is_terminal_openrouter_done_event",
"parse_native_sse_event",
"transform_native_sse_block_event",
]
@dataclass
class _UpstreamBlockState:
"""Per-upstream content block: segment index and liveness in the model stream."""
block_type: str
down_index: int
open: bool
last_start_block: dict[str, Any] | None = None
@dataclass
class NativeSseBlockPolicyState:
"""Track per-upstream content blocks and remapped Anthropic ``index`` field."""
next_index: int = 0
by_upstream: dict[int, _UpstreamBlockState] = field(default_factory=dict)
dropped_indexes: set[int] = field(default_factory=set)
pending_suppressed_stops: set[int] = field(default_factory=set)
message_stopped: bool = False
def format_native_sse_event(event_name: str | None, data_text: str) -> str:
"""Format an SSE event from its event name and data payload."""
lines: list[str] = []
if event_name:
lines.append(f"event: {event_name}")
lines.extend(f"data: {line}" for line in data_text.splitlines())
return "\n".join(lines) + "\n\n"
def parse_native_sse_event(event: str) -> tuple[str | None, str]:
"""Extract the event name and raw data payload from an SSE event."""
event_name = None
data_lines: list[str] = []
for line in event.strip().splitlines():
if line.startswith("event:"):
event_name = line[6:].strip()
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
return event_name, "\n".join(data_lines)
def is_terminal_openrouter_done_event(event_name: str | None, data_text: str) -> bool:
"""Return whether an event is OpenAI-style terminal noise (``[DONE]``)."""
return (event_name is None or event_name in {"data", "done"}) and (
data_text.strip().upper() == "[DONE]"
)
def _delta_type_to_block_kind(delta_type: Any) -> str | None:
"""Map a content_block_delta type to a content block kind (text/thinking/tool_use)."""
if not isinstance(delta_type, str):
return None
if delta_type in {"thinking_delta", "signature_delta"}:
return "thinking"
if delta_type == "text_delta":
return "text"
if delta_type == "input_json_delta":
return "tool_use"
return None
def _synthetic_start_content_block(
block_kind: str,
*,
upstream_index: int,
stored_tool_block: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build a `content_block` for a `content_block_start` with empty streaming fields."""
if block_kind == "tool_use":
if (
isinstance(stored_tool_block, dict)
and stored_tool_block.get("type") == "tool_use"
):
tool_id = stored_tool_block.get("id")
name = stored_tool_block.get("name")
inp = stored_tool_block.get("input")
return {
"type": "tool_use",
"id": tool_id
if isinstance(tool_id, str) and tool_id
else f"toolu_or_{upstream_index}",
"name": name if isinstance(name, str) else "",
"input": inp if isinstance(inp, dict) else {},
}
return {
"type": "tool_use",
"id": f"toolu_or_{upstream_index}",
"name": "",
"input": {},
}
if block_kind == "thinking":
return {"type": "thinking", "thinking": ""}
if block_kind == "text":
return {"type": "text", "text": ""}
return {"type": "text", "text": ""}
def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool:
if not isinstance(block_type, str):
return False
if block_type.startswith("redacted_thinking"):
return not thinking_enabled
return not thinking_enabled and "thinking" in block_type
def _synthetic_close_other_open_blocks(
state: NativeSseBlockPolicyState, current_upstream: int
) -> str:
"""Close every open block except `current_upstream` and track duplicate upstream stops."""
out: list[str] = []
for upstream, seg in list(state.by_upstream.items()):
if upstream == current_upstream or not seg.open:
continue
out.append(
format_native_sse_event(
"content_block_stop",
json.dumps(
{
"type": "content_block_stop",
"index": seg.down_index,
}
),
)
)
seg.open = False
state.pending_suppressed_stops.add(upstream)
return "".join(out)
def _allocate_new_segment(
state: NativeSseBlockPolicyState,
upstream_index: int,
block_type: str,
*,
last_start_block: dict[str, Any] | None = None,
) -> int:
"""Assign a new downstream `index` for a segment and record upstream state."""
new_idx = state.next_index
state.next_index += 1
state.by_upstream[upstream_index] = _UpstreamBlockState(
block_type=block_type,
down_index=new_idx,
open=True,
last_start_block=last_start_block,
)
return new_idx
def transform_native_sse_block_event(
event: str,
state: NativeSseBlockPolicyState,
*,
thinking_enabled: bool,
) -> str | None:
"""Normalize native Anthropic SSE events and enforce local thinking policy."""
event_name, data_text = parse_native_sse_event(event)
if not event_name or not data_text:
return event
try:
payload = json.loads(data_text)
except json.JSONDecodeError:
return event
if event_name == "content_block_start":
block = payload.get("content_block")
if not isinstance(block, dict):
return event
block_type = block.get("type")
upstream_index = payload.get("index")
if not isinstance(upstream_index, int):
return event
if _should_drop_block_type(block_type, thinking_enabled=thinking_enabled):
state.dropped_indexes.add(upstream_index)
return None
if not isinstance(block_type, str):
return event
prefix = _synthetic_close_other_open_blocks(state, upstream_index)
stored = copy.deepcopy(block)
new_idx = _allocate_new_segment(
state,
upstream_index,
block_type=block_type,
last_start_block=stored,
)
payload["index"] = new_idx
return prefix + format_native_sse_event(event_name, json.dumps(payload))
if event_name == "content_block_delta":
delta = payload.get("delta")
if not isinstance(delta, dict):
return event
delta_type = delta.get("type")
upstream_index = payload.get("index")
if not isinstance(upstream_index, int):
return event
if upstream_index in state.dropped_indexes:
return None
if _should_drop_block_type(delta_type, thinking_enabled=thinking_enabled):
return None
block_kind = _delta_type_to_block_kind(delta_type)
if block_kind is None:
return event
seg = state.by_upstream.get(upstream_index)
if seg and seg.open:
payload["index"] = seg.down_index
return format_native_sse_event(event_name, json.dumps(payload))
if seg is not None and not seg.open:
# More deltas for an upstream block after a synthetic (or other) close:
# reopen with a new downstream `index` and emit a synthetic `content_block_start` first.
state.pending_suppressed_stops.discard(upstream_index)
carry = seg.last_start_block
new_idx = _allocate_new_segment(
state,
upstream_index,
block_type=block_kind,
last_start_block=carry,
)
stored_tool = (
carry
if isinstance(carry, dict) and carry.get("type") == "tool_use"
else None
)
start_payload = {
"type": "content_block_start",
"index": new_idx,
"content_block": _synthetic_start_content_block(
block_kind,
upstream_index=upstream_index,
stored_tool_block=stored_tool,
),
}
prefix = format_native_sse_event(
"content_block_start", json.dumps(start_payload)
)
payload["index"] = new_idx
return prefix + format_native_sse_event(event_name, json.dumps(payload))
# Delta with no prior `content_block_start` in this stream
if block_kind in ("text", "tool_use"):
synthetic_block = _synthetic_start_content_block(
block_kind,
upstream_index=upstream_index,
)
new_idx = _allocate_new_segment(
state,
upstream_index,
block_type=block_kind,
last_start_block=copy.deepcopy(synthetic_block),
)
start_payload = {
"type": "content_block_start",
"index": new_idx,
"content_block": synthetic_block,
}
prefix = format_native_sse_event(
"content_block_start", json.dumps(start_payload)
)
payload["index"] = new_idx
return prefix + format_native_sse_event(event_name, json.dumps(payload))
# thinking: pass through raw (unusual upstream shape)
return event
if event_name == "content_block_stop":
upstream_index = payload.get("index")
if not isinstance(upstream_index, int):
return event
if upstream_index in state.dropped_indexes:
return None
if upstream_index in state.pending_suppressed_stops:
state.pending_suppressed_stops.discard(upstream_index)
return None
seg = state.by_upstream.get(upstream_index)
if seg is not None and seg.open:
payload["index"] = seg.down_index
seg.open = False
return format_native_sse_event(event_name, json.dumps(payload))
if seg is not None:
# Spurious or duplicate `content_block_stop` for a closed block.
return None
if not thinking_enabled:
return None
return event
return event

View file

@ -0,0 +1,34 @@
"""Canonical Anthropic-style SSE sequence for provider-side streaming errors."""
from __future__ import annotations
import uuid
from collections.abc import Iterator
from typing import Any
from core.anthropic.sse import SSEBuilder
def iter_provider_stream_error_sse_events(
*,
request: Any,
input_tokens: int,
error_message: str,
sent_any_event: bool,
log_raw_sse_events: bool,
message_id: str | None = None,
) -> Iterator[str]:
"""Yield message_start (if needed), a text block with the error, then message_delta/stop."""
mid = message_id or f"msg_{uuid.uuid4()}"
model = getattr(request, "model", "") or ""
sse = SSEBuilder(
mid,
model,
input_tokens,
log_raw_events=log_raw_sse_events,
)
if not sent_any_event:
yield sse.message_start()
yield from sse.emit_error(error_message)
yield sse.message_delta("end_turn", 1)
yield sse.message_stop()

View file

@ -0,0 +1,14 @@
"""SSE content_block ``type`` values for Anthropic web server tools (local handlers).
Shared by :mod:`api.web_tools` and stream contract tests to avoid drift.
"""
from __future__ import annotations
from typing import Final
SERVER_TOOL_USE: Final = "server_tool_use"
WEB_SEARCH_TOOL_RESULT: Final = "web_search_tool_result"
WEB_FETCH_TOOL_RESULT: Final = "web_fetch_tool_result"
WEB_SEARCH_TOOL_RESULT_ERROR: Final = "web_search_tool_result_error"
WEB_FETCH_TOOL_ERROR: Final = "web_fetch_tool_error"

View file

@ -1,5 +1,6 @@
"""SSE event builder for Anthropic-format streaming responses."""
import hashlib
import json
from collections.abc import Iterator
from dataclasses import dataclass, field
@ -14,6 +15,13 @@ except Exception:
ENCODER = None
# Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy.
ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = {
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
STOP_REASON_MAP = {
"stop": "end_turn",
"length": "max_tokens",
@ -29,6 +37,11 @@ def map_stop_reason(openai_reason: str | None) -> str:
)
def format_sse_event(event_type: str, data: dict) -> str:
"""Format one Anthropic-style SSE event (no logging)."""
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
@dataclass
class ToolCallState:
"""State for a single streaming tool call."""
@ -40,6 +53,7 @@ class ToolCallState:
started: bool = False
task_arg_buffer: str = ""
task_args_emitted: bool = False
pre_start_args: str = ""
@dataclass
@ -58,7 +72,25 @@ class ContentBlockManager:
self.next_index += 1
return idx
def ensure_tool_state(self, index: int) -> ToolCallState:
"""Create tool stream state for ``index`` when the first tool delta arrives."""
if index not in self.tool_states:
self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="")
return self.tool_states[index]
def set_stream_tool_id(self, index: int, tool_id: str | None) -> None:
"""Record OpenAI tool call id before ``content_block_start`` (split-stream providers)."""
if not tool_id:
return
state = self.ensure_tool_state(index)
state.tool_id = str(tool_id)
def register_tool_name(self, index: int, name: str) -> None:
"""Record tool name fragments as they arrive from chunked OpenAI streams.
Names may be split across deltas; later chunks can extend (``ab`` + ``c``)
or repeat prefixes, so we merge conservatively.
"""
if index not in self.tool_states:
self.tool_states[index] = ToolCallState(
block_index=-1, tool_id="", name=name
@ -82,8 +114,7 @@ class ContentBlockManager:
except Exception:
return None
if args_json.get("run_in_background") is not False:
args_json["run_in_background"] = False
_normalize_task_run_in_background(args_json)
state.task_args_emitted = True
state.task_arg_buffer = ""
@ -98,16 +129,17 @@ class ContentBlockManager:
out = "{}"
try:
args_json = json.loads(state.task_arg_buffer)
if args_json.get("run_in_background") is not False:
args_json["run_in_background"] = False
_normalize_task_run_in_background(args_json)
out = json.dumps(args_json)
except Exception as e:
prefix = state.task_arg_buffer[:120]
except (json.JSONDecodeError, TypeError, ValueError) as e:
digest = hashlib.sha256(
state.task_arg_buffer.encode("utf-8", errors="replace")
).hexdigest()[:16]
logger.warning(
"Task args invalid JSON (id={} len={} prefix={!r}): {}",
"Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}",
state.tool_id or "unknown",
len(state.task_arg_buffer),
prefix,
digest,
e,
)
@ -117,20 +149,41 @@ class ContentBlockManager:
return results
def _normalize_task_run_in_background(args_json: dict) -> None:
"""Force Claude Code Task subagents to run in foreground (single shared rule)."""
if args_json.get("run_in_background") is not False:
args_json["run_in_background"] = False
class SSEBuilder:
"""Builder for Anthropic SSE streaming events."""
def __init__(self, message_id: str, model: str, input_tokens: int = 0):
def __init__(
self,
message_id: str,
model: str,
input_tokens: int = 0,
*,
log_raw_events: bool = False,
):
self.message_id = message_id
self.model = model
self.input_tokens = input_tokens
self._log_raw_events = log_raw_events
self.blocks = ContentBlockManager()
self._accumulated_text_parts: list[str] = []
self._accumulated_reasoning_parts: list[str] = []
def _format_event(self, event_type: str, data: dict) -> str:
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
event_str = format_sse_event(event_type, data)
if self._log_raw_events:
logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
else:
logger.debug(
"SSE_EVENT: event_type={} serialized_bytes={}",
event_type,
len(event_str.encode("utf-8")),
)
return event_str
def message_start(self) -> str:
@ -289,10 +342,7 @@ class SSEBuilder:
yield self.stop_text_block()
def close_all_blocks(self) -> Iterator[str]:
if self.blocks.thinking_started:
yield self.stop_thinking_block()
if self.blocks.text_started:
yield self.stop_text_block()
yield from self.close_content_blocks()
for tool_index, state in list(self.blocks.tool_states.items()):
if state.started:
yield self.stop_tool_block(tool_index)

View file

@ -1,7 +1,6 @@
"""Neutral SSE parsing and Anthropic stream shape assertions.
Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for
opt-in smoke scenarios.
Used by default CI contract tests and by opt-in live smoke scenarios.
"""
from __future__ import annotations
@ -11,6 +10,36 @@ from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
from .server_tool_sse import (
SERVER_TOOL_USE,
WEB_FETCH_TOOL_RESULT,
WEB_SEARCH_TOOL_RESULT,
)
# Content blocks that only use content_block_start/stop (no deltas), including
# Anthropic server tools and eager text emitted in a single start event.
_NO_DELTA_BLOCK_KINDS = frozenset(
{
SERVER_TOOL_USE,
WEB_SEARCH_TOOL_RESULT,
WEB_FETCH_TOOL_RESULT,
"text_eager",
"redacted_thinking",
}
)
_ALLOWED_BLOCK_START_TYPES = frozenset(
{
"text",
"thinking",
"tool_use",
"redacted_thinking",
SERVER_TOOL_USE,
WEB_SEARCH_TOOL_RESULT,
WEB_FETCH_TOOL_RESULT,
}
)
@dataclass(frozen=True, slots=True)
class SSEEvent:
@ -86,35 +115,46 @@ def assert_anthropic_stream_contract(
raise AssertionError(f"unexpected SSE error event: {event.data}")
if event.event == "content_block_start":
index = _event_index(event)
index = event_index(event)
block = event.data.get("content_block", {})
assert isinstance(block, dict), event.data
block_type = str(block.get("type", ""))
assert block_type in {"text", "thinking", "tool_use"}, event.data
assert block_type in _ALLOWED_BLOCK_START_TYPES, event.data
assert index not in open_blocks, f"block {index} started twice"
assert index not in seen_blocks, f"block {index} reused after stop"
open_blocks[index] = block_type
if block_type == "text" and str(block.get("text", "")).strip():
storage = "text_eager"
else:
storage = block_type
open_blocks[index] = storage
seen_blocks.add(index)
continue
if event.event == "content_block_delta":
index = _event_index(event)
index = event_index(event)
assert index in open_blocks, f"delta for unopened block {index}"
kind = open_blocks[index]
assert kind not in _NO_DELTA_BLOCK_KINDS, (
f"unexpected delta for start/stop-only block {kind} at index {index}"
)
delta = event.data.get("delta", {})
assert isinstance(delta, dict), event.data
delta_type = str(delta.get("type", ""))
if kind == "thinking":
assert delta_type in (
"thinking_delta",
"signature_delta",
), f"block {index} is {kind}, got {delta_type}"
continue
expected = {
"text": "text_delta",
"thinking": "thinking_delta",
"tool_use": "input_json_delta",
}[open_blocks[index]]
assert delta_type == expected, (
f"block {index} is {open_blocks[index]}, got {delta_type}"
)
}[kind]
assert delta_type == expected, f"block {index} is {kind}, got {delta_type}"
continue
if event.event == "content_block_stop":
index = _event_index(event)
index = event_index(event)
assert index in open_blocks, f"stop for unopened block {index}"
open_blocks.pop(index)
@ -129,6 +169,12 @@ def event_names(events: list[SSEEvent]) -> list[str]:
def text_content(events: list[SSEEvent]) -> str:
parts: list[str] = []
for event in events:
if event.event == "content_block_start":
block = event.data.get("content_block", {})
if isinstance(block, dict) and block.get("type") == "text":
eager = str(block.get("text", ""))
if eager:
parts.append(eager)
delta = event.data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "text_delta":
parts.append(str(delta.get("text", "")))
@ -152,7 +198,8 @@ def has_tool_use(events: list[SSEEvent]) -> bool:
return False
def _event_index(event: SSEEvent) -> int:
def event_index(event: SSEEvent) -> int:
"""Return the content block ``index`` field from an SSE payload (strict)."""
value = event.data.get("index")
assert isinstance(value, int), event.data
return value

View file

@ -68,6 +68,27 @@ def get_token_count(
total_tokens += len(ENCODER.encode(json.dumps(content)))
total_tokens += len(ENCODER.encode(str(tool_use_id)))
total_tokens += 8
elif b_type in (
"server_tool_use",
"web_search_tool_result",
"web_fetch_tool_result",
):
if hasattr(block, "model_dump"):
blob: object = block.model_dump()
else:
blob = block
try:
total_tokens += len(
ENCODER.encode(
json.dumps(blob, default=str, ensure_ascii=False)
)
)
except (TypeError, ValueError, OverflowError) as e:
logger.debug(
"Block encode fallback b_type={} err={}", b_type, e
)
total_tokens += len(ENCODER.encode(str(blob)))
total_tokens += 12
else:
logger.debug(
"Unexpected block type %r, falling back to json/str encoding",

60
core/rate_limit.py Normal file
View file

@ -0,0 +1,60 @@
"""Shared strict sliding-window rate limiting primitives."""
from __future__ import annotations
import asyncio
import time
from collections import deque
class StrictSlidingWindowLimiter:
"""Strict sliding window limiter.
Guarantees: at most ``rate_limit`` acquisitions in any interval of length
``rate_window`` (seconds).
Implemented as an async context manager so call sites can do::
async with limiter:
...
"""
def __init__(self, rate_limit: int, rate_window: float) -> None:
if rate_limit <= 0:
raise ValueError("rate_limit must be > 0")
if rate_window <= 0:
raise ValueError("rate_window must be > 0")
self._rate_limit = int(rate_limit)
self._rate_window = float(rate_window)
self._times: deque[float] = deque()
self._lock = asyncio.Lock()
async def acquire(self) -> None:
while True:
wait_time = 0.0
async with self._lock:
now = time.monotonic()
cutoff = now - self._rate_window
while self._times and self._times[0] <= cutoff:
self._times.popleft()
if len(self._times) < self._rate_limit:
self._times.append(now)
return
oldest = self._times[0]
wait_time = max(0.0, (oldest + self._rate_window) - now)
if wait_time > 0:
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0)
async def __aenter__(self) -> StrictSlidingWindowLimiter:
await self.acquire()
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False

View file

@ -0,0 +1,67 @@
"""CLI event types and status-line mapping for transcript / UI updates."""
from collections.abc import Callable
from typing import Any
# Status message prefixes used to filter our own messages (ignore echo)
STATUS_MESSAGE_PREFIXES = (
"",
"💭",
"🔧",
"",
"",
"🚀",
"🤖",
"📋",
"📊",
"🔄",
)
# Event types that update the transcript (frozenset for O(1) membership)
TRANSCRIPT_EVENT_TYPES = frozenset(
{
"thinking_start",
"thinking_delta",
"thinking_chunk",
"thinking_stop",
"text_start",
"text_delta",
"text_chunk",
"text_stop",
"tool_use_start",
"tool_use_delta",
"tool_use_stop",
"tool_use",
"tool_result",
"block_stop",
"error",
}
)
# Event type -> (emoji, label) for status updates (O(1) lookup)
_EVENT_STATUS_MAP: dict[str, tuple[str, str]] = {
"thinking_start": ("🧠", "Claude is thinking..."),
"thinking_delta": ("🧠", "Claude is thinking..."),
"thinking_chunk": ("🧠", "Claude is thinking..."),
"text_start": ("🧠", "Claude is working..."),
"text_delta": ("🧠", "Claude is working..."),
"text_chunk": ("🧠", "Claude is working..."),
"tool_result": ("", "Executing tools..."),
}
def get_status_for_event(
ptype: str,
parsed: dict[str, Any],
format_status_fn: Callable[..., str],
) -> str | None:
"""Return status string for event type, or None if no status update needed."""
entry = _EVENT_STATUS_MAP.get(ptype)
if entry is not None:
emoji, label = entry
return format_status_fn(emoji, label)
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
if parsed.get("name") == "Task":
return format_status_fn("🤖", "Subagent working...")
return format_status_fn("", "Executing tools...")
return None

View file

@ -9,12 +9,14 @@ from typing import Any
from loguru import logger
def parse_cli_event(event: Any) -> list[dict]:
def parse_cli_event(event: Any, *, log_raw_cli: bool = False) -> list[dict]:
"""
Parse a CLI event and return a structured result.
Args:
event: Raw event dictionary from CLI
log_raw_cli: When True, log full error text from the CLI. Default is
metadata-only (lengths / exit codes) to avoid leaking user content.
Returns:
List of parsed event dicts. Empty list if not recognized.
@ -140,7 +142,11 @@ def parse_cli_event(event: Any) -> list[dict]:
if etype == "error":
err = event.get("error")
msg = err.get("message") if isinstance(err, dict) else str(err)
logger.info(f"CLI_PARSER: Parsed error event: {msg}")
if log_raw_cli:
logger.info("CLI_PARSER: Parsed error event: {}", msg)
else:
mlen = len(msg) if isinstance(msg, str) else 0
logger.info("CLI_PARSER: Parsed error event: message_chars={}", mlen)
return [{"type": "error", "message": msg}]
elif etype == "exit":
code = event.get("code", 0)
@ -151,7 +157,19 @@ def parse_cli_event(event: Any) -> list[dict]:
else:
# Non-zero exit is an error
error_msg = stderr if stderr else f"Process exited with code {code}"
logger.warning(f"CLI_PARSER: Error exit (code={code}): {error_msg}")
if log_raw_cli:
logger.warning(
"CLI_PARSER: Error exit (code={}): {}",
code,
error_msg,
)
else:
em = error_msg if isinstance(error_msg, str) else str(error_msg)
logger.warning(
"CLI_PARSER: Error exit (code={}): message_chars={}",
code,
len(em),
)
return [
{"type": "error", "message": error_msg},
{"type": "complete", "status": "failed"},

View file

@ -7,13 +7,12 @@ Uses tree-based queuing for message ordering.
"""
import asyncio
import os
import time
from loguru import logger
from core.anthropic import get_user_facing_error_message
from core.anthropic import format_user_error_preview, get_user_facing_error_message
from .cli_event_constants import STATUS_MESSAGE_PREFIXES
from .command_dispatcher import (
dispatch_command,
message_kind_for_command,
@ -21,8 +20,10 @@ from .command_dispatcher import (
)
from .event_parser import parse_cli_event
from .models import IncomingMessage
from .node_event_pipeline import handle_session_info_event, process_parsed_cli_event
from .platforms.base import MessagingPlatform, SessionManagerInterface
from .rendering.profiles import build_rendering_profile
from .safe_diagnostics import format_exception_for_log
from .session import SessionStore
from .transcript import RenderCtx, TranscriptBuffer
from .trees.queue_manager import (
@ -31,54 +32,7 @@ from .trees.queue_manager import (
MessageTree,
TreeQueueManager,
)
# Status message prefixes used to filter our own messages (ignore echo)
STATUS_MESSAGE_PREFIXES = ("", "💭", "🔧", "", "", "🚀", "🤖", "📋", "📊", "🔄")
# Event types that update the transcript (frozenset for O(1) membership)
TRANSCRIPT_EVENT_TYPES = frozenset(
{
"thinking_start",
"thinking_delta",
"thinking_chunk",
"thinking_stop",
"text_start",
"text_delta",
"text_chunk",
"text_stop",
"tool_use_start",
"tool_use_delta",
"tool_use_stop",
"tool_use",
"tool_result",
"block_stop",
"error",
}
)
# Event type -> (emoji, label) for status updates (O(1) lookup)
_EVENT_STATUS_MAP = {
"thinking_start": ("🧠", "Claude is thinking..."),
"thinking_delta": ("🧠", "Claude is thinking..."),
"thinking_chunk": ("🧠", "Claude is thinking..."),
"text_start": ("🧠", "Claude is working..."),
"text_delta": ("🧠", "Claude is working..."),
"text_chunk": ("🧠", "Claude is working..."),
"tool_result": ("", "Executing tools..."),
}
def _get_status_for_event(ptype: str, parsed: dict, format_status_fn) -> str | None:
"""Return status string for event type, or None if no status update needed."""
entry = _EVENT_STATUS_MAP.get(ptype)
if entry is not None:
emoji, label = entry
return format_status_fn(emoji, label)
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
if parsed.get("name") == "Task":
return format_status_fn("🤖", "Subagent working...")
return format_status_fn("", "Executing tools...")
return None
from .ui_updates import ThrottledTranscriptEditor
class ClaudeMessageHandler:
@ -97,10 +51,21 @@ class ClaudeMessageHandler:
platform: MessagingPlatform,
cli_manager: SessionManagerInterface,
session_store: SessionStore,
*,
debug_platform_edits: bool = False,
debug_subagent_stack: bool = False,
log_raw_messaging_content: bool = False,
log_raw_cli_diagnostics: bool = False,
log_messaging_error_details: bool = False,
):
self.platform = platform
self.cli_manager = cli_manager
self.session_store = session_store
self._debug_platform_edits = debug_platform_edits
self._debug_subagent_stack = debug_subagent_stack
self._log_raw_messaging_content = log_raw_messaging_content
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
self._log_messaging_error_details = log_messaging_error_details
self._tree_queue = TreeQueueManager(
queue_update_callback=self.update_queue_positions,
node_started_callback=self.mark_node_processing,
@ -137,8 +102,10 @@ class ClaudeMessageHandler:
Determines if this is a new conversation or reply,
creates/extends the message tree, and queues for processing.
"""
text_preview = (incoming.text or "")[:80]
if len(incoming.text or "") > 80:
raw = incoming.text or ""
if self._log_raw_messaging_content:
text_preview = raw[:80]
if len(raw) > 80:
text_preview += "..."
logger.info(
"HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}",
@ -147,6 +114,14 @@ class ClaudeMessageHandler:
incoming.reply_to_message_id,
text_preview,
)
else:
logger.info(
"HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_len={}",
incoming.chat_id,
incoming.message_id,
incoming.reply_to_message_id,
len(raw),
)
with logger.contextualize(
chat_id=incoming.chat_id, node_id=incoming.message_id
@ -169,7 +144,12 @@ class ClaudeMessageHandler:
kind=message_kind_for_command(cmd_base),
)
except Exception as e:
logger.debug(f"Failed to record incoming message_id: {e}")
logger.debug(
"Failed to record incoming message_id: {}",
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)
if await dispatch_command(self, incoming, cmd_base):
return
@ -276,7 +256,12 @@ class ClaudeMessageHandler:
try:
queued_ids = await tree.get_queue_snapshot()
except Exception as e:
logger.warning(f"Failed to read queue snapshot: {e}")
logger.warning(
"Failed to read queue snapshot: {}",
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)
return
if not queued_ids:
@ -317,86 +302,12 @@ class ClaudeMessageHandler:
self,
) -> tuple[TranscriptBuffer, RenderCtx]:
"""Create transcript buffer and render context for node processing."""
transcript = TranscriptBuffer(show_tool_results=False)
transcript = TranscriptBuffer(
show_tool_results=False,
debug_subagent_stack=self._debug_subagent_stack,
)
return transcript, self.get_render_ctx()
async def _handle_session_info_event(
self,
event_data: dict,
tree: MessageTree | None,
node_id: str,
captured_session_id: str | None,
temp_session_id: str | None,
) -> tuple[str | None, str | None]:
"""Handle session_info event; return updated (captured_session_id, temp_session_id)."""
if event_data.get("type") != "session_info":
return captured_session_id, temp_session_id
real_session_id = event_data.get("session_id")
if not real_session_id or not temp_session_id:
return captured_session_id, temp_session_id
await self.cli_manager.register_real_session_id(
temp_session_id, real_session_id
)
if tree and real_session_id:
await tree.update_state(
node_id,
MessageState.IN_PROGRESS,
session_id=real_session_id,
)
self.session_store.save_tree(tree.root_id, tree.to_dict())
return real_session_id, None
async def _process_parsed_event(
self,
parsed: dict,
transcript: TranscriptBuffer,
update_ui,
last_status: str | None,
had_transcript_events: bool,
tree: MessageTree | None,
node_id: str,
captured_session_id: str | None,
) -> tuple[str | None, bool]:
"""Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
ptype = parsed.get("type") or ""
if ptype in TRANSCRIPT_EVENT_TYPES:
transcript.apply(parsed)
had_transcript_events = True
status = _get_status_for_event(ptype, parsed, self.format_status)
if status is not None:
await update_ui(status)
last_status = status
elif ptype == "block_stop":
await update_ui(last_status, force=True)
elif ptype == "complete":
if not had_transcript_events:
transcript.apply({"type": "text_chunk", "text": "Done."})
logger.info("HANDLER: Task complete, updating UI")
await update_ui(self.format_status("", "Complete"), force=True)
if tree and captured_session_id:
await tree.update_state(
node_id,
MessageState.COMPLETED,
session_id=captured_session_id,
)
self.session_store.save_tree(tree.root_id, tree.to_dict())
elif ptype == "error":
error_msg = parsed.get("message", "Unknown error")
logger.error(f"HANDLER: Error event received: {error_msg}")
logger.info("HANDLER: Updating UI with error status")
await update_ui(self.format_status("", "Error"), force=True)
if tree:
await self._propagate_error_to_children(
node_id, error_msg, "Parent task failed"
)
return last_status, had_transcript_events
async def _process_node(
self,
node_id: str,
@ -426,8 +337,6 @@ class ClaudeMessageHandler:
transcript, render_ctx = self._create_transcript_and_render_ctx()
last_ui_update = 0.0
last_displayed_text = None
had_transcript_events = False
captured_session_id = None
temp_session_id = None
@ -439,52 +348,21 @@ class ClaudeMessageHandler:
if parent_session_id:
logger.info(f"Will fork from parent session: {parent_session_id}")
async def update_ui(status: str | None = None, force: bool = False) -> None:
nonlocal last_ui_update, last_displayed_text, last_status
now = time.time()
if not force and now - last_ui_update < 1.0:
return
last_ui_update = now
if status is not None:
last_status = status
try:
display = transcript.render(
render_ctx,
limit_chars=self._get_limit_chars(),
status=status,
)
except Exception as e:
logger.warning(f"Transcript render failed for node {node_id}: {e}")
return
if display and display != last_displayed_text:
logger.debug(
"PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
node_id,
chat_id,
status_msg_id,
bool(force),
status,
len(display),
)
if os.getenv("DEBUG_PLATFORM_EDITS") == "1":
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
else:
head = display[:500]
tail = display[-500:] if len(display) > 500 else ""
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head)
if tail:
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail)
last_displayed_text = display
try:
await self.platform.queue_edit_message(
chat_id,
status_msg_id,
display,
editor = ThrottledTranscriptEditor(
platform=self.platform,
parse_mode=self._parse_mode(),
get_limit_chars=self._get_limit_chars,
transcript=transcript,
render_ctx=render_ctx,
node_id=node_id,
chat_id=chat_id,
status_msg_id=status_msg_id,
debug_platform_edits=self._debug_platform_edits,
log_messaging_error_details=self._log_messaging_error_details,
)
except Exception as e:
logger.warning(f"Failed to update platform for node {node_id}: {e}")
async def update_ui(status: str | None = None, force: bool = False) -> None:
await editor.update(status, force=force)
try:
try:
@ -531,20 +409,28 @@ class ClaudeMessageHandler:
(
captured_session_id,
temp_session_id,
) = await self._handle_session_info_event(
event_data, tree, node_id, captured_session_id, temp_session_id
) = await handle_session_info_event(
event_data,
tree,
node_id,
captured_session_id,
temp_session_id,
cli_manager=self.cli_manager,
session_store=self.session_store,
)
if event_data.get("type") == "session_info":
continue
parsed_list = parse_cli_event(event_data)
parsed_list = parse_cli_event(
event_data, log_raw_cli=self._log_raw_cli_diagnostics
)
logger.debug(f"HANDLER: Parsed {len(parsed_list)} events from CLI")
for parsed in parsed_list:
(
last_status,
had_transcript_events,
) = await self._process_parsed_event(
) = await process_parsed_cli_event(
parsed,
transcript,
update_ui,
@ -553,6 +439,10 @@ class ClaudeMessageHandler:
tree,
node_id,
captured_session_id,
session_store=self.session_store,
format_status=self.format_status,
propagate_error_to_children=self._propagate_error_to_children,
log_messaging_error_details=self._log_messaging_error_details,
)
except asyncio.CancelledError:
@ -575,9 +465,12 @@ class ClaudeMessageHandler:
)
except Exception as e:
logger.error(
f"HANDLER: Task failed with exception: {type(e).__name__}: {e}"
"HANDLER: Task failed with exception: {}",
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)
error_msg = get_user_facing_error_message(e)[:200]
error_msg = format_user_error_preview(e)
transcript.apply({"type": "error", "message": error_msg})
await update_ui(self.format_status("💥", "Task Failed"), force=True)
if tree:
@ -595,7 +488,13 @@ class ClaudeMessageHandler:
elif temp_session_id:
await self.cli_manager.remove_session(temp_session_id)
except Exception as e:
logger.debug(f"Failed to remove session for node {node_id}: {e}")
logger.debug(
"Failed to remove session for node {}: {}",
node_id,
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)
async def _propagate_error_to_children(
self,
@ -691,7 +590,12 @@ class ClaudeMessageHandler:
platform, chat_id, str(msg_id), direction="out", kind=kind
)
except Exception as e:
logger.debug(f"Failed to record message_id: {e}")
logger.debug(
"Failed to record message_id: {}",
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)
def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
"""Update status messages and persist tree state for cancelled nodes."""

View file

@ -6,65 +6,16 @@ using a strict sliding window algorithm and a task queue.
"""
import asyncio
import os
import time
from collections import deque
from collections.abc import Awaitable, Callable
from typing import Any
from loguru import logger
from config.settings import get_settings
from core.rate_limit import StrictSlidingWindowLimiter as SlidingWindowLimiter
class SlidingWindowLimiter:
"""Strict sliding window limiter.
Guarantees: at most `rate_limit` acquisitions in any interval of length
`rate_window` (seconds).
Implemented as an async context manager so call sites can do:
async with limiter:
...
"""
def __init__(self, rate_limit: int, rate_window: float) -> None:
if rate_limit <= 0:
raise ValueError("rate_limit must be > 0")
if rate_window <= 0:
raise ValueError("rate_window must be > 0")
self._rate_limit = int(rate_limit)
self._rate_window = float(rate_window)
self._times: deque[float] = deque()
self._lock = asyncio.Lock()
async def acquire(self) -> None:
while True:
wait_time = 0.0
async with self._lock:
now = time.monotonic()
cutoff = now - self._rate_window
while self._times and self._times[0] <= cutoff:
self._times.popleft()
if len(self._times) < self._rate_limit:
self._times.append(now)
return
oldest = self._times[0]
wait_time = max(0.0, (oldest + self._rate_window) - now)
if wait_time > 0:
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0)
async def __aenter__(self) -> SlidingWindowLimiter:
await self.acquire()
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
from .safe_diagnostics import format_exception_for_log
class MessagingRateLimiter:
@ -82,23 +33,29 @@ class MessagingRateLimiter:
return super().__new__(cls)
@classmethod
async def get_instance(cls) -> MessagingRateLimiter:
"""Get the singleton instance of the limiter."""
async def get_instance(
cls,
*,
rate_limit: int = 1,
rate_window: float = 1.0,
) -> MessagingRateLimiter:
"""Get the singleton instance of the limiter.
``rate_limit`` and ``rate_window`` apply only when the singleton is first
created. Call :meth:`shutdown_instance` before changing parameters.
"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
cls._instance = cls(rate_limit=rate_limit, rate_window=rate_window)
# Start the background worker (tracked for graceful shutdown).
cls._instance._start_worker()
return cls._instance
def __init__(self):
def __init__(self, *, rate_limit: int, rate_window: float) -> None:
# Prevent double initialization in singleton
if hasattr(self, "_initialized"):
return
rate_limit = int(os.getenv("MESSAGING_RATE_LIMIT", "1"))
rate_window = float(os.getenv("MESSAGING_RATE_WINDOW", "2.0"))
self.limiter = SlidingWindowLimiter(rate_limit, rate_window)
# Custom queue state - using deque for O(1) popleft
self._queue_list: deque[str] = deque() # Deque of dedup_keys in order
@ -189,14 +146,26 @@ class MessagingRateLimiter:
asyncio.get_event_loop().time() + wait_secs
)
else:
d = get_settings().log_messaging_error_details
logger.error(
f"Error in limiter worker for key {dedup_key}: {type(e).__name__}: {e}"
"Error in limiter worker for key {}: {}",
dedup_key,
format_exception_for_log(e, log_full_message=d),
)
except asyncio.CancelledError:
break
except Exception as e:
d = get_settings().log_messaging_error_details
if d:
logger.error(
f"MessagingRateLimiter worker critical error: {e}", exc_info=True
"MessagingRateLimiter worker critical error: {}",
e,
exc_info=True,
)
else:
logger.error(
"MessagingRateLimiter worker critical error: exc_type={}",
type(e).__name__,
)
await asyncio.sleep(1)
@ -223,7 +192,11 @@ class MessagingRateLimiter:
except asyncio.CancelledError:
pass
except Exception as e:
logger.debug(f"MessagingRateLimiter worker shutdown error: {e}")
d = get_settings().log_messaging_error_details
logger.debug(
"MessagingRateLimiter worker shutdown error: {}",
format_exception_for_log(e, log_full_message=d),
)
finally:
self._worker_task = None
@ -296,14 +269,29 @@ class MessagingRateLimiter:
x in error_msg for x in ["connect", "timeout", "broken"]
):
wait = 2**attempt
d = get_settings().log_messaging_error_details
if d:
logger.warning(
f"Limiter fire_and_forget transient error (attempt {attempt + 1}): {e}. Retrying in {wait}s..."
"Limiter fire_and_forget transient error (attempt {}): {}. Retrying in {}s...",
attempt + 1,
e,
wait,
)
else:
logger.warning(
"Limiter fire_and_forget transient error (attempt {}): exc_type={}. Retrying in {}s...",
attempt + 1,
type(e).__name__,
wait,
)
await asyncio.sleep(wait)
continue
d = get_settings().log_messaging_error_details
logger.error(
f"Final error in fire_and_forget for key {dedup_key}: {type(e).__name__}: {e}"
"Final error in fire_and_forget for key {}: {}",
dedup_key,
format_exception_for_log(e, log_full_message=d),
)
if not future.done():
future.set_exception(e)

View file

@ -0,0 +1,103 @@
"""CLI event handling for a single queued node (transcript + session + errors)."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any
from loguru import logger
from .cli_event_constants import TRANSCRIPT_EVENT_TYPES, get_status_for_event
from .platforms.base import SessionManagerInterface
from .safe_diagnostics import text_len_hint
from .session import SessionStore
from .transcript import TranscriptBuffer
from .trees.queue_manager import MessageState, MessageTree
async def handle_session_info_event(
event_data: dict[str, Any],
tree: MessageTree | None,
node_id: str,
captured_session_id: str | None,
temp_session_id: str | None,
*,
cli_manager: SessionManagerInterface,
session_store: SessionStore,
) -> tuple[str | None, str | None]:
"""Handle session_info event; return updated (captured_session_id, temp_session_id)."""
if event_data.get("type") != "session_info":
return captured_session_id, temp_session_id
real_session_id = event_data.get("session_id")
if not real_session_id or not temp_session_id:
return captured_session_id, temp_session_id
await cli_manager.register_real_session_id(temp_session_id, real_session_id)
if tree and real_session_id:
await tree.update_state(
node_id,
MessageState.IN_PROGRESS,
session_id=real_session_id,
)
session_store.save_tree(tree.root_id, tree.to_dict())
return real_session_id, None
async def process_parsed_cli_event(
parsed: dict[str, Any],
transcript: TranscriptBuffer,
update_ui: Callable[..., Awaitable[None]],
last_status: str | None,
had_transcript_events: bool,
tree: MessageTree | None,
node_id: str,
captured_session_id: str | None,
*,
session_store: SessionStore,
format_status: Callable[..., str],
propagate_error_to_children: Callable[[str, str, str], Awaitable[None]],
log_messaging_error_details: bool = False,
) -> tuple[str | None, bool]:
"""Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
ptype = parsed.get("type") or ""
if ptype in TRANSCRIPT_EVENT_TYPES:
transcript.apply(parsed)
had_transcript_events = True
status = get_status_for_event(ptype, parsed, format_status)
if status is not None:
await update_ui(status)
last_status = status
elif ptype == "block_stop":
await update_ui(last_status, force=True)
elif ptype == "complete":
if not had_transcript_events:
transcript.apply({"type": "text_chunk", "text": "Done."})
logger.info("HANDLER: Task complete, updating UI")
await update_ui(format_status("", "Complete"), force=True)
if tree and captured_session_id:
await tree.update_state(
node_id,
MessageState.COMPLETED,
session_id=captured_session_id,
)
session_store.save_tree(tree.root_id, tree.to_dict())
elif ptype == "error":
error_msg = parsed.get("message", "Unknown error")
if log_messaging_error_details:
logger.error("HANDLER: Error event received: {}", error_msg)
else:
em = error_msg if isinstance(error_msg, str) else str(error_msg)
logger.error(
"HANDLER: Error event received: message_chars={}",
text_len_hint(em),
)
logger.info("HANDLER: Updating UI with error status")
await update_ui(format_status("", "Error"), force=True)
if tree:
await propagate_error_to_children(node_id, error_msg, "Parent task failed")
return last_status, had_transcript_events

View file

@ -6,7 +6,6 @@ Implements MessagingPlatform for Discord using discord.py.
import asyncio
import contextlib
import os
import tempfile
from collections.abc import Awaitable, Callable
from pathlib import Path
@ -14,7 +13,7 @@ from typing import Any, cast
from loguru import logger
from core.anthropic import get_user_facing_error_message
from core.anthropic import format_user_error_preview
from ..models import IncomingMessage
from ..rendering.discord_markdown import format_status_discord
@ -97,15 +96,18 @@ class DiscordPlatform(MessagingPlatform):
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
messaging_rate_limit: int = 1,
messaging_rate_window: float = 1.0,
log_raw_messaging_content: bool = False,
log_api_error_tracebacks: bool = False,
):
if not DISCORD_AVAILABLE:
raise ImportError(
"discord.py is required. Install with: pip install discord.py"
)
self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
self.bot_token = bot_token
self.allowed_channel_ids = _parse_allowed_channels(allowed_channel_ids)
if not self.bot_token:
logger.warning("DISCORD_BOT_TOKEN not set")
@ -130,6 +132,10 @@ class DiscordPlatform(MessagingPlatform):
self._voice_note_enabled = voice_note_enabled
self._whisper_model = whisper_model
self._whisper_device = whisper_device
self._messaging_rate_limit = messaging_rate_limit
self._messaging_rate_window = messaging_rate_window
self._log_raw_messaging_content = log_raw_messaging_content
self._log_api_error_tracebacks = log_api_error_tracebacks
async def _handle_client_message(self, message: Any) -> None:
"""Adapter entry point used by the internal discord client."""
@ -234,23 +240,40 @@ class DiscordPlatform(MessagingPlatform):
status_message_id=status_msg_id,
)
if self._log_raw_messaging_content:
logger.info(
"DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}",
channel_id,
message_id,
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
(
transcribed[:80] + "..."
if len(transcribed) > 80
else transcribed
),
)
else:
logger.info(
"DISCORD_VOICE: chat_id={} message_id={} transcribed_len={}",
channel_id,
message_id,
len(transcribed),
)
await self._message_handler(incoming)
return True
except ValueError as e:
await message.reply(get_user_facing_error_message(e)[:200])
await message.reply(format_user_error_preview(e))
return True
except ImportError as e:
await message.reply(get_user_facing_error_message(e)[:200])
await message.reply(format_user_error_preview(e))
return True
except Exception as e:
logger.error(f"Voice transcription failed: {e}")
if self._log_api_error_tracebacks:
logger.error("Voice transcription failed: {}", e)
else:
logger.error(
"Voice transcription failed: exc_type={}", type(e).__name__
)
await message.reply(
"Could not transcribe voice note. Please try again or send text."
)
@ -285,8 +308,10 @@ class DiscordPlatform(MessagingPlatform):
else None
)
text_preview = (message.content or "")[:80]
if len(message.content or "") > 80:
raw_content = message.content or ""
if self._log_raw_messaging_content:
text_preview = raw_content[:80]
if len(raw_content) > 80:
text_preview += "..."
logger.info(
"DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
@ -295,6 +320,14 @@ class DiscordPlatform(MessagingPlatform):
reply_to,
text_preview,
)
else:
logger.info(
"DISCORD_MSG: chat_id={} message_id={} reply_to={} text_len={}",
channel_id,
message_id,
reply_to,
len(raw_content),
)
if not self._message_handler:
return
@ -313,13 +346,14 @@ class DiscordPlatform(MessagingPlatform):
try:
await self._message_handler(incoming)
except Exception as e:
logger.error(f"Error handling message: {e}")
if self._log_api_error_tracebacks:
logger.error("Error handling message: {}", e)
else:
logger.error("Error handling message: exc_type={}", type(e).__name__)
with contextlib.suppress(Exception):
await self.send_message(
channel_id,
format_status_discord(
"Error:", get_user_facing_error_message(e)[:200]
),
format_status_discord("Error:", format_user_error_preview(e)),
reply_to=message_id,
)
@ -336,7 +370,10 @@ class DiscordPlatform(MessagingPlatform):
from ..limiter import MessagingRateLimiter
self._limiter = await MessagingRateLimiter.get_instance()
self._limiter = await MessagingRateLimiter.get_instance(
rate_limit=self._messaging_rate_limit,
rate_window=self._messaging_rate_window,
)
self._start_task = asyncio.create_task(
self._client.start(self.bot_token),

View file

@ -28,6 +28,10 @@ class MessagingPlatformOptions:
whisper_device: str = "cpu"
hf_token: str = ""
nvidia_nim_api_key: str = ""
messaging_rate_limit: int = 1
messaging_rate_window: float = 1.0
log_raw_messaging_content: bool = False
log_api_error_tracebacks: bool = False
def create_messaging_platform(
@ -64,6 +68,10 @@ def create_messaging_platform(
whisper_device=opts.whisper_device,
hf_token=opts.hf_token,
nvidia_nim_api_key=opts.nvidia_nim_api_key,
messaging_rate_limit=opts.messaging_rate_limit,
messaging_rate_window=opts.messaging_rate_window,
log_raw_messaging_content=opts.log_raw_messaging_content,
log_api_error_tracebacks=opts.log_api_error_tracebacks,
)
if platform_type == "discord":
@ -82,6 +90,10 @@ def create_messaging_platform(
whisper_device=opts.whisper_device,
hf_token=opts.hf_token,
nvidia_nim_api_key=opts.nvidia_nim_api_key,
messaging_rate_limit=opts.messaging_rate_limit,
messaging_rate_window=opts.messaging_rate_window,
log_raw_messaging_content=opts.log_raw_messaging_content,
log_api_error_tracebacks=opts.log_api_error_tracebacks,
)
logger.warning(

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any
from loguru import logger
from core.anthropic import get_user_facing_error_message
from core.anthropic import format_user_error_preview
if TYPE_CHECKING:
from telegram import Update
@ -68,14 +68,18 @@ class TelegramPlatform(MessagingPlatform):
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
messaging_rate_limit: int = 1,
messaging_rate_window: float = 1.0,
log_raw_messaging_content: bool = False,
log_api_error_tracebacks: bool = False,
):
if not TELEGRAM_AVAILABLE:
raise ImportError(
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
)
self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
self.bot_token = bot_token
self.allowed_user_id = allowed_user_id
if not self.bot_token:
# We don't raise here to allow instantiation for testing/conditional logic,
@ -97,6 +101,10 @@ class TelegramPlatform(MessagingPlatform):
self._voice_note_enabled = voice_note_enabled
self._whisper_model = whisper_model
self._whisper_device = whisper_device
self._messaging_rate_limit = messaging_rate_limit
self._messaging_rate_window = messaging_rate_window
self._log_raw_messaging_content = log_raw_messaging_content
self._log_api_error_tracebacks = log_api_error_tracebacks
async def _register_pending_voice(
self, chat_id: str, voice_msg_id: str, status_msg_id: str
@ -172,7 +180,10 @@ class TelegramPlatform(MessagingPlatform):
# Initialize rate limiter
from ..limiter import MessagingRateLimiter
self._limiter = await MessagingRateLimiter.get_instance()
self._limiter = await MessagingRateLimiter.get_instance(
rate_limit=self._messaging_rate_limit,
rate_window=self._messaging_rate_window,
)
# Send startup notification
try:
@ -187,7 +198,13 @@ class TelegramPlatform(MessagingPlatform):
startup_text,
)
except Exception as e:
logger.warning(f"Could not send startup message: {e}")
if self._log_api_error_tracebacks:
logger.warning("Could not send startup message: {}", e)
else:
logger.warning(
"Could not send startup message: exc_type={}",
type(e).__name__,
)
logger.info("Telegram platform started (Bot API)")
@ -506,8 +523,10 @@ class TelegramPlatform(MessagingPlatform):
if getattr(update.message, "message_thread_id", None) is not None
else None
)
text_preview = (update.message.text or "")[:80]
if len(update.message.text or "") > 80:
raw_text = update.message.text or ""
if self._log_raw_messaging_content:
text_preview = raw_text[:80]
if len(raw_text) > 80:
text_preview += "..."
logger.info(
"TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
@ -516,6 +535,14 @@ class TelegramPlatform(MessagingPlatform):
reply_to,
text_preview,
)
else:
logger.info(
"TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_len={}",
chat_id,
message_id,
reply_to,
len(raw_text),
)
if not self._message_handler:
return
@ -534,11 +561,14 @@ class TelegramPlatform(MessagingPlatform):
try:
await self._message_handler(incoming)
except Exception as e:
logger.error(f"Error handling message: {e}")
if self._log_api_error_tracebacks:
logger.error("Error handling message: {}", e)
else:
logger.error("Error handling message: exc_type={}", type(e).__name__)
with contextlib.suppress(Exception):
await self.send_message(
chat_id,
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}",
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(format_user_error_preview(e))}",
reply_to=incoming.message_id,
message_thread_id=thread_id,
parse_mode="MarkdownV2",
@ -631,20 +661,37 @@ class TelegramPlatform(MessagingPlatform):
status_message_id=status_msg_id,
)
if self._log_raw_messaging_content:
logger.info(
"TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}",
chat_id,
message_id,
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
(
transcribed[:80] + "..."
if len(transcribed) > 80
else transcribed
),
)
else:
logger.info(
"TELEGRAM_VOICE: chat_id={} message_id={} transcribed_len={}",
chat_id,
message_id,
len(transcribed),
)
await self._message_handler(incoming)
except ValueError as e:
await update.message.reply_text(get_user_facing_error_message(e)[:200])
await update.message.reply_text(format_user_error_preview(e))
except ImportError as e:
await update.message.reply_text(get_user_facing_error_message(e)[:200])
await update.message.reply_text(format_user_error_preview(e))
except Exception as e:
logger.error(f"Voice transcription failed: {e}")
if self._log_api_error_tracebacks:
logger.error("Voice transcription failed: {}", e)
else:
logger.error(
"Voice transcription failed: exc_type={}", type(e).__name__
)
await update.message.reply_text(
"Could not transcribe voice note. Please try again or send text."
)

View file

@ -0,0 +1,17 @@
"""Helpers for redacting user-derived content from log lines."""
from __future__ import annotations
def format_exception_for_log(exc: BaseException, *, log_full_message: bool) -> str:
"""Return exception type and optionally ``str(exc)`` for operator diagnostics."""
if log_full_message:
return f"{type(exc).__name__}: {exc}"
return type(exc).__name__
def text_len_hint(text: str | None) -> int:
"""Length of text for metadata-only logging (0 when missing)."""
if not text:
return 0
return len(text)

View file

@ -5,8 +5,10 @@ Provides persistent storage for mapping platform messages to Claude CLI session
and message trees for conversation continuation.
"""
import contextlib
import json
import os
import tempfile
import threading
from datetime import UTC, datetime
from typing import Any
@ -22,7 +24,12 @@ class SessionStore:
Platform-agnostic: works with any messaging platform.
"""
def __init__(self, storage_path: str = "sessions.json"):
def __init__(
self,
storage_path: str = "sessions.json",
*,
message_log_cap: int | None = None,
):
self.storage_path = storage_path
self._lock = threading.Lock()
self._trees: dict[str, dict] = {} # root_id -> tree data
@ -34,11 +41,7 @@ class SessionStore:
self._dirty = False
self._save_timer: threading.Timer | None = None
self._save_debounce_secs = 0.5
cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
try:
self._message_log_cap: int | None = int(cap_raw) if cap_raw else None
except ValueError:
self._message_log_cap = None
self._message_log_cap: int | None = message_log_cap
self._load()
def _make_chat_key(self, platform: str, chat_id: str) -> str:
@ -104,9 +107,22 @@ class SessionStore:
}
def _write_data(self, data: dict) -> None:
"""Write data dict to disk. Must be called WITHOUT holding self._lock."""
with open(self.storage_path, "w", encoding="utf-8") as f:
"""Atomically write data dict to disk. Must be called WITHOUT holding self._lock."""
abs_target = os.path.abspath(self.storage_path)
dir_name = os.path.dirname(abs_target) or "."
fd, tmp_path = tempfile.mkstemp(
dir=dir_name, prefix=".sessions.", suffix=".tmp.json"
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, abs_target)
except BaseException:
with contextlib.suppress(OSError):
os.unlink(tmp_path)
raise
def _schedule_save(self) -> None:
"""Schedule a debounced save. Caller must hold self._lock."""

View file

@ -9,7 +9,6 @@ the transcript grows over time and older content must be truncated.
from __future__ import annotations
import json
import os
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Callable, Iterable
@ -210,7 +209,12 @@ class RenderCtx:
class TranscriptBuffer:
"""Maintains an ordered, truncatable transcript of events."""
def __init__(self, *, show_tool_results: bool = True) -> None:
def __init__(
self,
*,
show_tool_results: bool = True,
debug_subagent_stack: bool = False,
) -> None:
self._segments: list[Segment] = []
self._open_thinking_by_index: dict[int, ThinkingSegment] = {}
self._open_text_by_index: dict[int, TextSegment] = {}
@ -227,7 +231,7 @@ class TranscriptBuffer:
self._subagent_stack: list[str] = []
# Parallel stack of segments for rendering nested subagents.
self._subagent_segments: list[SubagentSegment] = []
self._debug_subagent_stack = os.getenv("DEBUG_SUBAGENT_STACK") == "1"
self._debug_subagent_stack = debug_subagent_stack
def _in_subagent(self) -> bool:
return bool(self._subagent_stack)

View file

@ -5,32 +5,18 @@ Supports:
- NVIDIA NIM: NVIDIA NIM Whisper/Parakeet
"""
import os
from pathlib import Path
from typing import Any
from loguru import logger
from config.settings import get_settings
from providers.nvidia_nim.voice import (
transcribe_audio_file as transcribe_nvidia_nim_audio,
)
# Max file size in bytes (25 MB)
MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024
# NVIDIA NIM Whisper model mapping: (function_id, language_code)
_NIM_MODEL_MAP: dict[str, tuple[str, str]] = {
"nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"),
"nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"),
"nvidia/parakeet-ctc-0.6b-es": ("None", "es-US"),
"nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"),
"nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"),
"nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"),
"nvidia/parakeet-1.1b-rnnt-multilingual-asr": (
"71203149-d3b7-4460-8231-1be2543a1fca",
"",
),
"openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"),
}
# Short model names -> full Hugging Face model IDs (for local Whisper)
_MODEL_MAP: dict[str, str] = {
"tiny": "openai/whisper-tiny",
@ -51,22 +37,19 @@ def _resolve_model_id(whisper_model: str) -> str:
return _MODEL_MAP.get(whisper_model, whisper_model)
def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any:
def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any:
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
global _pipeline_cache
if device not in ("cpu", "cuda"):
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
resolved_token = (
hf_token if hf_token is not None else get_settings().hf_token
) or ""
resolved_token = hf_token or ""
cache_key = (model_id, device, resolved_token)
if cache_key not in _pipeline_cache:
try:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
if resolved_token:
os.environ["HF_TOKEN"] = resolved_token
hf_auth_token = resolved_token or None
use_cuda = device == "cuda" and torch.cuda.is_available()
pipe_device = "cuda:0" if use_cuda else "cpu"
@ -77,9 +60,10 @@ def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> An
dtype=model_dtype,
low_cpu_mem_usage=True,
attn_implementation="sdpa",
token=hf_auth_token,
)
model = model.to(pipe_device)
processor = AutoProcessor.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token)
pipe = pipeline(
"automatic-speech-recognition",
@ -119,7 +103,7 @@ def transcribe_audio(
file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported)
mime_type: MIME type of the audio (e.g. "audio/ogg")
whisper_model: Model ID or short name (local) or NVIDIA NIM model
whisper_device: "cpu" | "cuda" | "nvidia_nim" (defaults to WHISPER_DEVICE env var)
whisper_device: "cpu" | "cuda" | "nvidia_nim"
Returns:
Transcribed text
@ -140,8 +124,8 @@ def transcribe_audio(
)
if whisper_device == "nvidia_nim":
return _transcribe_nim(
file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key
return transcribe_nvidia_nim_audio(
file_path, whisper_model, api_key=nvidia_nim_api_key
)
return _transcribe_local(
file_path, whisper_model, whisper_device, hf_token=hf_token
@ -169,8 +153,7 @@ def _transcribe_local(
) -> str:
"""Transcribe using transformers Whisper pipeline."""
model_id = _resolve_model_id(whisper_model)
token: str | None = hf_token if hf_token else None
pipe = _get_pipeline(model_id, whisper_device, hf_token=token)
pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token)
audio = _load_audio(file_path)
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
text = result.get("text", "") or ""
@ -179,65 +162,3 @@ def _transcribe_local(
result_text = text.strip()
logger.debug(f"Local transcription: {len(result_text)} chars")
return result_text or "(no speech detected)"
def _transcribe_nim(
file_path: Path, model: str, *, nvidia_nim_api_key: str = ""
) -> str:
"""Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
try:
import riva.client
except ImportError as e:
raise ImportError(
"NVIDIA NIM transcription requires the voice extra. "
"Install with: uv sync --extra voice"
) from e
api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key
# Look up function ID and language code from model mapping
model_config = _NIM_MODEL_MAP.get(model)
if not model_config:
raise ValueError(
f"No NVIDIA NIM config found for model: {model}. "
f"Supported models: {', '.join(_NIM_MODEL_MAP.keys())}"
)
function_id, language_code = model_config
# Riva server configuration
server = "grpc.nvcf.nvidia.com:443"
# Auth with SSL and metadata
auth = riva.client.Auth(
use_ssl=True,
uri=server,
metadata_args=[
["function-id", function_id],
["authorization", f"Bearer {api_key}"],
],
)
asr_service = riva.client.ASRService(auth)
# Configure recognition - language_code from model config
config = riva.client.RecognitionConfig(
language_code=language_code,
max_alternatives=1,
verbatim_transcripts=True,
)
# Read audio file
with open(file_path, "rb") as f:
data = f.read()
# Perform offline recognition
response = asr_service.offline_recognize(data, config)
# Extract text from response - use getattr for safe attribute access
transcript = ""
results = getattr(response, "results", None)
if results and results[0].alternatives:
transcript = results[0].alternatives[0].transcript
logger.debug(f"NIM transcription: {len(transcript)} chars")
return transcript or "(no speech detected)"

View file

@ -1,165 +0,0 @@
"""Async queue processor for message trees.
Handles the async processing lifecycle of tree nodes.
"""
import asyncio
from collections.abc import Awaitable, Callable
from loguru import logger
from core.anthropic import get_user_facing_error_message
from .data import MessageNode, MessageState, MessageTree
class TreeQueueProcessor:
"""
Handles async queue processing for a single tree.
Separates the async processing logic from the data management.
"""
def __init__(
self,
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
| None = None,
):
self._queue_update_callback = queue_update_callback
self._node_started_callback = node_started_callback
def set_queue_update_callback(
self,
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
) -> None:
"""Update the callback used to refresh queue positions."""
self._queue_update_callback = queue_update_callback
def set_node_started_callback(
self,
node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
) -> None:
"""Update the callback used when a queued node starts processing."""
self._node_started_callback = node_started_callback
async def _notify_queue_updated(self, tree: MessageTree) -> None:
"""Invoke queue update callback if set."""
if not self._queue_update_callback:
return
try:
await self._queue_update_callback(tree)
except Exception as e:
logger.warning(f"Queue update callback failed: {e}")
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
"""Invoke node started callback if set."""
if not self._node_started_callback:
return
try:
await self._node_started_callback(tree, node_id)
except Exception as e:
logger.warning(f"Node started callback failed: {e}")
async def process_node(
self,
tree: MessageTree,
node: MessageNode,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process a single node and then check the queue."""
# Skip if already in terminal state (e.g. from error propagation)
if node.state == MessageState.ERROR:
logger.info(
f"Skipping node {node.node_id} as it is already in state {node.state}"
)
# Still need to check for next messages
await self._process_next(tree, processor)
return
try:
await processor(node.node_id, node)
except asyncio.CancelledError:
logger.info(f"Task for node {node.node_id} was cancelled")
raise
except Exception as e:
logger.error(f"Error processing node {node.node_id}: {e}")
await tree.update_state(
node.node_id,
MessageState.ERROR,
error_message=get_user_facing_error_message(e),
)
finally:
async with tree.with_lock():
tree.clear_current_node()
# Check if there are more messages in the queue
await self._process_next(tree, processor)
async def _process_next(
self,
tree: MessageTree,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process the next message in queue, if any."""
next_node_id = None
node = None
async with tree.with_lock():
next_node_id = await tree.dequeue()
if not next_node_id:
tree.set_processing_state(None, False)
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
return
tree.set_processing_state(next_node_id, True)
logger.info(f"Processing next queued node {next_node_id}")
# Process next node (outside lock)
node = tree.get_node(next_node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
# Notify that this node has started processing and refresh queue positions.
if next_node_id:
await self._notify_node_started(tree, next_node_id)
await self._notify_queue_updated(tree)
async def enqueue_and_start(
self,
tree: MessageTree,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node or start processing immediately.
Args:
tree: The message tree
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
async with tree.with_lock():
if tree.is_processing:
tree.put_queue_unlocked(node_id)
queue_size = tree.get_queue_size()
logger.info(f"Queued node {node_id}, position {queue_size}")
return True
else:
tree.set_processing_state(node_id, True)
# Process outside the lock
node = tree.get_node(node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
return False
def cancel_current(self, tree: MessageTree) -> bool:
"""Cancel the currently running task in a tree."""
return tree.cancel_current_task()

View file

@ -1,30 +1,345 @@
"""Tree-Based Message Queue Manager - Refactored.
Coordinates data access, async processing, and error handling.
Uses TreeRepository for data, TreeQueueProcessor for async logic.
"""
"""Tree-based message queue: index, async node processor, and public manager API."""
import asyncio
from collections.abc import Awaitable, Callable
from loguru import logger
from config.settings import get_settings
from core.anthropic import get_user_facing_error_message
from ..models import IncomingMessage
from ..safe_diagnostics import format_exception_for_log
from .data import MessageNode, MessageState, MessageTree
from .processor import TreeQueueProcessor
from .repository import TreeRepository
class TreeRepository:
"""
In-memory index of trees and node-to-root mappings.
Used only by :class:`TreeQueueManager`; kept as a named type for tests.
"""
def __init__(self) -> None:
self._trees: dict[str, MessageTree] = {} # root_id -> tree
self._node_to_tree: dict[str, str] = {} # node_id -> root_id
def get_tree(self, root_id: str) -> MessageTree | None:
"""Get a tree by its root ID."""
return self._trees.get(root_id)
def get_tree_for_node(self, node_id: str) -> MessageTree | None:
"""Get the tree containing a given node."""
root_id = self._node_to_tree.get(node_id)
if not root_id:
return None
return self._trees.get(root_id)
def get_node(self, node_id: str) -> MessageNode | None:
"""Get a node from any tree."""
tree = self.get_tree_for_node(node_id)
return tree.get_node(node_id) if tree else None
def add_tree(self, root_id: str, tree: MessageTree) -> None:
"""Add a new tree to the repository."""
self._trees[root_id] = tree
self._node_to_tree[root_id] = root_id
logger.debug("TREE_REPO: add_tree root_id={}", root_id)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree."""
self._node_to_tree[node_id] = root_id
logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
def has_node(self, node_id: str) -> bool:
"""Check if a node is registered in any tree."""
return node_id in self._node_to_tree
def tree_count(self) -> int:
"""Get the number of trees in the repository."""
return len(self._trees)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
tree = self._trees.get(root_id)
return tree.is_processing if tree else False
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
tree = self.get_tree_for_node(node_id)
return tree.is_processing if tree else False
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
tree = self.get_tree_for_node(node_id)
return tree.get_queue_size() if tree else 0
def resolve_parent_node_id(self, msg_id: str) -> str | None:
"""
Resolve a message ID to the actual parent node ID.
Handles the case where msg_id is a status message ID
(which maps to the tree but isn't an actual node).
Returns:
The node_id to use as parent, or None if not found
"""
tree = self.get_tree_for_node(msg_id)
if not tree:
return None
if tree.has_node(msg_id):
return msg_id
node = tree.find_node_by_status_message(msg_id)
if node:
return node.node_id
return None
def get_pending_children(self, node_id: str) -> list[MessageNode]:
"""
Get all pending child nodes (recursively) of a given node.
Used for error propagation - when a node fails, its pending
children should also be marked as failed.
"""
tree = self.get_tree_for_node(node_id)
if not tree:
return []
pending: list[MessageNode] = []
stack = [node_id]
while stack:
current_id = stack.pop()
node = tree.get_node(current_id)
if not node:
continue
for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
stack.append(child_id)
return pending
def all_trees(self) -> list[MessageTree]:
"""Get all trees in the repository."""
return list(self._trees.values())
def tree_ids(self) -> list[str]:
"""Get all tree root IDs."""
return list(self._trees.keys())
def unregister_nodes(self, node_ids: list[str]) -> None:
"""Remove node IDs from the node-to-tree mapping."""
for nid in node_ids:
self._node_to_tree.pop(nid, None)
def remove_tree(self, root_id: str) -> MessageTree | None:
"""
Remove a tree and all its node mappings from the repository.
Returns:
The removed tree, or None if not found.
"""
tree = self._trees.pop(root_id, None)
if not tree:
return None
for node in tree.all_nodes():
self._node_to_tree.pop(node.node_id, None)
logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
return tree
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
"""Get all message IDs (incoming + status) for a given platform/chat."""
msg_ids: set[str] = set()
for tree in self._trees.values():
for node in tree.all_nodes():
if str(node.incoming.platform) == str(platform) and str(
node.incoming.chat_id
) == str(chat_id):
if node.incoming.message_id is not None:
msg_ids.add(str(node.incoming.message_id))
if node.status_message_id:
msg_ids.add(str(node.status_message_id))
return msg_ids
def to_dict(self) -> dict:
"""Serialize all trees."""
return {
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
"node_to_tree": self._node_to_tree.copy(),
}
@classmethod
def from_dict(cls, data: dict) -> TreeRepository:
"""Deserialize from dictionary."""
repo = cls()
for root_id, tree_data in data.get("trees", {}).items():
repo._trees[root_id] = MessageTree.from_dict(tree_data)
repo._node_to_tree = data.get("node_to_tree", {})
return repo
class TreeQueueProcessor:
"""
Per-tree async queue processing (one manager owns one processor instance).
"""
def __init__(
self,
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
| None = None,
) -> None:
self._queue_update_callback = queue_update_callback
self._node_started_callback = node_started_callback
def set_queue_update_callback(
self,
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
) -> None:
"""Update the callback used to refresh queue positions."""
self._queue_update_callback = queue_update_callback
def set_node_started_callback(
self,
node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
) -> None:
"""Update the callback used when a queued node starts processing."""
self._node_started_callback = node_started_callback
async def _notify_queue_updated(self, tree: MessageTree) -> None:
"""Invoke queue update callback if set."""
if not self._queue_update_callback:
return
try:
await self._queue_update_callback(tree)
except Exception as e:
d = get_settings().log_messaging_error_details
logger.warning(
"Queue update callback failed: {}",
format_exception_for_log(e, log_full_message=d),
)
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
"""Invoke node started callback if set."""
if not self._node_started_callback:
return
try:
await self._node_started_callback(tree, node_id)
except Exception as e:
d = get_settings().log_messaging_error_details
logger.warning(
"Node started callback failed: {}",
format_exception_for_log(e, log_full_message=d),
)
async def process_node(
self,
tree: MessageTree,
node: MessageNode,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process a single node and then check the queue."""
if node.state == MessageState.ERROR:
logger.info(
f"Skipping node {node.node_id} as it is already in state {node.state}"
)
await self._process_next(tree, processor)
return
try:
await processor(node.node_id, node)
except asyncio.CancelledError:
logger.info(f"Task for node {node.node_id} was cancelled")
raise
except Exception as e:
d = get_settings().log_messaging_error_details
logger.error(
"Error processing node {}: {}",
node.node_id,
format_exception_for_log(e, log_full_message=d),
)
await tree.update_state(
node.node_id,
MessageState.ERROR,
error_message=get_user_facing_error_message(e),
)
finally:
async with tree.with_lock():
tree.clear_current_node()
await self._process_next(tree, processor)
async def _process_next(
self,
tree: MessageTree,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process the next message in queue, if any."""
next_node_id = None
async with tree.with_lock():
next_node_id = await tree.dequeue()
if not next_node_id:
tree.set_processing_state(None, False)
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
return
tree.set_processing_state(next_node_id, True)
logger.info(f"Processing next queued node {next_node_id}")
node = tree.get_node(next_node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
if next_node_id:
await self._notify_node_started(tree, next_node_id)
await self._notify_queue_updated(tree)
async def enqueue_and_start(
self,
tree: MessageTree,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node or start processing immediately.
Returns:
True if queued, False if processing immediately
"""
async with tree.with_lock():
if tree.is_processing:
tree.put_queue_unlocked(node_id)
queue_size = tree.get_queue_size()
logger.info(f"Queued node {node_id}, position {queue_size}")
return True
else:
tree.set_processing_state(node_id, True)
node = tree.get_node(node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
return False
def cancel_current(self, tree: MessageTree) -> bool:
"""Cancel the currently running task in a tree."""
return tree.cancel_current_task()
class TreeQueueManager:
"""
Manages multiple message trees. Facade that coordinates components.
Manages multiple message trees: index + async processing.
Each new conversation creates a new tree.
Replies to existing messages add nodes to existing trees.
Components:
- TreeRepository: Data access layer
- TreeQueueProcessor: Async queue processing
"""
def __init__(
@ -33,7 +348,7 @@ class TreeQueueManager:
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
| None = None,
_repository: TreeRepository | None = None,
):
) -> None:
self._repository = _repository or TreeRepository()
self._processor = TreeQueueProcessor(
queue_update_callback=queue_update_callback,
@ -101,7 +416,6 @@ class TreeQueueManager:
if not tree:
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
node = await tree.add_node(
node_id=node_id,
incoming=incoming,
@ -228,7 +542,6 @@ class TreeQueueManager:
cleanup_count = 0
async with tree.with_lock():
# 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
@ -240,12 +553,10 @@ class TreeQueueManager:
tree.set_node_error_sync(node, "Cancelled by user")
cancelled_nodes.append(node)
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
@ -269,10 +580,6 @@ class TreeQueueManager:
"""
Cancel a single node (queued or in-progress) without affecting other nodes.
- If the node is currently running, cancels the current asyncio task.
- If the node is queued, removes it from the queue.
- Marks the node as ERROR with "Cancelled by user".
Returns:
List containing the cancelled node if it was cancellable, else empty list.
"""
@ -351,8 +658,6 @@ class TreeQueueManager:
async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]:
"""
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
@ -435,3 +740,10 @@ class TreeQueueManager:
node_started_callback=node_started_callback,
_repository=TreeRepository.from_dict(data),
)
__all__ = [
"TreeQueueManager",
"TreeQueueProcessor",
"TreeRepository",
]

View file

@ -1,186 +0,0 @@
"""Repository for message tree data access.
Provides data access layer for managing trees and node mappings.
"""
from loguru import logger
from .data import MessageNode, MessageState, MessageTree
class TreeRepository:
"""
Repository for message tree data access.
Manages the storage and lookup of trees and node-to-tree mappings.
"""
def __init__(self):
self._trees: dict[str, MessageTree] = {} # root_id -> tree
self._node_to_tree: dict[str, str] = {} # node_id -> root_id
def get_tree(self, root_id: str) -> MessageTree | None:
"""Get a tree by its root ID."""
return self._trees.get(root_id)
def get_tree_for_node(self, node_id: str) -> MessageTree | None:
"""Get the tree containing a given node."""
root_id = self._node_to_tree.get(node_id)
if not root_id:
return None
return self._trees.get(root_id)
def get_node(self, node_id: str) -> MessageNode | None:
"""Get a node from any tree."""
tree = self.get_tree_for_node(node_id)
return tree.get_node(node_id) if tree else None
def add_tree(self, root_id: str, tree: MessageTree) -> None:
"""Add a new tree to the repository."""
self._trees[root_id] = tree
self._node_to_tree[root_id] = root_id
logger.debug("TREE_REPO: add_tree root_id={}", root_id)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree."""
self._node_to_tree[node_id] = root_id
logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
def has_node(self, node_id: str) -> bool:
"""Check if a node is registered in any tree."""
return node_id in self._node_to_tree
def tree_count(self) -> int:
"""Get the number of trees in the repository."""
return len(self._trees)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
tree = self._trees.get(root_id)
return tree.is_processing if tree else False
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
tree = self.get_tree_for_node(node_id)
return tree.is_processing if tree else False
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
tree = self.get_tree_for_node(node_id)
return tree.get_queue_size() if tree else 0
def resolve_parent_node_id(self, msg_id: str) -> str | None:
"""
Resolve a message ID to the actual parent node ID.
Handles the case where msg_id is a status message ID
(which maps to the tree but isn't an actual node).
Returns:
The node_id to use as parent, or None if not found
"""
tree = self.get_tree_for_node(msg_id)
if not tree:
return None
# Check if msg_id is an actual node
if tree.has_node(msg_id):
return msg_id
# Otherwise, it might be a status message - find the owning node
node = tree.find_node_by_status_message(msg_id)
if node:
return node.node_id
return None
def get_pending_children(self, node_id: str) -> list[MessageNode]:
"""
Get all pending child nodes (recursively) of a given node.
Used for error propagation - when a node fails, its pending
children should also be marked as failed.
"""
tree = self.get_tree_for_node(node_id)
if not tree:
return []
pending: list[MessageNode] = []
stack = [node_id]
while stack:
current_id = stack.pop()
node = tree.get_node(current_id)
if not node:
continue
for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
stack.append(child_id)
return pending
def all_trees(self) -> list[MessageTree]:
"""Get all trees in the repository."""
return list(self._trees.values())
def tree_ids(self) -> list[str]:
"""Get all tree root IDs."""
return list(self._trees.keys())
def unregister_nodes(self, node_ids: list[str]) -> None:
"""Remove node IDs from the node-to-tree mapping."""
for nid in node_ids:
self._node_to_tree.pop(nid, None)
def remove_tree(self, root_id: str) -> MessageTree | None:
"""
Remove a tree and all its node mappings from the repository.
Returns:
The removed tree, or None if not found.
"""
tree = self._trees.pop(root_id, None)
if not tree:
return None
for node in tree.all_nodes():
self._node_to_tree.pop(node.node_id, None)
logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
return tree
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
"""Get all message IDs (incoming + status) for a given platform/chat.
Note: O(total_nodes) scan. Acceptable because this is only called
from /clear (user-initiated, infrequent).
"""
msg_ids: set[str] = set()
for tree in self._trees.values():
for node in tree.all_nodes():
if str(node.incoming.platform) == str(platform) and str(
node.incoming.chat_id
) == str(chat_id):
if node.incoming.message_id is not None:
msg_ids.add(str(node.incoming.message_id))
if node.status_message_id:
msg_ids.add(str(node.status_message_id))
return msg_ids
def to_dict(self) -> dict:
"""Serialize all trees."""
return {
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
"node_to_tree": self._node_to_tree.copy(),
}
@classmethod
def from_dict(cls, data: dict) -> TreeRepository:
"""Deserialize from dictionary."""
from .data import MessageTree
repo = cls()
for root_id, tree_data in data.get("trees", {}).items():
repo._trees[root_id] = MessageTree.from_dict(tree_data)
repo._node_to_tree = data.get("node_to_tree", {})
return repo

101
messaging/ui_updates.py Normal file
View file

@ -0,0 +1,101 @@
"""Throttled platform UI updates driven by transcript rendering."""
from __future__ import annotations
import time
from collections.abc import Callable
from loguru import logger
from .platforms.base import MessagingPlatform
from .safe_diagnostics import format_exception_for_log
from .transcript import RenderCtx, TranscriptBuffer
class ThrottledTranscriptEditor:
"""Rate-limited status message edits from a growing transcript."""
def __init__(
self,
*,
platform: MessagingPlatform,
parse_mode: str | None,
get_limit_chars: Callable[[], int],
transcript: TranscriptBuffer,
render_ctx: RenderCtx,
node_id: str,
chat_id: str,
status_msg_id: str,
debug_platform_edits: bool,
log_messaging_error_details: bool = False,
) -> None:
self._platform = platform
self._parse_mode = parse_mode
self._get_limit_chars = get_limit_chars
self._transcript = transcript
self._render_ctx = render_ctx
self._node_id = node_id
self._chat_id = chat_id
self._status_msg_id = status_msg_id
self._debug_platform_edits = debug_platform_edits
self._log_messaging_error_details = log_messaging_error_details
self._last_ui_update = 0.0
self._last_displayed_text: str | None = None
self._last_status: str | None = None
@property
def last_status(self) -> str | None:
return self._last_status
async def update(self, status: str | None = None, *, force: bool = False) -> None:
"""Render transcript + optional status line and edit the platform message."""
now = time.time()
if not force and now - self._last_ui_update < 1.0:
return
self._last_ui_update = now
if status is not None:
self._last_status = status
try:
display = self._transcript.render(
self._render_ctx,
limit_chars=self._get_limit_chars(),
status=status,
)
except Exception as e:
logger.warning(
"Transcript render failed for node {}: {}",
self._node_id,
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)
return
if display and display != self._last_displayed_text:
logger.debug(
"PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
self._node_id,
self._chat_id,
self._status_msg_id,
bool(force),
status,
len(display),
)
if self._debug_platform_edits:
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
self._last_displayed_text = display
try:
await self._platform.queue_edit_message(
self._chat_id,
self._status_msg_id,
display,
parse_mode=self._parse_mode,
)
except Exception as e:
logger.warning(
"Failed to update platform for node {}: {}",
self._node_id,
format_exception_for_log(
e, log_full_message=self._log_messaging_error_details
),
)

View file

@ -2,19 +2,32 @@
from __future__ import annotations
import json
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal
import httpx
from loguru import logger
from core.anthropic import get_user_facing_error_message
from config.constants import (
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES,
)
from core.anthropic import iter_provider_stream_error_sse_events
from core.anthropic.emitted_sse_tracker import EmittedNativeSseTracker
from core.anthropic.native_messages_request import (
build_base_native_anthropic_request_body,
)
from core.anthropic.native_sse_block_policy import (
NativeSseBlockPolicyState,
transform_native_sse_block_event,
)
from providers.base import BaseProvider, ProviderConfig
from providers.error_mapping import map_error
from providers.error_mapping import (
map_error,
user_visible_message_for_mapped_provider_error,
)
from providers.rate_limit import GlobalRateLimiter
ANTHROPIC_DEFAULT_MAX_TOKENS = 81920
StreamChunkMode = Literal["line", "event"]
@ -64,25 +77,11 @@ class AnthropicMessagesTransport(BaseProvider):
) -> dict:
"""Build a native Anthropic request body."""
thinking_enabled = self._is_thinking_enabled(request, thinking_enabled)
body = request.model_dump(exclude_none=True)
body.pop("extra_body", None)
body.pop("original_model", None)
body.pop("resolved_provider_model", None)
if "thinking" in body:
thinking_cfg = body.pop("thinking")
if thinking_enabled and isinstance(thinking_cfg, dict):
thinking_payload = {"type": "enabled"}
budget_tokens = thinking_cfg.get("budget_tokens")
if isinstance(budget_tokens, int):
thinking_payload["budget_tokens"] = budget_tokens
body["thinking"] = thinking_payload
if "max_tokens" not in body:
body["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS
return body
return build_base_native_anthropic_request_body(
request,
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
thinking_enabled=thinking_enabled,
)
async def _send_stream_request(self, body: dict) -> httpx.Response:
"""Create a streaming messages response."""
@ -97,30 +96,68 @@ class AnthropicMessagesTransport(BaseProvider):
async def _raise_for_status(
self, response: httpx.Response, *, req_tag: str
) -> None:
"""Raise for non-200 responses after logging the upstream body."""
"""Raise for non-200 responses after logging safe metadata (or capped body if opted in)."""
try:
response.raise_for_status()
except httpx.HTTPStatusError as error:
response_text = await self._read_error_body(response)
if response_text:
if self._config.log_api_error_tracebacks:
preview, truncated = await self._read_error_body_preview(
response, NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES
)
if preview:
text = preview.decode("utf-8", errors="replace")
logger.error(
"{}_ERROR:{} HTTP {}: {}",
"{}_ERROR:{} HTTP {} body_preview_bytes={} truncated={}: {}",
self._provider_name,
req_tag,
response.status_code,
response_text,
len(preview),
truncated,
text,
)
else:
logger.error(
"{}_ERROR:{} HTTP {} (empty error body)",
self._provider_name,
req_tag,
response.status_code,
)
else:
cl = response.headers.get("content-length", "").strip()
extra = f" content_length_declared={cl}" if cl.isdigit() else ""
logger.error(
"{}_ERROR:{} HTTP {}{}",
self._provider_name,
req_tag,
response.status_code,
extra,
)
raise error
async def _read_error_body(self, response: httpx.Response) -> str:
"""Read a response body for diagnostics."""
aread = getattr(response, "aread", None)
if aread is None:
return ""
body = await aread()
if isinstance(body, bytes):
return body.decode("utf-8", errors="replace")
return str(body)
async def _read_error_body_preview(
self, response: httpx.Response, max_bytes: int
) -> tuple[bytes, bool]:
"""Read at most ``max_bytes`` from the error body for logging. Returns (preview, truncated)."""
if max_bytes <= 0:
return b"", False
received = 0
parts: list[bytes] = []
truncated = False
async for chunk in response.aiter_bytes(chunk_size=65_536):
if received >= max_bytes:
truncated = True
break
remaining = max_bytes - received
take = chunk if len(chunk) <= remaining else chunk[:remaining]
if take:
parts.append(take)
received += len(take)
if len(chunk) > len(take):
truncated = True
break
if received >= max_bytes:
break
return (b"".join(parts), truncated)
async def _iter_sse_lines(self, response: httpx.Response) -> AsyncIterator[str]:
"""Yield raw SSE line chunks preserving local provider behavior."""
@ -145,6 +182,8 @@ class AnthropicMessagesTransport(BaseProvider):
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
"""Return per-stream provider state for event transformation."""
if self.stream_chunk_mode == "line":
return NativeSseBlockPolicyState()
return None
def _transform_stream_event(
@ -155,6 +194,10 @@ class AnthropicMessagesTransport(BaseProvider):
thinking_enabled: bool,
) -> str | None:
"""Transform or drop a grouped SSE event before yielding it downstream."""
if isinstance(state, NativeSseBlockPolicyState):
return transform_native_sse_block_event(
event, state, thinking_enabled=thinking_enabled
)
return event
def _format_error_message(self, base_message: str, request_id: str | None) -> str:
@ -166,14 +209,10 @@ class AnthropicMessagesTransport(BaseProvider):
def _get_error_message(self, error: Exception, request_id: str | None) -> str:
"""Map an exception into a user-facing provider error message."""
mapped_error = map_error(error, rate_limiter=self._global_rate_limiter)
if getattr(mapped_error, "status_code", None) == 405:
base_message = (
f"Upstream provider {self._provider_name} rejected the request method "
"or endpoint (HTTP 405)."
)
else:
base_message = get_user_facing_error_message(
mapped_error, read_timeout_s=self._config.http_read_timeout
base_message = user_visible_message_for_mapped_provider_error(
mapped_error,
provider_name=self._provider_name,
read_timeout_s=self._config.http_read_timeout,
)
return self._format_error_message(base_message, request_id)
@ -185,12 +224,14 @@ class AnthropicMessagesTransport(BaseProvider):
error_message: str,
sent_any_event: bool,
) -> Iterator[str]:
"""Emit a native Anthropic error event."""
error_event = {
"type": "error",
"error": {"type": "api_error", "message": error_message},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n"
"""Emit the same Anthropic message lifecycle used by OpenAI-compat providers."""
yield from iter_provider_stream_error_sse_events(
request=request,
input_tokens=input_tokens,
error_message=error_message,
sent_any_event=sent_any_event,
log_raw_sse_events=self._config.log_raw_sse_events,
)
async def _iter_stream_chunks(
self,
@ -200,6 +241,21 @@ class AnthropicMessagesTransport(BaseProvider):
thinking_enabled: bool,
) -> AsyncIterator[str]:
"""Yield stream chunks according to the provider's observable chunk shape."""
if self.stream_chunk_mode == "line" and isinstance(
state, NativeSseBlockPolicyState
):
async for event in self._iter_sse_events(response):
output_event = self._transform_stream_event(
event,
state,
thinking_enabled=thinking_enabled,
)
if output_event is None:
continue
for line in output_event.splitlines(keepends=True):
yield line
return
if self.stream_chunk_mode == "line":
async for chunk in self._iter_sse_lines(response):
yield chunk
@ -240,15 +296,28 @@ class AnthropicMessagesTransport(BaseProvider):
response: httpx.Response | None = None
sent_any_event = False
state = self._new_stream_state(request, thinking_enabled=thinking_enabled)
emitted_tracker = EmittedNativeSseTracker()
async with self._global_rate_limiter.concurrency_slot():
try:
response = await self._global_rate_limiter.execute_with_retry(
self._send_stream_request, body
)
if response.status_code != 200:
await self._raise_for_status(response, req_tag=req_tag)
async def _validated_stream_send() -> httpx.Response:
"""Send request; raise inside retry loop on 429 so rate limiter can backoff."""
send_response = await self._send_stream_request(body)
if send_response.status_code == 429:
await send_response.aclose()
send_response.raise_for_status()
if send_response.status_code != 200:
try:
await self._raise_for_status(send_response, req_tag=req_tag)
finally:
if not send_response.is_closed:
await send_response.aclose()
return send_response
response = await self._global_rate_limiter.execute_with_retry(
_validated_stream_send
)
async for chunk in self._iter_stream_chunks(
response,
@ -256,12 +325,12 @@ class AnthropicMessagesTransport(BaseProvider):
thinking_enabled=thinking_enabled,
):
sent_any_event = True
emitted_tracker.feed(chunk)
yield chunk
except Exception as error:
logger.error(
"{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
)
if not isinstance(error, httpx.HTTPStatusError):
self._log_stream_transport_error(tag, req_tag, error)
error_message = self._get_error_message(error, request_id)
if response is not None and not response.is_closed:
@ -273,11 +342,22 @@ class AnthropicMessagesTransport(BaseProvider):
type(error).__name__,
req_tag,
)
if sent_any_event:
for event in emitted_tracker.iter_close_unclosed_blocks():
yield event
for event in emitted_tracker.iter_midstream_error_tail(
error_message,
request=request,
input_tokens=input_tokens,
log_raw_sse_events=self._config.log_raw_sse_events,
):
yield event
else:
for event in self._emit_error_events(
request=request,
input_tokens=input_tokens,
error_message=error_message,
sent_any_event=sent_any_event,
sent_any_event=False,
):
yield event
return

View file

@ -6,6 +6,8 @@ from typing import Any
from pydantic import BaseModel
from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT
class ProviderConfig(BaseModel):
"""Configuration for a provider.
@ -21,9 +23,11 @@ class ProviderConfig(BaseModel):
max_concurrency: int = 5
http_read_timeout: float = 300.0
http_write_timeout: float = 10.0
http_connect_timeout: float = 2.0
http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT
enable_thinking: bool = True
proxy: str = ""
log_raw_sse_events: bool = False
log_api_error_tracebacks: bool = False
class BaseProvider(ABC):
@ -61,6 +65,42 @@ class BaseProvider(ABC):
request_enabled = bool(enabled)
return config_enabled and request_enabled
def preflight_stream(
self, request: Any, *, thinking_enabled: bool | None = None
) -> None:
"""Eagerly validate/build the upstream request before opening an SSE stream.
Subclasses with ``_build_request_body`` (OpenAI and native) raise
:class:`providers.exceptions.InvalidRequestError` on conversion failures.
"""
build = getattr(self, "_build_request_body", None)
if build is None:
return
build(request, thinking_enabled=thinking_enabled)
def _log_stream_transport_error(
self, tag: str, req_tag: str, error: Exception
) -> None:
"""Log streaming transport failures (metadata-only unless verbose is enabled)."""
from loguru import logger
if self._config.log_api_error_tracebacks:
logger.error(
"{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
)
return
response = getattr(error, "response", None)
status_code = (
getattr(response, "status_code", None) if response is not None else None
)
logger.error(
"{}_ERROR:{} exc_type={} http_status={}",
tag,
req_tag,
type(error).__name__,
status_code,
)
@abstractmethod
async def cleanup(self) -> None:
"""Release any resources held by this provider."""
@ -75,5 +115,7 @@ class BaseProvider(ABC):
thinking_enabled: bool | None = None,
) -> AsyncIterator[str]:
"""Stream response in Anthropic SSE format."""
# Typing: abstract async generators need a yield for AsyncIterator[str]
# inference; this branch is never executed.
if False:
yield "" # Required for ty/mypy to accept abstract async generator
yield ""

View file

@ -1,5 +1,7 @@
"""DeepSeek provider exports."""
from .client import DEEPSEEK_BASE_URL, DeepSeekProvider
from providers.defaults import DEEPSEEK_DEFAULT_BASE
__all__ = ["DEEPSEEK_BASE_URL", "DeepSeekProvider"]
from .client import DeepSeekProvider
__all__ = ["DEEPSEEK_DEFAULT_BASE", "DeepSeekProvider"]

View file

@ -3,7 +3,7 @@
from typing import Any
from providers.base import ProviderConfig
from providers.defaults import DEEPSEEK_BASE_URL
from providers.defaults import DEEPSEEK_DEFAULT_BASE
from providers.openai_compat import OpenAIChatTransport
from .request import build_request_body
@ -16,7 +16,7 @@ class DeepSeekProvider(OpenAIChatTransport):
super().__init__(
config,
provider_name="DEEPSEEK",
base_url=config.base_url or DEEPSEEK_BASE_URL,
base_url=config.base_url or DEEPSEEK_DEFAULT_BASE,
api_key=config.api_key,
)

View file

@ -5,6 +5,8 @@ from typing import Any
from loguru import logger
from core.anthropic import build_base_request_body
from core.anthropic.conversion import OpenAIConversionError
from providers.exceptions import InvalidRequestError
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
@ -14,10 +16,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
try:
body = build_base_request_body(
request_data,
include_reasoning_content=True,
include_thinking=thinking_enabled,
include_reasoning_content=thinking_enabled,
)
except OpenAIConversionError as exc:
raise InvalidRequestError(str(exc)) from exc
extra_body: dict[str, Any] = {}
request_extra = getattr(request_data, "extra_body", None)

View file

@ -1,21 +1,19 @@
"""Default upstream base URLs and shared provider constants.
"""Re-exports default upstream base URLs from the config provider catalog."""
Adapters and :mod:`providers.registry` import from here to avoid duplicating
literals and to keep ``providers.registry`` free of per-adapter eager imports.
"""
from config.provider_catalog import (
DEEPSEEK_DEFAULT_BASE,
LLAMACPP_DEFAULT_BASE,
LMSTUDIO_DEFAULT_BASE,
NVIDIA_NIM_DEFAULT_BASE,
OLLAMA_DEFAULT_BASE,
OPENROUTER_DEFAULT_BASE,
)
# OpenAI-compatible chat (NIM, DeepSeek) and local/native provider endpoints
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
OLLAMA_DEFAULT_BASE = "http://localhost:11434"
# Backward-compatible names used by existing adapter modules
NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE
DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE
OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE
LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE
LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE
OLLAMA_DEFAULT_BASE_URL = OLLAMA_DEFAULT_BASE
__all__ = (
"DEEPSEEK_DEFAULT_BASE",
"LLAMACPP_DEFAULT_BASE",
"LMSTUDIO_DEFAULT_BASE",
"NVIDIA_NIM_DEFAULT_BASE",
"OLLAMA_DEFAULT_BASE",
"OPENROUTER_DEFAULT_BASE",
)

View file

@ -14,10 +14,30 @@ from providers.exceptions import (
from providers.rate_limit import GlobalRateLimiter
def user_visible_message_for_mapped_provider_error(
mapped: Exception,
*,
provider_name: str,
read_timeout_s: float | None,
) -> str:
"""Return the user-visible string after :func:`map_error` (405 + mapped types)."""
if getattr(mapped, "status_code", None) == 405:
return (
f"Upstream provider {provider_name} rejected the request method "
"or endpoint (HTTP 405)."
)
return get_user_facing_error_message(mapped, read_timeout_s=read_timeout_s)
def map_error(
e: Exception, *, rate_limiter: GlobalRateLimiter | None = None
) -> Exception:
"""Map OpenAI or HTTPX exception to specific ProviderError."""
"""Map OpenAI or HTTPX exception to specific ProviderError.
Streaming transports should pass their scoped limiter (``self._global_rate_limiter``)
so reactive 429 handling applies to the correct provider. Tests may omit
``rate_limiter`` to use the process-wide singleton.
"""
message = get_user_facing_error_message(e)
limiter = rate_limiter or GlobalRateLimiter.get_instance()

View file

@ -90,8 +90,20 @@ class APIError(ProviderError):
)
class UnknownProviderTypeError(ValueError):
class UnknownProviderTypeError(InvalidRequestError):
"""Raised when ``provider_id`` is not registered in the provider map."""
def __init__(self, message: str) -> None:
super().__init__(message)
class ServiceUnavailableError(ProviderError):
"""Raised when the server is not ready (e.g. app lifespan did not wire state)."""
def __init__(self, message: str, raw_error: Any = None):
super().__init__(
message,
status_code=503,
error_type="api_error",
raw_error=raw_error,
)

View file

@ -2,7 +2,7 @@
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
from providers.defaults import LLAMACPP_DEFAULT_BASE_URL
from providers.defaults import LLAMACPP_DEFAULT_BASE
class LlamaCppProvider(AnthropicMessagesTransport):
@ -12,5 +12,5 @@ class LlamaCppProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="LLAMACPP",
default_base_url=LLAMACPP_DEFAULT_BASE_URL,
default_base_url=LLAMACPP_DEFAULT_BASE,
)

View file

@ -2,7 +2,7 @@
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL
from providers.defaults import LMSTUDIO_DEFAULT_BASE
class LMStudioProvider(AnthropicMessagesTransport):
@ -12,5 +12,5 @@ class LMStudioProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="LMSTUDIO",
default_base_url=LMSTUDIO_DEFAULT_BASE_URL,
default_base_url=LMSTUDIO_DEFAULT_BASE,
)

View file

@ -1,5 +1,7 @@
"""NVIDIA NIM provider package."""
from .client import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
__all__ = ["NVIDIA_NIM_BASE_URL", "NvidiaNimProvider"]
from .client import NvidiaNimProvider
__all__ = ["NVIDIA_NIM_DEFAULT_BASE", "NvidiaNimProvider"]

View file

@ -8,7 +8,7 @@ from loguru import logger
from config.nim import NimSettings
from providers.base import ProviderConfig
from providers.defaults import NVIDIA_NIM_BASE_URL
from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
from providers.openai_compat import OpenAIChatTransport
from .request import (
@ -25,7 +25,7 @@ class NvidiaNimProvider(OpenAIChatTransport):
super().__init__(
config,
provider_name="NIM",
base_url=config.base_url or NVIDIA_NIM_BASE_URL,
base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE,
api_key=config.api_key,
)
self._nim_settings = nim_settings

View file

@ -1,5 +1,6 @@
"""Request builder for NVIDIA NIM provider."""
from collections.abc import Callable
from copy import deepcopy
from typing import Any
@ -7,6 +8,42 @@ from loguru import logger
from config.nim import NimSettings
from core.anthropic import build_base_request_body, set_if_not_none
from core.anthropic.conversion import OpenAIConversionError
from providers.exceptions import InvalidRequestError
def _clone_strip_extra_body(
body: dict[str, Any],
strip: Callable[[dict[str, Any]], bool],
) -> dict[str, Any] | None:
"""Deep-clone ``body`` and remove fields via ``strip`` on ``extra_body`` only.
Returns ``None`` when there is no ``extra_body`` dict or ``strip`` reports no change.
"""
cloned_body = deepcopy(body)
extra_body = cloned_body.get("extra_body")
if not isinstance(extra_body, dict):
return None
if not strip(extra_body):
return None
if not extra_body:
cloned_body.pop("extra_body", None)
return cloned_body
def _strip_reasoning_budget_fields(extra_body: dict[str, Any]) -> bool:
removed = extra_body.pop("reasoning_budget", None) is not None
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if (
isinstance(chat_template_kwargs, dict)
and chat_template_kwargs.pop("reasoning_budget", None) is not None
):
removed = True
return removed
def _strip_chat_template_field(extra_body: dict[str, Any]) -> bool:
return extra_body.pop("chat_template", None) is not None
def _set_extra(
@ -23,43 +60,12 @@ def _set_extra(
def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None:
"""Clone a request body and strip only reasoning_budget fields."""
cloned_body = deepcopy(body)
extra_body = cloned_body.get("extra_body")
if not isinstance(extra_body, dict):
return None
removed = extra_body.pop("reasoning_budget", None) is not None
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if (
isinstance(chat_template_kwargs, dict)
and chat_template_kwargs.pop("reasoning_budget", None) is not None
):
removed = True
if not extra_body:
cloned_body.pop("extra_body", None)
if not removed:
return None
return cloned_body
return _clone_strip_extra_body(body, _strip_reasoning_budget_fields)
def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None:
"""Clone a request body and strip only chat_template."""
cloned_body = deepcopy(body)
extra_body = cloned_body.get("extra_body")
if not isinstance(extra_body, dict):
return None
if extra_body.pop("chat_template", None) is None:
return None
if not extra_body:
cloned_body.pop("extra_body", None)
return cloned_body
return _clone_strip_extra_body(body, _strip_chat_template_field)
def build_request_body(
@ -71,10 +77,13 @@ def build_request_body(
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
try:
body = build_base_request_body(
request_data,
include_thinking=thinking_enabled,
)
except OpenAIConversionError as exc:
raise InvalidRequestError(str(exc)) from exc
# NIM-specific max_tokens: cap against nim.max_tokens
max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None)

View file

@ -0,0 +1,95 @@
"""NVIDIA NIM / Riva offline ASR for voice notes (provider-owned transport)."""
from __future__ import annotations
from pathlib import Path
from loguru import logger
# NVIDIA NIM Whisper model mapping: (function_id, language_code)
_NIM_ASR_MODEL_MAP: dict[str, tuple[str, str]] = {
"nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"),
"nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"),
# function-id from NVIDIA NIM API docs (parakeet-ctc-0.6b-es).
"nvidia/parakeet-ctc-0.6b-es": ("a9eeee8f-b509-4712-b19d-194361fa5f31", "es-US"),
"nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"),
"nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"),
"nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"),
"nvidia/parakeet-1.1b-rnnt-multilingual-asr": (
"71203149-d3b7-4460-8231-1be2543a1fca",
"",
),
"openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"),
}
_RIVA_SERVER = "grpc.nvcf.nvidia.com:443"
def transcribe_audio_file(
file_path: Path,
model: str,
*,
api_key: str,
) -> str:
"""Transcribe audio using NVIDIA NIM / Riva gRPC (offline recognition).
Args:
file_path: Path to encoded audio bytes readable by Riva.
model: Hugging Face-style NIM model id (see ``_NIM_ASR_MODEL_MAP``).
api_key: NVIDIA API key (Bearer token); must be non-empty.
Returns:
Transcript text, or ``(no speech detected)`` when empty.
"""
key = (api_key or "").strip()
if not key:
raise ValueError(
"NVIDIA NIM transcription requires a non-empty nvidia_nim_api_key "
"(configure NVIDIA_NIM_API_KEY or pass api_key explicitly)."
)
try:
import riva.client
except ImportError as e:
raise ImportError(
"NVIDIA NIM transcription requires the voice extra. "
"Install with: uv sync --extra voice"
) from e
model_config = _NIM_ASR_MODEL_MAP.get(model)
if not model_config:
raise ValueError(
f"No NVIDIA NIM config found for model: {model}. "
f"Supported models: {', '.join(_NIM_ASR_MODEL_MAP.keys())}"
)
function_id, language_code = model_config
auth = riva.client.Auth(
use_ssl=True,
uri=_RIVA_SERVER,
metadata_args=[
["function-id", function_id],
["authorization", f"Bearer {key}"],
],
)
asr_service = riva.client.ASRService(auth)
config = riva.client.RecognitionConfig(
language_code=language_code,
max_alternatives=1,
verbatim_transcripts=True,
)
with open(file_path, "rb") as f:
data = f.read()
response = asr_service.offline_recognize(data, config)
transcript = ""
results = getattr(response, "results", None)
if results and results[0].alternatives:
transcript = results[0].alternatives[0].transcript
logger.debug(f"NIM transcription: {len(transcript)} chars")
return transcript or "(no speech detected)"

View file

@ -1,5 +1,7 @@
"""Ollama provider package."""
from .client import OLLAMA_BASE_URL, OllamaProvider
from providers.defaults import OLLAMA_DEFAULT_BASE
__all__ = ["OLLAMA_BASE_URL", "OllamaProvider"]
from .client import OllamaProvider
__all__ = ["OLLAMA_DEFAULT_BASE", "OllamaProvider"]

View file

@ -6,8 +6,6 @@ from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
from providers.defaults import OLLAMA_DEFAULT_BASE
OLLAMA_BASE_URL = OLLAMA_DEFAULT_BASE
class OllamaProvider(AnthropicMessagesTransport):
"""Ollama provider using native Anthropic Messages API."""
@ -16,7 +14,7 @@ class OllamaProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="OLLAMA",
default_base_url=OLLAMA_BASE_URL,
default_base_url=OLLAMA_DEFAULT_BASE,
)
self._api_key = config.api_key or "ollama"

View file

@ -1,5 +1,7 @@
"""OpenRouter provider - Anthropic-compatible native transport."""
from .client import OPENROUTER_BASE_URL, OpenRouterProvider
from providers.defaults import OPENROUTER_DEFAULT_BASE
__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"]
from .client import OpenRouterProvider
__all__ = ["OPENROUTER_DEFAULT_BASE", "OpenRouterProvider"]

View file

@ -2,34 +2,25 @@
from __future__ import annotations
import json
import uuid
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
from core.anthropic import SSEBuilder, append_request_id
from core.anthropic import append_request_id, iter_provider_stream_error_sse_events
from core.anthropic.native_sse_block_policy import (
NativeSseBlockPolicyState,
is_terminal_openrouter_done_event,
parse_native_sse_event,
transform_native_sse_block_event,
)
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
from providers.base import ProviderConfig
from providers.defaults import OPENROUTER_BASE_URL
from providers.defaults import OPENROUTER_DEFAULT_BASE
from .request import build_request_body
_ANTHROPIC_VERSION = "2023-06-01"
@dataclass
class _SSEFilterState:
"""Track Anthropic content block index remapping while filtering thinking."""
next_index: int = 0
index_map: dict[int, int] = field(default_factory=dict)
dropped_indexes: set[int] = field(default_factory=set)
open_block_types: dict[int, str] = field(default_factory=dict)
closed_indexes: set[int] = field(default_factory=set)
message_stopped: bool = False
class OpenRouterProvider(AnthropicMessagesTransport):
"""OpenRouter provider using the native Anthropic-compatible messages API."""
@ -39,7 +30,7 @@ class OpenRouterProvider(AnthropicMessagesTransport):
super().__init__(
config,
provider_name="OPENROUTER",
default_base_url=OPENROUTER_BASE_URL,
default_base_url=OPENROUTER_DEFAULT_BASE,
)
def _build_request_body(
@ -60,163 +51,9 @@ class OpenRouterProvider(AnthropicMessagesTransport):
"anthropic-version": _ANTHROPIC_VERSION,
}
@staticmethod
def _format_sse_event(event_name: str | None, data_text: str) -> str:
"""Format an SSE event from its event name and data payload."""
lines: list[str] = []
if event_name:
lines.append(f"event: {event_name}")
lines.extend(f"data: {line}" for line in data_text.splitlines())
return "\n".join(lines) + "\n\n"
@staticmethod
def _parse_sse_event(event: str) -> tuple[str | None, str]:
"""Extract the event name and raw data payload from an SSE event."""
event_name = None
data_lines: list[str] = []
for line in event.strip().splitlines():
if line.startswith("event:"):
event_name = line[6:].strip()
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
return event_name, "\n".join(data_lines)
@staticmethod
def _is_terminal_done_event(event_name: str | None, data_text: str) -> bool:
"""Return whether an event is OpenAI-style terminal noise."""
return (event_name is None or event_name in {"data", "done"}) and (
data_text.strip().upper() == "[DONE]"
)
@staticmethod
def _remap_index(
payload: dict[str, Any], state: _SSEFilterState, *, create: bool
) -> int | None:
"""Return the downstream index for a content block event."""
upstream_index = payload.get("index")
if not isinstance(upstream_index, int):
return None
if upstream_index in state.dropped_indexes:
return None
mapped_index = state.index_map.get(upstream_index)
if mapped_index is None and create:
mapped_index = state.next_index
state.index_map[upstream_index] = mapped_index
state.next_index += 1
return mapped_index
def _close_open_blocks_before(
self, state: _SSEFilterState, upstream_index: int
) -> str:
"""Close overlapping upstream blocks before starting a new block."""
events: list[str] = []
for open_upstream_index in list(state.open_block_types):
if open_upstream_index == upstream_index:
continue
mapped_index = state.index_map.get(open_upstream_index)
if mapped_index is None:
continue
payload = {"type": "content_block_stop", "index": mapped_index}
events.append(
self._format_sse_event("content_block_stop", json.dumps(payload))
)
state.closed_indexes.add(open_upstream_index)
state.open_block_types.pop(open_upstream_index, None)
return "".join(events)
@staticmethod
def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool:
if not isinstance(block_type, str):
return False
if block_type.startswith("redacted_thinking"):
return True
return not thinking_enabled and "thinking" in block_type
def _transform_sse_payload(
self,
event: str,
state: _SSEFilterState,
*,
thinking_enabled: bool,
) -> str | None:
"""Normalize OpenRouter SSE events and enforce local thinking policy."""
event_name, data_text = self._parse_sse_event(event)
if not event_name or not data_text:
return event
try:
payload = json.loads(data_text)
except json.JSONDecodeError:
return event
if event_name == "content_block_start":
block = payload.get("content_block")
if not isinstance(block, dict):
return event
block_type = block.get("type")
upstream_index = payload.get("index")
if self._should_drop_block_type(
block_type, thinking_enabled=thinking_enabled
):
if isinstance(upstream_index, int):
state.dropped_indexes.add(upstream_index)
return None
mapped_index = self._remap_index(payload, state, create=True)
if mapped_index is not None:
payload["index"] = mapped_index
if isinstance(upstream_index, int) and isinstance(block_type, str):
prefix = self._close_open_blocks_before(state, upstream_index)
state.open_block_types[upstream_index] = block_type
return prefix + self._format_sse_event(
event_name, json.dumps(payload)
)
return self._format_sse_event(event_name, json.dumps(payload))
return None if not thinking_enabled else event
if event_name == "content_block_delta":
delta = payload.get("delta")
if not isinstance(delta, dict):
return event
delta_type = delta.get("type")
if self._should_drop_block_type(
delta_type, thinking_enabled=thinking_enabled
):
return None
mapped_index = self._remap_index(payload, state, create=False)
if mapped_index is not None:
payload["index"] = mapped_index
return self._format_sse_event(event_name, json.dumps(payload))
if payload.get("index") in state.dropped_indexes:
return None
if not thinking_enabled:
return None
if event_name == "content_block_stop":
upstream_index = payload.get("index")
if (
isinstance(upstream_index, int)
and upstream_index in state.closed_indexes
):
state.closed_indexes.discard(upstream_index)
return None
mapped_index = self._remap_index(payload, state, create=False)
if mapped_index is not None:
payload["index"] = mapped_index
if isinstance(upstream_index, int):
state.open_block_types.pop(upstream_index, None)
return self._format_sse_event(event_name, json.dumps(payload))
if payload.get("index") in state.dropped_indexes:
return None
if not thinking_enabled:
return None
return event
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
"""Create per-stream state for thinking block filtering."""
return _SSEFilterState()
return NativeSseBlockPolicyState()
def _transform_stream_event(
self,
@ -226,23 +63,17 @@ class OpenRouterProvider(AnthropicMessagesTransport):
thinking_enabled: bool,
) -> str | None:
"""Drop provider-specific terminal noise and hidden thinking events."""
if isinstance(state, _SSEFilterState):
event_name, data_text = self._parse_sse_event(event)
if state.message_stopped or self._is_terminal_done_event(
if isinstance(state, NativeSseBlockPolicyState):
event_name, data_text = parse_native_sse_event(event)
if state.message_stopped or is_terminal_openrouter_done_event(
event_name, data_text
):
return None
if event_name == "message_stop":
state.message_stopped = True
if thinking_enabled:
if isinstance(state, _SSEFilterState):
return self._transform_sse_payload(
event, state, thinking_enabled=thinking_enabled
)
return event
if isinstance(state, _SSEFilterState):
return self._transform_sse_payload(
if isinstance(state, NativeSseBlockPolicyState):
return transform_native_sse_block_event(
event, state, thinking_enabled=thinking_enabled
)
return event
@ -260,9 +91,10 @@ class OpenRouterProvider(AnthropicMessagesTransport):
sent_any_event: bool,
) -> Iterator[str]:
"""Emit the Anthropic SSE error shape expected by Claude clients."""
sse = SSEBuilder(f"msg_{uuid.uuid4()}", request.model, input_tokens)
if not sent_any_event:
yield sse.message_start()
yield from sse.emit_error(error_message)
yield sse.message_delta("end_turn", 1)
yield sse.message_stop()
yield from iter_provider_stream_error_sse_events(
request=request,
input_tokens=input_tokens,
error_message=error_message,
sent_any_event=sent_any_event,
log_raw_sse_events=self._config.log_raw_sse_events,
)

View file

@ -2,136 +2,18 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from loguru import logger
from pydantic import BaseModel
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
_REQUEST_FIELDS = (
"model",
"messages",
"system",
"max_tokens",
"stop_sequences",
"stream",
"temperature",
"top_p",
"top_k",
"metadata",
"tools",
"tool_choice",
"thinking",
"extra_body",
"original_model",
"resolved_provider_model",
from config.constants import (
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS as OPENROUTER_DEFAULT_MAX_TOKENS,
)
_INTERNAL_FIELDS = {
"thinking",
"extra_body",
"original_model",
"resolved_provider_model",
}
_THINKING_HISTORY_BLOCK_TYPES = {"thinking", "redacted_thinking"}
def _serialize_value(value: Any) -> Any:
"""Convert Pydantic models and lightweight objects into JSON-ready values."""
if isinstance(value, BaseModel):
return value.model_dump(exclude_none=True)
if isinstance(value, dict):
return {
key: _serialize_value(item)
for key, item in value.items()
if item is not None
}
if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
return [_serialize_value(item) for item in value]
if value is None or isinstance(value, str | int | float | bool):
return value
if hasattr(value, "__dict__"):
return {
key: _serialize_value(item)
for key, item in vars(value).items()
if not key.startswith("_") and item is not None
}
return value
def _dump_request_fields(request_data: Any) -> dict[str, Any]:
"""Extract the public request fields forwarded to OpenRouter."""
if isinstance(request_data, BaseModel):
return request_data.model_dump(exclude_none=True)
dumped: dict[str, Any] = {}
for field in _REQUEST_FIELDS:
value = getattr(request_data, field, None)
if value is not None:
dumped[field] = _serialize_value(value)
return dumped
def _strip_unsigned_thinking_history(messages: Any) -> Any:
"""Remove assistant thinking history blocks that OpenRouter cannot replay."""
if not isinstance(messages, list):
return messages
sanitized_messages: list[Any] = []
for message in messages:
if not isinstance(message, dict):
sanitized_messages.append(message)
continue
content = message.get("content")
if not isinstance(content, list):
sanitized_messages.append(message)
continue
sanitized_content = [
block
for block in content
if not (
isinstance(block, dict)
and block.get("type") in _THINKING_HISTORY_BLOCK_TYPES
and not isinstance(block.get("signature"), str)
from core.anthropic.native_messages_request import (
OpenRouterExtraBodyError,
build_openrouter_native_request_body,
)
]
sanitized_message = dict(message)
sanitized_message["content"] = sanitized_content or ""
sanitized_messages.append(sanitized_message)
return sanitized_messages
def _normalize_system_prompt(system: Any) -> Any:
"""Flatten Claude SDK system blocks for OpenRouter's native endpoint."""
if not isinstance(system, list):
return system
text_parts: list[str] = []
for block in system:
if not isinstance(block, dict):
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text_parts.append(block["text"])
return "\n\n".join(text_parts).strip() if text_parts else system
def _apply_reasoning(body: dict[str, Any], thinking_cfg: Any) -> None:
"""Map Anthropic thinking controls onto OpenRouter reasoning controls."""
reasoning = body.setdefault("reasoning", {"enabled": True})
if not isinstance(reasoning, dict):
return
reasoning.setdefault("enabled", True)
if not isinstance(thinking_cfg, dict):
return
budget_tokens = thinking_cfg.get("budget_tokens")
if isinstance(budget_tokens, int):
reasoning.setdefault("max_tokens", budget_tokens)
from providers.exceptions import InvalidRequestError
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
@ -142,27 +24,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
len(getattr(request_data, "messages", [])),
)
dumped_request = _dump_request_fields(request_data)
request_extra = dumped_request.pop("extra_body", None)
thinking_cfg = dumped_request.get("thinking")
body = {
key: value
for key, value in dumped_request.items()
if key not in _INTERNAL_FIELDS
}
if isinstance(request_extra, dict):
body.update(request_extra)
body["messages"] = _strip_unsigned_thinking_history(body.get("messages"))
if "system" in body:
body["system"] = _normalize_system_prompt(body["system"])
body["stream"] = True
if body.get("max_tokens") is None:
body["max_tokens"] = OPENROUTER_DEFAULT_MAX_TOKENS
if thinking_enabled:
_apply_reasoning(body, thinking_cfg)
try:
body = build_openrouter_native_request_body(
request_data,
thinking_enabled=thinking_enabled,
default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS,
)
except OpenRouterExtraBodyError as exc:
raise InvalidRequestError(str(exc)) from exc
logger.debug(
"OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}",

View file

@ -21,14 +21,40 @@ from core.anthropic import (
SSEBuilder,
ThinkTagParser,
append_request_id,
get_user_facing_error_message,
map_stop_reason,
)
from providers.base import BaseProvider, ProviderConfig
from providers.error_mapping import map_error
from providers.error_mapping import (
map_error,
user_visible_message_for_mapped_provider_error,
)
from providers.rate_limit import GlobalRateLimiter
def _iter_heuristic_tool_use_sse(
sse: SSEBuilder, tool_use: dict[str, Any]
) -> Iterator[str]:
"""Emit SSE for one heuristic tool_use block (closes open text/thinking first)."""
if tool_use.get("name") == "Task" and isinstance(tool_use.get("input"), dict):
task_input = tool_use["input"]
if task_input.get("run_in_background") is not False:
task_input["run_in_background"] = False
yield from sse.close_content_blocks()
block_idx = sse.blocks.allocate_index()
yield sse.content_block_start(
block_idx,
"tool_use",
id=tool_use["id"],
name=tool_use["name"],
)
yield sse.content_block_delta(
block_idx,
"input_json_delta",
json.dumps(tool_use["input"]),
)
yield sse.content_block_stop(block_idx)
class OpenAIChatTransport(BaseProvider):
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …)."""
@ -113,6 +139,22 @@ class OpenAIChatTransport(BaseProvider):
)
return stream, retry_body
def _emit_tool_arg_delta(
self, sse: SSEBuilder, tc_index: int, args: str
) -> Iterator[str]:
"""Emit one argument fragment for a started tool block (Task buffer or raw JSON)."""
if not args:
return
state = sse.blocks.tool_states.get(tc_index)
if state is None:
return
if state.name == "Task":
parsed = sse.blocks.buffer_task_args(tc_index, args)
if parsed is not None:
yield sse.emit_tool_delta(tc_index, json.dumps(parsed))
return
yield sse.emit_tool_delta(tc_index, args)
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
"""Process a single tool call delta and yield SSE events."""
tc_index = tc.get("index", 0)
@ -121,34 +163,42 @@ class OpenAIChatTransport(BaseProvider):
fn_delta = tc.get("function", {})
incoming_name = fn_delta.get("name")
arguments = fn_delta.get("arguments", "")
arguments = fn_delta.get("arguments", "") or ""
if tc.get("id") is not None:
sse.blocks.set_stream_tool_id(tc_index, tc.get("id"))
if incoming_name is not None:
sse.blocks.register_tool_name(tc_index, incoming_name)
state = sse.blocks.tool_states.get(tc_index)
if state is None or not state.started:
name = state.name if state else ""
if name or tc.get("id"):
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
yield sse.start_tool_block(tc_index, tool_id, name)
resolved_id = (state.tool_id if state and state.tool_id else None) or tc.get(
"id"
)
resolved_name = (state.name if state else "") or ""
args = arguments
if args:
state = sse.blocks.tool_states.get(tc_index)
if state is None or not state.started:
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
name = (state.name if state else None) or "tool_call"
yield sse.start_tool_block(tc_index, tool_id, name)
state = sse.blocks.tool_states.get(tc_index)
if not state or not state.started:
name_ok = bool((resolved_name or "").strip())
if name_ok:
tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}"
display_name = (resolved_name or "").strip() or "tool_call"
yield sse.start_tool_block(tc_index, tool_id, display_name)
state = sse.blocks.tool_states[tc_index]
if state.pre_start_args:
pre = state.pre_start_args
state.pre_start_args = ""
yield from self._emit_tool_arg_delta(sse, tc_index, pre)
current_name = state.name if state else ""
if current_name == "Task":
parsed = sse.blocks.buffer_task_args(tc_index, args)
if parsed is not None:
yield sse.emit_tool_delta(tc_index, json.dumps(parsed))
state = sse.blocks.tool_states.get(tc_index)
if not arguments:
return
if state is None or not state.started:
state = sse.blocks.ensure_tool_state(tc_index)
if not (resolved_name or "").strip():
state.pre_start_args += arguments
return
yield sse.emit_tool_delta(tc_index, args)
yield from self._emit_tool_arg_delta(sse, tc_index, arguments)
def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]:
"""Emit buffered Task args as a single JSON delta (best-effort)."""
@ -181,7 +231,12 @@ class OpenAIChatTransport(BaseProvider):
"""Shared streaming implementation."""
tag = self._provider_name
message_id = f"msg_{uuid.uuid4()}"
sse = SSEBuilder(message_id, request.model, input_tokens)
sse = SSEBuilder(
message_id,
request.model,
input_tokens,
log_raw_events=self._config.log_raw_sse_events,
)
body = self._build_request_body(request, thinking_enabled=thinking_enabled)
thinking_enabled = self._is_thinking_enabled(request, thinking_enabled)
@ -201,8 +256,6 @@ class OpenAIChatTransport(BaseProvider):
heuristic_parser = HeuristicToolParser()
finish_reason = None
usage_info = None
error_occurred = False
error_message = ""
async with self._global_rate_limiter.concurrency_slot():
try:
@ -258,26 +311,10 @@ class OpenAIChatTransport(BaseProvider):
yield sse.emit_text_delta(filtered_text)
for tool_use in detected_tools:
for event in sse.close_content_blocks():
yield event
block_idx = sse.blocks.allocate_index()
if tool_use.get("name") == "Task" and isinstance(
tool_use.get("input"), dict
for event in _iter_heuristic_tool_use_sse(
sse, tool_use
):
tool_use["input"]["run_in_background"] = False
yield sse.content_block_start(
block_idx,
"tool_use",
id=tool_use["id"],
name=tool_use["name"],
)
yield sse.content_block_delta(
block_idx,
"input_json_delta",
json.dumps(tool_use["input"]),
)
yield sse.content_block_stop(block_idx)
yield event
# Handle native tool calls
if delta.tool_calls:
@ -298,17 +335,12 @@ class OpenAIChatTransport(BaseProvider):
except asyncio.CancelledError, GeneratorExit:
raise
except Exception as e:
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
self._log_stream_transport_error(tag, req_tag, e)
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)
error_occurred = True
if getattr(mapped_e, "status_code", None) == 405:
base_message = (
f"Upstream provider {tag} rejected the request method "
"or endpoint (HTTP 405)."
)
else:
base_message = get_user_facing_error_message(
mapped_e, read_timeout_s=self._config.http_read_timeout
base_message = user_visible_message_for_mapped_provider_error(
mapped_e,
provider_name=tag,
read_timeout_s=self._config.http_read_timeout,
)
error_message = append_request_id(base_message, request_id)
logger.info(
@ -317,10 +349,13 @@ class OpenAIChatTransport(BaseProvider):
type(e).__name__,
req_tag,
)
for event in sse.close_content_blocks():
for event in sse.close_all_blocks():
yield event
for event in sse.emit_error(error_message):
yield event
yield sse.message_delta("end_turn", 1)
yield sse.message_stop()
return
# Flush remaining content
remaining = think_parser.flush()
@ -338,32 +373,16 @@ class OpenAIChatTransport(BaseProvider):
yield sse.emit_text_delta(remaining.content)
for tool_use in heuristic_parser.flush():
for event in sse.close_content_blocks():
for event in _iter_heuristic_tool_use_sse(sse, tool_use):
yield event
block_idx = sse.blocks.allocate_index()
yield sse.content_block_start(
block_idx,
"tool_use",
id=tool_use["id"],
name=tool_use["name"],
has_started_tool = any(s.started for s in sse.blocks.tool_states.values())
has_content_blocks = (
sse.blocks.text_index != -1
or sse.blocks.thinking_index != -1
or has_started_tool
)
if tool_use.get("name") == "Task" and isinstance(
tool_use.get("input"), dict
):
tool_use["input"]["run_in_background"] = False
yield sse.content_block_delta(
block_idx,
"input_json_delta",
json.dumps(tool_use["input"]),
)
yield sse.content_block_stop(block_idx)
if (
not error_occurred
and sse.blocks.text_index == -1
and not sse.blocks.tool_states
):
if not has_content_blocks:
for event in sse.ensure_text_block():
yield event
yield sse.emit_text_delta(" ")

View file

@ -3,14 +3,16 @@
import asyncio
import random
import time
from collections import deque
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from typing import Any, ClassVar, TypeVar
import httpx
import openai
from loguru import logger
from core.rate_limit import StrictSlidingWindowLimiter
T = TypeVar("T")
@ -51,10 +53,10 @@ class GlobalRateLimiter:
self._rate_limit = rate_limit
self._rate_window = float(rate_window)
self._max_concurrency = max_concurrency
# Monotonic timestamps of the last granted slots.
self._request_times: deque[float] = deque()
self._proactive_limiter = StrictSlidingWindowLimiter(
self._rate_limit, self._rate_window
)
self._blocked_until: float = 0
self._lock = asyncio.Lock()
self._concurrency_sem = asyncio.Semaphore(max_concurrency)
self._initialized = True
@ -107,7 +109,6 @@ class GlobalRateLimiter:
logger.info(
"Rebuilding provider rate limiter for updated scope '{}'", scope
)
if scope not in cls._scoped_instances or existing:
cls._scoped_instances[scope] = cls(
rate_limit=desired_rate_limit,
rate_window=desired_rate_window,
@ -150,27 +151,7 @@ class GlobalRateLimiter:
Guarantees: at most `self._rate_limit` acquisitions in any interval of length
`self._rate_window` (seconds).
"""
while True:
wait_time = 0.0
async with self._lock:
now = time.monotonic()
cutoff = now - self._rate_window
while self._request_times and self._request_times[0] <= cutoff:
self._request_times.popleft()
if len(self._request_times) < self._rate_limit:
self._request_times.append(now)
return
oldest = self._request_times[0]
wait_time = max(0.0, (oldest + self._rate_window) - now)
# Sleep outside the lock so other tasks can continue to queue.
if wait_time > 0:
await asyncio.sleep(wait_time)
else:
await asyncio.sleep(0)
await self._proactive_limiter.acquire()
def set_blocked(self, seconds: float = 60) -> None:
"""
@ -263,6 +244,24 @@ class GlobalRateLimiter:
)
self.set_blocked(delay)
await asyncio.sleep(delay)
except httpx.HTTPStatusError as e:
if e.response.status_code != 429:
raise
last_exc = e
if attempt >= max_retries:
logger.warning(
f"HTTP 429 retry exhausted after {max_retries} retries"
)
break
delay = min(base_delay * (2**attempt), max_delay)
delay += random.uniform(0, jitter)
logger.warning(
f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. "
f"Retrying in {delay:.1f}s..."
)
self.set_blocked(delay)
await asyncio.sleep(delay)
assert last_exc is not None
raise last_exc

View file

@ -3,108 +3,20 @@
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from dataclasses import dataclass
from typing import Literal
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from config.provider_catalog import (
PROVIDER_CATALOG,
SUPPORTED_PROVIDER_IDS,
ProviderDescriptor,
)
from config.settings import Settings
from providers.base import BaseProvider, ProviderConfig
from providers.defaults import (
DEEPSEEK_DEFAULT_BASE,
LLAMACPP_DEFAULT_BASE,
LMSTUDIO_DEFAULT_BASE,
NVIDIA_NIM_DEFAULT_BASE,
OLLAMA_DEFAULT_BASE,
OPENROUTER_DEFAULT_BASE,
)
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
TransportType = Literal["openai_chat", "anthropic_messages"]
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
@dataclass(frozen=True, slots=True)
class ProviderDescriptor:
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
provider_id: str
transport_type: TransportType
capabilities: tuple[str, ...]
credential_env: str | None = None
credential_url: str | None = None
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
credential_attr: str | None = None
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
static_credential: str | None = None
default_base_url: str | None = None
base_url_attr: str | None = None
proxy_attr: str | None = None
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
"nvidia_nim": ProviderDescriptor(
provider_id="nvidia_nim",
transport_type="openai_chat",
credential_env="NVIDIA_NIM_API_KEY",
credential_url="https://build.nvidia.com/settings/api-keys",
credential_attr="nvidia_nim_api_key",
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
proxy_attr="nvidia_nim_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
"open_router": ProviderDescriptor(
provider_id="open_router",
transport_type="anthropic_messages",
credential_env="OPENROUTER_API_KEY",
credential_url="https://openrouter.ai/keys",
credential_attr="open_router_api_key",
default_base_url=OPENROUTER_DEFAULT_BASE,
proxy_attr="open_router_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
),
"deepseek": ProviderDescriptor(
provider_id="deepseek",
transport_type="openai_chat",
credential_env="DEEPSEEK_API_KEY",
credential_url="https://platform.deepseek.com/api_keys",
credential_attr="deepseek_api_key",
default_base_url=DEEPSEEK_DEFAULT_BASE,
capabilities=("chat", "streaming", "thinking"),
),
"lmstudio": ProviderDescriptor(
provider_id="lmstudio",
transport_type="anthropic_messages",
static_credential="lm-studio",
default_base_url=LMSTUDIO_DEFAULT_BASE,
base_url_attr="lm_studio_base_url",
proxy_attr="lmstudio_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
),
"llamacpp": ProviderDescriptor(
provider_id="llamacpp",
transport_type="anthropic_messages",
static_credential="llamacpp",
default_base_url=LLAMACPP_DEFAULT_BASE,
base_url_attr="llamacpp_base_url",
proxy_attr="llamacpp_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
),
"ollama": ProviderDescriptor(
provider_id="ollama",
transport_type="anthropic_messages",
static_credential="ollama",
default_base_url=OLLAMA_DEFAULT_BASE,
base_url_attr="ollama_base_url",
capabilities=(
"chat",
"streaming",
"tools",
"thinking",
"native_anthropic",
"local",
),
),
}
# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
@ -119,25 +31,25 @@ def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProv
return OpenRouterProvider(config)
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.deepseek import DeepSeekProvider
return DeepSeekProvider(config)
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.lmstudio import LMStudioProvider
return LMStudioProvider(config)
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.llamacpp import LlamaCppProvider
return LlamaCppProvider(config)
def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider:
def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.ollama import OllamaProvider
return OllamaProvider(config)
@ -208,6 +120,8 @@ def build_provider_config(
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_model_thinking,
proxy=proxy,
log_raw_sse_events=settings.log_raw_sse_events,
log_api_error_tracebacks=settings.log_api_error_tracebacks,
)

View file

@ -1,11 +1,13 @@
"""
Claude Code Proxy - Entry Point
Minimal entry point that imports the app from the api module.
Minimal entry point that builds the ASGI app via :func:`api.app.create_app`.
Run with: uv run uvicorn server:app --host 0.0.0.0 --port 8082 --timeout-graceful-shutdown 5
"""
from api.app import app, create_app
from api.app import create_app
app = create_app()
__all__ = ["app", "create_app"]

View file

@ -98,7 +98,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
"providers.registry.ProviderRegistry",
"provider id and Settings",
"configured BaseProvider instance",
"503 for missing credentials; ValueError for unknown provider",
"503 for missing credentials; invalid_request_error for unknown provider",
("tests/api/test_dependencies.py", "tests/providers/test_registry.py"),
("test_configured_provider_models_stream_successfully",),
),

View file

@ -19,6 +19,14 @@ import httpx
import pytest
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from core.anthropic.stream_contracts import (
SSEEvent,
assert_anthropic_stream_contract,
event_index,
has_tool_use,
parse_sse_lines,
text_content,
)
from messaging.handler import ClaudeMessageHandler
from messaging.models import IncomingMessage
from messaging.platforms.base import MessagingPlatform
@ -26,13 +34,6 @@ from messaging.session import SessionStore
from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers
from smoke.lib.server import RunningServer, start_server
from smoke.lib.skips import fail_missing_env
from smoke.lib.sse import (
SSEEvent,
assert_anthropic_stream_contract,
has_tool_use,
parse_sse_lines,
text_content,
)
@dataclass(slots=True)
@ -590,14 +591,14 @@ def assistant_content_from_events(events: list[SSEEvent]) -> list[dict[str, Any]
block_order: list[int] = []
for event in events:
if event.event == "content_block_start":
index = _event_index(event)
index = event_index(event)
block = event.data.get("content_block", {})
if isinstance(block, dict):
blocks[index] = dict(block)
block_order.append(index)
continue
if event.event == "content_block_delta":
index = _event_index(event)
index = event_index(event)
block = blocks.get(index)
delta = event.data.get("delta", {})
if not isinstance(block, dict) or not isinstance(delta, dict):
@ -663,12 +664,6 @@ def default_cli_events(session_id: str) -> list[dict[str, Any]]:
]
def _event_index(event: SSEEvent) -> int:
value = event.data.get("index")
assert isinstance(value, int), event.data
return value
def assert_product_stream(events: list[SSEEvent]) -> None:
assert_anthropic_stream_contract(events)
assert text_content(events).strip() or has_tool_use(events), (

View file

@ -6,9 +6,10 @@ from typing import Any
import httpx
from core.anthropic.stream_contracts import SSEEvent, parse_sse_lines
from .config import SmokeConfig, auth_headers, redacted
from .server import RunningServer
from .sse import SSEEvent, parse_sse_lines
def message_payload(

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import httpx
import pytest
from smoke.lib.sse import SSEEvent
from core.anthropic.stream_contracts import SSEEvent
UPSTREAM_UNAVAILABLE_MARKERS = (
"connection refused",

View file

@ -1,29 +0,0 @@
"""SSE parsing and Anthropic stream assertions for smoke tests.
Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this
module re-exports it for smoke and historical import paths.
"""
from __future__ import annotations
from core.anthropic.stream_contracts import (
SSEEvent,
assert_anthropic_stream_contract,
event_names,
has_tool_use,
parse_sse_lines,
parse_sse_text,
text_content,
thinking_content,
)
__all__ = [
"SSEEvent",
"assert_anthropic_stream_contract",
"event_names",
"has_tool_use",
"parse_sse_lines",
"parse_sse_text",
"text_content",
"thinking_content",
]

View file

@ -5,6 +5,10 @@ import time
import httpx
import pytest
from core.anthropic.stream_contracts import (
assert_anthropic_stream_contract,
text_content,
)
from smoke.lib.config import SmokeConfig, auth_headers
from smoke.lib.http import collect_message_stream, message_payload
from smoke.lib.server import start_server
@ -12,7 +16,6 @@ from smoke.lib.skips import (
skip_if_upstream_unavailable_events,
skip_if_upstream_unavailable_exception,
)
from smoke.lib.sse import assert_anthropic_stream_contract, text_content
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")]

View file

@ -2,11 +2,14 @@ from __future__ import annotations
import pytest
from core.anthropic.stream_contracts import (
assert_anthropic_stream_contract,
has_tool_use,
)
from smoke.lib.config import SmokeConfig
from smoke.lib.http import collect_message_stream, message_payload
from smoke.lib.server import start_server
from smoke.lib.skips import skip_if_upstream_unavailable_events
from smoke.lib.sse import assert_anthropic_stream_contract, has_tool_use
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("tools")]

View file

@ -24,12 +24,13 @@ def test_voice_transcription_backend_when_explicitly_enabled(
wav_path = tmp_path / "smoke-tone.wav"
_write_tone_wav(wav_path)
try:
text = transcribe_audio(
wav_path,
"audio/wav",
whisper_model=smoke_config.settings.whisper_model,
whisper_device=smoke_config.settings.whisper_device,
)
t_kw: dict[str, str] = {
"whisper_model": smoke_config.settings.whisper_model,
"whisper_device": smoke_config.settings.whisper_device,
}
if smoke_config.settings.whisper_device == "nvidia_nim":
t_kw["nvidia_nim_api_key"] = smoke_config.settings.nvidia_nim_api_key
text = transcribe_audio(wav_path, "audio/wav", **t_kw)
except ImportError as exc:
pytest.skip(str(exc))
assert isinstance(text, str)

View file

@ -3,6 +3,11 @@ from __future__ import annotations
import httpx
import pytest
from core.anthropic.stream_contracts import (
assert_anthropic_stream_contract,
parse_sse_lines,
text_content,
)
from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers
from smoke.lib.e2e import (
ConversationDriver,
@ -16,11 +21,6 @@ from smoke.lib.skips import (
skip_if_upstream_unavailable_events,
skip_if_upstream_unavailable_exception,
)
from smoke.lib.sse import (
assert_anthropic_stream_contract,
parse_sse_lines,
text_content,
)
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")]

View file

@ -55,6 +55,7 @@ def test_voice_nim_backend_e2e(smoke_config: SmokeConfig, tmp_path: Path) -> Non
"audio/wav",
whisper_model=smoke_config.settings.whisper_model,
whisper_device="nvidia_nim",
nvidia_nim_api_key=smoke_config.settings.nvidia_nim_api_key,
)
assert isinstance(text, str)

1
tests/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Test suite package."""

View file

@ -0,0 +1,184 @@
"""Pydantic passthrough of Anthropic protocol fields (e.g. ``cache_control``)."""
from __future__ import annotations
from api.models.anthropic import (
ContentBlockServerToolUse,
ContentBlockText,
ContentBlockWebSearchToolResult,
Message,
MessagesRequest,
SystemContent,
Tool,
)
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
from core.anthropic.native_messages_request import (
build_base_native_anthropic_request_body,
)
def test_cache_control_preserved_on_parsed_user_text_system_and_tool() -> None:
raw = {
"model": "m",
"max_tokens": 20,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "hi",
"cache_control": {"type": "ephemeral"},
}
],
}
],
"system": [
{
"type": "text",
"text": "be brief",
"cache_control": {"type": "ephemeral"},
}
],
"tools": [
{
"name": "alpha",
"input_schema": {"type": "object"},
"cache_control": {"type": "ephemeral"},
}
],
}
req = MessagesRequest.model_validate(raw)
block = req.messages[0].content[0]
assert isinstance(block, ContentBlockText)
assert block.model_dump()["cache_control"] == {"type": "ephemeral"}
s0 = req.system[0] if isinstance(req.system, list) else None
assert isinstance(s0, SystemContent)
assert s0.model_dump()["cache_control"] == {"type": "ephemeral"}
t0 = req.tools[0] if req.tools else None
assert isinstance(t0, Tool)
assert t0.model_dump()["cache_control"] == {"type": "ephemeral"}
def test_build_base_native_body_includes_cache_control() -> None:
req = MessagesRequest(
model="m",
max_tokens=20,
messages=[
Message(
role="user",
content=[
ContentBlockText.model_validate(
{
"type": "text",
"text": "x",
"cache_control": {"type": "ephemeral"},
}
)
],
)
],
system=[
SystemContent.model_validate(
{
"type": "text",
"text": "s",
"cache_control": {"type": "ephemeral"},
}
)
],
tools=[
Tool.model_validate(
{
"name": "n",
"input_schema": {"type": "object"},
"cache_control": {"type": "ephemeral"},
}
)
],
)
body = build_base_native_anthropic_request_body(
req,
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
thinking_enabled=False,
)
assert body["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"}
assert body["system"][0]["cache_control"] == {"type": "ephemeral"}
assert body["tools"][0]["cache_control"] == {"type": "ephemeral"}
def test_pydantic_discriminator_still_distinguishes_blocks() -> None:
m = Message.model_validate(
{
"role": "user",
"content": [{"type": "text", "text": "a", "z": 1}],
}
)
b = m.content[0]
assert isinstance(b, ContentBlockText)
assert b.model_dump()["z"] == 1
def test_server_tool_assistant_blocks_round_trip_in_native_body() -> None:
"""Local server-tool responses must parse as valid history for a follow-up request."""
raw = {
"model": "m",
"max_tokens": 20,
"messages": [
{
"role": "assistant",
"content": [
{
"type": "server_tool_use",
"id": "srvtoolu_1",
"name": "web_search",
"input": {"query": "q"},
},
{
"type": "web_search_tool_result",
"tool_use_id": "srvtoolu_1",
"content": [
{
"type": "web_search_result",
"title": "T",
"url": "https://example.com",
}
],
},
],
}
],
"mcp_servers": [{"type": "url", "url": "https://example.com/mcp"}],
}
req = MessagesRequest.model_validate(raw)
assert len(req.messages) == 1
blocks = req.messages[0].content
assert isinstance(blocks, list)
assert isinstance(blocks[0], ContentBlockServerToolUse)
assert isinstance(blocks[1], ContentBlockWebSearchToolResult)
body = build_base_native_anthropic_request_body(
req,
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
thinking_enabled=False,
)
assert body["mcp_servers"][0]["type"] == "url"
assert body["messages"][0]["content"][0]["type"] == "server_tool_use"
assert body["messages"][0]["content"][1]["type"] == "web_search_tool_result"
def test_native_body_preserves_context_and_output_config() -> None:
raw = {
"model": "m",
"max_tokens": 20,
"messages": [{"role": "user", "content": "x"}],
"context_management": {"edits": [{"type": "clear"}]},
"output_config": {"some": "hint"},
}
req = MessagesRequest.model_validate(raw)
body = build_base_native_anthropic_request_body(
req,
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
thinking_enabled=False,
)
assert body["context_management"] == raw["context_management"]
assert body["output_config"] == raw["output_config"]

View file

@ -3,9 +3,11 @@ from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from api.app import app
from api.app import create_app
from providers.nvidia_nim import NvidiaNimProvider
app = create_app()
# Mock provider
mock_provider = MagicMock(spec=NvidiaNimProvider)
@ -171,7 +173,7 @@ def test_generic_exception_returns_500(client: TestClient):
def test_generic_exception_with_status_code(client: TestClient):
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
"""Unexpected errors always map to HTTP 500 (ignore ad-hoc status_code attrs)."""
class ExceptionWithStatus(RuntimeError):
def __init__(self, msg: str, status_code: int = 500):
@ -191,7 +193,7 @@ def test_generic_exception_with_status_code(client: TestClient):
"stream": True,
},
)
assert response.status_code == 502
assert response.status_code == 500
mock_provider.stream_response = _mock_stream_response

View file

@ -17,6 +17,15 @@ _RUNTIME_EXTRAS = {
"nvidia_nim_api_key": "",
"claude_cli_bin": "claude",
"uses_process_anthropic_auth_token": lambda: False,
"messaging_rate_limit": 1,
"messaging_rate_window": 1.0,
"max_message_log_entries_per_chat": None,
"debug_platform_edits": False,
"debug_subagent_stack": False,
"log_api_error_tracebacks": False,
"log_raw_messaging_content": False,
"log_raw_cli_diagnostics": False,
"log_messaging_error_details": False,
}
@ -86,6 +95,47 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
assert body["error"]["type"] == "authentication_error"
def test_create_app_provider_error_default_logs_exclude_provider_message():
"""Provider errors must not log exc.message by default."""
from api.app import create_app
from providers.exceptions import AuthenticationError
app = create_app()
secret = "provider-upstream-secret-detail"
@app.get("/raise_provider_secret")
async def _raise():
raise AuthenticationError(secret)
api_app_mod = importlib.import_module("api.app")
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token=None,
allowed_telegram_user_id=None,
discord_bot_token=None,
allowed_discord_channels=None,
allowed_dir="",
claude_workspace="./agent_workspace",
host="127.0.0.1",
port=8082,
log_file="server.log",
log_api_error_tracebacks=False,
)
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
patch.object(api_app_mod.logger, "error") as log_err,
):
with TestClient(app) as client:
resp = client.get("/raise_provider_secret")
assert resp.status_code == 401
blob = " ".join(str(a) for c in log_err.call_args_list for a in c.args)
blob += repr([c.kwargs for c in log_err.call_args_list])
assert secret not in blob
assert "authentication_error" in blob
def test_create_app_general_exception_handler_returns_500():
from api.app import create_app
@ -120,6 +170,50 @@ def test_create_app_general_exception_handler_returns_500():
assert body["error"]["type"] == "api_error"
def test_create_app_general_exception_default_logs_exclude_exception_message():
"""Unhandled errors must not log exception text by default (may echo user content)."""
from api.app import create_app
app = create_app()
secret = "user-provided-secret-token-xyzzy"
@app.get("/raise_secret")
async def _raise_secret():
raise ValueError(secret)
api_app_mod = importlib.import_module("api.app")
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token=None,
allowed_telegram_user_id=None,
discord_bot_token=None,
allowed_discord_channels=None,
allowed_dir="",
claude_workspace="./agent_workspace",
host="127.0.0.1",
port=8082,
log_file="server.log",
log_api_error_tracebacks=False,
)
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
patch.object(api_app_mod.logger, "error") as log_err,
):
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/raise_secret")
assert resp.status_code == 500
flattened: list[str] = []
for call in log_err.call_args_list:
flattened.extend(str(arg) for arg in call.args)
flattened.append(repr(call.kwargs))
blob = " ".join(flattened)
assert secret not in blob
assert "ValueError" in blob
@pytest.mark.parametrize(
"messaging_enabled", [True, False], ids=["with_platform", "no_platform"]
)

View file

@ -2,10 +2,12 @@ from unittest.mock import patch
from fastapi.testclient import TestClient
from api.app import app
from api.app import create_app
from api.dependencies import get_settings
from config.settings import Settings
app = create_app()
def test_anthropic_auth_token_required_and_accepts_x_api_key():
client = TestClient(app)

View file

@ -16,7 +16,7 @@ from api.dependencies import (
)
from config.nim import NimSettings
from providers.deepseek import DeepSeekProvider
from providers.exceptions import UnknownProviderTypeError
from providers.exceptions import ServiceUnavailableError, UnknownProviderTypeError
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NvidiaNimProvider
from providers.ollama import OllamaProvider
@ -40,7 +40,7 @@ def _make_mock_settings(**overrides):
mock.nim = NimSettings()
mock.http_read_timeout = 300.0
mock.http_write_timeout = 10.0
mock.http_connect_timeout = 2.0
mock.http_connect_timeout = 10.0
mock.enable_model_thinking = True
for key, value in overrides.items():
setattr(mock, key, value)
@ -434,18 +434,17 @@ def test_resolve_provider_per_app_uses_separate_registries() -> None:
assert p1 is not p2
def test_resolve_provider_lazily_installs_registry() -> None:
"""If app has no provider_registry, one is created on app.state."""
def test_resolve_provider_missing_registry_raises_service_unavailable() -> None:
"""HTTP apps must install app.state.provider_registry (e.g. via AppRuntime)."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
settings = _make_mock_settings()
app = SimpleNamespace(state=State())
assert getattr(app.state, "provider_registry", None) is None
with pytest.raises(
ServiceUnavailableError, match="Provider registry is not configured"
):
resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
reg = app.state.provider_registry
assert reg is not None
p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
assert p2 is reg.get("nvidia_nim", settings) # same registry instance
def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None:

View file

@ -3,10 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from api.app import app
from api.app import create_app
from api.dependencies import get_settings
from config.settings import Settings
app = create_app()
@pytest.fixture
def client():
@ -103,6 +105,17 @@ def test_create_message_empty_messages_returns_400(client):
assert "cannot be empty" in data.get("error", {}).get("message", "")
def test_count_tokens_empty_messages_returns_400(client):
"""POST /v1/messages/count_tokens with messages: [] matches messages validation."""
payload = {"model": "claude-3-sonnet", "messages": []}
response = client.post("/v1/messages/count_tokens", json=payload)
assert response.status_code == 400
data = response.json()
assert data.get("type") == "error"
assert data.get("error", {}).get("type") == "invalid_request_error"
assert "cannot be empty" in data.get("error", {}).get("message", "")
def test_count_tokens_endpoint(client):
payload = {
"model": "claude-3-sonnet",

View file

@ -0,0 +1,70 @@
"""Tests for safe default logging in :mod:`api.runtime`."""
import importlib
import logging
from unittest.mock import MagicMock, patch
import pytest
from tests.api.test_app_lifespan_and_errors import _app_settings
@pytest.mark.asyncio
async def test_messaging_start_failure_default_logs_exclude_traceback(caplog):
api_runtime_mod = importlib.import_module("api.runtime")
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="t",
allowed_telegram_user_id="1",
discord_bot_token=None,
allowed_discord_channels=None,
allowed_dir="",
claude_workspace="./agent_workspace",
host="127.0.0.1",
port=8082,
log_file="server.log",
log_api_error_tracebacks=False,
)
runtime = api_runtime_mod.AppRuntime(app=MagicMock(), settings=settings)
with (
patch(
"messaging.platforms.factory.create_messaging_platform",
side_effect=RuntimeError("SECRET_RUNTIME_DETAIL"),
),
caplog.at_level(logging.ERROR),
):
await runtime._start_messaging_if_configured()
blob = " | ".join(r.getMessage() for r in caplog.records)
assert "SECRET_RUNTIME_DETAIL" not in blob
assert "exc_type=RuntimeError" in blob
@pytest.mark.asyncio
async def test_best_effort_default_logs_exclude_exception_text(caplog):
api_runtime_mod = importlib.import_module("api.runtime")
async def boom():
raise ValueError("SECRET_SHUTDOWN")
with caplog.at_level(logging.WARNING):
await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=False)
blob = " | ".join(r.getMessage() for r in caplog.records)
assert "SECRET_SHUTDOWN" not in blob
assert "exc_type=ValueError" in blob
@pytest.mark.asyncio
async def test_best_effort_verbose_includes_exception_text(caplog):
api_runtime_mod = importlib.import_module("api.runtime")
async def boom():
raise ValueError("VISIBLE_SHUTDOWN")
with caplog.at_level(logging.WARNING):
await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=True)
blob = " | ".join(r.getMessage() for r in caplog.records)
assert "VISIBLE_SHUTDOWN" in blob

View file

@ -0,0 +1,208 @@
"""Tests that API and SSE logging avoid raw sensitive payloads by default."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException
from api import services as services_mod
from api.models.anthropic import Message, MessagesRequest
from api.services import ClaudeProxyService
from config.settings import Settings
from core.anthropic.sse import SSEBuilder
def test_create_message_skips_full_payload_debug_log_by_default():
settings = Settings()
assert settings.log_raw_api_payloads is False
mock_provider = MagicMock()
async def fake_stream(*_a, **_kw):
yield "event: ping\ndata: {}\n\n"
mock_provider.stream_response = fake_stream
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
request = MessagesRequest(
model="claude-3-haiku-20240307",
max_tokens=10,
messages=[Message(role="user", content="secret-user-text")],
)
with patch.object(services_mod.logger, "debug") as mock_debug:
service.create_message(request)
full_payload_calls = [
c
for c in mock_debug.call_args_list
if c.args and str(c.args[0]) == "FULL_PAYLOAD [{}]: {}"
]
assert not full_payload_calls
def test_create_message_logs_full_payload_when_opt_in():
settings = Settings()
settings.log_raw_api_payloads = True
mock_provider = MagicMock()
async def fake_stream(*_a, **_kw):
yield "event: ping\ndata: {}\n\n"
mock_provider.stream_response = fake_stream
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
request = MessagesRequest(
model="claude-3-haiku-20240307",
max_tokens=10,
messages=[Message(role="user", content="visible")],
)
with patch.object(services_mod.logger, "debug") as mock_debug:
service.create_message(request)
keys = [c.args[0] for c in mock_debug.call_args_list if c.args]
assert any(k == "FULL_PAYLOAD [{}]: {}" for k in keys)
def test_sse_builder_default_debug_has_no_serialized_json_content():
with patch("core.anthropic.sse.logger.debug") as mock_debug:
sse = SSEBuilder("msg_x", "m", 1, log_raw_events=False)
sse.message_start()
assert mock_debug.call_count == 1
message = str(mock_debug.call_args)
assert "serialized_bytes=" in message
assert "role" not in message
assert "assistant" not in message
def test_sse_builder_raw_logging_includes_event_body_when_enabled():
with patch("core.anthropic.sse.logger.debug") as mock_debug:
sse = SSEBuilder("msg_x", "m", 1, log_raw_events=True)
sse.message_start()
assert mock_debug.call_count == 1
message = str(mock_debug.call_args)
assert "message_start" in message
assert "role" in message
def _flatten_log_calls(mock_log) -> str:
parts: list[str] = []
for call in mock_log.call_args_list:
parts.extend(str(arg) for arg in call.args)
parts.append(repr(call.kwargs))
return " ".join(parts)
def test_create_message_unexpected_error_default_logs_exclude_exception_text():
settings = Settings()
assert settings.log_api_error_tracebacks is False
secret = "upstream-secret-token-abc"
mock_provider = MagicMock()
def stream_boom(*_a, **_kw):
raise RuntimeError(secret)
mock_provider.stream_response = stream_boom
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
request = MessagesRequest(
model="claude-3-haiku-20240307",
max_tokens=10,
messages=[Message(role="user", content="hi")],
)
with (
patch.object(services_mod.logger, "error") as log_err,
pytest.raises(HTTPException),
):
service.create_message(request)
blob = _flatten_log_calls(log_err)
assert secret not in blob
assert "RuntimeError" in blob
def test_create_message_unexpected_error_always_returns_500():
"""Non-provider failures must not leak arbitrary status_code attributes."""
class WeirdError(Exception):
status_code = 418
settings = Settings()
mock_provider = MagicMock()
def stream_boom(*_a, **_kw):
raise WeirdError("no")
mock_provider.stream_response = stream_boom
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
request = MessagesRequest(
model="claude-3-haiku-20240307",
max_tokens=10,
messages=[Message(role="user", content="hi")],
)
with pytest.raises(HTTPException) as excinfo:
service.create_message(request)
assert excinfo.value.status_code == 500
def test_parse_cli_event_error_logs_metadata_by_default():
"""CLI parser must not log raw error text unless LOG_RAW_CLI_DIAGNOSTICS is on."""
from messaging.event_parser import parse_cli_event
secret = "user-secret-parser-leak-xyz"
with patch("messaging.event_parser.logger.info") as log_info:
parse_cli_event(
{"type": "error", "error": {"message": secret}}, log_raw_cli=False
)
flat = " ".join(str(c) for c in log_info.call_args_list)
assert secret not in flat
assert "message_chars" in flat
def test_parse_cli_event_error_logs_text_when_log_raw_cli_enabled():
from messaging.event_parser import parse_cli_event
secret = "visible-cli-parser-msg"
with patch("messaging.event_parser.logger.info") as log_info:
parse_cli_event(
{"type": "error", "error": {"message": secret}}, log_raw_cli=True
)
flat = " ".join(str(c) for c in log_info.call_args_list)
assert secret in flat
def test_count_tokens_unexpected_error_default_logs_exclude_exception_text():
settings = Settings()
assert settings.log_api_error_tracebacks is False
secret = "count-tokens-leak-xyz"
def boom(*_a, **_kw):
raise ValueError(secret)
service = ClaudeProxyService(
settings,
provider_getter=lambda _: MagicMock(),
token_counter=boom,
)
from api.models.anthropic import TokenCountRequest
req = TokenCountRequest(
model="claude-3-haiku-20240307",
messages=[Message(role="user", content="x")],
)
with (
patch.object(services_mod.logger, "error") as log_err,
pytest.raises(HTTPException),
):
service.count_tokens(req)
blob = _flatten_log_calls(log_err)
assert secret not in blob
assert "ValueError" in blob

View file

@ -0,0 +1,33 @@
"""Tests for validation log summaries (metadata only)."""
from api.validation_log import summarize_request_validation_body
def test_summarize_lists_block_metadata_without_echoing_string_content():
body = {
"messages": [
{
"role": "user",
"content": "secret user phrase",
}
],
"tools": [{"name": "web_search", "type": "web_search_20250305"}],
}
summary, tool_names = summarize_request_validation_body(body)
assert summary == [
{
"role": "user",
"content_kind": "str",
"content_length": 18,
}
]
assert tool_names == ["web_search"]
blob = repr(summary) + repr(tool_names)
assert "secret" not in blob
def test_summarize_handles_non_dict_messages_and_missing_tools():
body = {"messages": ["not_a_dict"]}
summary, tool_names = summarize_request_validation_body(body)
assert summary == [{"message_kind": "str"}]
assert tool_names == []

Some files were not shown because too many files have changed in this diff Show more