From f3a7528d4976ca33d1d0515b28d8285a608915e5 Mon Sep 17 00:00:00 2001 From: Alishahryar1 Date: Sun, 26 Apr 2026 02:55:10 -0700 Subject: [PATCH] Major refactor: API, providers, messaging, and Anthropic protocol Consolidates the incremental refactor work into a single change set: modular web tools (api/web_tools), native Anthropic request building and SSE block policy, OpenAI conversion and error handling, provider transports and rate limiting, messaging handler and tree queue, safe logging, smoke tests, and broad test coverage. --- .env.example | 26 +- AGENTS.md | 3 +- PLAN.md | 19 +- api/__init__.py | 3 +- api/app.py | 85 +-- api/command_utils.py | 2 +- api/dependencies.py | 33 +- api/detection.py | 31 +- api/models/anthropic.py | 64 +- api/runtime.py | 99 ++- api/services.py | 132 +++- api/validation_log.py | 48 ++ api/web_server_tools.py | 341 +--------- api/web_tools/__init__.py | 17 + api/web_tools/constants.py | 15 + api/web_tools/egress.py | 99 +++ api/web_tools/outbound.py | 278 +++++++++ api/web_tools/parsers.py | 104 ++++ api/web_tools/request.py | 87 +++ api/web_tools/streaming.py | 206 ++++++ cli/entrypoints.py | 3 +- cli/manager.py | 18 +- cli/session.py | 90 ++- config/constants.py | 10 + config/logging_config.py | 39 +- config/nim.py | 16 +- config/provider_catalog.py | 108 ++++ config/provider_ids.py | 19 +- config/settings.py | 104 +++- core/anthropic/__init__.py | 21 +- core/anthropic/conversion.py | 93 ++- core/anthropic/emitted_sse_tracker.py | 97 +++ core/anthropic/errors.py | 31 +- core/anthropic/native_messages_request.py | 260 ++++++++ core/anthropic/native_sse_block_policy.py | 313 ++++++++++ core/anthropic/provider_stream_error.py | 34 + core/anthropic/server_tool_sse.py | 14 + core/anthropic/sse.py | 80 ++- core/anthropic/stream_contracts.py | 73 ++- core/anthropic/tokens.py | 21 + core/rate_limit.py | 60 ++ messaging/cli_event_constants.py | 67 ++ messaging/event_parser.py | 24 +- messaging/handler.py | 296 +++------ messaging/limiter.py | 124 ++-- messaging/node_event_pipeline.py | 103 +++ messaging/platforms/discord.py | 95 ++- messaging/platforms/factory.py | 12 + messaging/platforms/telegram.py | 99 ++- messaging/safe_diagnostics.py | 17 + messaging/session.py | 34 +- messaging/transcript.py | 10 +- messaging/transcription.py | 103 +-- messaging/trees/processor.py | 165 ----- messaging/trees/queue_manager.py | 358 ++++++++++- messaging/trees/repository.py | 186 ------ messaging/ui_updates.py | 101 +++ providers/anthropic_messages.py | 214 +++++-- providers/base.py | 46 +- providers/deepseek/__init__.py | 6 +- providers/deepseek/client.py | 4 +- providers/deepseek/request.py | 14 +- providers/defaults.py | 36 +- providers/error_mapping.py | 22 +- providers/exceptions.py | 14 +- providers/llamacpp/client.py | 4 +- providers/lmstudio/client.py | 4 +- providers/nvidia_nim/__init__.py | 6 +- providers/nvidia_nim/client.py | 4 +- providers/nvidia_nim/request.py | 83 +-- providers/nvidia_nim/voice.py | 95 +++ providers/ollama/__init__.py | 6 +- providers/ollama/client.py | 4 +- providers/open_router/__init__.py | 6 +- providers/open_router/client.py | 212 +------ providers/open_router/request.py | 161 +---- providers/openai_compat.py | 181 +++--- providers/rate_limit.py | 61 +- providers/registry.py | 112 +--- server.py | 6 +- smoke/capabilities.py | 2 +- smoke/lib/e2e.py | 25 +- smoke/lib/http.py | 3 +- smoke/lib/skips.py | 2 +- smoke/lib/sse.py | 29 - smoke/prereq/test_provider_prereq_live.py | 5 +- smoke/prereq/test_tools_prereq_live.py | 5 +- smoke/prereq/test_voice_prereq_live.py | 13 +- smoke/product/test_provider_product_live.py | 10 +- smoke/product/test_voice_product_live.py | 1 + tests/__init__.py | 1 + .../api/test_anthropic_request_passthrough.py | 184 ++++++ tests/api/test_api.py | 8 +- tests/api/test_app_lifespan_and_errors.py | 100 ++- tests/api/test_auth.py | 4 +- tests/api/test_dependencies.py | 17 +- tests/api/test_routes_optimizations.py | 15 +- tests/api/test_runtime_safe_logging.py | 70 +++ tests/api/test_safe_logging.py | 208 +++++++ tests/api/test_validation_log.py | 33 + tests/api/test_web_server_tools.py | 585 +++++++++++++++++- tests/cli/test_cli.py | 59 +- tests/config/test_config.py | 32 +- tests/config/test_logging_config.py | 37 ++ .../contracts/test_architecture_contracts.py | 19 + tests/contracts/test_import_boundaries.py | 43 +- tests/contracts/test_smoke_sse_reexport.py | 11 - tests/contracts/test_stream_contracts.py | 91 ++- .../anthropic/test_native_sse_block_policy.py | 133 ++++ tests/core/test_strict_sliding_window.py | 38 ++ tests/messaging/test_handler.py | 43 ++ tests/messaging/test_handler_format.py | 1 - .../test_handler_markdown_and_status_edges.py | 8 +- tests/messaging/test_limiter.py | 73 +-- tests/messaging/test_messaging_factory.py | 8 + .../test_session_store_edge_cases.py | 29 +- .../test_stream_transcript_contract.py | 62 ++ tests/messaging/test_transcription.py | 28 +- tests/messaging/test_transcription_nim.py | 39 ++ tests/messaging/test_tree_processor.py | 2 +- tests/messaging/test_tree_repository.py | 2 +- tests/provider_request_mocks.py | 25 + tests/providers/test_anthropic_messages.py | 100 ++- .../test_anthropic_messages_429_retry.py | 125 ++++ tests/providers/test_converter.py | 106 +++- tests/providers/test_deepseek.py | 105 +++- tests/providers/test_error_mapping.py | 65 +- tests/providers/test_llamacpp.py | 53 +- tests/providers/test_lmstudio.py | 78 ++- tests/providers/test_nim_request_clone.py | 40 ++ tests/providers/test_nvidia_nim.py | 23 +- tests/providers/test_ollama.py | 37 +- tests/providers/test_open_router.py | 264 +++++++- tests/providers/test_provider_rate_limit.py | 40 +- .../test_provider_transport_logging.py | 263 ++++++++ tests/providers/test_registry.py | 2 +- tests/providers/test_sse_builder.py | 29 +- tests/providers/test_streaming_errors.py | 140 ++++- tests/stream_contract.py | 18 + 139 files changed, 7460 insertions(+), 2422 deletions(-) create mode 100644 api/validation_log.py create mode 100644 api/web_tools/__init__.py create mode 100644 api/web_tools/constants.py create mode 100644 api/web_tools/egress.py create mode 100644 api/web_tools/outbound.py create mode 100644 api/web_tools/parsers.py create mode 100644 api/web_tools/request.py create mode 100644 api/web_tools/streaming.py create mode 100644 config/constants.py create mode 100644 config/provider_catalog.py create mode 100644 core/anthropic/emitted_sse_tracker.py create mode 100644 core/anthropic/native_messages_request.py create mode 100644 core/anthropic/native_sse_block_policy.py create mode 100644 core/anthropic/provider_stream_error.py create mode 100644 core/anthropic/server_tool_sse.py create mode 100644 core/rate_limit.py create mode 100644 messaging/cli_event_constants.py create mode 100644 messaging/node_event_pipeline.py create mode 100644 messaging/safe_diagnostics.py delete mode 100644 messaging/trees/processor.py delete mode 100644 messaging/trees/repository.py create mode 100644 messaging/ui_updates.py create mode 100644 providers/nvidia_nim/voice.py delete mode 100644 smoke/lib/sse.py create mode 100644 tests/__init__.py create mode 100644 tests/api/test_anthropic_request_passthrough.py create mode 100644 tests/api/test_runtime_safe_logging.py create mode 100644 tests/api/test_safe_logging.py create mode 100644 tests/api/test_validation_log.py delete mode 100644 tests/contracts/test_smoke_sse_reexport.py create mode 100644 tests/core/anthropic/test_native_sse_block_policy.py create mode 100644 tests/core/test_strict_sliding_window.py create mode 100644 tests/messaging/test_stream_transcript_contract.py create mode 100644 tests/messaging/test_transcription_nim.py create mode 100644 tests/provider_request_mocks.py create mode 100644 tests/providers/test_anthropic_messages_429_retry.py create mode 100644 tests/providers/test_nim_request_clone.py create mode 100644 tests/providers/test_provider_transport_logging.py create mode 100644 tests/stream_contract.py diff --git a/.env.example b/.env.example index 42d51a3..e9e1990 100644 --- a/.env.example +++ b/.env.example @@ -47,8 +47,8 @@ OPENROUTER_PROXY="" LMSTUDIO_PROXY="" LLAMACPP_PROXY="" -PROVIDER_RATE_LIMIT=40 -PROVIDER_RATE_WINDOW=60 +PROVIDER_RATE_LIMIT=1 +PROVIDER_RATE_WINDOW=3 PROVIDER_MAX_CONCURRENCY=5 @@ -102,3 +102,25 @@ ENABLE_NETWORK_PROBE_MOCK=true ENABLE_TITLE_GENERATION_SKIP=true ENABLE_SUGGESTION_MODE_SKIP=true ENABLE_FILEPATH_EXTRACTION_MOCK=true + + +# Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default) +ENABLE_WEB_SERVER_TOOLS=true +WEB_FETCH_ALLOWED_SCHEMES=http,https +WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false + + +# Verbose diagnostics (avoid logging raw prompts / SSE bodies in production) +DEBUG_PLATFORM_EDITS=false +DEBUG_SUBAGENT_STACK=false +# When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging). +LOG_RAW_API_PAYLOADS=false +LOG_RAW_SSE_EVENTS=false +# When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data). +LOG_API_ERROR_TRACEBACKS=false +# When true, log message/transcription text previews in messaging adapters (may leak user content). +LOG_RAW_MESSAGING_CONTENT=false +# When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text. +LOG_RAW_CLI_DIAGNOSTICS=false +# When true, log full exception and CLI error message strings in messaging (may leak user content). +LOG_MESSAGING_ERROR_DETAILS=false diff --git a/AGENTS.md b/AGENTS.md index f4335d2..265b3a8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,13 +32,14 @@ - **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code. - **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue. - **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required. +- **Maximum Test Coverage**: There should be maximum test coverage for everything, preferably live smoke test coverage to catch bugs early ## COGNITIVE WORKFLOW 1. **ANALYZE**: Read relevant files. Do not guess. 2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency. 3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits. -4. **VERIFY**: Run ci checks. Confirm the fix via logs or output. +4. **VERIFY**: Run ci checks and relevant smoke tests. Confirm the fix via logs or output. 5. **SPECIFICITY**: Do exactly as much as asked; nothing more, nothing less. 6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly. diff --git a/PLAN.md b/PLAN.md index 9f4607b..f5011f6 100644 --- a/PLAN.md +++ b/PLAN.md @@ -56,13 +56,20 @@ with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging **Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`, and `providers.registry` from the providers package (not per-adapter modules). `core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`. -`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract -assertions for default CI live under `core/anthropic/stream_contracts.py`; -`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers -(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and -unit tests; production HTTP handlers must use `resolve_provider` with +`messaging/` does not import `api`, `cli`, or `smoke`, and may import `providers` +only via `providers.nvidia_nim.voice` (NVIDIA/Riva offline ASR). Stream contract +helpers live in `core/anthropic/stream_contracts.py`; live smoke imports that +module directly (no dedicated smoke SSE shim). NVIDIA NIM chat tuning uses the +canonical `config.nim.NimSettings` model on `Settings`; `providers.registry` +passes `settings.nim` into `NvidiaNimProvider` without a duplicate schema. +Default upstream base URLs use a single constant per endpoint in +`providers/defaults.py` (e.g. `NVIDIA_NIM_DEFAULT_BASE`). Process-cached provider +helpers (`api.dependencies.get_provider` / `get_provider_for_type`) exist for +scripts and unit tests; production HTTP handlers must use `resolve_provider` with `request.app` so the app-scoped `ProviderRegistry` is used. The `api` package -`__all__` exposes HTTP models and `create_app` only (not those helpers). +`__all__` exposes HTTP models and `create_app` only (not `app`, not those helpers). +`api.app:create_app` is the ASGI factory (e.g. `uvicorn api.app:create_app --factory`); +`server.py` still exposes `server:app` as a module-level instance for convenience. ## Target Boundaries diff --git a/api/__init__.py b/api/__init__.py index 03094ba..5f6a6e4 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,6 +1,6 @@ """API layer for Claude Code Proxy.""" -from .app import app, create_app +from .app import create_app from .models import ( MessagesRequest, MessagesResponse, @@ -13,6 +13,5 @@ __all__ = [ "MessagesResponse", "TokenCountRequest", "TokenCountResponse", - "app", "create_app", ] diff --git a/api/app.py b/api/app.py index b47deb3..5b70133 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,6 @@ """FastAPI application factory and configuration.""" -import os +import traceback from contextlib import asynccontextmanager from typing import Any @@ -16,13 +16,7 @@ from providers.exceptions import ProviderError from .routes import router from .runtime import AppRuntime - -# Opt-in to future behavior for python-telegram-bot -os.environ["PTB_TIMEDELTA"] = "1" - -# Configure logging first (before any module logs) -_settings = get_settings() -configure_logging(_settings.log_file) +from .validation_log import summarize_request_validation_body @asynccontextmanager @@ -38,6 +32,11 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: """Create and configure the FastAPI application.""" + settings = get_settings() + configure_logging( + settings.log_file, verbose_third_party=settings.log_raw_api_payloads + ) + app = FastAPI( title="Claude Code Proxy", version="2.0.0", @@ -57,33 +56,7 @@ def create_app() -> FastAPI: except Exception as e: body = {"_json_error": type(e).__name__} - messages = body.get("messages") if isinstance(body, dict) else None - message_summary: list[dict[str, Any]] = [] - if isinstance(messages, list): - for msg in messages: - if not isinstance(msg, dict): - message_summary.append({"message_kind": type(msg).__name__}) - continue - content = msg.get("content") - item: dict[str, Any] = { - "role": msg.get("role"), - "content_kind": type(content).__name__, - } - if isinstance(content, list): - item["block_types"] = [ - block.get("type", "dict") - if isinstance(block, dict) - else type(block).__name__ - for block in content[:12] - ] - item["block_keys"] = [ - sorted(str(key) for key in block)[:12] - for block in content[:5] - if isinstance(block, dict) - ] - elif isinstance(content, str): - item["content_length"] = len(content) - message_summary.append(item) + message_summary, tool_names = summarize_request_validation_body(body) logger.debug( "Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}", @@ -92,20 +65,27 @@ def create_app() -> FastAPI: [list(error.get("loc", ())) for error in exc.errors()], [str(error.get("type", "")) for error in exc.errors()], message_summary, - [ - str(tool.get("name", "")) - for tool in body.get("tools", []) - if isinstance(body, dict) - and isinstance(body.get("tools"), list) - and isinstance(tool, dict) - ], + tool_names, ) return await request_validation_exception_handler(request, exc) @app.exception_handler(ProviderError) async def provider_error_handler(request: Request, exc: ProviderError): """Handle provider-specific errors and return Anthropic format.""" - logger.error(f"Provider Error: {exc.error_type} - {exc.message}") + err_settings = get_settings() + if err_settings.log_api_error_tracebacks: + logger.error( + "Provider Error: error_type={} status_code={} message={}", + exc.error_type, + exc.status_code, + exc.message, + ) + else: + logger.error( + "Provider Error: error_type={} status_code={}", + exc.error_type, + exc.status_code, + ) return JSONResponse( status_code=exc.status_code, content=exc.to_anthropic_format(), @@ -114,10 +94,17 @@ def create_app() -> FastAPI: @app.exception_handler(Exception) async def general_error_handler(request: Request, exc: Exception): """Handle general errors and return Anthropic format.""" - logger.error(f"General Error: {exc!s}") - import traceback - - logger.error(traceback.format_exc()) + settings = get_settings() + if settings.log_api_error_tracebacks: + logger.error("General Error: {}", exc) + logger.error(traceback.format_exc()) + else: + logger.error( + "General Error: path={} method={} exc_type={}", + request.url.path, + request.method, + type(exc).__name__, + ) return JSONResponse( status_code=500, content={ @@ -130,7 +117,3 @@ def create_app() -> FastAPI: ) return app - - -# Default app instance for uvicorn -app = create_app() diff --git a/api/command_utils.py b/api/command_utils.py index ea5251b..76963cd 100644 --- a/api/command_utils.py +++ b/api/command_utils.py @@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str: return "\n" - except Exception: + except ValueError: return "\n" diff --git a/api/dependencies.py b/api/dependencies.py index 3527eb8..c3e5743 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -8,7 +8,11 @@ from config.settings import Settings from config.settings import get_settings as _get_settings from core.anthropic import get_user_facing_error_message from providers.base import BaseProvider -from providers.exceptions import AuthenticationError, UnknownProviderTypeError +from providers.exceptions import ( + AuthenticationError, + ServiceUnavailableError, + UnknownProviderTypeError, +) from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry # Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider` @@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {} def get_settings() -> Settings: - """Get application settings via dependency injection.""" + """Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias).""" return _get_settings() @@ -31,10 +35,9 @@ def resolve_provider( """Resolve a provider using the app-scoped registry when ``app`` is set. When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry` - is always used. If the registry is missing (e.g. a test app without - :class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry` - is installed on ``app.state`` so the process cache is never mixed with - per-request app identity. + must exist (installed by :class:`~api.runtime.AppRuntime` during startup). + Callers that construct a bare ``FastAPI`` without lifespan must set + ``app.state.provider_registry`` explicitly. When ``app`` is ``None`` (no HTTP context), uses the process-level :data:`_providers` cache only. @@ -42,8 +45,10 @@ def resolve_provider( if app is not None: reg = getattr(app.state, "provider_registry", None) if reg is None: - reg = ProviderRegistry() - app.state.provider_registry = reg + raise ServiceUnavailableError( + "Provider registry is not configured. Ensure AppRuntime startup ran " + "or assign app.state.provider_registry for test apps." + ) return _resolve_with_registry(reg, provider_type, settings) return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings) @@ -55,9 +60,10 @@ def _resolve_with_registry( try: provider = registry.get(provider_type, settings) except AuthenticationError as e: - raise HTTPException( - status_code=503, detail=get_user_facing_error_message(e) - ) from e + # Provider :class:`~providers.exceptions.AuthenticationError` messages are + # curated configuration hints (env var names, docs links), not upstream noise. + detail = str(e).strip() or get_user_facing_error_message(e) + raise HTTPException(status_code=503, detail=detail) from e except UnknownProviderTypeError: logger.error( "Unknown provider_type: '{}'. Supported: {}", @@ -73,8 +79,9 @@ def _resolve_with_registry( def get_provider_for_type(provider_type: str) -> BaseProvider: """Get or create a provider in the process-level cache (no ``app``/Request). - For server requests, use :func:`resolve_provider` with the active - :attr:`request.app` so the app-scoped provider registry is used. + HTTP route handlers should call :func:`resolve_provider` with the active + :attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this + process-wide cache. """ return resolve_provider(provider_type, app=None, settings=get_settings()) diff --git a/api/detection.py b/api/detection.py index a299151..7977c1b 100644 --- a/api/detection.py +++ b/api/detection.py @@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st try: cmd_start = content.rfind("Command:") + len("Command:") return True, content[cmd_start:].strip() - except Exception: - pass + except TypeError: + return False, "" return False, "" @@ -121,19 +121,16 @@ def is_filepath_extraction_request( if not user_has_filepaths and not system_has_extract: return False, "", "" - try: - cmd_start = content.find("Command:") + len("Command:") - output_marker = content.find("Output:", cmd_start) - if output_marker == -1: - return False, "", "" - - command = content[cmd_start:output_marker].strip() - output = content[output_marker + len("Output:") :].strip() - - for marker in ["<", "\n\n"]: - if marker in output: - output = output.split(marker)[0].strip() - - return True, command, output - except Exception: + cmd_start = content.find("Command:") + len("Command:") + output_marker = content.find("Output:", cmd_start) + if output_marker == -1: return False, "", "" + + command = content[cmd_start:output_marker].strip() + output = content[output_marker + len("Output:") :].strip() + + for marker in ["<", "\n\n"]: + if marker in output: + output = output.split(marker)[0].strip() + + return True, command, output diff --git a/api/models/anthropic.py b/api/models/anthropic.py index 12b6533..5bda060 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -3,7 +3,7 @@ from enum import StrEnum from typing import Any, Literal -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field # ============================================================================= @@ -15,41 +15,68 @@ class Role(StrEnum): system = "system" -class ContentBlockText(BaseModel): +class _AnthropicBlockBase(BaseModel): + """Pass through provider fields (e.g. ``cache_control``) for native transports.""" + + model_config = ConfigDict(extra="allow") + + +class ContentBlockText(_AnthropicBlockBase): type: Literal["text"] text: str -class ContentBlockImage(BaseModel): +class ContentBlockImage(_AnthropicBlockBase): type: Literal["image"] source: dict[str, Any] -class ContentBlockToolUse(BaseModel): +class ContentBlockToolUse(_AnthropicBlockBase): type: Literal["tool_use"] id: str name: str input: dict[str, Any] -class ContentBlockToolResult(BaseModel): +class ContentBlockToolResult(_AnthropicBlockBase): type: Literal["tool_result"] tool_use_id: str content: str | list[Any] | dict[str, Any] -class ContentBlockThinking(BaseModel): +class ContentBlockThinking(_AnthropicBlockBase): type: Literal["thinking"] thinking: str signature: str | None = None -class ContentBlockRedactedThinking(BaseModel): +class ContentBlockRedactedThinking(_AnthropicBlockBase): type: Literal["redacted_thinking"] data: str -class SystemContent(BaseModel): +class ContentBlockServerToolUse(_AnthropicBlockBase): + """Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``).""" + + type: Literal["server_tool_use"] + id: str + name: str + input: dict[str, Any] + + +class ContentBlockWebSearchToolResult(_AnthropicBlockBase): + type: Literal["web_search_tool_result"] + tool_use_id: str + content: Any + + +class ContentBlockWebFetchToolResult(_AnthropicBlockBase): + type: Literal["web_fetch_tool_result"] + tool_use_id: str + content: Any + + +class SystemContent(_AnthropicBlockBase): type: Literal["text"] text: str @@ -68,12 +95,15 @@ class Message(BaseModel): | ContentBlockToolResult | ContentBlockThinking | ContentBlockRedactedThinking + | ContentBlockServerToolUse + | ContentBlockWebSearchToolResult + | ContentBlockWebFetchToolResult ] ) reasoning_content: str | None = None -class Tool(BaseModel): +class Tool(_AnthropicBlockBase): name: str # Anthropic server tools (e.g. web_search beta tools) include a ``type`` and # may omit ``input_schema`` because the provider owns the schema. @@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel): # Request Models # ============================================================================= class MessagesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + model: str + # Internal routing / debug: accepted on parse but not serialized to providers. + original_model: str | None = Field(default=None, exclude=True) + resolved_provider_model: str | None = Field(default=None, exclude=True) max_tokens: int | None = None messages: list[Message] system: str | list[SystemContent] | None = None @@ -105,13 +140,24 @@ class MessagesRequest(BaseModel): tools: list[Tool] | None = None tool_choice: dict[str, Any] | None = None thinking: ThinkingConfig | None = None + # Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion. + context_management: dict[str, Any] | None = None + output_config: dict[str, Any] | None = None + mcp_servers: list[dict[str, Any]] | None = None extra_body: dict[str, Any] | None = None class TokenCountRequest(BaseModel): + model_config = ConfigDict(extra="allow") + model: str + original_model: str | None = Field(default=None, exclude=True) + resolved_provider_model: str | None = Field(default=None, exclude=True) messages: list[Message] system: str | list[SystemContent] | None = None tools: list[Tool] | None = None thinking: ThinkingConfig | None = None tool_choice: dict[str, Any] | None = None + context_management: dict[str, Any] | None = None + output_config: dict[str, Any] | None = None + mcp_servers: list[dict[str, Any]] | None = None diff --git a/api/runtime.py b/api/runtime.py index 56cfe07..56df000 100644 --- a/api/runtime.py +++ b/api/runtime.py @@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0 async def best_effort( - name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S + name: str, + awaitable: Any, + timeout_s: float = _SHUTDOWN_TIMEOUT_S, + *, + log_verbose_errors: bool = False, ) -> None: """Run a shutdown step with timeout; never raise to callers.""" try: await asyncio.wait_for(awaitable, timeout=timeout_s) except TimeoutError: - logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)") + logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s) except Exception as e: - logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}") + if log_verbose_errors: + logger.warning( + "Shutdown step failed: {}: {}: {}", + name, + type(e).__name__, + e, + ) + else: + logger.warning( + "Shutdown step failed: {}: exc_type={}", + name, + type(e).__name__, + ) def warn_if_process_auth_token(settings: Settings) -> None: @@ -73,20 +89,37 @@ class AppRuntime: self._publish_state() async def shutdown(self) -> None: + verbose = self.settings.log_api_error_tracebacks if self.message_handler is not None: try: self.message_handler.session_store.flush_pending_save() except Exception as e: - logger.warning(f"Session store flush on shutdown: {e}") + if verbose: + logger.warning("Session store flush on shutdown: {}", e) + else: + logger.warning( + "Session store flush on shutdown: exc_type={}", + type(e).__name__, + ) logger.info("Shutdown requested, cleaning up...") if self.messaging_platform: - await best_effort("messaging_platform.stop", self.messaging_platform.stop()) + await best_effort( + "messaging_platform.stop", + self.messaging_platform.stop(), + log_verbose_errors=verbose, + ) if self.cli_manager: - await best_effort("cli_manager.stop_all", self.cli_manager.stop_all()) + await best_effort( + "cli_manager.stop_all", + self.cli_manager.stop_all(), + log_verbose_errors=verbose, + ) if self._provider_registry is not None: await best_effort( - "provider_registry.cleanup", self._provider_registry.cleanup() + "provider_registry.cleanup", + self._provider_registry.cleanup(), + log_verbose_errors=verbose, ) await self._shutdown_limiter() logger.info("Server shut down cleanly") @@ -110,6 +143,10 @@ class AppRuntime: whisper_device=self.settings.whisper_device, hf_token=self.settings.hf_token, nvidia_nim_api_key=self.settings.nvidia_nim_api_key, + messaging_rate_limit=self.settings.messaging_rate_limit, + messaging_rate_window=self.settings.messaging_rate_window, + log_raw_messaging_content=self.settings.log_raw_messaging_content, + log_api_error_tracebacks=self.settings.log_api_error_tracebacks, ), ) @@ -117,12 +154,24 @@ class AppRuntime: await self._start_message_handler() except ImportError as e: - logger.warning(f"Messaging module import error: {e}") + if self.settings.log_api_error_tracebacks: + logger.warning("Messaging module import error: {}", e) + else: + logger.warning( + "Messaging module import error: exc_type={}", + type(e).__name__, + ) except Exception as e: - logger.error(f"Failed to start messaging platform: {e}") - import traceback + if self.settings.log_api_error_tracebacks: + logger.error("Failed to start messaging platform: {}", e) + import traceback - logger.error(traceback.format_exc()) + logger.error(traceback.format_exc()) + else: + logger.error( + "Failed to start messaging platform: exc_type={}", + type(e).__name__, + ) async def _start_message_handler(self) -> None: from cli.manager import CLISessionManager @@ -151,10 +200,13 @@ class AppRuntime: allowed_dirs=allowed_dirs, plans_directory=plans_directory, claude_bin=self.settings.claude_cli_bin, + log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics, + log_messaging_error_details=self.settings.log_messaging_error_details, ) session_store = SessionStore( - storage_path=os.path.join(data_path, "sessions.json") + storage_path=os.path.join(data_path, "sessions.json"), + message_log_cap=self.settings.max_message_log_entries_per_chat, ) platform = self.messaging_platform assert platform is not None @@ -162,6 +214,11 @@ class AppRuntime: platform=platform, cli_manager=self.cli_manager, session_store=session_store, + debug_platform_edits=self.settings.debug_platform_edits, + debug_subagent_stack=self.settings.debug_subagent_stack, + log_raw_messaging_content=self.settings.log_raw_messaging_content, + log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics, + log_messaging_error_details=self.settings.log_messaging_error_details, ) self._restore_tree_state(session_store) @@ -201,18 +258,26 @@ class AppRuntime: self.app.state.cli_manager = self.cli_manager async def _shutdown_limiter(self) -> None: + verbose = self.settings.log_api_error_tracebacks try: from messaging.limiter import MessagingRateLimiter except Exception as e: - logger.debug( - "Rate limiter shutdown skipped (import failed): {}: {}", - type(e).__name__, - e, - ) + if verbose: + logger.debug( + "Rate limiter shutdown skipped (import failed): {}: {}", + type(e).__name__, + e, + ) + else: + logger.debug( + "Rate limiter shutdown skipped (import failed): exc_type={}", + type(e).__name__, + ) return await best_effort( "MessagingRateLimiter.shutdown_instance", MessagingRateLimiter.shutdown_instance(), timeout_s=2.0, + log_verbose_errors=verbose, ) diff --git a/api/services.py b/api/services.py index d0bbcaa..219287c 100644 --- a/api/services.py +++ b/api/services.py @@ -4,7 +4,7 @@ from __future__ import annotations import traceback import uuid -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from typing import Any from fastapi import HTTPException @@ -13,6 +13,7 @@ from loguru import logger from config.settings import Settings from core.anthropic import get_token_count, get_user_facing_error_message +from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS from providers.base import BaseProvider from providers.exceptions import InvalidRequestError, ProviderError @@ -20,15 +21,67 @@ from .model_router import ModelRouter from .models.anthropic import MessagesRequest, TokenCountRequest from .models.responses import TokenCountResponse from .optimization_handlers import try_optimizations -from .web_server_tools import ( +from .web_tools.egress import WebFetchEgressPolicy +from .web_tools.request import ( is_web_server_tool_request, - stream_web_server_tool_response, + openai_chat_upstream_server_tool_error, ) +from .web_tools.streaming import stream_web_server_tool_response TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int] ProviderGetter = Callable[[str], BaseProvider] +# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). +_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"}) + + +def anthropic_sse_streaming_response( + body: AsyncIterator[str], +) -> StreamingResponse: + """Return a :class:`StreamingResponse` for Anthropic-style SSE streams.""" + return StreamingResponse( + body, + media_type="text/event-stream", + headers=ANTHROPIC_SSE_RESPONSE_HEADERS, + ) + + +def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int: + """HTTP status for uncaught non-provider failures (stable client contract).""" + return 500 + + +def _log_unexpected_service_exception( + settings: Settings, + exc: BaseException, + *, + context: str, + request_id: str | None = None, +) -> None: + """Log service-layer failures without echoing exception text unless opted in.""" + if settings.log_api_error_tracebacks: + if request_id is not None: + logger.error("{} request_id={}: {}", context, request_id, exc) + else: + logger.error("{}: {}", context, exc) + logger.error(traceback.format_exc()) + return + if request_id is not None: + logger.error( + "{} request_id={} exc_type={}", + context, + request_id, + type(exc).__name__, + ) + else: + logger.error("{} exc_type={}", context, type(exc).__name__) + + +def _require_non_empty_messages(messages: list[Any]) -> None: + if not messages: + raise InvalidRequestError("messages cannot be empty") + class ClaudeProxyService: """Coordinate request optimization, model routing, token count, and providers.""" @@ -48,25 +101,35 @@ class ClaudeProxyService: def create_message(self, request_data: MessagesRequest) -> object: """Create a message response or streaming response.""" try: - if not request_data.messages: - raise InvalidRequestError("messages cannot be empty") + _require_non_empty_messages(request_data.messages) routed = self._model_router.resolve_messages_request(request_data) - if is_web_server_tool_request(routed.request): + if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS: + tool_err = openai_chat_upstream_server_tool_error( + routed.request, + web_tools_enabled=self._settings.enable_web_server_tools, + ) + if tool_err is not None: + raise InvalidRequestError(tool_err) + + if self._settings.enable_web_server_tools and is_web_server_tool_request( + routed.request + ): input_tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools ) logger.info("Optimization: Handling Anthropic web server tool") - return StreamingResponse( + egress = WebFetchEgressPolicy( + allow_private_network_targets=self._settings.web_fetch_allow_private_networks, + allowed_schemes=self._settings.web_fetch_allowed_scheme_set(), + ) + return anthropic_sse_streaming_response( stream_web_server_tool_response( - routed.request, input_tokens=input_tokens + routed.request, + input_tokens=input_tokens, + web_fetch_egress=egress, + verbose_client_errors=self._settings.log_api_error_tracebacks, ), - media_type="text/event-stream", - headers={ - "X-Accel-Buffering": "no", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, ) optimized = try_optimizations(routed.request, self._settings) @@ -75,6 +138,10 @@ class ClaudeProxyService: logger.debug("No optimization matched, routing to provider") provider = self._provider_getter(routed.resolved.provider_id) + provider.preflight_stream( + routed.request, + thinking_enabled=routed.resolved.thinking_enabled, + ) request_id = f"req_{uuid.uuid4().hex[:12]}" logger.info( @@ -83,34 +150,31 @@ class ClaudeProxyService: routed.request.model, len(routed.request.messages), ) - logger.debug( - "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump() - ) + if self._settings.log_raw_api_payloads: + logger.debug( + "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump() + ) input_tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools ) - return StreamingResponse( + return anthropic_sse_streaming_response( provider.stream_response( routed.request, input_tokens=input_tokens, request_id=request_id, thinking_enabled=routed.resolved.thinking_enabled, ), - media_type="text/event-stream", - headers={ - "X-Accel-Buffering": "no", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, ) except ProviderError: raise except Exception as e: - logger.error(f"Error: {e!s}\n{traceback.format_exc()}") + _log_unexpected_service_exception( + self._settings, e, context="CREATE_MESSAGE_ERROR" + ) raise HTTPException( - status_code=getattr(e, "status_code", 500), + status_code=_http_status_for_unexpected_service_exception(e), detail=get_user_facing_error_message(e), ) from e @@ -119,6 +183,7 @@ class ClaudeProxyService: request_id = f"req_{uuid.uuid4().hex[:12]}" with logger.contextualize(request_id=request_id): try: + _require_non_empty_messages(request_data.messages) routed = self._model_router.resolve_token_count_request(request_data) tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools @@ -131,13 +196,16 @@ class ClaudeProxyService: tokens, ) return TokenCountResponse(input_tokens=tokens) + except ProviderError: + raise except Exception as e: - logger.error( - "COUNT_TOKENS_ERROR: request_id={} error={}\n{}", - request_id, - get_user_facing_error_message(e), - traceback.format_exc(), + _log_unexpected_service_exception( + self._settings, + e, + context="COUNT_TOKENS_ERROR", + request_id=request_id, ) raise HTTPException( - status_code=500, detail=get_user_facing_error_message(e) + status_code=_http_status_for_unexpected_service_exception(e), + detail=get_user_facing_error_message(e), ) from e diff --git a/api/validation_log.py b/api/validation_log.py new file mode 100644 index 0000000..9ccdff0 --- /dev/null +++ b/api/validation_log.py @@ -0,0 +1,48 @@ +"""Safe metadata summaries for HTTP 422 validation logging (no raw text content).""" + +from __future__ import annotations + +from typing import Any + + +def summarize_request_validation_body( + body: Any, +) -> tuple[list[dict[str, Any]], list[str]]: + """Return message shape summary and tool name list for debug logs.""" + messages = body.get("messages") if isinstance(body, dict) else None + message_summary: list[dict[str, Any]] = [] + if isinstance(messages, list): + for msg in messages: + if not isinstance(msg, dict): + message_summary.append({"message_kind": type(msg).__name__}) + continue + content = msg.get("content") + item: dict[str, Any] = { + "role": msg.get("role"), + "content_kind": type(content).__name__, + } + if isinstance(content, list): + item["block_types"] = [ + block.get("type", "dict") + if isinstance(block, dict) + else type(block).__name__ + for block in content[:12] + ] + item["block_keys"] = [ + sorted(str(key) for key in block)[:12] + for block in content[:5] + if isinstance(block, dict) + ] + elif isinstance(content, str): + item["content_length"] = len(content) + message_summary.append(item) + + tool_names: list[str] = [] + if isinstance(body, dict) and isinstance(body.get("tools"), list): + tool_names = [ + str(tool.get("name", "")) + for tool in body["tools"] + if isinstance(tool, dict) + ] + + return message_summary, tool_names diff --git a/api/web_server_tools.py b/api/web_server_tools.py index 5daced4..cedaf95 100644 --- a/api/web_server_tools.py +++ b/api/web_server_tools.py @@ -1,331 +1,22 @@ -"""Local handlers for Anthropic web server tools. - -OpenAI-compatible upstreams can emit regular function calls, but Anthropic's -web tools are server-side: the API response itself must include the tool result. -""" +"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch).""" from __future__ import annotations -import html -import json -import re -import uuid -from collections.abc import AsyncIterator -from datetime import UTC, datetime -from html.parser import HTMLParser -from typing import Any -from urllib.parse import parse_qs, unquote, urlparse - import httpx -from .models.anthropic import MessagesRequest +from api.web_tools.egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + enforce_web_fetch_egress, +) +from api.web_tools.request import is_web_server_tool_request +from api.web_tools.streaming import stream_web_server_tool_response -_REQUEST_TIMEOUT_S = 20.0 -_MAX_SEARCH_RESULTS = 10 -_MAX_FETCH_CHARS = 24_000 - - -class _SearchResultParser(HTMLParser): - def __init__(self) -> None: - super().__init__() - self.results: list[dict[str, str]] = [] - self._href: str | None = None - self._title_parts: list[str] = [] - - def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: - if tag != "a": - return - href = dict(attrs).get("href") - if not href or "uddg=" not in href: - return - parsed = urlparse(href) - query = parse_qs(parsed.query) - uddg = query.get("uddg", [""])[0] - if not uddg: - return - self._href = unquote(uddg) - self._title_parts = [] - - def handle_data(self, data: str) -> None: - if self._href is not None: - self._title_parts.append(data) - - def handle_endtag(self, tag: str) -> None: - if tag != "a" or self._href is None: - return - title = " ".join("".join(self._title_parts).split()) - if title and not any(result["url"] == self._href for result in self.results): - self.results.append({"title": html.unescape(title), "url": self._href}) - self._href = None - self._title_parts = [] - - -class _HTMLTextParser(HTMLParser): - def __init__(self) -> None: - super().__init__() - self.title = "" - self.text_parts: list[str] = [] - self._in_title = False - self._skip_depth = 0 - - def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: - if tag in {"script", "style", "noscript"}: - self._skip_depth += 1 - elif tag == "title": - self._in_title = True - - def handle_endtag(self, tag: str) -> None: - if tag in {"script", "style", "noscript"} and self._skip_depth: - self._skip_depth -= 1 - elif tag == "title": - self._in_title = False - - def handle_data(self, data: str) -> None: - text = " ".join(data.split()) - if not text: - return - if self._in_title: - self.title = f"{self.title} {text}".strip() - elif not self._skip_depth: - self.text_parts.append(text) - - -def _format_event(event_type: str, data: dict[str, Any]) -> str: - return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - - -def _content_text(content: Any) -> str: - if isinstance(content, str): - return content - if isinstance(content, list): - parts = [] - for item in content: - if isinstance(item, dict): - parts.append(str(item.get("text", ""))) - else: - parts.append(str(getattr(item, "text", ""))) - return "\n".join(part for part in parts if part) - return str(content) - - -def _request_text(request: MessagesRequest) -> str: - return "\n".join(_content_text(message.content) for message in request.messages) - - -def _web_tool_name(request: MessagesRequest) -> str | None: - for tool in request.tools or []: - name = tool.name - tool_type = tool.type or "" - if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"): - return name - return None - - -def is_web_server_tool_request(request: MessagesRequest) -> bool: - return _web_tool_name(request) in {"web_search", "web_fetch"} - - -def _extract_query(text: str) -> str: - match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL) - if match: - return match.group(1).strip().strip("\"'") - return text.strip() - - -def _extract_url(text: str) -> str: - match = re.search(r"https?://\S+", text) - return match.group(0).rstrip(").,]") if match else text.strip() - - -async def _run_web_search(query: str) -> list[dict[str, str]]: - async with httpx.AsyncClient( - timeout=_REQUEST_TIMEOUT_S, - follow_redirects=True, - headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"}, - ) as client: - response = await client.get( - "https://lite.duckduckgo.com/lite/", - params={"q": query}, - ) - response.raise_for_status() - - parser = _SearchResultParser() - parser.feed(response.text) - return parser.results[:_MAX_SEARCH_RESULTS] - - -async def _run_web_fetch(url: str) -> dict[str, str]: - async with httpx.AsyncClient( - timeout=_REQUEST_TIMEOUT_S, - follow_redirects=True, - headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"}, - ) as client: - response = await client.get(url) - response.raise_for_status() - - content_type = response.headers.get("content-type", "text/plain") - title = url - data = response.text - if "html" in content_type.lower(): - parser = _HTMLTextParser() - parser.feed(response.text) - title = parser.title or url - data = "\n".join(parser.text_parts) - return { - "url": str(response.url), - "title": title, - "media_type": "text/plain", - "data": data[:_MAX_FETCH_CHARS], - } - - -def _search_summary(query: str, results: list[dict[str, str]]) -> str: - if not results: - return f"No web search results found for: {query}" - lines = [f"Search results for: {query}"] - for index, result in enumerate(results, start=1): - lines.append(f"{index}. {result['title']}\n{result['url']}") - return "\n\n".join(lines) - - -async def stream_web_server_tool_response( - request: MessagesRequest, input_tokens: int -) -> AsyncIterator[str]: - tool_name = _web_tool_name(request) - if tool_name is None: - return - - text = _request_text(request) - message_id = f"msg_{uuid.uuid4()}" - tool_id = f"srvtoolu_{uuid.uuid4().hex}" - output_tokens = 1 - usage_key = ( - "web_search_requests" if tool_name == "web_search" else "web_fetch_requests" - ) - tool_input = ( - {"query": _extract_query(text)} - if tool_name == "web_search" - else {"url": _extract_url(text)} - ) - - yield _format_event( - "message_start", - { - "type": "message_start", - "message": { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [], - "model": request.model, - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": input_tokens, "output_tokens": 1}, - }, - }, - ) - yield _format_event( - "content_block_start", - { - "type": "content_block_start", - "index": 0, - "content_block": { - "type": "server_tool_use", - "id": tool_id, - "name": tool_name, - "input": tool_input, - }, - }, - ) - yield _format_event( - "content_block_stop", {"type": "content_block_stop", "index": 0} - ) - - try: - if tool_name == "web_search": - query = str(tool_input["query"]) - results = await _run_web_search(query) - result_content: Any = [ - { - "type": "web_search_result", - "title": result["title"], - "url": result["url"], - } - for result in results - ] - summary = _search_summary(query, results) - result_block_type = "web_search_tool_result" - else: - fetched = await _run_web_fetch(str(tool_input["url"])) - result_content = { - "type": "web_fetch_result", - "url": fetched["url"], - "content": { - "type": "document", - "source": { - "type": "text", - "media_type": fetched["media_type"], - "data": fetched["data"], - }, - "title": fetched["title"], - "citations": {"enabled": True}, - }, - "retrieved_at": datetime.now(UTC).isoformat(), - } - summary = fetched["data"][:_MAX_FETCH_CHARS] - result_block_type = "web_fetch_tool_result" - except Exception as error: - result_block_type = ( - "web_search_tool_result" - if tool_name == "web_search" - else "web_fetch_tool_result" - ) - error_type = ( - "web_search_tool_result_error" - if tool_name == "web_search" - else "web_fetch_tool_error" - ) - result_content = {"type": error_type, "error_code": "unavailable"} - summary = f"{tool_name} failed: {type(error).__name__}" - - output_tokens = max(1, len(summary) // 4) - - yield _format_event( - "content_block_start", - { - "type": "content_block_start", - "index": 1, - "content_block": { - "type": result_block_type, - "tool_use_id": tool_id, - "content": result_content, - }, - }, - ) - yield _format_event( - "content_block_stop", {"type": "content_block_stop", "index": 1} - ) - yield _format_event( - "content_block_start", - { - "type": "content_block_start", - "index": 2, - "content_block": {"type": "text", "text": summary}, - }, - ) - yield _format_event( - "content_block_stop", {"type": "content_block_stop", "index": 2} - ) - yield _format_event( - "message_delta", - { - "type": "message_delta", - "delta": {"stop_reason": "end_turn", "stop_sequence": None}, - "usage": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "server_tool_use": {usage_key: 1}, - }, - }, - ) - yield _format_event("message_stop", {"type": "message_stop"}) +__all__ = [ + "WebFetchEgressPolicy", + "WebFetchEgressViolation", + "enforce_web_fetch_egress", + "httpx", + "is_web_server_tool_request", + "stream_web_server_tool_response", +] diff --git a/api/web_tools/__init__.py b/api/web_tools/__init__.py new file mode 100644 index 0000000..e0fd14c --- /dev/null +++ b/api/web_tools/__init__.py @@ -0,0 +1,17 @@ +"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming).""" + +from .egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + enforce_web_fetch_egress, +) +from .request import is_web_server_tool_request +from .streaming import stream_web_server_tool_response + +__all__ = [ + "WebFetchEgressPolicy", + "WebFetchEgressViolation", + "enforce_web_fetch_egress", + "is_web_server_tool_request", + "stream_web_server_tool_response", +] diff --git a/api/web_tools/constants.py b/api/web_tools/constants.py new file mode 100644 index 0000000..e7b2c01 --- /dev/null +++ b/api/web_tools/constants.py @@ -0,0 +1,15 @@ +"""Limits and defaults for outbound web server tool HTTP.""" + +_REQUEST_TIMEOUT_S = 20.0 +_MAX_SEARCH_RESULTS = 10 +_MAX_FETCH_CHARS = 24_000 +# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound). +_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024 +# Drain at most this many bytes from redirect responses before following Location. +_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536 +_MAX_WEB_FETCH_REDIRECTS = 10 +_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + +_WEB_TOOL_HTTP_HEADERS = { + "User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0", +} diff --git a/api/web_tools/egress.py b/api/web_tools/egress.py new file mode 100644 index 0000000..30b29d5 --- /dev/null +++ b/api/web_tools/egress.py @@ -0,0 +1,99 @@ +"""Egress policy for user-controlled web_fetch URLs (SSRF guard).""" + +from __future__ import annotations + +import ipaddress +import socket +from dataclasses import dataclass +from urllib.parse import urlparse + + +@dataclass(frozen=True, slots=True) +class WebFetchEgressPolicy: + """Egress rules for user-influenced web_fetch URLs.""" + + allow_private_network_targets: bool + allowed_schemes: frozenset[str] + + +class WebFetchEgressViolation(ValueError): + """Raised when a web_fetch URL is rejected by egress policy (SSRF guard).""" + + +def _port_for_url(parsed) -> int: + if parsed.port is not None: + return parsed.port + return 443 if (parsed.scheme or "").lower() == "https" else 80 + + +def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]: + try: + return socket.getaddrinfo( + host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP + ) + except OSError as exc: + raise WebFetchEgressViolation( + f"Could not resolve host {host!r}: {exc}" + ) from exc + + +def get_validated_stream_addrinfos_for_egress( + url: str, policy: WebFetchEgressPolicy +) -> list[tuple]: + """Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning. + + Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS + server cannot rebind to a disallowed address between resolution and the TCP + connect (used by :func:`api.web_tools.outbound._run_web_fetch`). + """ + parsed = urlparse(url) + scheme = (parsed.scheme or "").lower() + if scheme not in policy.allowed_schemes: + raise WebFetchEgressViolation( + f"URL scheme {scheme!r} is not allowed for web_fetch" + ) + + host = parsed.hostname + if host is None or host == "": + raise WebFetchEgressViolation("web_fetch URL must include a host") + + port = _port_for_url(parsed) + + if policy.allow_private_network_targets: + return _stream_getaddrinfo_or_raise(host, port) + + host_lower = host.lower() + if host_lower == "localhost" or host_lower.endswith(".localhost"): + raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch") + if host_lower.endswith(".local"): + raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch") + + try: + parsed_ip = ipaddress.ip_address(host) + except ValueError: + parsed_ip = None + + if parsed_ip is not None: + if not parsed_ip.is_global: + raise WebFetchEgressViolation( + f"Non-public IP host {host!r} is not allowed for web_fetch" + ) + return _stream_getaddrinfo_or_raise(host, port) + + infos = _stream_getaddrinfo_or_raise(host, port) + for *_, sockaddr in infos: + addr = sockaddr[0] + try: + resolved = ipaddress.ip_address(addr) + except ValueError: + continue + if not resolved.is_global: + raise WebFetchEgressViolation( + f"Host {host!r} resolves to a non-public address ({resolved})" + ) + return infos + + +def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None: + """Validate ``url`` (scheme, host, and resolved addresses) for web_fetch.""" + get_validated_stream_addrinfos_for_egress(url, policy) diff --git a/api/web_tools/outbound.py b/api/web_tools/outbound.py new file mode 100644 index 0000000..4f93f9b --- /dev/null +++ b/api/web_tools/outbound.py @@ -0,0 +1,278 @@ +"""Outbound HTTP for web_search / web_fetch (client, body caps, logging).""" + +from __future__ import annotations + +import asyncio +import socket +from collections.abc import AsyncIterator +from urllib.parse import urljoin, urlparse + +import aiohttp +import httpx +from aiohttp import ClientSession, ClientTimeout, TCPConnector +from aiohttp.abc import AbstractResolver, ResolveResult +from loguru import logger + +from . import constants +from .constants import ( + _MAX_FETCH_CHARS, + _MAX_SEARCH_RESULTS, + _REDIRECT_RESPONSE_BODY_CAP_BYTES, + _REQUEST_TIMEOUT_S, + _WEB_FETCH_REDIRECT_STATUSES, + _WEB_TOOL_HTTP_HEADERS, +) +from .egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + get_validated_stream_addrinfos_for_egress, +) +from .parsers import HTMLTextParser, SearchResultParser + + +def _safe_public_host_for_logs(url: str) -> str: + host = urlparse(url).hostname or "" + return host[:253] + + +def _log_web_tool_failure( + tool_name: str, + error: BaseException, + *, + fetch_url: str | None = None, +) -> None: + exc_type = type(error).__name__ + if isinstance(error, WebFetchEgressViolation): + host = _safe_public_host_for_logs(fetch_url) if fetch_url else "" + logger.warning( + "web_tool_egress_rejected tool={} exc_type={} host={!r}", + tool_name, + exc_type, + host, + ) + return + if tool_name == "web_fetch" and fetch_url: + logger.warning( + "web_tool_failure tool={} exc_type={} host={!r}", + tool_name, + exc_type, + _safe_public_host_for_logs(fetch_url), + ) + else: + logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type) + + +def _web_tool_client_error_summary( + tool_name: str, + error: BaseException, + *, + verbose: bool, +) -> str: + if verbose: + return f"{tool_name} failed: {type(error).__name__}" + return "Web tool request failed." + + +async def _iter_response_body_under_cap( + response: httpx.Response, max_bytes: int +) -> AsyncIterator[bytes]: + if max_bytes <= 0: + return + received = 0 + async for chunk in response.aiter_bytes(chunk_size=65_536): + if received >= max_bytes: + break + remaining = max_bytes - received + if len(chunk) <= remaining: + received += len(chunk) + yield chunk + if received >= max_bytes: + break + else: + yield chunk[:remaining] + break + + +async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None: + async for _ in _iter_response_body_under_cap(response, max_bytes): + pass + + +async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes: + return b"".join( + [piece async for piece in _iter_response_body_under_cap(response, max_bytes)] + ) + + +_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV +_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV + + +def getaddrinfo_rows_to_resolve_results( + host: str, addrinfos: list[tuple] +) -> list[ResolveResult]: + """Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic).""" + out: list[ResolveResult] = [] + for family, _type, proto, _canon, sockaddr in addrinfos: + if family == socket.AF_INET6: + if len(sockaddr) < 3: + continue + if sockaddr[3]: + resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS) + else: + resolved_host, port = sockaddr[:2] + else: + assert family == socket.AF_INET, family + resolved_host, port = sockaddr[0], sockaddr[1] + resolved_host = str(resolved_host) + port = int(port) + out.append( + ResolveResult( + hostname=host, + host=resolved_host, + port=int(port), + family=family, + proto=proto, + flags=_NUMERIC_RESOLVE_FLAGS, + ) + ) + return out + + +class _PinnedEgressStaticResolver(AbstractResolver): + """Return only pre-validated :class:`ResolveResult` for the outbound request.""" + + def __init__(self, results: list[ResolveResult]) -> None: + self._results = results + + async def resolve( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> list[ResolveResult]: + return self._results + + async def close(self) -> None: # pragma: no cover - aiohttp contract + return + + +async def _read_aiohttp_body_capped( + response: aiohttp.ClientResponse, max_bytes: int +) -> bytes: + received = 0 + parts: list[bytes] = [] + async for chunk in response.content.iter_chunked(65_536): + if received >= max_bytes: + break + remaining = max_bytes - received + if len(chunk) <= remaining: + received += len(chunk) + parts.append(chunk) + else: + parts.append(chunk[:remaining]) + break + return b"".join(parts) + + +async def _drain_aiohttp_body_capped( + response: aiohttp.ClientResponse, max_bytes: int +) -> None: + if max_bytes <= 0: + return + received = 0 + async for chunk in response.content.iter_chunked(65_536): + received += len(chunk) + if received >= max_bytes: + break + + +async def _run_web_search(query: str) -> list[dict[str, str]]: + async with ( + httpx.AsyncClient( + timeout=_REQUEST_TIMEOUT_S, + follow_redirects=True, + headers=_WEB_TOOL_HTTP_HEADERS, + ) as client, + client.stream( + "GET", + "https://lite.duckduckgo.com/lite/", + params={"q": query}, + ) as response, + ): + response.raise_for_status() + body_bytes = await _read_response_body_capped( + response, constants._MAX_WEB_FETCH_RESPONSE_BYTES + ) + text = body_bytes.decode("utf-8", errors="replace") + parser = SearchResultParser() + parser.feed(text) + return parser.results[:_MAX_SEARCH_RESULTS] + + +async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]: + """Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses.""" + current_url = url + redirect_hops = 0 + timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S) + + while True: + addr_infos = await asyncio.to_thread( + get_validated_stream_addrinfos_for_egress, current_url, egress + ) + host = urlparse(current_url).hostname or "" + results = getaddrinfo_rows_to_resolve_results(host, addr_infos) + resolver = _PinnedEgressStaticResolver(results) + connector = TCPConnector( + resolver=resolver, + force_close=True, + ) + try: + async with ( + ClientSession( + timeout=timeout, + headers=_WEB_TOOL_HTTP_HEADERS, + connector=connector, + ) as session, + session.get(current_url, allow_redirects=False) as response, + ): + if response.status in _WEB_FETCH_REDIRECT_STATUSES: + await _drain_aiohttp_body_capped( + response, _REDIRECT_RESPONSE_BODY_CAP_BYTES + ) + if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS: + raise WebFetchEgressViolation( + "web_fetch exceeded maximum redirects " + f"({constants._MAX_WEB_FETCH_REDIRECTS})" + ) + location = response.headers.get("location") + if not location or not location.strip(): + raise WebFetchEgressViolation( + "web_fetch redirect response missing Location header" + ) + current_url = urljoin(str(response.url), location.strip()) + redirect_hops += 1 + continue + response.raise_for_status() + content_type = response.headers.get("content-type", "text/plain") + final_url = str(response.url) + encoding = response.get_encoding() or "utf-8" + body_bytes = await _read_aiohttp_body_capped( + response, constants._MAX_WEB_FETCH_RESPONSE_BYTES + ) + finally: + await connector.close() + + break + + text = body_bytes.decode(encoding, errors="replace") + title = final_url + data = text + if "html" in content_type.lower(): + parser = HTMLTextParser() + parser.feed(text) + title = parser.title or final_url + data = "\n".join(parser.text_parts) + return { + "url": final_url, + "title": title, + "media_type": "text/plain", + "data": data[:_MAX_FETCH_CHARS], + } diff --git a/api/web_tools/parsers.py b/api/web_tools/parsers.py new file mode 100644 index 0000000..198b412 --- /dev/null +++ b/api/web_tools/parsers.py @@ -0,0 +1,104 @@ +"""HTML parsing for web_search / web_fetch.""" + +from __future__ import annotations + +import html +import re +from html.parser import HTMLParser +from typing import Any +from urllib.parse import parse_qs, unquote, urlparse + + +class SearchResultParser(HTMLParser): + """DuckDuckGo lite HTML: extract result links and titles.""" + + def __init__(self) -> None: + super().__init__() + self.results: list[dict[str, str]] = [] + self._href: str | None = None + self._title_parts: list[str] = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag != "a": + return + href = dict(attrs).get("href") + if not href or "uddg=" not in href: + return + parsed = urlparse(href) + query = parse_qs(parsed.query) + uddg = query.get("uddg", [""])[0] + if not uddg: + return + self._href = unquote(uddg) + self._title_parts = [] + + def handle_data(self, data: str) -> None: + if self._href is not None: + self._title_parts.append(data) + + def handle_endtag(self, tag: str) -> None: + if tag != "a" or self._href is None: + return + title = " ".join("".join(self._title_parts).split()) + if title and not any(result["url"] == self._href for result in self.results): + self.results.append({"title": html.unescape(title), "url": self._href}) + self._href = None + self._title_parts = [] + + +class HTMLTextParser(HTMLParser): + """Strip scripts/styles and collect visible text + title for fetch previews.""" + + def __init__(self) -> None: + super().__init__() + self.title = "" + self.text_parts: list[str] = [] + self._in_title = False + self._skip_depth = 0 + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag in {"script", "style", "noscript"}: + self._skip_depth += 1 + elif tag == "title": + self._in_title = True + + def handle_endtag(self, tag: str) -> None: + if tag in {"script", "style", "noscript"} and self._skip_depth: + self._skip_depth -= 1 + elif tag == "title": + self._in_title = False + + def handle_data(self, data: str) -> None: + text = " ".join(data.split()) + if not text: + return + if self._in_title: + self.title = f"{self.title} {text}".strip() + elif not self._skip_depth: + self.text_parts.append(text) + + +def content_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text", ""))) + else: + parts.append(str(getattr(item, "text", ""))) + return "\n".join(part for part in parts if part) + return str(content) + + +def extract_query(text: str) -> str: + match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip().strip("\"'") + return text.strip() + + +def extract_url(text: str) -> str: + match = re.search(r"https?://\S+", text) + return match.group(0).rstrip(").,]") if match else text.strip() diff --git a/api/web_tools/request.py b/api/web_tools/request.py new file mode 100644 index 0000000..b794685 --- /dev/null +++ b/api/web_tools/request.py @@ -0,0 +1,87 @@ +"""Detect forced Anthropic web server tool requests.""" + +from __future__ import annotations + +from api.models.anthropic import MessagesRequest, Tool + + +def request_text(request: MessagesRequest) -> str: + """Join all user/assistant message content into one string for tool input parsing.""" + from .parsers import content_text + + return "\n".join(content_text(message.content) for message in request.messages) + + +def forced_tool_turn_text(request: MessagesRequest) -> str: + """Text for parsing forced server-tool inputs: latest user turn only (avoids stale history).""" + if not request.messages: + return "" + + from .parsers import content_text + + for message in reversed(request.messages): + if message.role == "user": + return content_text(message.content) + return "" + + +def forced_server_tool_name(request: MessagesRequest) -> str | None: + """Return web_search or web_fetch only when tool_choice forces that server tool.""" + tc = request.tool_choice + if not isinstance(tc, dict): + return None + if tc.get("type") != "tool": + return None + name = tc.get("name") + if name in {"web_search", "web_fetch"}: + return str(name) + return None + + +def has_tool_named(request: MessagesRequest, name: str) -> bool: + return any(tool.name == name for tool in request.tools or []) + + +def is_web_server_tool_request(request: MessagesRequest) -> bool: + """True when the client forces a web server tool via tool_choice (not merely listed).""" + forced = forced_server_tool_name(request) + if forced is None: + return False + return has_tool_named(request, forced) + + +def is_anthropic_server_tool_definition(tool: Tool) -> bool: + """Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family).""" + name = (tool.name or "").strip() + if name in ("web_search", "web_fetch"): + return True + typ = tool.type + if isinstance(typ, str): + return typ.startswith("web_search") or typ.startswith("web_fetch") + return False + + +def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool: + """True when tools include web_search / web_fetch-style entries (listed, forced or not).""" + return any(is_anthropic_server_tool_definition(t) for t in (request.tools or [])) + + +def openai_chat_upstream_server_tool_error( + request: MessagesRequest, *, web_tools_enabled: bool +) -> str | None: + """Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics.""" + forced = forced_server_tool_name(request) + if forced and not web_tools_enabled: + return ( + f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are " + "disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic " + "Messages transport (e.g. open_router, ollama, lmstudio)." + ) + if not forced and has_listed_anthropic_server_tools(request): + return ( + "OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools " + "(web_search / web_fetch) without the local web server tool handler. Use a native " + "Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with " + "tool_choice, or remove these tools from the request." + ) + return None diff --git a/api/web_tools/streaming.py b/api/web_tools/streaming.py new file mode 100644 index 0000000..eb9b414 --- /dev/null +++ b/api/web_tools/streaming.py @@ -0,0 +1,206 @@ +"""SSE streaming for local web_search / web_fetch server tool results.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncIterator +from datetime import UTC, datetime +from typing import Any + +from api.models.anthropic import MessagesRequest +from core.anthropic.server_tool_sse import ( + SERVER_TOOL_USE, + WEB_FETCH_TOOL_ERROR, + WEB_FETCH_TOOL_RESULT, + WEB_SEARCH_TOOL_RESULT, + WEB_SEARCH_TOOL_RESULT_ERROR, +) +from core.anthropic.sse import format_sse_event + +from . import outbound +from .constants import _MAX_FETCH_CHARS +from .egress import WebFetchEgressPolicy +from .parsers import extract_query, extract_url +from .request import ( + forced_server_tool_name, + forced_tool_turn_text, + has_tool_named, +) + + +def _search_summary(query: str, results: list[dict[str, str]]) -> str: + if not results: + return f"No web search results found for: {query}" + lines = [f"Search results for: {query}"] + for index, result in enumerate(results, start=1): + lines.append(f"{index}. {result['title']}\n{result['url']}") + return "\n\n".join(lines) + + +async def stream_web_server_tool_response( + request: MessagesRequest, + input_tokens: int, + *, + web_fetch_egress: WebFetchEgressPolicy, + verbose_client_errors: bool = False, +) -> AsyncIterator[str]: + """Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback). + + When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full + hosted Anthropic citation or encrypted-content pipeline. + """ + tool_name = forced_server_tool_name(request) + if tool_name is None or not has_tool_named(request, tool_name): + return + + text = forced_tool_turn_text(request) + message_id = f"msg_{uuid.uuid4()}" + tool_id = f"srvtoolu_{uuid.uuid4().hex}" + usage_key = ( + "web_search_requests" if tool_name == "web_search" else "web_fetch_requests" + ) + tool_input = ( + {"query": extract_query(text)} + if tool_name == "web_search" + else {"url": extract_url(text)} + ) + _result_block_for_tool = { + "web_search": WEB_SEARCH_TOOL_RESULT, + "web_fetch": WEB_FETCH_TOOL_RESULT, + } + _error_payload_type_for_tool = { + "web_search": WEB_SEARCH_TOOL_RESULT_ERROR, + "web_fetch": WEB_FETCH_TOOL_ERROR, + } + + yield format_sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [], + "model": request.model, + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": 1}, + }, + }, + ) + yield format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": SERVER_TOOL_USE, + "id": tool_id, + "name": tool_name, + "input": tool_input, + }, + }, + ) + yield format_sse_event( + "content_block_stop", {"type": "content_block_stop", "index": 0} + ) + + try: + if tool_name == "web_search": + query = str(tool_input["query"]) + results = await outbound._run_web_search(query) + result_content: Any = [ + { + "type": "web_search_result", + "title": result["title"], + "url": result["url"], + } + for result in results + ] + summary = _search_summary(query, results) + result_block_type = WEB_SEARCH_TOOL_RESULT + else: + fetched = await outbound._run_web_fetch( + str(tool_input["url"]), web_fetch_egress + ) + result_content = { + "type": "web_fetch_result", + "url": fetched["url"], + "content": { + "type": "document", + "source": { + "type": "text", + "media_type": fetched["media_type"], + "data": fetched["data"], + }, + "title": fetched["title"], + "citations": {"enabled": True}, + }, + "retrieved_at": datetime.now(UTC).isoformat(), + } + summary = fetched["data"][:_MAX_FETCH_CHARS] + result_block_type = WEB_FETCH_TOOL_RESULT + except Exception as error: + fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None + outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url) + result_block_type = _result_block_for_tool[tool_name] + result_content = { + "type": _error_payload_type_for_tool[tool_name], + "error_code": "unavailable", + } + summary = outbound._web_tool_client_error_summary( + tool_name, error, verbose=verbose_client_errors + ) + + output_tokens = max(1, len(summary) // 4) + + yield format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": result_block_type, + "tool_use_id": tool_id, + "content": result_content, + }, + }, + ) + yield format_sse_event( + "content_block_stop", {"type": "content_block_stop", "index": 1} + ) + # Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`, + # not eager `text` on `content_block_start`). + yield format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 2, + "content_block": {"type": "text", "text": ""}, + }, + ) + yield format_sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": 2, + "delta": {"type": "text_delta", "text": summary}, + }, + ) + yield format_sse_event( + "content_block_stop", {"type": "content_block_stop", "index": 2} + ) + yield format_sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": None}, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "server_tool_use": {usage_key: 1}, + }, + }, + ) + yield format_sse_event("message_stop", {"type": "message_stop"}) diff --git a/cli/entrypoints.py b/cli/entrypoints.py index ba61a9e..c161de9 100644 --- a/cli/entrypoints.py +++ b/cli/entrypoints.py @@ -30,7 +30,8 @@ def serve() -> None: settings = get_settings() try: uvicorn.run( - "api.app:app", + "api.app:create_app", + factory=True, host=settings.host, port=settings.port, log_level="debug", diff --git a/cli/manager.py b/cli/manager.py index b267633..cb27419 100644 --- a/cli/manager.py +++ b/cli/manager.py @@ -29,6 +29,9 @@ class CLISessionManager: allowed_dirs: list[str] | None = None, plans_directory: str | None = None, claude_bin: str = "claude", + *, + log_raw_cli_diagnostics: bool = False, + log_messaging_error_details: bool = False, ): """ Initialize the session manager. @@ -44,6 +47,8 @@ class CLISessionManager: self.allowed_dirs = allowed_dirs or [] self.plans_directory = plans_directory self.claude_bin = claude_bin + self._log_raw_cli_diagnostics = log_raw_cli_diagnostics + self._log_messaging_error_details = log_messaging_error_details self._sessions: dict[str, CLISession] = {} self._pending_sessions: dict[str, CLISession] = {} @@ -79,6 +84,7 @@ class CLISessionManager: allowed_dirs=self.allowed_dirs, plans_directory=self.plans_directory, claude_bin=self.claude_bin, + log_raw_cli_diagnostics=self._log_raw_cli_diagnostics, ) self._pending_sessions[temp_id] = new_session logger.info(f"Created new session: {temp_id}") @@ -130,7 +136,17 @@ class CLISessionManager: try: await session.stop() except Exception as e: - logger.error(f"Error stopping session: {e}") + if self._log_messaging_error_details: + logger.error( + "Error stopping session: {}: {}", + type(e).__name__, + e, + ) + else: + logger.error( + "Error stopping session: exc_type={}", + type(e).__name__, + ) self._sessions.clear() self._pending_sessions.clear() diff --git a/cli/session.py b/cli/session.py index db18f6c..d007975 100644 --- a/cli/session.py +++ b/cli/session.py @@ -11,6 +11,9 @@ from loguru import logger from .process_registry import register_pid, unregister_pid +# Cap stderr capture so a runaway child cannot exhaust memory; pipe is still drained. +_MAX_STDERR_CAPTURE_BYTES = 256 * 1024 + @dataclass(frozen=True, slots=True) class ClaudeCliConfig: @@ -33,6 +36,8 @@ class CLISession: allowed_dirs: list[str] | None = None, plans_directory: str | None = None, claude_bin: str = "claude", + *, + log_raw_cli_diagnostics: bool = False, ): self.config = ClaudeCliConfig( workspace_path=os.path.normpath(os.path.abspath(workspace_path)), @@ -46,11 +51,40 @@ class CLISession: self.allowed_dirs = self.config.allowed_dirs self.plans_directory = self.config.plans_directory self.claude_bin = self.config.claude_bin + self._log_raw_cli_diagnostics = log_raw_cli_diagnostics self.process: asyncio.subprocess.Process | None = None self.current_session_id: str | None = None self._is_busy = False self._cli_lock = asyncio.Lock() + @staticmethod + async def _drain_stderr_bounded( + process: asyncio.subprocess.Process, + *, + max_bytes: int = _MAX_STDERR_CAPTURE_BYTES, + ) -> bytes: + """Read stderr concurrently with stdout to avoid subprocess pipe deadlocks. + + Retains at most ``max_bytes`` for logging; any excess is discarded, but + the pipe is read until EOF so a noisy child cannot fill the buffer and + block forever. + """ + if not process.stderr: + return b"" + parts: list[bytes] = [] + received = 0 + while True: + chunk = await process.stderr.read(65_536) + if not chunk: + break + if received < max_bytes: + take = min(len(chunk), max_bytes - received) + if take: + parts.append(chunk[:take]) + received += take + # If already at cap, keep reading and discarding until EOF. + return b"".join(parts) + @property def is_busy(self) -> bool: """Check if a task is currently running.""" @@ -140,6 +174,11 @@ class CLISession: session_id_extracted = False buffer = bytearray() + stderr_task: asyncio.Task[bytes] | None = None + if self.process.stderr: + stderr_task = asyncio.create_task( + self._drain_stderr_bounded(self.process) + ) try: while True: @@ -179,23 +218,27 @@ class CLISession: except asyncio.CancelledError: # Cancelling the handler task should not leave a Claude CLI # subprocess running in the background. - try: - await asyncio.shield(self.stop()) - finally: - raise + await asyncio.shield(self.stop()) + raise + finally: + stderr_bytes = b"" + if stderr_task is not None: + stderr_bytes = await stderr_task stderr_text = None - if self.process.stderr: - stderr_output = await self.process.stderr.read() - if stderr_output: - stderr_text = stderr_output.decode( - "utf-8", errors="replace" - ).strip() - logger.error(f"Claude CLI Stderr: {stderr_text}") - # Yield stderr as error event so it shows in UI - if stderr_text: - logger.info("CLI_SESSION: Yielding error event from stderr") - yield {"type": "error", "error": {"message": stderr_text}} + if stderr_bytes: + stderr_text = stderr_bytes.decode("utf-8", errors="replace").strip() + if stderr_text: + if self._log_raw_cli_diagnostics: + logger.error("Claude CLI stderr: {}", stderr_text) + else: + logger.error( + "Claude CLI stderr: bytes={} text_chars={}", + len(stderr_bytes), + len(stderr_text), + ) + logger.info("CLI_SESSION: Yielding error event from stderr") + yield {"type": "error", "error": {"message": stderr_text}} return_code = await self.process.wait() logger.info( @@ -230,7 +273,10 @@ class CLISession: yield event except json.JSONDecodeError: - logger.debug(f"Non-JSON output: {line_str}") + if self._log_raw_cli_diagnostics: + logger.debug("Non-JSON output: {}", line_str) + else: + logger.debug("Non-JSON CLI line: char_len={}", len(line_str)) yield {"type": "raw", "content": line_str} def _extract_session_id(self, event: Any) -> str | None: @@ -273,6 +319,16 @@ class CLISession: unregister_pid(self.process.pid) return True except Exception as e: - logger.error(f"Error stopping process: {e}") + if self._log_raw_cli_diagnostics: + logger.error( + "Error stopping process: {}: {}", + type(e).__name__, + e, + ) + else: + logger.error( + "Error stopping process: exc_type={}", + type(e).__name__, + ) return False return False diff --git a/config/constants.py b/config/constants.py new file mode 100644 index 0000000..987027a --- /dev/null +++ b/config/constants.py @@ -0,0 +1,10 @@ +"""Shared defaults used by config models and provider adapters.""" + +# HTTP client connect timeout (seconds). Keep aligned with README.md and .env.example. +HTTP_CONNECT_TIMEOUT_DEFAULT = 10.0 + +# Anthropic Messages API default when the client omits max_tokens. +ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS = 81920 + +# Max bytes read from a non-200 native messages response when verbose error logging is on. +NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES = 4096 diff --git a/config/logging_config.py b/config/logging_config.py index e1a31e4..7c5d8fd 100644 --- a/config/logging_config.py +++ b/config/logging_config.py @@ -8,6 +8,7 @@ included at top level for easy grep/filter. import json import logging +import re from pathlib import Path from loguru import logger @@ -17,6 +18,22 @@ _configured = False # Context keys we promote to top-level JSON for traceability _CONTEXT_KEYS = ("request_id", "node_id", "chat_id") +_TELEGRAM_BOT_RE = re.compile( + r"(https?://api\.telegram\.org/)bot([0-9]+:[A-Za-z0-9_-]+)(/?)", + re.IGNORECASE, +) +# Authorization: Bearer (HTTP client / proxy debug lines) +_AUTH_BEARER_RE = re.compile( + r"(\bAuthorization\s*:\s*Bearer\s+)([^\s'\"]+)", + re.IGNORECASE, +) + + +def _redact_sensitive_substrings(message: str) -> str: + """Remove obvious API tokens and secrets before JSON log line emission.""" + text = _TELEGRAM_BOT_RE.sub(r"\1bot\3", message) + return _AUTH_BEARER_RE.sub(r"\1", text) + def _serialize_with_context(record) -> str: """Format record as JSON with context vars at top level. @@ -26,7 +43,7 @@ def _serialize_with_context(record) -> str: out = { "time": str(record["time"]), "level": record["level"].name, - "message": record["message"], + "message": _redact_sensitive_substrings(str(record["message"])), "module": record["name"], "function": record["function"], "line": record["line"], @@ -57,11 +74,16 @@ class InterceptHandler(logging.Handler): ) -def configure_logging(log_file: str, *, force: bool = False) -> None: +def configure_logging( + log_file: str, *, force: bool = False, verbose_third_party: bool = False +) -> None: """Configure loguru with JSON output to log_file and intercept stdlib logging. Idempotent: skips if already configured (e.g. hot reload). Use force=True to reconfigure (e.g. in tests with a different log path). + + When ``verbose_third_party`` is false, noisy HTTP and Telegram loggers are capped + at WARNING unless explicitly configured otherwise. """ global _configured if _configured and not force: @@ -88,3 +110,16 @@ def configure_logging(log_file: str, *, force: bool = False) -> None: intercept = InterceptHandler() logging.root.handlers = [intercept] logging.root.setLevel(logging.DEBUG) + + third_party = ( + "httpx", + "httpcore", + "httpcore.http11", + "httpcore.connection", + "telegram", + "telegram.ext", + ) + for name in third_party: + logging.getLogger(name).setLevel( + logging.WARNING if not verbose_third_party else logging.NOTSET + ) diff --git a/config/nim.py b/config/nim.py index b3a9c3b..5bf7761 100644 --- a/config/nim.py +++ b/config/nim.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS + class NimSettings(BaseModel): """Fixed NVIDIA NIM settings (not configurable via env).""" @@ -14,7 +16,9 @@ class NimSettings(BaseModel): ) top_k: int = -1 max_tokens: int = Field( - 81920, ge=1, description="Maximum number of tokens in output." + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + ge=1, + description="Maximum number of tokens in output.", ) presence_penalty: float = Field(0.0, ge=-2.0, le=2.0) frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0) @@ -68,7 +72,7 @@ class NimSettings(BaseModel): return field_defaults.get(key, 1.0) try: val = float(v) - except Exception as err: + except (TypeError, ValueError) as err: raise ValueError( f"{info.field_name} must be a float. Got {type(v).__name__}." ) from err @@ -78,15 +82,15 @@ class NimSettings(BaseModel): @classmethod def validate_int_fields(cls, v, info: ValidationInfo): field_defaults = { - "max_tokens": 81920, + "max_tokens": ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, "min_tokens": 0, } if v is None or v == "": key = info.field_name or "max_tokens" - return field_defaults.get(key, 81920) + return field_defaults.get(key, ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS) try: val = int(v) - except Exception as err: + except (TypeError, ValueError) as err: raise ValueError( f"{info.field_name} must be an int. Got {type(v).__name__}." ) from err @@ -99,7 +103,7 @@ class NimSettings(BaseModel): return None try: return int(v) - except Exception as err: + except (TypeError, ValueError) as err: raise ValueError( f"{info.field_name} must be an int or empty/None." ) from err diff --git a/config/provider_catalog.py b/config/provider_catalog.py new file mode 100644 index 0000000..b0b731e --- /dev/null +++ b/config/provider_catalog.py @@ -0,0 +1,108 @@ +"""Neutral provider catalog: IDs, credentials, defaults, proxy and capability metadata. + +Adapter factories live in :mod:`providers.registry`; this module stays free of +provider implementation imports (see contract tests). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +TransportType = Literal["openai_chat", "anthropic_messages"] + +# Default upstream base URLs (also re-exported via :mod:`providers.defaults`) +NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1" +DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com" +OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1" +LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1" +LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" +OLLAMA_DEFAULT_BASE = "http://localhost:11434" + + +@dataclass(frozen=True, slots=True) +class ProviderDescriptor: + """Metadata for building :class:`~providers.base.ProviderConfig` and factory wiring.""" + + provider_id: str + transport_type: TransportType + capabilities: tuple[str, ...] + credential_env: str | None = None + credential_url: str | None = None + credential_attr: str | None = None + static_credential: str | None = None + default_base_url: str | None = None + base_url_attr: str | None = None + proxy_attr: str | None = None + + +PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { + "nvidia_nim": ProviderDescriptor( + provider_id="nvidia_nim", + transport_type="openai_chat", + credential_env="NVIDIA_NIM_API_KEY", + credential_url="https://build.nvidia.com/settings/api-keys", + credential_attr="nvidia_nim_api_key", + default_base_url=NVIDIA_NIM_DEFAULT_BASE, + proxy_attr="nvidia_nim_proxy", + capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), + ), + "open_router": ProviderDescriptor( + provider_id="open_router", + transport_type="anthropic_messages", + credential_env="OPENROUTER_API_KEY", + credential_url="https://openrouter.ai/keys", + credential_attr="open_router_api_key", + default_base_url=OPENROUTER_DEFAULT_BASE, + proxy_attr="open_router_proxy", + capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"), + ), + "deepseek": ProviderDescriptor( + provider_id="deepseek", + transport_type="openai_chat", + credential_env="DEEPSEEK_API_KEY", + credential_url="https://platform.deepseek.com/api_keys", + credential_attr="deepseek_api_key", + default_base_url=DEEPSEEK_DEFAULT_BASE, + capabilities=("chat", "streaming", "thinking"), + ), + "lmstudio": ProviderDescriptor( + provider_id="lmstudio", + transport_type="anthropic_messages", + static_credential="lm-studio", + default_base_url=LMSTUDIO_DEFAULT_BASE, + base_url_attr="lm_studio_base_url", + proxy_attr="lmstudio_proxy", + capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), + ), + "llamacpp": ProviderDescriptor( + provider_id="llamacpp", + transport_type="anthropic_messages", + static_credential="llamacpp", + default_base_url=LLAMACPP_DEFAULT_BASE, + base_url_attr="llamacpp_base_url", + proxy_attr="llamacpp_proxy", + capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), + ), + "ollama": ProviderDescriptor( + provider_id="ollama", + transport_type="anthropic_messages", + static_credential="ollama", + default_base_url=OLLAMA_DEFAULT_BASE, + base_url_attr="ollama_base_url", + capabilities=( + "chat", + "streaming", + "tools", + "thinking", + "native_anthropic", + "local", + ), + ), +} + +# Order matches docs / historical error text; must match PROVIDER_CATALOG keys. +SUPPORTED_PROVIDER_IDS: tuple[str, ...] = tuple(PROVIDER_CATALOG.keys()) + +if len(set(SUPPORTED_PROVIDER_IDS)) != len(SUPPORTED_PROVIDER_IDS): + raise AssertionError("Duplicate provider ids in PROVIDER_CATALOG key order") diff --git a/config/provider_ids.py b/config/provider_ids.py index e3e50bf..fd08ab7 100644 --- a/config/provider_ids.py +++ b/config/provider_ids.py @@ -1,18 +1,7 @@ -"""Canonical list of model provider type prefixes (provider_id values). - -`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this -module holds the id set for config validation and must stay in sync -(registries assert in `providers.registry`). -""" +"""Canonical provider id tuple (re-exported from the provider catalog).""" from __future__ import annotations -# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys. -SUPPORTED_PROVIDER_IDS: tuple[str, ...] = ( - "nvidia_nim", - "open_router", - "deepseek", - "lmstudio", - "llamacpp", - "ollama", -) +from .provider_catalog import SUPPORTED_PROVIDER_IDS + +__all__ = ("SUPPORTED_PROVIDER_IDS",) diff --git a/config/settings.py b/config/settings.py index f7129fb..c4afb1b 100644 --- a/config/settings.py +++ b/config/settings.py @@ -10,6 +10,7 @@ from dotenv import dotenv_values from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict +from .constants import HTTP_CONNECT_TIMEOUT_DEFAULT from .nim import NimSettings from .provider_ids import SUPPORTED_PROVIDER_IDS @@ -105,6 +106,12 @@ class Settings(BaseSettings): messaging_platform: str = Field( default="discord", validation_alias="MESSAGING_PLATFORM" ) + messaging_rate_limit: int = Field( + default=1, validation_alias="MESSAGING_RATE_LIMIT" + ) + messaging_rate_window: float = Field( + default=1.0, validation_alias="MESSAGING_RATE_WINDOW" + ) # ==================== NVIDIA NIM Config ==================== nvidia_nim_api_key: str = "" @@ -173,7 +180,8 @@ class Settings(BaseSettings): default=10.0, validation_alias="HTTP_WRITE_TIMEOUT" ) http_connect_timeout: float = Field( - default=2.0, validation_alias="HTTP_CONNECT_TIMEOUT" + default=HTTP_CONNECT_TIMEOUT_DEFAULT, + validation_alias="HTTP_CONNECT_TIMEOUT", ) # ==================== Fast Prefix Detection ==================== @@ -185,6 +193,51 @@ class Settings(BaseSettings): enable_suggestion_mode_skip: bool = True enable_filepath_extraction_mock: bool = True + # ==================== Local web server tools (web_search / web_fetch) ==================== + # Off by default: these tools perform outbound HTTP from the proxy (SSRF risk). + enable_web_server_tools: bool = Field( + default=False, validation_alias="ENABLE_WEB_SERVER_TOOLS" + ) + # Comma-separated URL schemes allowed for web_fetch (default: http,https). + web_fetch_allowed_schemes: str = Field( + default="http,https", validation_alias="WEB_FETCH_ALLOWED_SCHEMES" + ) + # When true, skip private/loopback/link-local IP blocking for web_fetch (lab only). + web_fetch_allow_private_networks: bool = Field( + default=False, validation_alias="WEB_FETCH_ALLOW_PRIVATE_NETWORKS" + ) + + # ==================== Debug / diagnostic logging (avoid sensitive content) ==================== + # When false (default), API and SSE helpers log only metadata (counts, lengths, ids). + log_raw_api_payloads: bool = Field( + default=False, validation_alias="LOG_RAW_API_PAYLOADS" + ) + log_raw_sse_events: bool = Field( + default=False, validation_alias="LOG_RAW_SSE_EVENTS" + ) + # When false (default), unhandled exceptions log only type + route metadata (no message/traceback). + log_api_error_tracebacks: bool = Field( + default=False, validation_alias="LOG_API_ERROR_TRACEBACKS" + ) + # When false (default), messaging logs omit text/transcription previews (metadata only). + log_raw_messaging_content: bool = Field( + default=False, validation_alias="LOG_RAW_MESSAGING_CONTENT" + ) + # When true, log full Claude CLI stderr, non-JSON lines, and parser error text. + log_raw_cli_diagnostics: bool = Field( + default=False, validation_alias="LOG_RAW_CLI_DIAGNOSTICS" + ) + # When true, log exception text / CLI error strings in messaging (may leak user content). + log_messaging_error_details: bool = Field( + default=False, validation_alias="LOG_MESSAGING_ERROR_DETAILS" + ) + debug_platform_edits: bool = Field( + default=False, validation_alias="DEBUG_PLATFORM_EDITS" + ) + debug_subagent_stack: bool = Field( + default=False, validation_alias="DEBUG_SUBAGENT_STACK" + ) + # ==================== NIM Settings ==================== nim: NimSettings = Field(default_factory=NimSettings) @@ -215,6 +268,9 @@ class Settings(BaseSettings): claude_workspace: str = "./agent_workspace" allowed_dir: str = "" claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN") + max_message_log_entries_per_chat: int | None = Field( + default=None, validation_alias="MAX_MESSAGE_LOG_ENTRIES_PER_CHAT" + ) # ==================== Server ==================== host: str = "0.0.0.0" @@ -254,6 +310,13 @@ class Settings(BaseSettings): return None return v + @field_validator("max_message_log_entries_per_chat", mode="before") + @classmethod + def parse_optional_log_cap(cls, v: Any) -> Any: + if v == "" or v is None: + return None + return v + @field_validator("whisper_device") @classmethod def validate_whisper_device(cls, v: str) -> str: @@ -272,6 +335,33 @@ class Settings(BaseSettings): ) return v + @field_validator("messaging_rate_limit") + @classmethod + def validate_messaging_rate_limit(cls, v: int) -> int: + if v <= 0: + raise ValueError("messaging_rate_limit must be > 0") + return v + + @field_validator("messaging_rate_window") + @classmethod + def validate_messaging_rate_window(cls, v: float) -> float: + if v <= 0: + raise ValueError("messaging_rate_window must be > 0") + return float(v) + + @field_validator("web_fetch_allowed_schemes") + @classmethod + def validate_web_fetch_allowed_schemes(cls, v: str) -> str: + schemes = [part.strip().lower() for part in v.split(",") if part.strip()] + if not schemes: + raise ValueError("web_fetch_allowed_schemes must list at least one scheme") + for scheme in schemes: + if not scheme.isascii() or not scheme.isalpha(): + raise ValueError( + f"Invalid URL scheme in web_fetch_allowed_schemes: {scheme!r}" + ) + return ",".join(schemes) + @field_validator("ollama_base_url") @classmethod def validate_ollama_base_url(cls, v: str) -> str: @@ -329,12 +419,12 @@ class Settings(BaseSettings): @property def provider_type(self) -> str: """Extract provider type from the default model string.""" - return self.model.split("/", 1)[0] + return Settings.parse_provider_type(self.model) @property def model_name(self) -> str: """Extract the actual model name from the default model string.""" - return self.model.split("/", 1)[1] + return Settings.parse_model_name(self.model) def resolve_model(self, claude_model_name: str) -> str: """Resolve a Claude model name to the configured provider/model string. @@ -362,6 +452,14 @@ class Settings(BaseSettings): return self.enable_sonnet_thinking return self.enable_model_thinking + def web_fetch_allowed_scheme_set(self) -> frozenset[str]: + """Return normalized schemes allowed for web_fetch.""" + return frozenset( + part.strip().lower() + for part in self.web_fetch_allowed_schemes.split(",") + if part.strip() + ) + @staticmethod def parse_provider_type(model_string: str) -> str: """Extract provider type from any 'provider/model' string.""" diff --git a/core/anthropic/__init__.py b/core/anthropic/__init__.py index 9af5987..eaea8a1 100644 --- a/core/anthropic/__init__.py +++ b/core/anthropic/__init__.py @@ -1,9 +1,19 @@ """Anthropic protocol helpers shared across API, providers, and integrations.""" from .content import extract_text_from_content, get_block_attr, get_block_type -from .conversion import AnthropicToOpenAIConverter, build_base_request_body -from .errors import append_request_id, get_user_facing_error_message -from .sse import ContentBlockManager, SSEBuilder, map_stop_reason +from .conversion import ( + AnthropicToOpenAIConverter, + OpenAIConversionError, + build_base_request_body, +) +from .errors import ( + append_request_id, + format_user_error_preview, + get_user_facing_error_message, +) +from .native_messages_request import sanitize_native_messages_thinking_policy +from .provider_stream_error import iter_provider_stream_error_sse_events +from .sse import ContentBlockManager, SSEBuilder, format_sse_event, map_stop_reason from .thinking import ContentChunk, ContentType, ThinkTagParser from .tokens import get_token_count from .tools import HeuristicToolParser @@ -15,15 +25,20 @@ __all__ = [ "ContentChunk", "ContentType", "HeuristicToolParser", + "OpenAIConversionError", "SSEBuilder", "ThinkTagParser", "append_request_id", "build_base_request_body", "extract_text_from_content", + "format_sse_event", + "format_user_error_preview", "get_block_attr", "get_block_type", "get_token_count", "get_user_facing_error_message", + "iter_provider_stream_error_sse_events", "map_stop_reason", + "sanitize_native_messages_thinking_policy", "set_if_not_none", ] diff --git a/core/anthropic/conversion.py b/core/anthropic/conversion.py index 5979ba1..c484ceb 100644 --- a/core/anthropic/conversion.py +++ b/core/anthropic/conversion.py @@ -3,10 +3,34 @@ import json from typing import Any +from pydantic import BaseModel + from .content import get_block_attr, get_block_type from .utils import set_if_not_none +class OpenAIConversionError(Exception): + """Raised when Anthropic content cannot be converted to OpenAI chat without data loss.""" + + +def _openai_reject_native_only_top_level_fields(request_data: Any) -> None: + """OpenAI chat providers may only convert known top-level request fields. + + First-class model fields (e.g. ``context_management``) are not forwarded to + the OpenAI API but are allowed so clients do not hit spurious 400s. + Unknown extra keys (``__pydantic_extra__``) are still rejected. + """ + if not isinstance(request_data, BaseModel): + return + extra = getattr(request_data, "__pydantic_extra__", None) + if not extra: + return + raise OpenAIConversionError( + "OpenAI chat conversion does not support these top-level request fields: " + f"{sorted(str(k) for k in extra)}. Use a native Anthropic transport provider." + ) + + def _tool_name(tool: Any) -> str: return str(getattr(tool, "name", "") or "") @@ -18,6 +42,27 @@ def _tool_input_schema(tool: Any) -> dict[str, Any]: return {"type": "object", "properties": {}} +def _serialize_tool_result_content(tool_content: Any) -> str: + """Serialize tool_result content for OpenAI ``role: tool`` messages (stable JSON for structured values).""" + if tool_content is None: + return "" + if isinstance(tool_content, str): + return tool_content + if isinstance(tool_content, dict): + return json.dumps(tool_content, ensure_ascii=False) + if isinstance(tool_content, list): + parts: list[str] = [] + for item in tool_content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, dict): + parts.append(json.dumps(item, ensure_ascii=False)) + else: + parts.append(str(item)) + return "\n".join(parts) + return str(tool_content) + + class AnthropicToOpenAIConverter: """Convert Anthropic message format to OpenAI-compatible format.""" @@ -64,20 +109,38 @@ class AnthropicToOpenAIConverter: content_parts: list[str] = [] thinking_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] + seen_tool_use = False for block in content: block_type = get_block_type(block) if block_type == "text": + if seen_tool_use: + raise OpenAIConversionError( + "OpenAI chat conversion does not support assistant text after " + "tool_use in the same message; split the transcript or use a " + "native Anthropic provider." + ) content_parts.append(get_block_attr(block, "text", "")) elif block_type == "thinking": if not include_thinking: continue + if seen_tool_use: + raise OpenAIConversionError( + "OpenAI chat conversion does not support assistant thinking after " + "tool_use in the same message; split the transcript or use a " + "native Anthropic provider." + ) thinking = get_block_attr(block, "thinking", "") content_parts.append(f"\n{thinking}\n") if include_reasoning_content: thinking_parts.append(thinking) + elif block_type == "redacted_thinking": + # Opaque provider continuation data; do not materialize as model-visible text + # or reasoning_content for OpenAI chat upstreams. + continue elif block_type == "tool_use": + seen_tool_use = True tool_input = get_block_attr(block, "input", {}) tool_calls.append( { @@ -91,6 +154,19 @@ class AnthropicToOpenAIConverter: }, } ) + elif block_type == "image": + raise OpenAIConversionError( + "Assistant image blocks are not supported for OpenAI chat conversion." + ) + elif block_type in ( + "server_tool_use", + "web_search_tool_result", + "web_fetch_tool_result", + ): + raise OpenAIConversionError( + "OpenAI chat conversion does not support Anthropic server tool blocks " + f"({block_type!r} in an assistant message). Use a native Anthropic transport provider." + ) content_str = "\n\n".join(content_parts) if not content_str and not tool_calls: @@ -122,21 +198,21 @@ class AnthropicToOpenAIConverter: if block_type == "text": text_parts.append(get_block_attr(block, "text", "")) + elif block_type == "image": + raise OpenAIConversionError( + "User message image blocks are not supported for OpenAI chat " + "conversion; use a vision-capable native Anthropic provider or " + "extend the converter." + ) elif block_type == "tool_result": flush_text() tool_content = get_block_attr(block, "content", "") - if isinstance(tool_content, list): - tool_content = "\n".join( - item.get("text", str(item)) - if isinstance(item, dict) - else str(item) - for item in tool_content - ) + serialized = _serialize_tool_result_content(tool_content) result.append( { "role": "tool", "tool_call_id": get_block_attr(block, "tool_use_id"), - "content": str(tool_content) if tool_content else "", + "content": serialized if serialized else "", } ) @@ -199,6 +275,7 @@ def build_base_request_body( include_reasoning_content: bool = False, ) -> dict[str, Any]: """Build the common parts of an OpenAI-format request body.""" + _openai_reject_native_only_top_level_fields(request_data) messages = AnthropicToOpenAIConverter.convert_messages( request_data.messages, include_thinking=include_thinking, diff --git a/core/anthropic/emitted_sse_tracker.py b/core/anthropic/emitted_sse_tracker.py new file mode 100644 index 0000000..079dd51 --- /dev/null +++ b/core/anthropic/emitted_sse_tracker.py @@ -0,0 +1,97 @@ +"""Track content-block state for native Anthropic SSE strings we emit to clients.""" + +from __future__ import annotations + +import uuid +from collections.abc import Iterator +from contextlib import suppress +from typing import Any + +from core.anthropic.sse import SSEBuilder, format_sse_event +from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines + + +class EmittedNativeSseTracker: + """Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index.""" + + def __init__(self) -> None: + self._buf = "" + self._open_stack: list[int] = [] + self._max_index = -1 + self.message_id: str | None = None + self.model: str = "" + + def feed(self, chunk: str) -> None: + """Record SSE frames completed by ``chunk`` (handles splitting across reads).""" + self._buf += chunk + while True: + sep = self._buf.find("\n\n") + if sep < 0: + break + frame = self._buf[:sep] + self._buf = self._buf[sep + 2 :] + if not frame.strip(): + continue + for event in parse_sse_lines(frame.splitlines()): + self._observe(event) + + def _observe(self, event: SSEEvent) -> None: + if event.event == "message_start": + message = event.data.get("message") + if isinstance(message, dict): + mid = message.get("id") + if isinstance(mid, str) and mid: + self.message_id = mid + model = message.get("model") + if isinstance(model, str) and model: + self.model = model + return + + if event.event == "content_block_start": + idx = event_index(event) + self._max_index = max(self._max_index, idx) + self._open_stack.append(idx) + return + + if event.event == "content_block_stop": + idx = event_index(event) + if self._open_stack and self._open_stack[-1] == idx: + self._open_stack.pop() + else: + with suppress(ValueError): + self._open_stack.remove(idx) + + def next_content_index(self) -> int: + """Next unused content block index based on emitted starts.""" + return self._max_index + 1 + + def iter_close_unclosed_blocks(self) -> Iterator[str]: + """Yield ``content_block_stop`` events for blocks that were started but not stopped.""" + while self._open_stack: + idx = self._open_stack.pop() + yield format_sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": idx}, + ) + + def iter_midstream_error_tail( + self, + error_message: str, + *, + request: Any, + input_tokens: int, + log_raw_sse_events: bool, + ) -> Iterator[str]: + """Close dangling blocks, emit a text error block at a fresh index, then message tail.""" + mid = self.message_id or f"msg_{uuid.uuid4()}" + model = self.model or (getattr(request, "model", "") or "") + sse = SSEBuilder( + mid, + model, + input_tokens, + log_raw_events=log_raw_sse_events, + ) + sse.blocks.next_index = self.next_content_index() + yield from sse.emit_error(error_message) + yield sse.message_delta("end_turn", 1) + yield sse.message_stop() diff --git a/core/anthropic/errors.py b/core/anthropic/errors.py index 38e3d12..afe093d 100644 --- a/core/anthropic/errors.py +++ b/core/anthropic/errors.py @@ -9,11 +9,12 @@ def get_user_facing_error_message( *, read_timeout_s: float | None = None, ) -> str: - """Return a readable, non-empty error message for users.""" - message = str(e).strip() - if message: - return message + """Return a readable, non-empty error message for users. + Known transport and OpenAI SDK exception types are mapped to stable wording + before falling back to ``str(e)``, so empty or noisy SDK messages do not skip + the mapped path. + """ if isinstance(e, httpx.ReadTimeout): if read_timeout_s is not None: return f"Provider request timed out after {read_timeout_s:g}s." @@ -25,13 +26,20 @@ def get_user_facing_error_message( return f"Provider request timed out after {read_timeout_s:g}s." return "Request timed out." + if isinstance(e, openai.RateLimitError): + return "Provider rate limit reached. Please retry shortly." + if isinstance(e, openai.AuthenticationError): + return "Provider authentication failed. Check API key." + if isinstance(e, openai.BadRequestError): + return "Invalid request sent to provider." + name = type(e).__name__ status_code = getattr(e, "status_code", None) - if isinstance(e, openai.RateLimitError) or name == "RateLimitError": + if name == "RateLimitError": return "Provider rate limit reached. Please retry shortly." - if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError": + if name == "AuthenticationError": return "Provider authentication failed. Check API key." - if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError": + if name == "InvalidRequestError": return "Invalid request sent to provider." if name == "OverloadedError": return "Provider is currently overloaded. Please retry." @@ -42,9 +50,18 @@ def get_user_facing_error_message( if name.endswith("ProviderError") or name == "ProviderError": return "Provider request failed." + message = str(e).strip() + if message: + return message + return "Provider request failed unexpectedly." +def format_user_error_preview(exc: Exception, *, max_len: int = 200) -> str: + """Truncate a user-facing error string for short chat replies.""" + return get_user_facing_error_message(exc)[:max_len] + + def append_request_id(message: str, request_id: str | None) -> str: """Append request_id suffix when available.""" base = message.strip() or "Provider request failed unexpectedly." diff --git a/core/anthropic/native_messages_request.py b/core/anthropic/native_messages_request.py new file mode 100644 index 0000000..ed24c6c --- /dev/null +++ b/core/anthropic/native_messages_request.py @@ -0,0 +1,260 @@ +"""Native Anthropic Messages request body construction (JSON-ready dicts). + +Provider adapters supply policy via parameters (defaults, OpenRouter post-steps). +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from pydantic import BaseModel + +_REQUEST_FIELDS = ( + "model", + "messages", + "system", + "max_tokens", + "stop_sequences", + "stream", + "temperature", + "top_p", + "top_k", + "metadata", + "tools", + "tool_choice", + "thinking", + "context_management", + "output_config", + "mcp_servers", + "extra_body", +) + +# Keys that would override routed canonical request fields if merged from ``extra_body``. +_OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS = frozenset( + { + "model", + "messages", + "system", + "tools", + "tool_choice", + "stream", + "max_tokens", + "temperature", + "top_p", + "top_k", + "metadata", + "stop_sequences", + "context_management", + "output_config", + "mcp_servers", + } +) + + +class OpenRouterExtraBodyError(ValueError): + """``extra_body`` contained reserved keys that would override canonical fields.""" + + +def validate_openrouter_extra_body(extra: Any) -> None: + """Reject ``extra_body`` keys that must not override routed request fields.""" + if not isinstance(extra, dict) or not extra: + return + bad = _OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS & extra.keys() + if bad: + raise OpenRouterExtraBodyError( + f"extra_body must not override canonical request fields: {sorted(bad)}" + ) + + +_INTERNAL_FIELDS = { + "thinking", + "extra_body", +} + + +def _serialize_value(value: Any) -> Any: + """Convert Pydantic models and lightweight objects into JSON-ready values.""" + if isinstance(value, BaseModel): + return value.model_dump(exclude_none=True) + if isinstance(value, dict): + return { + key: _serialize_value(item) + for key, item in value.items() + if item is not None + } + if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray): + return [_serialize_value(item) for item in value] + if value is None or isinstance(value, str | int | float | bool): + return value + if hasattr(value, "__dict__"): + return { + key: _serialize_value(item) + for key, item in vars(value).items() + if not key.startswith("_") and item is not None + } + return value + + +def _dump_request_fields(request_data: Any) -> dict[str, Any]: + """Extract the public request fields (OpenRouter-style explicit field list).""" + if isinstance(request_data, BaseModel): + return request_data.model_dump(exclude_none=True) + + dumped: dict[str, Any] = {} + for field in _REQUEST_FIELDS: + value = getattr(request_data, field, None) + if value is not None: + dumped[field] = _serialize_value(value) + return dumped + + +def sanitize_native_messages_thinking_policy( + messages: Any, *, thinking_enabled: bool +) -> Any: + """Filter assistant message thinking blocks for upstream native Anthropic JSON. + + When ``thinking_enabled`` is false, remove ``thinking`` and ``redacted_thinking`` + history so disabled policy is not undermined by prior turns. + + When true, keep ``redacted_thinking`` and signed ``thinking``; remove only + unsigned plain ``thinking`` blocks (not replayable). + """ + if not isinstance(messages, list): + return messages + + sanitized_messages: list[Any] = [] + for message in messages: + if not isinstance(message, dict): + sanitized_messages.append(message) + continue + + if message.get("role") != "assistant": + sanitized_messages.append(message) + continue + + content = message.get("content") + if not isinstance(content, list): + sanitized_messages.append(message) + continue + + if not thinking_enabled: + sanitized_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") in ("thinking", "redacted_thinking") + ) + ] + else: + sanitized_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") == "thinking" + and not isinstance(block.get("signature"), str) + ) + ] + + sanitized_message = dict(message) + sanitized_message["content"] = sanitized_content or "" + sanitized_messages.append(sanitized_message) + + return sanitized_messages + + +def _normalize_system_prompt_for_openrouter(system: Any) -> Any: + """Flatten Claude SDK system blocks for OpenRouter's native endpoint.""" + if not isinstance(system, list): + return system + + text_parts: list[str] = [] + for block in system: + if not isinstance(block, dict): + continue + if block.get("type") == "text" and isinstance(block.get("text"), str): + text_parts.append(block["text"]) + return "\n\n".join(text_parts).strip() if text_parts else system + + +def _apply_openrouter_reasoning_policy(body: dict[str, Any], thinking_cfg: Any) -> None: + """Map Anthropic thinking controls onto OpenRouter reasoning controls.""" + reasoning = body.setdefault("reasoning", {"enabled": True}) + if not isinstance(reasoning, dict): + return + reasoning.setdefault("enabled", True) + if not isinstance(thinking_cfg, dict): + return + budget_tokens = thinking_cfg.get("budget_tokens") + if isinstance(budget_tokens, int): + reasoning.setdefault("max_tokens", budget_tokens) + + +def build_base_native_anthropic_request_body( + request: Any, + *, + default_max_tokens: int, + thinking_enabled: bool, +) -> dict[str, Any]: + """Serialize a Pydantic messages request to a generic native Anthropic body.""" + body = request.model_dump(exclude_none=True) + + body.pop("extra_body", None) + + if "thinking" in body: + thinking_cfg = body.pop("thinking") + if thinking_enabled and isinstance(thinking_cfg, dict): + thinking_payload: dict[str, Any] = {"type": "enabled"} + budget_tokens = thinking_cfg.get("budget_tokens") + if isinstance(budget_tokens, int): + thinking_payload["budget_tokens"] = budget_tokens + body["thinking"] = thinking_payload + + if "max_tokens" not in body: + body["max_tokens"] = default_max_tokens + + if "messages" in body: + body["messages"] = sanitize_native_messages_thinking_policy( + body["messages"], + thinking_enabled=thinking_enabled, + ) + + return body + + +def build_openrouter_native_request_body( + request_data: Any, + *, + thinking_enabled: bool, + default_max_tokens: int, +) -> dict[str, Any]: + """Build an Anthropic-format request body for OpenRouter (policy hooks built-in).""" + dumped_request = _dump_request_fields(request_data) + request_extra = dumped_request.pop("extra_body", None) + thinking_cfg = dumped_request.get("thinking") + body: dict[str, Any] = { + key: value + for key, value in dumped_request.items() + if key not in _INTERNAL_FIELDS + } + + if isinstance(request_extra, dict): + validate_openrouter_extra_body(request_extra) + body.update(request_extra) + + body["messages"] = sanitize_native_messages_thinking_policy( + body.get("messages"), + thinking_enabled=thinking_enabled, + ) + if "system" in body: + body["system"] = _normalize_system_prompt_for_openrouter(body["system"]) + body["stream"] = True + if body.get("max_tokens") is None: + body["max_tokens"] = default_max_tokens + + if thinking_enabled: + _apply_openrouter_reasoning_policy(body, thinking_cfg) + + return body diff --git a/core/anthropic/native_sse_block_policy.py b/core/anthropic/native_sse_block_policy.py new file mode 100644 index 0000000..7a3fd81 --- /dev/null +++ b/core/anthropic/native_sse_block_policy.py @@ -0,0 +1,313 @@ +"""Shared native Anthropic SSE thinking policy, block remapping, and overlap repair. + +Used by :class:`OpenRouterProvider` and line-mode +:class:`providers.anthropic_messages.AnthropicMessagesTransport` providers. +""" + +from __future__ import annotations + +import copy +import json +from dataclasses import dataclass, field +from typing import Any + +__all__ = [ + "NativeSseBlockPolicyState", + "format_native_sse_event", + "is_terminal_openrouter_done_event", + "parse_native_sse_event", + "transform_native_sse_block_event", +] + + +@dataclass +class _UpstreamBlockState: + """Per-upstream content block: segment index and liveness in the model stream.""" + + block_type: str + down_index: int + open: bool + last_start_block: dict[str, Any] | None = None + + +@dataclass +class NativeSseBlockPolicyState: + """Track per-upstream content blocks and remapped Anthropic ``index`` field.""" + + next_index: int = 0 + by_upstream: dict[int, _UpstreamBlockState] = field(default_factory=dict) + dropped_indexes: set[int] = field(default_factory=set) + pending_suppressed_stops: set[int] = field(default_factory=set) + message_stopped: bool = False + + +def format_native_sse_event(event_name: str | None, data_text: str) -> str: + """Format an SSE event from its event name and data payload.""" + lines: list[str] = [] + if event_name: + lines.append(f"event: {event_name}") + lines.extend(f"data: {line}" for line in data_text.splitlines()) + return "\n".join(lines) + "\n\n" + + +def parse_native_sse_event(event: str) -> tuple[str | None, str]: + """Extract the event name and raw data payload from an SSE event.""" + event_name = None + data_lines: list[str] = [] + for line in event.strip().splitlines(): + if line.startswith("event:"): + event_name = line[6:].strip() + elif line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + return event_name, "\n".join(data_lines) + + +def is_terminal_openrouter_done_event(event_name: str | None, data_text: str) -> bool: + """Return whether an event is OpenAI-style terminal noise (``[DONE]``).""" + return (event_name is None or event_name in {"data", "done"}) and ( + data_text.strip().upper() == "[DONE]" + ) + + +def _delta_type_to_block_kind(delta_type: Any) -> str | None: + """Map a content_block_delta type to a content block kind (text/thinking/tool_use).""" + if not isinstance(delta_type, str): + return None + if delta_type in {"thinking_delta", "signature_delta"}: + return "thinking" + if delta_type == "text_delta": + return "text" + if delta_type == "input_json_delta": + return "tool_use" + return None + + +def _synthetic_start_content_block( + block_kind: str, + *, + upstream_index: int, + stored_tool_block: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build a `content_block` for a `content_block_start` with empty streaming fields.""" + if block_kind == "tool_use": + if ( + isinstance(stored_tool_block, dict) + and stored_tool_block.get("type") == "tool_use" + ): + tool_id = stored_tool_block.get("id") + name = stored_tool_block.get("name") + inp = stored_tool_block.get("input") + return { + "type": "tool_use", + "id": tool_id + if isinstance(tool_id, str) and tool_id + else f"toolu_or_{upstream_index}", + "name": name if isinstance(name, str) else "", + "input": inp if isinstance(inp, dict) else {}, + } + return { + "type": "tool_use", + "id": f"toolu_or_{upstream_index}", + "name": "", + "input": {}, + } + if block_kind == "thinking": + return {"type": "thinking", "thinking": ""} + if block_kind == "text": + return {"type": "text", "text": ""} + return {"type": "text", "text": ""} + + +def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool: + if not isinstance(block_type, str): + return False + if block_type.startswith("redacted_thinking"): + return not thinking_enabled + return not thinking_enabled and "thinking" in block_type + + +def _synthetic_close_other_open_blocks( + state: NativeSseBlockPolicyState, current_upstream: int +) -> str: + """Close every open block except `current_upstream` and track duplicate upstream stops.""" + out: list[str] = [] + for upstream, seg in list(state.by_upstream.items()): + if upstream == current_upstream or not seg.open: + continue + out.append( + format_native_sse_event( + "content_block_stop", + json.dumps( + { + "type": "content_block_stop", + "index": seg.down_index, + } + ), + ) + ) + seg.open = False + state.pending_suppressed_stops.add(upstream) + return "".join(out) + + +def _allocate_new_segment( + state: NativeSseBlockPolicyState, + upstream_index: int, + block_type: str, + *, + last_start_block: dict[str, Any] | None = None, +) -> int: + """Assign a new downstream `index` for a segment and record upstream state.""" + new_idx = state.next_index + state.next_index += 1 + state.by_upstream[upstream_index] = _UpstreamBlockState( + block_type=block_type, + down_index=new_idx, + open=True, + last_start_block=last_start_block, + ) + return new_idx + + +def transform_native_sse_block_event( + event: str, + state: NativeSseBlockPolicyState, + *, + thinking_enabled: bool, +) -> str | None: + """Normalize native Anthropic SSE events and enforce local thinking policy.""" + event_name, data_text = parse_native_sse_event(event) + if not event_name or not data_text: + return event + + try: + payload = json.loads(data_text) + except json.JSONDecodeError: + return event + + if event_name == "content_block_start": + block = payload.get("content_block") + if not isinstance(block, dict): + return event + block_type = block.get("type") + upstream_index = payload.get("index") + if not isinstance(upstream_index, int): + return event + if _should_drop_block_type(block_type, thinking_enabled=thinking_enabled): + state.dropped_indexes.add(upstream_index) + return None + + if not isinstance(block_type, str): + return event + prefix = _synthetic_close_other_open_blocks(state, upstream_index) + stored = copy.deepcopy(block) + new_idx = _allocate_new_segment( + state, + upstream_index, + block_type=block_type, + last_start_block=stored, + ) + payload["index"] = new_idx + return prefix + format_native_sse_event(event_name, json.dumps(payload)) + + if event_name == "content_block_delta": + delta = payload.get("delta") + if not isinstance(delta, dict): + return event + delta_type = delta.get("type") + upstream_index = payload.get("index") + if not isinstance(upstream_index, int): + return event + if upstream_index in state.dropped_indexes: + return None + if _should_drop_block_type(delta_type, thinking_enabled=thinking_enabled): + return None + + block_kind = _delta_type_to_block_kind(delta_type) + if block_kind is None: + return event + + seg = state.by_upstream.get(upstream_index) + if seg and seg.open: + payload["index"] = seg.down_index + return format_native_sse_event(event_name, json.dumps(payload)) + + if seg is not None and not seg.open: + # More deltas for an upstream block after a synthetic (or other) close: + # reopen with a new downstream `index` and emit a synthetic `content_block_start` first. + state.pending_suppressed_stops.discard(upstream_index) + carry = seg.last_start_block + new_idx = _allocate_new_segment( + state, + upstream_index, + block_type=block_kind, + last_start_block=carry, + ) + stored_tool = ( + carry + if isinstance(carry, dict) and carry.get("type") == "tool_use" + else None + ) + start_payload = { + "type": "content_block_start", + "index": new_idx, + "content_block": _synthetic_start_content_block( + block_kind, + upstream_index=upstream_index, + stored_tool_block=stored_tool, + ), + } + prefix = format_native_sse_event( + "content_block_start", json.dumps(start_payload) + ) + payload["index"] = new_idx + return prefix + format_native_sse_event(event_name, json.dumps(payload)) + + # Delta with no prior `content_block_start` in this stream + if block_kind in ("text", "tool_use"): + synthetic_block = _synthetic_start_content_block( + block_kind, + upstream_index=upstream_index, + ) + new_idx = _allocate_new_segment( + state, + upstream_index, + block_type=block_kind, + last_start_block=copy.deepcopy(synthetic_block), + ) + start_payload = { + "type": "content_block_start", + "index": new_idx, + "content_block": synthetic_block, + } + prefix = format_native_sse_event( + "content_block_start", json.dumps(start_payload) + ) + payload["index"] = new_idx + return prefix + format_native_sse_event(event_name, json.dumps(payload)) + # thinking: pass through raw (unusual upstream shape) + return event + + if event_name == "content_block_stop": + upstream_index = payload.get("index") + if not isinstance(upstream_index, int): + return event + if upstream_index in state.dropped_indexes: + return None + if upstream_index in state.pending_suppressed_stops: + state.pending_suppressed_stops.discard(upstream_index) + return None + + seg = state.by_upstream.get(upstream_index) + if seg is not None and seg.open: + payload["index"] = seg.down_index + seg.open = False + return format_native_sse_event(event_name, json.dumps(payload)) + if seg is not None: + # Spurious or duplicate `content_block_stop` for a closed block. + return None + if not thinking_enabled: + return None + return event + + return event diff --git a/core/anthropic/provider_stream_error.py b/core/anthropic/provider_stream_error.py new file mode 100644 index 0000000..5534c50 --- /dev/null +++ b/core/anthropic/provider_stream_error.py @@ -0,0 +1,34 @@ +"""Canonical Anthropic-style SSE sequence for provider-side streaming errors.""" + +from __future__ import annotations + +import uuid +from collections.abc import Iterator +from typing import Any + +from core.anthropic.sse import SSEBuilder + + +def iter_provider_stream_error_sse_events( + *, + request: Any, + input_tokens: int, + error_message: str, + sent_any_event: bool, + log_raw_sse_events: bool, + message_id: str | None = None, +) -> Iterator[str]: + """Yield message_start (if needed), a text block with the error, then message_delta/stop.""" + mid = message_id or f"msg_{uuid.uuid4()}" + model = getattr(request, "model", "") or "" + sse = SSEBuilder( + mid, + model, + input_tokens, + log_raw_events=log_raw_sse_events, + ) + if not sent_any_event: + yield sse.message_start() + yield from sse.emit_error(error_message) + yield sse.message_delta("end_turn", 1) + yield sse.message_stop() diff --git a/core/anthropic/server_tool_sse.py b/core/anthropic/server_tool_sse.py new file mode 100644 index 0000000..1c8727d --- /dev/null +++ b/core/anthropic/server_tool_sse.py @@ -0,0 +1,14 @@ +"""SSE content_block ``type`` values for Anthropic web server tools (local handlers). + +Shared by :mod:`api.web_tools` and stream contract tests to avoid drift. +""" + +from __future__ import annotations + +from typing import Final + +SERVER_TOOL_USE: Final = "server_tool_use" +WEB_SEARCH_TOOL_RESULT: Final = "web_search_tool_result" +WEB_FETCH_TOOL_RESULT: Final = "web_fetch_tool_result" +WEB_SEARCH_TOOL_RESULT_ERROR: Final = "web_search_tool_result_error" +WEB_FETCH_TOOL_ERROR: Final = "web_fetch_tool_error" diff --git a/core/anthropic/sse.py b/core/anthropic/sse.py index 7998839..3bf5ebe 100644 --- a/core/anthropic/sse.py +++ b/core/anthropic/sse.py @@ -1,5 +1,6 @@ """SSE event builder for Anthropic-format streaming responses.""" +import hashlib import json from collections.abc import Iterator from dataclasses import dataclass, field @@ -14,6 +15,13 @@ except Exception: ENCODER = None +# Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy. +ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = { + "X-Accel-Buffering": "no", + "Cache-Control": "no-cache", + "Connection": "keep-alive", +} + STOP_REASON_MAP = { "stop": "end_turn", "length": "max_tokens", @@ -29,6 +37,11 @@ def map_stop_reason(openai_reason: str | None) -> str: ) +def format_sse_event(event_type: str, data: dict) -> str: + """Format one Anthropic-style SSE event (no logging).""" + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + + @dataclass class ToolCallState: """State for a single streaming tool call.""" @@ -40,6 +53,7 @@ class ToolCallState: started: bool = False task_arg_buffer: str = "" task_args_emitted: bool = False + pre_start_args: str = "" @dataclass @@ -58,7 +72,25 @@ class ContentBlockManager: self.next_index += 1 return idx + def ensure_tool_state(self, index: int) -> ToolCallState: + """Create tool stream state for ``index`` when the first tool delta arrives.""" + if index not in self.tool_states: + self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="") + return self.tool_states[index] + + def set_stream_tool_id(self, index: int, tool_id: str | None) -> None: + """Record OpenAI tool call id before ``content_block_start`` (split-stream providers).""" + if not tool_id: + return + state = self.ensure_tool_state(index) + state.tool_id = str(tool_id) + def register_tool_name(self, index: int, name: str) -> None: + """Record tool name fragments as they arrive from chunked OpenAI streams. + + Names may be split across deltas; later chunks can extend (``ab`` + ``c``) + or repeat prefixes, so we merge conservatively. + """ if index not in self.tool_states: self.tool_states[index] = ToolCallState( block_index=-1, tool_id="", name=name @@ -82,8 +114,7 @@ class ContentBlockManager: except Exception: return None - if args_json.get("run_in_background") is not False: - args_json["run_in_background"] = False + _normalize_task_run_in_background(args_json) state.task_args_emitted = True state.task_arg_buffer = "" @@ -98,16 +129,17 @@ class ContentBlockManager: out = "{}" try: args_json = json.loads(state.task_arg_buffer) - if args_json.get("run_in_background") is not False: - args_json["run_in_background"] = False + _normalize_task_run_in_background(args_json) out = json.dumps(args_json) - except Exception as e: - prefix = state.task_arg_buffer[:120] + except (json.JSONDecodeError, TypeError, ValueError) as e: + digest = hashlib.sha256( + state.task_arg_buffer.encode("utf-8", errors="replace") + ).hexdigest()[:16] logger.warning( - "Task args invalid JSON (id={} len={} prefix={!r}): {}", + "Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}", state.tool_id or "unknown", len(state.task_arg_buffer), - prefix, + digest, e, ) @@ -117,20 +149,41 @@ class ContentBlockManager: return results +def _normalize_task_run_in_background(args_json: dict) -> None: + """Force Claude Code Task subagents to run in foreground (single shared rule).""" + if args_json.get("run_in_background") is not False: + args_json["run_in_background"] = False + + class SSEBuilder: """Builder for Anthropic SSE streaming events.""" - def __init__(self, message_id: str, model: str, input_tokens: int = 0): + def __init__( + self, + message_id: str, + model: str, + input_tokens: int = 0, + *, + log_raw_events: bool = False, + ): self.message_id = message_id self.model = model self.input_tokens = input_tokens + self._log_raw_events = log_raw_events self.blocks = ContentBlockManager() self._accumulated_text_parts: list[str] = [] self._accumulated_reasoning_parts: list[str] = [] def _format_event(self, event_type: str, data: dict) -> str: - event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip()) + event_str = format_sse_event(event_type, data) + if self._log_raw_events: + logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip()) + else: + logger.debug( + "SSE_EVENT: event_type={} serialized_bytes={}", + event_type, + len(event_str.encode("utf-8")), + ) return event_str def message_start(self) -> str: @@ -289,10 +342,7 @@ class SSEBuilder: yield self.stop_text_block() def close_all_blocks(self) -> Iterator[str]: - if self.blocks.thinking_started: - yield self.stop_thinking_block() - if self.blocks.text_started: - yield self.stop_text_block() + yield from self.close_content_blocks() for tool_index, state in list(self.blocks.tool_states.items()): if state.started: yield self.stop_tool_block(tool_index) diff --git a/core/anthropic/stream_contracts.py b/core/anthropic/stream_contracts.py index 8058392..ba2b4d8 100644 --- a/core/anthropic/stream_contracts.py +++ b/core/anthropic/stream_contracts.py @@ -1,7 +1,6 @@ """Neutral SSE parsing and Anthropic stream shape assertions. -Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for -opt-in smoke scenarios. +Used by default CI contract tests and by opt-in live smoke scenarios. """ from __future__ import annotations @@ -11,6 +10,36 @@ from collections.abc import Iterable from dataclasses import dataclass from typing import Any +from .server_tool_sse import ( + SERVER_TOOL_USE, + WEB_FETCH_TOOL_RESULT, + WEB_SEARCH_TOOL_RESULT, +) + +# Content blocks that only use content_block_start/stop (no deltas), including +# Anthropic server tools and eager text emitted in a single start event. +_NO_DELTA_BLOCK_KINDS = frozenset( + { + SERVER_TOOL_USE, + WEB_SEARCH_TOOL_RESULT, + WEB_FETCH_TOOL_RESULT, + "text_eager", + "redacted_thinking", + } +) + +_ALLOWED_BLOCK_START_TYPES = frozenset( + { + "text", + "thinking", + "tool_use", + "redacted_thinking", + SERVER_TOOL_USE, + WEB_SEARCH_TOOL_RESULT, + WEB_FETCH_TOOL_RESULT, + } +) + @dataclass(frozen=True, slots=True) class SSEEvent: @@ -86,35 +115,46 @@ def assert_anthropic_stream_contract( raise AssertionError(f"unexpected SSE error event: {event.data}") if event.event == "content_block_start": - index = _event_index(event) + index = event_index(event) block = event.data.get("content_block", {}) assert isinstance(block, dict), event.data block_type = str(block.get("type", "")) - assert block_type in {"text", "thinking", "tool_use"}, event.data + assert block_type in _ALLOWED_BLOCK_START_TYPES, event.data assert index not in open_blocks, f"block {index} started twice" assert index not in seen_blocks, f"block {index} reused after stop" - open_blocks[index] = block_type + if block_type == "text" and str(block.get("text", "")).strip(): + storage = "text_eager" + else: + storage = block_type + open_blocks[index] = storage seen_blocks.add(index) continue if event.event == "content_block_delta": - index = _event_index(event) + index = event_index(event) assert index in open_blocks, f"delta for unopened block {index}" + kind = open_blocks[index] + assert kind not in _NO_DELTA_BLOCK_KINDS, ( + f"unexpected delta for start/stop-only block {kind} at index {index}" + ) delta = event.data.get("delta", {}) assert isinstance(delta, dict), event.data delta_type = str(delta.get("type", "")) + if kind == "thinking": + assert delta_type in ( + "thinking_delta", + "signature_delta", + ), f"block {index} is {kind}, got {delta_type}" + continue expected = { "text": "text_delta", - "thinking": "thinking_delta", "tool_use": "input_json_delta", - }[open_blocks[index]] - assert delta_type == expected, ( - f"block {index} is {open_blocks[index]}, got {delta_type}" - ) + }[kind] + assert delta_type == expected, f"block {index} is {kind}, got {delta_type}" continue if event.event == "content_block_stop": - index = _event_index(event) + index = event_index(event) assert index in open_blocks, f"stop for unopened block {index}" open_blocks.pop(index) @@ -129,6 +169,12 @@ def event_names(events: list[SSEEvent]) -> list[str]: def text_content(events: list[SSEEvent]) -> str: parts: list[str] = [] for event in events: + if event.event == "content_block_start": + block = event.data.get("content_block", {}) + if isinstance(block, dict) and block.get("type") == "text": + eager = str(block.get("text", "")) + if eager: + parts.append(eager) delta = event.data.get("delta", {}) if isinstance(delta, dict) and delta.get("type") == "text_delta": parts.append(str(delta.get("text", ""))) @@ -152,7 +198,8 @@ def has_tool_use(events: list[SSEEvent]) -> bool: return False -def _event_index(event: SSEEvent) -> int: +def event_index(event: SSEEvent) -> int: + """Return the content block ``index`` field from an SSE payload (strict).""" value = event.data.get("index") assert isinstance(value, int), event.data return value diff --git a/core/anthropic/tokens.py b/core/anthropic/tokens.py index dfa1115..081d28b 100644 --- a/core/anthropic/tokens.py +++ b/core/anthropic/tokens.py @@ -68,6 +68,27 @@ def get_token_count( total_tokens += len(ENCODER.encode(json.dumps(content))) total_tokens += len(ENCODER.encode(str(tool_use_id))) total_tokens += 8 + elif b_type in ( + "server_tool_use", + "web_search_tool_result", + "web_fetch_tool_result", + ): + if hasattr(block, "model_dump"): + blob: object = block.model_dump() + else: + blob = block + try: + total_tokens += len( + ENCODER.encode( + json.dumps(blob, default=str, ensure_ascii=False) + ) + ) + except (TypeError, ValueError, OverflowError) as e: + logger.debug( + "Block encode fallback b_type={} err={}", b_type, e + ) + total_tokens += len(ENCODER.encode(str(blob))) + total_tokens += 12 else: logger.debug( "Unexpected block type %r, falling back to json/str encoding", diff --git a/core/rate_limit.py b/core/rate_limit.py new file mode 100644 index 0000000..8d9bd81 --- /dev/null +++ b/core/rate_limit.py @@ -0,0 +1,60 @@ +"""Shared strict sliding-window rate limiting primitives.""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque + + +class StrictSlidingWindowLimiter: + """Strict sliding window limiter. + + Guarantees: at most ``rate_limit`` acquisitions in any interval of length + ``rate_window`` (seconds). + + Implemented as an async context manager so call sites can do:: + + async with limiter: + ... + """ + + def __init__(self, rate_limit: int, rate_window: float) -> None: + if rate_limit <= 0: + raise ValueError("rate_limit must be > 0") + if rate_window <= 0: + raise ValueError("rate_window must be > 0") + + self._rate_limit = int(rate_limit) + self._rate_window = float(rate_window) + self._times: deque[float] = deque() + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + while True: + wait_time = 0.0 + async with self._lock: + now = time.monotonic() + cutoff = now - self._rate_window + + while self._times and self._times[0] <= cutoff: + self._times.popleft() + + if len(self._times) < self._rate_limit: + self._times.append(now) + return + + oldest = self._times[0] + wait_time = max(0.0, (oldest + self._rate_window) - now) + + if wait_time > 0: + await asyncio.sleep(wait_time) + else: + await asyncio.sleep(0) + + async def __aenter__(self) -> StrictSlidingWindowLimiter: + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False diff --git a/messaging/cli_event_constants.py b/messaging/cli_event_constants.py new file mode 100644 index 0000000..3d531ec --- /dev/null +++ b/messaging/cli_event_constants.py @@ -0,0 +1,67 @@ +"""CLI event types and status-line mapping for transcript / UI updates.""" + +from collections.abc import Callable +from typing import Any + +# Status message prefixes used to filter our own messages (ignore echo) +STATUS_MESSAGE_PREFIXES = ( + "⏳", + "💭", + "🔧", + "✅", + "❌", + "🚀", + "🤖", + "📋", + "📊", + "🔄", +) + +# Event types that update the transcript (frozenset for O(1) membership) +TRANSCRIPT_EVENT_TYPES = frozenset( + { + "thinking_start", + "thinking_delta", + "thinking_chunk", + "thinking_stop", + "text_start", + "text_delta", + "text_chunk", + "text_stop", + "tool_use_start", + "tool_use_delta", + "tool_use_stop", + "tool_use", + "tool_result", + "block_stop", + "error", + } +) + +# Event type -> (emoji, label) for status updates (O(1) lookup) +_EVENT_STATUS_MAP: dict[str, tuple[str, str]] = { + "thinking_start": ("🧠", "Claude is thinking..."), + "thinking_delta": ("🧠", "Claude is thinking..."), + "thinking_chunk": ("🧠", "Claude is thinking..."), + "text_start": ("🧠", "Claude is working..."), + "text_delta": ("🧠", "Claude is working..."), + "text_chunk": ("🧠", "Claude is working..."), + "tool_result": ("⏳", "Executing tools..."), +} + + +def get_status_for_event( + ptype: str, + parsed: dict[str, Any], + format_status_fn: Callable[..., str], +) -> str | None: + """Return status string for event type, or None if no status update needed.""" + entry = _EVENT_STATUS_MAP.get(ptype) + if entry is not None: + emoji, label = entry + return format_status_fn(emoji, label) + if ptype in ("tool_use_start", "tool_use_delta", "tool_use"): + if parsed.get("name") == "Task": + return format_status_fn("🤖", "Subagent working...") + return format_status_fn("⏳", "Executing tools...") + return None diff --git a/messaging/event_parser.py b/messaging/event_parser.py index 2499248..87eb82a 100644 --- a/messaging/event_parser.py +++ b/messaging/event_parser.py @@ -9,12 +9,14 @@ from typing import Any from loguru import logger -def parse_cli_event(event: Any) -> list[dict]: +def parse_cli_event(event: Any, *, log_raw_cli: bool = False) -> list[dict]: """ Parse a CLI event and return a structured result. Args: event: Raw event dictionary from CLI + log_raw_cli: When True, log full error text from the CLI. Default is + metadata-only (lengths / exit codes) to avoid leaking user content. Returns: List of parsed event dicts. Empty list if not recognized. @@ -140,7 +142,11 @@ def parse_cli_event(event: Any) -> list[dict]: if etype == "error": err = event.get("error") msg = err.get("message") if isinstance(err, dict) else str(err) - logger.info(f"CLI_PARSER: Parsed error event: {msg}") + if log_raw_cli: + logger.info("CLI_PARSER: Parsed error event: {}", msg) + else: + mlen = len(msg) if isinstance(msg, str) else 0 + logger.info("CLI_PARSER: Parsed error event: message_chars={}", mlen) return [{"type": "error", "message": msg}] elif etype == "exit": code = event.get("code", 0) @@ -151,7 +157,19 @@ def parse_cli_event(event: Any) -> list[dict]: else: # Non-zero exit is an error error_msg = stderr if stderr else f"Process exited with code {code}" - logger.warning(f"CLI_PARSER: Error exit (code={code}): {error_msg}") + if log_raw_cli: + logger.warning( + "CLI_PARSER: Error exit (code={}): {}", + code, + error_msg, + ) + else: + em = error_msg if isinstance(error_msg, str) else str(error_msg) + logger.warning( + "CLI_PARSER: Error exit (code={}): message_chars={}", + code, + len(em), + ) return [ {"type": "error", "message": error_msg}, {"type": "complete", "status": "failed"}, diff --git a/messaging/handler.py b/messaging/handler.py index 949cafb..7702e11 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -7,13 +7,12 @@ Uses tree-based queuing for message ordering. """ import asyncio -import os -import time from loguru import logger -from core.anthropic import get_user_facing_error_message +from core.anthropic import format_user_error_preview, get_user_facing_error_message +from .cli_event_constants import STATUS_MESSAGE_PREFIXES from .command_dispatcher import ( dispatch_command, message_kind_for_command, @@ -21,8 +20,10 @@ from .command_dispatcher import ( ) from .event_parser import parse_cli_event from .models import IncomingMessage +from .node_event_pipeline import handle_session_info_event, process_parsed_cli_event from .platforms.base import MessagingPlatform, SessionManagerInterface from .rendering.profiles import build_rendering_profile +from .safe_diagnostics import format_exception_for_log from .session import SessionStore from .transcript import RenderCtx, TranscriptBuffer from .trees.queue_manager import ( @@ -31,54 +32,7 @@ from .trees.queue_manager import ( MessageTree, TreeQueueManager, ) - -# Status message prefixes used to filter our own messages (ignore echo) -STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄") - -# Event types that update the transcript (frozenset for O(1) membership) -TRANSCRIPT_EVENT_TYPES = frozenset( - { - "thinking_start", - "thinking_delta", - "thinking_chunk", - "thinking_stop", - "text_start", - "text_delta", - "text_chunk", - "text_stop", - "tool_use_start", - "tool_use_delta", - "tool_use_stop", - "tool_use", - "tool_result", - "block_stop", - "error", - } -) - -# Event type -> (emoji, label) for status updates (O(1) lookup) -_EVENT_STATUS_MAP = { - "thinking_start": ("🧠", "Claude is thinking..."), - "thinking_delta": ("🧠", "Claude is thinking..."), - "thinking_chunk": ("🧠", "Claude is thinking..."), - "text_start": ("🧠", "Claude is working..."), - "text_delta": ("🧠", "Claude is working..."), - "text_chunk": ("🧠", "Claude is working..."), - "tool_result": ("⏳", "Executing tools..."), -} - - -def _get_status_for_event(ptype: str, parsed: dict, format_status_fn) -> str | None: - """Return status string for event type, or None if no status update needed.""" - entry = _EVENT_STATUS_MAP.get(ptype) - if entry is not None: - emoji, label = entry - return format_status_fn(emoji, label) - if ptype in ("tool_use_start", "tool_use_delta", "tool_use"): - if parsed.get("name") == "Task": - return format_status_fn("🤖", "Subagent working...") - return format_status_fn("⏳", "Executing tools...") - return None +from .ui_updates import ThrottledTranscriptEditor class ClaudeMessageHandler: @@ -97,10 +51,21 @@ class ClaudeMessageHandler: platform: MessagingPlatform, cli_manager: SessionManagerInterface, session_store: SessionStore, + *, + debug_platform_edits: bool = False, + debug_subagent_stack: bool = False, + log_raw_messaging_content: bool = False, + log_raw_cli_diagnostics: bool = False, + log_messaging_error_details: bool = False, ): self.platform = platform self.cli_manager = cli_manager self.session_store = session_store + self._debug_platform_edits = debug_platform_edits + self._debug_subagent_stack = debug_subagent_stack + self._log_raw_messaging_content = log_raw_messaging_content + self._log_raw_cli_diagnostics = log_raw_cli_diagnostics + self._log_messaging_error_details = log_messaging_error_details self._tree_queue = TreeQueueManager( queue_update_callback=self.update_queue_positions, node_started_callback=self.mark_node_processing, @@ -137,16 +102,26 @@ class ClaudeMessageHandler: Determines if this is a new conversation or reply, creates/extends the message tree, and queues for processing. """ - text_preview = (incoming.text or "")[:80] - if len(incoming.text or "") > 80: - text_preview += "..." - logger.info( - "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}", - incoming.chat_id, - incoming.message_id, - incoming.reply_to_message_id, - text_preview, - ) + raw = incoming.text or "" + if self._log_raw_messaging_content: + text_preview = raw[:80] + if len(raw) > 80: + text_preview += "..." + logger.info( + "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}", + incoming.chat_id, + incoming.message_id, + incoming.reply_to_message_id, + text_preview, + ) + else: + logger.info( + "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_len={}", + incoming.chat_id, + incoming.message_id, + incoming.reply_to_message_id, + len(raw), + ) with logger.contextualize( chat_id=incoming.chat_id, node_id=incoming.message_id @@ -169,7 +144,12 @@ class ClaudeMessageHandler: kind=message_kind_for_command(cmd_base), ) except Exception as e: - logger.debug(f"Failed to record incoming message_id: {e}") + logger.debug( + "Failed to record incoming message_id: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) if await dispatch_command(self, incoming, cmd_base): return @@ -276,7 +256,12 @@ class ClaudeMessageHandler: try: queued_ids = await tree.get_queue_snapshot() except Exception as e: - logger.warning(f"Failed to read queue snapshot: {e}") + logger.warning( + "Failed to read queue snapshot: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) return if not queued_ids: @@ -317,85 +302,11 @@ class ClaudeMessageHandler: self, ) -> tuple[TranscriptBuffer, RenderCtx]: """Create transcript buffer and render context for node processing.""" - transcript = TranscriptBuffer(show_tool_results=False) - return transcript, self.get_render_ctx() - - async def _handle_session_info_event( - self, - event_data: dict, - tree: MessageTree | None, - node_id: str, - captured_session_id: str | None, - temp_session_id: str | None, - ) -> tuple[str | None, str | None]: - """Handle session_info event; return updated (captured_session_id, temp_session_id).""" - if event_data.get("type") != "session_info": - return captured_session_id, temp_session_id - - real_session_id = event_data.get("session_id") - if not real_session_id or not temp_session_id: - return captured_session_id, temp_session_id - - await self.cli_manager.register_real_session_id( - temp_session_id, real_session_id + transcript = TranscriptBuffer( + show_tool_results=False, + debug_subagent_stack=self._debug_subagent_stack, ) - if tree and real_session_id: - await tree.update_state( - node_id, - MessageState.IN_PROGRESS, - session_id=real_session_id, - ) - self.session_store.save_tree(tree.root_id, tree.to_dict()) - - return real_session_id, None - - async def _process_parsed_event( - self, - parsed: dict, - transcript: TranscriptBuffer, - update_ui, - last_status: str | None, - had_transcript_events: bool, - tree: MessageTree | None, - node_id: str, - captured_session_id: str | None, - ) -> tuple[str | None, bool]: - """Process a single parsed CLI event. Returns (last_status, had_transcript_events).""" - ptype = parsed.get("type") or "" - - if ptype in TRANSCRIPT_EVENT_TYPES: - transcript.apply(parsed) - had_transcript_events = True - - status = _get_status_for_event(ptype, parsed, self.format_status) - if status is not None: - await update_ui(status) - last_status = status - elif ptype == "block_stop": - await update_ui(last_status, force=True) - elif ptype == "complete": - if not had_transcript_events: - transcript.apply({"type": "text_chunk", "text": "Done."}) - logger.info("HANDLER: Task complete, updating UI") - await update_ui(self.format_status("✅", "Complete"), force=True) - if tree and captured_session_id: - await tree.update_state( - node_id, - MessageState.COMPLETED, - session_id=captured_session_id, - ) - self.session_store.save_tree(tree.root_id, tree.to_dict()) - elif ptype == "error": - error_msg = parsed.get("message", "Unknown error") - logger.error(f"HANDLER: Error event received: {error_msg}") - logger.info("HANDLER: Updating UI with error status") - await update_ui(self.format_status("❌", "Error"), force=True) - if tree: - await self._propagate_error_to_children( - node_id, error_msg, "Parent task failed" - ) - - return last_status, had_transcript_events + return transcript, self.get_render_ctx() async def _process_node( self, @@ -426,8 +337,6 @@ class ClaudeMessageHandler: transcript, render_ctx = self._create_transcript_and_render_ctx() - last_ui_update = 0.0 - last_displayed_text = None had_transcript_events = False captured_session_id = None temp_session_id = None @@ -439,52 +348,21 @@ class ClaudeMessageHandler: if parent_session_id: logger.info(f"Will fork from parent session: {parent_session_id}") - async def update_ui(status: str | None = None, force: bool = False) -> None: - nonlocal last_ui_update, last_displayed_text, last_status - now = time.time() - if not force and now - last_ui_update < 1.0: - return + editor = ThrottledTranscriptEditor( + platform=self.platform, + parse_mode=self._parse_mode(), + get_limit_chars=self._get_limit_chars, + transcript=transcript, + render_ctx=render_ctx, + node_id=node_id, + chat_id=chat_id, + status_msg_id=status_msg_id, + debug_platform_edits=self._debug_platform_edits, + log_messaging_error_details=self._log_messaging_error_details, + ) - last_ui_update = now - if status is not None: - last_status = status - try: - display = transcript.render( - render_ctx, - limit_chars=self._get_limit_chars(), - status=status, - ) - except Exception as e: - logger.warning(f"Transcript render failed for node {node_id}: {e}") - return - if display and display != last_displayed_text: - logger.debug( - "PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}", - node_id, - chat_id, - status_msg_id, - bool(force), - status, - len(display), - ) - if os.getenv("DEBUG_PLATFORM_EDITS") == "1": - logger.debug("PLATFORM_EDIT_TEXT:\n{}", display) - else: - head = display[:500] - tail = display[-500:] if len(display) > 500 else "" - logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head) - if tail: - logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail) - last_displayed_text = display - try: - await self.platform.queue_edit_message( - chat_id, - status_msg_id, - display, - parse_mode=self._parse_mode(), - ) - except Exception as e: - logger.warning(f"Failed to update platform for node {node_id}: {e}") + async def update_ui(status: str | None = None, force: bool = False) -> None: + await editor.update(status, force=force) try: try: @@ -531,20 +409,28 @@ class ClaudeMessageHandler: ( captured_session_id, temp_session_id, - ) = await self._handle_session_info_event( - event_data, tree, node_id, captured_session_id, temp_session_id + ) = await handle_session_info_event( + event_data, + tree, + node_id, + captured_session_id, + temp_session_id, + cli_manager=self.cli_manager, + session_store=self.session_store, ) if event_data.get("type") == "session_info": continue - parsed_list = parse_cli_event(event_data) + parsed_list = parse_cli_event( + event_data, log_raw_cli=self._log_raw_cli_diagnostics + ) logger.debug(f"HANDLER: Parsed {len(parsed_list)} events from CLI") for parsed in parsed_list: ( last_status, had_transcript_events, - ) = await self._process_parsed_event( + ) = await process_parsed_cli_event( parsed, transcript, update_ui, @@ -553,6 +439,10 @@ class ClaudeMessageHandler: tree, node_id, captured_session_id, + session_store=self.session_store, + format_status=self.format_status, + propagate_error_to_children=self._propagate_error_to_children, + log_messaging_error_details=self._log_messaging_error_details, ) except asyncio.CancelledError: @@ -575,9 +465,12 @@ class ClaudeMessageHandler: ) except Exception as e: logger.error( - f"HANDLER: Task failed with exception: {type(e).__name__}: {e}" + "HANDLER: Task failed with exception: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), ) - error_msg = get_user_facing_error_message(e)[:200] + error_msg = format_user_error_preview(e) transcript.apply({"type": "error", "message": error_msg}) await update_ui(self.format_status("💥", "Task Failed"), force=True) if tree: @@ -595,7 +488,13 @@ class ClaudeMessageHandler: elif temp_session_id: await self.cli_manager.remove_session(temp_session_id) except Exception as e: - logger.debug(f"Failed to remove session for node {node_id}: {e}") + logger.debug( + "Failed to remove session for node {}: {}", + node_id, + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) async def _propagate_error_to_children( self, @@ -691,7 +590,12 @@ class ClaudeMessageHandler: platform, chat_id, str(msg_id), direction="out", kind=kind ) except Exception as e: - logger.debug(f"Failed to record message_id: {e}") + logger.debug( + "Failed to record message_id: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None: """Update status messages and persist tree state for cancelled nodes.""" diff --git a/messaging/limiter.py b/messaging/limiter.py index 6367e9e..648d4a8 100644 --- a/messaging/limiter.py +++ b/messaging/limiter.py @@ -6,65 +6,16 @@ using a strict sliding window algorithm and a task queue. """ import asyncio -import os -import time from collections import deque from collections.abc import Awaitable, Callable from typing import Any from loguru import logger +from config.settings import get_settings +from core.rate_limit import StrictSlidingWindowLimiter as SlidingWindowLimiter -class SlidingWindowLimiter: - """Strict sliding window limiter. - - Guarantees: at most `rate_limit` acquisitions in any interval of length - `rate_window` (seconds). - - Implemented as an async context manager so call sites can do: - async with limiter: - ... - """ - - def __init__(self, rate_limit: int, rate_window: float) -> None: - if rate_limit <= 0: - raise ValueError("rate_limit must be > 0") - if rate_window <= 0: - raise ValueError("rate_window must be > 0") - - self._rate_limit = int(rate_limit) - self._rate_window = float(rate_window) - self._times: deque[float] = deque() - self._lock = asyncio.Lock() - - async def acquire(self) -> None: - while True: - wait_time = 0.0 - async with self._lock: - now = time.monotonic() - cutoff = now - self._rate_window - - while self._times and self._times[0] <= cutoff: - self._times.popleft() - - if len(self._times) < self._rate_limit: - self._times.append(now) - return - - oldest = self._times[0] - wait_time = max(0.0, (oldest + self._rate_window) - now) - - if wait_time > 0: - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0) - - async def __aenter__(self) -> SlidingWindowLimiter: - await self.acquire() - return self - - async def __aexit__(self, exc_type, exc, tb) -> bool: - return False +from .safe_diagnostics import format_exception_for_log class MessagingRateLimiter: @@ -82,23 +33,29 @@ class MessagingRateLimiter: return super().__new__(cls) @classmethod - async def get_instance(cls) -> MessagingRateLimiter: - """Get the singleton instance of the limiter.""" + async def get_instance( + cls, + *, + rate_limit: int = 1, + rate_window: float = 1.0, + ) -> MessagingRateLimiter: + """Get the singleton instance of the limiter. + + ``rate_limit`` and ``rate_window`` apply only when the singleton is first + created. Call :meth:`shutdown_instance` before changing parameters. + """ async with cls._lock: if cls._instance is None: - cls._instance = cls() + cls._instance = cls(rate_limit=rate_limit, rate_window=rate_window) # Start the background worker (tracked for graceful shutdown). cls._instance._start_worker() return cls._instance - def __init__(self): + def __init__(self, *, rate_limit: int, rate_window: float) -> None: # Prevent double initialization in singleton if hasattr(self, "_initialized"): return - rate_limit = int(os.getenv("MESSAGING_RATE_LIMIT", "1")) - rate_window = float(os.getenv("MESSAGING_RATE_WINDOW", "2.0")) - self.limiter = SlidingWindowLimiter(rate_limit, rate_window) # Custom queue state - using deque for O(1) popleft self._queue_list: deque[str] = deque() # Deque of dedup_keys in order @@ -189,15 +146,27 @@ class MessagingRateLimiter: asyncio.get_event_loop().time() + wait_secs ) else: + d = get_settings().log_messaging_error_details logger.error( - f"Error in limiter worker for key {dedup_key}: {type(e).__name__}: {e}" + "Error in limiter worker for key {}: {}", + dedup_key, + format_exception_for_log(e, log_full_message=d), ) except asyncio.CancelledError: break except Exception as e: - logger.error( - f"MessagingRateLimiter worker critical error: {e}", exc_info=True - ) + d = get_settings().log_messaging_error_details + if d: + logger.error( + "MessagingRateLimiter worker critical error: {}", + e, + exc_info=True, + ) + else: + logger.error( + "MessagingRateLimiter worker critical error: exc_type={}", + type(e).__name__, + ) await asyncio.sleep(1) async def shutdown(self, timeout: float = 2.0) -> None: @@ -223,7 +192,11 @@ class MessagingRateLimiter: except asyncio.CancelledError: pass except Exception as e: - logger.debug(f"MessagingRateLimiter worker shutdown error: {e}") + d = get_settings().log_messaging_error_details + logger.debug( + "MessagingRateLimiter worker shutdown error: {}", + format_exception_for_log(e, log_full_message=d), + ) finally: self._worker_task = None @@ -296,14 +269,29 @@ class MessagingRateLimiter: x in error_msg for x in ["connect", "timeout", "broken"] ): wait = 2**attempt - logger.warning( - f"Limiter fire_and_forget transient error (attempt {attempt + 1}): {e}. Retrying in {wait}s..." - ) + d = get_settings().log_messaging_error_details + if d: + logger.warning( + "Limiter fire_and_forget transient error (attempt {}): {}. Retrying in {}s...", + attempt + 1, + e, + wait, + ) + else: + logger.warning( + "Limiter fire_and_forget transient error (attempt {}): exc_type={}. Retrying in {}s...", + attempt + 1, + type(e).__name__, + wait, + ) await asyncio.sleep(wait) continue + d = get_settings().log_messaging_error_details logger.error( - f"Final error in fire_and_forget for key {dedup_key}: {type(e).__name__}: {e}" + "Final error in fire_and_forget for key {}: {}", + dedup_key, + format_exception_for_log(e, log_full_message=d), ) if not future.done(): future.set_exception(e) diff --git a/messaging/node_event_pipeline.py b/messaging/node_event_pipeline.py new file mode 100644 index 0000000..6f233e2 --- /dev/null +++ b/messaging/node_event_pipeline.py @@ -0,0 +1,103 @@ +"""CLI event handling for a single queued node (transcript + session + errors).""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from loguru import logger + +from .cli_event_constants import TRANSCRIPT_EVENT_TYPES, get_status_for_event +from .platforms.base import SessionManagerInterface +from .safe_diagnostics import text_len_hint +from .session import SessionStore +from .transcript import TranscriptBuffer +from .trees.queue_manager import MessageState, MessageTree + + +async def handle_session_info_event( + event_data: dict[str, Any], + tree: MessageTree | None, + node_id: str, + captured_session_id: str | None, + temp_session_id: str | None, + *, + cli_manager: SessionManagerInterface, + session_store: SessionStore, +) -> tuple[str | None, str | None]: + """Handle session_info event; return updated (captured_session_id, temp_session_id).""" + if event_data.get("type") != "session_info": + return captured_session_id, temp_session_id + + real_session_id = event_data.get("session_id") + if not real_session_id or not temp_session_id: + return captured_session_id, temp_session_id + + await cli_manager.register_real_session_id(temp_session_id, real_session_id) + if tree and real_session_id: + await tree.update_state( + node_id, + MessageState.IN_PROGRESS, + session_id=real_session_id, + ) + session_store.save_tree(tree.root_id, tree.to_dict()) + + return real_session_id, None + + +async def process_parsed_cli_event( + parsed: dict[str, Any], + transcript: TranscriptBuffer, + update_ui: Callable[..., Awaitable[None]], + last_status: str | None, + had_transcript_events: bool, + tree: MessageTree | None, + node_id: str, + captured_session_id: str | None, + *, + session_store: SessionStore, + format_status: Callable[..., str], + propagate_error_to_children: Callable[[str, str, str], Awaitable[None]], + log_messaging_error_details: bool = False, +) -> tuple[str | None, bool]: + """Process a single parsed CLI event. Returns (last_status, had_transcript_events).""" + ptype = parsed.get("type") or "" + + if ptype in TRANSCRIPT_EVENT_TYPES: + transcript.apply(parsed) + had_transcript_events = True + + status = get_status_for_event(ptype, parsed, format_status) + if status is not None: + await update_ui(status) + last_status = status + elif ptype == "block_stop": + await update_ui(last_status, force=True) + elif ptype == "complete": + if not had_transcript_events: + transcript.apply({"type": "text_chunk", "text": "Done."}) + logger.info("HANDLER: Task complete, updating UI") + await update_ui(format_status("✅", "Complete"), force=True) + if tree and captured_session_id: + await tree.update_state( + node_id, + MessageState.COMPLETED, + session_id=captured_session_id, + ) + session_store.save_tree(tree.root_id, tree.to_dict()) + elif ptype == "error": + error_msg = parsed.get("message", "Unknown error") + if log_messaging_error_details: + logger.error("HANDLER: Error event received: {}", error_msg) + else: + em = error_msg if isinstance(error_msg, str) else str(error_msg) + logger.error( + "HANDLER: Error event received: message_chars={}", + text_len_hint(em), + ) + logger.info("HANDLER: Updating UI with error status") + await update_ui(format_status("❌", "Error"), force=True) + if tree: + await propagate_error_to_children(node_id, error_msg, "Parent task failed") + + return last_status, had_transcript_events diff --git a/messaging/platforms/discord.py b/messaging/platforms/discord.py index 201682b..53aa7dc 100644 --- a/messaging/platforms/discord.py +++ b/messaging/platforms/discord.py @@ -6,7 +6,6 @@ Implements MessagingPlatform for Discord using discord.py. import asyncio import contextlib -import os import tempfile from collections.abc import Awaitable, Callable from pathlib import Path @@ -14,7 +13,7 @@ from typing import Any, cast from loguru import logger -from core.anthropic import get_user_facing_error_message +from core.anthropic import format_user_error_preview from ..models import IncomingMessage from ..rendering.discord_markdown import format_status_discord @@ -97,15 +96,18 @@ class DiscordPlatform(MessagingPlatform): whisper_device: str = "cpu", hf_token: str = "", nvidia_nim_api_key: str = "", + messaging_rate_limit: int = 1, + messaging_rate_window: float = 1.0, + log_raw_messaging_content: bool = False, + log_api_error_tracebacks: bool = False, ): if not DISCORD_AVAILABLE: raise ImportError( "discord.py is required. Install with: pip install discord.py" ) - self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN") - raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS") - self.allowed_channel_ids = _parse_allowed_channels(raw_channels) + self.bot_token = bot_token + self.allowed_channel_ids = _parse_allowed_channels(allowed_channel_ids) if not self.bot_token: logger.warning("DISCORD_BOT_TOKEN not set") @@ -130,6 +132,10 @@ class DiscordPlatform(MessagingPlatform): self._voice_note_enabled = voice_note_enabled self._whisper_model = whisper_model self._whisper_device = whisper_device + self._messaging_rate_limit = messaging_rate_limit + self._messaging_rate_window = messaging_rate_window + self._log_raw_messaging_content = log_raw_messaging_content + self._log_api_error_tracebacks = log_api_error_tracebacks async def _handle_client_message(self, message: Any) -> None: """Adapter entry point used by the internal discord client.""" @@ -234,23 +240,40 @@ class DiscordPlatform(MessagingPlatform): status_message_id=status_msg_id, ) - logger.info( - "DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}", - channel_id, - message_id, - (transcribed[:80] + "..." if len(transcribed) > 80 else transcribed), - ) + if self._log_raw_messaging_content: + logger.info( + "DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}", + channel_id, + message_id, + ( + transcribed[:80] + "..." + if len(transcribed) > 80 + else transcribed + ), + ) + else: + logger.info( + "DISCORD_VOICE: chat_id={} message_id={} transcribed_len={}", + channel_id, + message_id, + len(transcribed), + ) await self._message_handler(incoming) return True except ValueError as e: - await message.reply(get_user_facing_error_message(e)[:200]) + await message.reply(format_user_error_preview(e)) return True except ImportError as e: - await message.reply(get_user_facing_error_message(e)[:200]) + await message.reply(format_user_error_preview(e)) return True except Exception as e: - logger.error(f"Voice transcription failed: {e}") + if self._log_api_error_tracebacks: + logger.error("Voice transcription failed: {}", e) + else: + logger.error( + "Voice transcription failed: exc_type={}", type(e).__name__ + ) await message.reply( "Could not transcribe voice note. Please try again or send text." ) @@ -285,16 +308,26 @@ class DiscordPlatform(MessagingPlatform): else None ) - text_preview = (message.content or "")[:80] - if len(message.content or "") > 80: - text_preview += "..." - logger.info( - "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", - channel_id, - message_id, - reply_to, - text_preview, - ) + raw_content = message.content or "" + if self._log_raw_messaging_content: + text_preview = raw_content[:80] + if len(raw_content) > 80: + text_preview += "..." + logger.info( + "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", + channel_id, + message_id, + reply_to, + text_preview, + ) + else: + logger.info( + "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_len={}", + channel_id, + message_id, + reply_to, + len(raw_content), + ) if not self._message_handler: return @@ -313,13 +346,14 @@ class DiscordPlatform(MessagingPlatform): try: await self._message_handler(incoming) except Exception as e: - logger.error(f"Error handling message: {e}") + if self._log_api_error_tracebacks: + logger.error("Error handling message: {}", e) + else: + logger.error("Error handling message: exc_type={}", type(e).__name__) with contextlib.suppress(Exception): await self.send_message( channel_id, - format_status_discord( - "Error:", get_user_facing_error_message(e)[:200] - ), + format_status_discord("Error:", format_user_error_preview(e)), reply_to=message_id, ) @@ -336,7 +370,10 @@ class DiscordPlatform(MessagingPlatform): from ..limiter import MessagingRateLimiter - self._limiter = await MessagingRateLimiter.get_instance() + self._limiter = await MessagingRateLimiter.get_instance( + rate_limit=self._messaging_rate_limit, + rate_window=self._messaging_rate_window, + ) self._start_task = asyncio.create_task( self._client.start(self.bot_token), diff --git a/messaging/platforms/factory.py b/messaging/platforms/factory.py index b40fe87..772a31d 100644 --- a/messaging/platforms/factory.py +++ b/messaging/platforms/factory.py @@ -28,6 +28,10 @@ class MessagingPlatformOptions: whisper_device: str = "cpu" hf_token: str = "" nvidia_nim_api_key: str = "" + messaging_rate_limit: int = 1 + messaging_rate_window: float = 1.0 + log_raw_messaging_content: bool = False + log_api_error_tracebacks: bool = False def create_messaging_platform( @@ -64,6 +68,10 @@ def create_messaging_platform( whisper_device=opts.whisper_device, hf_token=opts.hf_token, nvidia_nim_api_key=opts.nvidia_nim_api_key, + messaging_rate_limit=opts.messaging_rate_limit, + messaging_rate_window=opts.messaging_rate_window, + log_raw_messaging_content=opts.log_raw_messaging_content, + log_api_error_tracebacks=opts.log_api_error_tracebacks, ) if platform_type == "discord": @@ -82,6 +90,10 @@ def create_messaging_platform( whisper_device=opts.whisper_device, hf_token=opts.hf_token, nvidia_nim_api_key=opts.nvidia_nim_api_key, + messaging_rate_limit=opts.messaging_rate_limit, + messaging_rate_window=opts.messaging_rate_window, + log_raw_messaging_content=opts.log_raw_messaging_content, + log_api_error_tracebacks=opts.log_api_error_tracebacks, ) logger.warning( diff --git a/messaging/platforms/telegram.py b/messaging/platforms/telegram.py index d131e44..a8f34df 100644 --- a/messaging/platforms/telegram.py +++ b/messaging/platforms/telegram.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any from loguru import logger -from core.anthropic import get_user_facing_error_message +from core.anthropic import format_user_error_preview if TYPE_CHECKING: from telegram import Update @@ -68,14 +68,18 @@ class TelegramPlatform(MessagingPlatform): whisper_device: str = "cpu", hf_token: str = "", nvidia_nim_api_key: str = "", + messaging_rate_limit: int = 1, + messaging_rate_window: float = 1.0, + log_raw_messaging_content: bool = False, + log_api_error_tracebacks: bool = False, ): if not TELEGRAM_AVAILABLE: raise ImportError( "python-telegram-bot is required. Install with: pip install python-telegram-bot" ) - self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN") - self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID") + self.bot_token = bot_token + self.allowed_user_id = allowed_user_id if not self.bot_token: # We don't raise here to allow instantiation for testing/conditional logic, @@ -97,6 +101,10 @@ class TelegramPlatform(MessagingPlatform): self._voice_note_enabled = voice_note_enabled self._whisper_model = whisper_model self._whisper_device = whisper_device + self._messaging_rate_limit = messaging_rate_limit + self._messaging_rate_window = messaging_rate_window + self._log_raw_messaging_content = log_raw_messaging_content + self._log_api_error_tracebacks = log_api_error_tracebacks async def _register_pending_voice( self, chat_id: str, voice_msg_id: str, status_msg_id: str @@ -172,7 +180,10 @@ class TelegramPlatform(MessagingPlatform): # Initialize rate limiter from ..limiter import MessagingRateLimiter - self._limiter = await MessagingRateLimiter.get_instance() + self._limiter = await MessagingRateLimiter.get_instance( + rate_limit=self._messaging_rate_limit, + rate_window=self._messaging_rate_window, + ) # Send startup notification try: @@ -187,7 +198,13 @@ class TelegramPlatform(MessagingPlatform): startup_text, ) except Exception as e: - logger.warning(f"Could not send startup message: {e}") + if self._log_api_error_tracebacks: + logger.warning("Could not send startup message: {}", e) + else: + logger.warning( + "Could not send startup message: exc_type={}", + type(e).__name__, + ) logger.info("Telegram platform started (Bot API)") @@ -506,16 +523,26 @@ class TelegramPlatform(MessagingPlatform): if getattr(update.message, "message_thread_id", None) is not None else None ) - text_preview = (update.message.text or "")[:80] - if len(update.message.text or "") > 80: - text_preview += "..." - logger.info( - "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", - chat_id, - message_id, - reply_to, - text_preview, - ) + raw_text = update.message.text or "" + if self._log_raw_messaging_content: + text_preview = raw_text[:80] + if len(raw_text) > 80: + text_preview += "..." + logger.info( + "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", + chat_id, + message_id, + reply_to, + text_preview, + ) + else: + logger.info( + "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_len={}", + chat_id, + message_id, + reply_to, + len(raw_text), + ) if not self._message_handler: return @@ -534,11 +561,14 @@ class TelegramPlatform(MessagingPlatform): try: await self._message_handler(incoming) except Exception as e: - logger.error(f"Error handling message: {e}") + if self._log_api_error_tracebacks: + logger.error("Error handling message: {}", e) + else: + logger.error("Error handling message: exc_type={}", type(e).__name__) with contextlib.suppress(Exception): await self.send_message( chat_id, - f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}", + f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(format_user_error_preview(e))}", reply_to=incoming.message_id, message_thread_id=thread_id, parse_mode="MarkdownV2", @@ -631,20 +661,37 @@ class TelegramPlatform(MessagingPlatform): status_message_id=status_msg_id, ) - logger.info( - "TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}", - chat_id, - message_id, - (transcribed[:80] + "..." if len(transcribed) > 80 else transcribed), - ) + if self._log_raw_messaging_content: + logger.info( + "TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}", + chat_id, + message_id, + ( + transcribed[:80] + "..." + if len(transcribed) > 80 + else transcribed + ), + ) + else: + logger.info( + "TELEGRAM_VOICE: chat_id={} message_id={} transcribed_len={}", + chat_id, + message_id, + len(transcribed), + ) await self._message_handler(incoming) except ValueError as e: - await update.message.reply_text(get_user_facing_error_message(e)[:200]) + await update.message.reply_text(format_user_error_preview(e)) except ImportError as e: - await update.message.reply_text(get_user_facing_error_message(e)[:200]) + await update.message.reply_text(format_user_error_preview(e)) except Exception as e: - logger.error(f"Voice transcription failed: {e}") + if self._log_api_error_tracebacks: + logger.error("Voice transcription failed: {}", e) + else: + logger.error( + "Voice transcription failed: exc_type={}", type(e).__name__ + ) await update.message.reply_text( "Could not transcribe voice note. Please try again or send text." ) diff --git a/messaging/safe_diagnostics.py b/messaging/safe_diagnostics.py new file mode 100644 index 0000000..add30c5 --- /dev/null +++ b/messaging/safe_diagnostics.py @@ -0,0 +1,17 @@ +"""Helpers for redacting user-derived content from log lines.""" + +from __future__ import annotations + + +def format_exception_for_log(exc: BaseException, *, log_full_message: bool) -> str: + """Return exception type and optionally ``str(exc)`` for operator diagnostics.""" + if log_full_message: + return f"{type(exc).__name__}: {exc}" + return type(exc).__name__ + + +def text_len_hint(text: str | None) -> int: + """Length of text for metadata-only logging (0 when missing).""" + if not text: + return 0 + return len(text) diff --git a/messaging/session.py b/messaging/session.py index a4e4e9d..bbb0cdc 100644 --- a/messaging/session.py +++ b/messaging/session.py @@ -5,8 +5,10 @@ Provides persistent storage for mapping platform messages to Claude CLI session and message trees for conversation continuation. """ +import contextlib import json import os +import tempfile import threading from datetime import UTC, datetime from typing import Any @@ -22,7 +24,12 @@ class SessionStore: Platform-agnostic: works with any messaging platform. """ - def __init__(self, storage_path: str = "sessions.json"): + def __init__( + self, + storage_path: str = "sessions.json", + *, + message_log_cap: int | None = None, + ): self.storage_path = storage_path self._lock = threading.Lock() self._trees: dict[str, dict] = {} # root_id -> tree data @@ -34,11 +41,7 @@ class SessionStore: self._dirty = False self._save_timer: threading.Timer | None = None self._save_debounce_secs = 0.5 - cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip() - try: - self._message_log_cap: int | None = int(cap_raw) if cap_raw else None - except ValueError: - self._message_log_cap = None + self._message_log_cap: int | None = message_log_cap self._load() def _make_chat_key(self, platform: str, chat_id: str) -> str: @@ -104,9 +107,22 @@ class SessionStore: } def _write_data(self, data: dict) -> None: - """Write data dict to disk. Must be called WITHOUT holding self._lock.""" - with open(self.storage_path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) + """Atomically write data dict to disk. Must be called WITHOUT holding self._lock.""" + abs_target = os.path.abspath(self.storage_path) + dir_name = os.path.dirname(abs_target) or "." + fd, tmp_path = tempfile.mkstemp( + dir=dir_name, prefix=".sessions.", suffix=".tmp.json" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, abs_target) + except BaseException: + with contextlib.suppress(OSError): + os.unlink(tmp_path) + raise def _schedule_save(self) -> None: """Schedule a debounced save. Caller must hold self._lock.""" diff --git a/messaging/transcript.py b/messaging/transcript.py index 073f40b..f1a015d 100644 --- a/messaging/transcript.py +++ b/messaging/transcript.py @@ -9,7 +9,6 @@ the transcript grows over time and older content must be truncated. from __future__ import annotations import json -import os from abc import ABC, abstractmethod from collections import deque from collections.abc import Callable, Iterable @@ -210,7 +209,12 @@ class RenderCtx: class TranscriptBuffer: """Maintains an ordered, truncatable transcript of events.""" - def __init__(self, *, show_tool_results: bool = True) -> None: + def __init__( + self, + *, + show_tool_results: bool = True, + debug_subagent_stack: bool = False, + ) -> None: self._segments: list[Segment] = [] self._open_thinking_by_index: dict[int, ThinkingSegment] = {} self._open_text_by_index: dict[int, TextSegment] = {} @@ -227,7 +231,7 @@ class TranscriptBuffer: self._subagent_stack: list[str] = [] # Parallel stack of segments for rendering nested subagents. self._subagent_segments: list[SubagentSegment] = [] - self._debug_subagent_stack = os.getenv("DEBUG_SUBAGENT_STACK") == "1" + self._debug_subagent_stack = debug_subagent_stack def _in_subagent(self) -> bool: return bool(self._subagent_stack) diff --git a/messaging/transcription.py b/messaging/transcription.py index 389a9be..9a5f01f 100644 --- a/messaging/transcription.py +++ b/messaging/transcription.py @@ -5,32 +5,18 @@ Supports: - NVIDIA NIM: NVIDIA NIM Whisper/Parakeet """ -import os from pathlib import Path from typing import Any from loguru import logger -from config.settings import get_settings +from providers.nvidia_nim.voice import ( + transcribe_audio_file as transcribe_nvidia_nim_audio, +) # Max file size in bytes (25 MB) MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024 -# NVIDIA NIM Whisper model mapping: (function_id, language_code) -_NIM_MODEL_MAP: dict[str, tuple[str, str]] = { - "nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"), - "nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"), - "nvidia/parakeet-ctc-0.6b-es": ("None", "es-US"), - "nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"), - "nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"), - "nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"), - "nvidia/parakeet-1.1b-rnnt-multilingual-asr": ( - "71203149-d3b7-4460-8231-1be2543a1fca", - "", - ), - "openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"), -} - # Short model names -> full Hugging Face model IDs (for local Whisper) _MODEL_MAP: dict[str, str] = { "tiny": "openai/whisper-tiny", @@ -51,22 +37,19 @@ def _resolve_model_id(whisper_model: str) -> str: return _MODEL_MAP.get(whisper_model, whisper_model) -def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any: +def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any: """Lazy-load transformers Whisper pipeline. Raises ImportError if not installed.""" global _pipeline_cache if device not in ("cpu", "cuda"): raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}") - resolved_token = ( - hf_token if hf_token is not None else get_settings().hf_token - ) or "" + resolved_token = hf_token or "" cache_key = (model_id, device, resolved_token) if cache_key not in _pipeline_cache: try: import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline - if resolved_token: - os.environ["HF_TOKEN"] = resolved_token + hf_auth_token = resolved_token or None use_cuda = device == "cuda" and torch.cuda.is_available() pipe_device = "cuda:0" if use_cuda else "cpu" @@ -77,9 +60,10 @@ def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> An dtype=model_dtype, low_cpu_mem_usage=True, attn_implementation="sdpa", + token=hf_auth_token, ) model = model.to(pipe_device) - processor = AutoProcessor.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token) pipe = pipeline( "automatic-speech-recognition", @@ -119,7 +103,7 @@ def transcribe_audio( file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported) mime_type: MIME type of the audio (e.g. "audio/ogg") whisper_model: Model ID or short name (local) or NVIDIA NIM model - whisper_device: "cpu" | "cuda" | "nvidia_nim" (defaults to WHISPER_DEVICE env var) + whisper_device: "cpu" | "cuda" | "nvidia_nim" Returns: Transcribed text @@ -140,8 +124,8 @@ def transcribe_audio( ) if whisper_device == "nvidia_nim": - return _transcribe_nim( - file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key + return transcribe_nvidia_nim_audio( + file_path, whisper_model, api_key=nvidia_nim_api_key ) return _transcribe_local( file_path, whisper_model, whisper_device, hf_token=hf_token @@ -169,8 +153,7 @@ def _transcribe_local( ) -> str: """Transcribe using transformers Whisper pipeline.""" model_id = _resolve_model_id(whisper_model) - token: str | None = hf_token if hf_token else None - pipe = _get_pipeline(model_id, whisper_device, hf_token=token) + pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token) audio = _load_audio(file_path) result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"}) text = result.get("text", "") or "" @@ -179,65 +162,3 @@ def _transcribe_local( result_text = text.strip() logger.debug(f"Local transcription: {len(result_text)} chars") return result_text or "(no speech detected)" - - -def _transcribe_nim( - file_path: Path, model: str, *, nvidia_nim_api_key: str = "" -) -> str: - """Transcribe using NVIDIA NIM Whisper API via Riva gRPC client.""" - try: - import riva.client - except ImportError as e: - raise ImportError( - "NVIDIA NIM transcription requires the voice extra. " - "Install with: uv sync --extra voice" - ) from e - - api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key - - # Look up function ID and language code from model mapping - model_config = _NIM_MODEL_MAP.get(model) - if not model_config: - raise ValueError( - f"No NVIDIA NIM config found for model: {model}. " - f"Supported models: {', '.join(_NIM_MODEL_MAP.keys())}" - ) - function_id, language_code = model_config - - # Riva server configuration - server = "grpc.nvcf.nvidia.com:443" - - # Auth with SSL and metadata - auth = riva.client.Auth( - use_ssl=True, - uri=server, - metadata_args=[ - ["function-id", function_id], - ["authorization", f"Bearer {api_key}"], - ], - ) - - asr_service = riva.client.ASRService(auth) - - # Configure recognition - language_code from model config - config = riva.client.RecognitionConfig( - language_code=language_code, - max_alternatives=1, - verbatim_transcripts=True, - ) - - # Read audio file - with open(file_path, "rb") as f: - data = f.read() - - # Perform offline recognition - response = asr_service.offline_recognize(data, config) - - # Extract text from response - use getattr for safe attribute access - transcript = "" - results = getattr(response, "results", None) - if results and results[0].alternatives: - transcript = results[0].alternatives[0].transcript - - logger.debug(f"NIM transcription: {len(transcript)} chars") - return transcript or "(no speech detected)" diff --git a/messaging/trees/processor.py b/messaging/trees/processor.py deleted file mode 100644 index 5fdcca6..0000000 --- a/messaging/trees/processor.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Async queue processor for message trees. - -Handles the async processing lifecycle of tree nodes. -""" - -import asyncio -from collections.abc import Awaitable, Callable - -from loguru import logger - -from core.anthropic import get_user_facing_error_message - -from .data import MessageNode, MessageState, MessageTree - - -class TreeQueueProcessor: - """ - Handles async queue processing for a single tree. - - Separates the async processing logic from the data management. - """ - - def __init__( - self, - queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, - node_started_callback: Callable[[MessageTree, str], Awaitable[None]] - | None = None, - ): - self._queue_update_callback = queue_update_callback - self._node_started_callback = node_started_callback - - def set_queue_update_callback( - self, - queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None, - ) -> None: - """Update the callback used to refresh queue positions.""" - self._queue_update_callback = queue_update_callback - - def set_node_started_callback( - self, - node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None, - ) -> None: - """Update the callback used when a queued node starts processing.""" - self._node_started_callback = node_started_callback - - async def _notify_queue_updated(self, tree: MessageTree) -> None: - """Invoke queue update callback if set.""" - if not self._queue_update_callback: - return - try: - await self._queue_update_callback(tree) - except Exception as e: - logger.warning(f"Queue update callback failed: {e}") - - async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None: - """Invoke node started callback if set.""" - if not self._node_started_callback: - return - try: - await self._node_started_callback(tree, node_id) - except Exception as e: - logger.warning(f"Node started callback failed: {e}") - - async def process_node( - self, - tree: MessageTree, - node: MessageNode, - processor: Callable[[str, MessageNode], Awaitable[None]], - ) -> None: - """Process a single node and then check the queue.""" - # Skip if already in terminal state (e.g. from error propagation) - if node.state == MessageState.ERROR: - logger.info( - f"Skipping node {node.node_id} as it is already in state {node.state}" - ) - # Still need to check for next messages - await self._process_next(tree, processor) - return - - try: - await processor(node.node_id, node) - except asyncio.CancelledError: - logger.info(f"Task for node {node.node_id} was cancelled") - raise - except Exception as e: - logger.error(f"Error processing node {node.node_id}: {e}") - await tree.update_state( - node.node_id, - MessageState.ERROR, - error_message=get_user_facing_error_message(e), - ) - finally: - async with tree.with_lock(): - tree.clear_current_node() - # Check if there are more messages in the queue - await self._process_next(tree, processor) - - async def _process_next( - self, - tree: MessageTree, - processor: Callable[[str, MessageNode], Awaitable[None]], - ) -> None: - """Process the next message in queue, if any.""" - next_node_id = None - node = None - async with tree.with_lock(): - next_node_id = await tree.dequeue() - - if not next_node_id: - tree.set_processing_state(None, False) - logger.debug(f"Tree {tree.root_id} queue empty, marking as free") - return - - tree.set_processing_state(next_node_id, True) - logger.info(f"Processing next queued node {next_node_id}") - - # Process next node (outside lock) - node = tree.get_node(next_node_id) - if node: - tree.set_current_task( - asyncio.create_task(self.process_node(tree, node, processor)) - ) - - # Notify that this node has started processing and refresh queue positions. - if next_node_id: - await self._notify_node_started(tree, next_node_id) - await self._notify_queue_updated(tree) - - async def enqueue_and_start( - self, - tree: MessageTree, - node_id: str, - processor: Callable[[str, MessageNode], Awaitable[None]], - ) -> bool: - """ - Enqueue a node or start processing immediately. - - Args: - tree: The message tree - node_id: Node to process - processor: Async function to process the node - - Returns: - True if queued, False if processing immediately - """ - async with tree.with_lock(): - if tree.is_processing: - tree.put_queue_unlocked(node_id) - queue_size = tree.get_queue_size() - logger.info(f"Queued node {node_id}, position {queue_size}") - return True - else: - tree.set_processing_state(node_id, True) - - # Process outside the lock - node = tree.get_node(node_id) - if node: - tree.set_current_task( - asyncio.create_task(self.process_node(tree, node, processor)) - ) - return False - - def cancel_current(self, tree: MessageTree) -> bool: - """Cancel the currently running task in a tree.""" - return tree.cancel_current_task() diff --git a/messaging/trees/queue_manager.py b/messaging/trees/queue_manager.py index fbf7029..2c0cecd 100644 --- a/messaging/trees/queue_manager.py +++ b/messaging/trees/queue_manager.py @@ -1,30 +1,345 @@ -"""Tree-Based Message Queue Manager - Refactored. - -Coordinates data access, async processing, and error handling. -Uses TreeRepository for data, TreeQueueProcessor for async logic. -""" +"""Tree-based message queue: index, async node processor, and public manager API.""" import asyncio from collections.abc import Awaitable, Callable from loguru import logger +from config.settings import get_settings +from core.anthropic import get_user_facing_error_message + from ..models import IncomingMessage +from ..safe_diagnostics import format_exception_for_log from .data import MessageNode, MessageState, MessageTree -from .processor import TreeQueueProcessor -from .repository import TreeRepository + + +class TreeRepository: + """ + In-memory index of trees and node-to-root mappings. + + Used only by :class:`TreeQueueManager`; kept as a named type for tests. + """ + + def __init__(self) -> None: + self._trees: dict[str, MessageTree] = {} # root_id -> tree + self._node_to_tree: dict[str, str] = {} # node_id -> root_id + + def get_tree(self, root_id: str) -> MessageTree | None: + """Get a tree by its root ID.""" + return self._trees.get(root_id) + + def get_tree_for_node(self, node_id: str) -> MessageTree | None: + """Get the tree containing a given node.""" + root_id = self._node_to_tree.get(node_id) + if not root_id: + return None + return self._trees.get(root_id) + + def get_node(self, node_id: str) -> MessageNode | None: + """Get a node from any tree.""" + tree = self.get_tree_for_node(node_id) + return tree.get_node(node_id) if tree else None + + def add_tree(self, root_id: str, tree: MessageTree) -> None: + """Add a new tree to the repository.""" + self._trees[root_id] = tree + self._node_to_tree[root_id] = root_id + logger.debug("TREE_REPO: add_tree root_id={}", root_id) + + def register_node(self, node_id: str, root_id: str) -> None: + """Register a node ID to a tree.""" + self._node_to_tree[node_id] = root_id + logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id) + + def has_node(self, node_id: str) -> bool: + """Check if a node is registered in any tree.""" + return node_id in self._node_to_tree + + def tree_count(self) -> int: + """Get the number of trees in the repository.""" + return len(self._trees) + + def is_tree_busy(self, root_id: str) -> bool: + """Check if a tree is currently processing.""" + tree = self._trees.get(root_id) + return tree.is_processing if tree else False + + def is_node_tree_busy(self, node_id: str) -> bool: + """Check if the tree containing a node is busy.""" + tree = self.get_tree_for_node(node_id) + return tree.is_processing if tree else False + + def get_queue_size(self, node_id: str) -> int: + """Get queue size for the tree containing a node.""" + tree = self.get_tree_for_node(node_id) + return tree.get_queue_size() if tree else 0 + + def resolve_parent_node_id(self, msg_id: str) -> str | None: + """ + Resolve a message ID to the actual parent node ID. + + Handles the case where msg_id is a status message ID + (which maps to the tree but isn't an actual node). + + Returns: + The node_id to use as parent, or None if not found + """ + tree = self.get_tree_for_node(msg_id) + if not tree: + return None + + if tree.has_node(msg_id): + return msg_id + + node = tree.find_node_by_status_message(msg_id) + if node: + return node.node_id + + return None + + def get_pending_children(self, node_id: str) -> list[MessageNode]: + """ + Get all pending child nodes (recursively) of a given node. + + Used for error propagation - when a node fails, its pending + children should also be marked as failed. + """ + tree = self.get_tree_for_node(node_id) + if not tree: + return [] + + pending: list[MessageNode] = [] + stack = [node_id] + + while stack: + current_id = stack.pop() + node = tree.get_node(current_id) + if not node: + continue + for child_id in node.children_ids: + child = tree.get_node(child_id) + if child and child.state == MessageState.PENDING: + pending.append(child) + stack.append(child_id) + + return pending + + def all_trees(self) -> list[MessageTree]: + """Get all trees in the repository.""" + return list(self._trees.values()) + + def tree_ids(self) -> list[str]: + """Get all tree root IDs.""" + return list(self._trees.keys()) + + def unregister_nodes(self, node_ids: list[str]) -> None: + """Remove node IDs from the node-to-tree mapping.""" + for nid in node_ids: + self._node_to_tree.pop(nid, None) + + def remove_tree(self, root_id: str) -> MessageTree | None: + """ + Remove a tree and all its node mappings from the repository. + + Returns: + The removed tree, or None if not found. + """ + tree = self._trees.pop(root_id, None) + if not tree: + return None + for node in tree.all_nodes(): + self._node_to_tree.pop(node.node_id, None) + logger.debug("TREE_REPO: remove_tree root_id={}", root_id) + return tree + + def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: + """Get all message IDs (incoming + status) for a given platform/chat.""" + msg_ids: set[str] = set() + for tree in self._trees.values(): + for node in tree.all_nodes(): + if str(node.incoming.platform) == str(platform) and str( + node.incoming.chat_id + ) == str(chat_id): + if node.incoming.message_id is not None: + msg_ids.add(str(node.incoming.message_id)) + if node.status_message_id: + msg_ids.add(str(node.status_message_id)) + return msg_ids + + def to_dict(self) -> dict: + """Serialize all trees.""" + return { + "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()}, + "node_to_tree": self._node_to_tree.copy(), + } + + @classmethod + def from_dict(cls, data: dict) -> TreeRepository: + """Deserialize from dictionary.""" + repo = cls() + for root_id, tree_data in data.get("trees", {}).items(): + repo._trees[root_id] = MessageTree.from_dict(tree_data) + repo._node_to_tree = data.get("node_to_tree", {}) + return repo + + +class TreeQueueProcessor: + """ + Per-tree async queue processing (one manager owns one processor instance). + """ + + def __init__( + self, + queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, + node_started_callback: Callable[[MessageTree, str], Awaitable[None]] + | None = None, + ) -> None: + self._queue_update_callback = queue_update_callback + self._node_started_callback = node_started_callback + + def set_queue_update_callback( + self, + queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None, + ) -> None: + """Update the callback used to refresh queue positions.""" + self._queue_update_callback = queue_update_callback + + def set_node_started_callback( + self, + node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None, + ) -> None: + """Update the callback used when a queued node starts processing.""" + self._node_started_callback = node_started_callback + + async def _notify_queue_updated(self, tree: MessageTree) -> None: + """Invoke queue update callback if set.""" + if not self._queue_update_callback: + return + try: + await self._queue_update_callback(tree) + except Exception as e: + d = get_settings().log_messaging_error_details + logger.warning( + "Queue update callback failed: {}", + format_exception_for_log(e, log_full_message=d), + ) + + async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None: + """Invoke node started callback if set.""" + if not self._node_started_callback: + return + try: + await self._node_started_callback(tree, node_id) + except Exception as e: + d = get_settings().log_messaging_error_details + logger.warning( + "Node started callback failed: {}", + format_exception_for_log(e, log_full_message=d), + ) + + async def process_node( + self, + tree: MessageTree, + node: MessageNode, + processor: Callable[[str, MessageNode], Awaitable[None]], + ) -> None: + """Process a single node and then check the queue.""" + if node.state == MessageState.ERROR: + logger.info( + f"Skipping node {node.node_id} as it is already in state {node.state}" + ) + await self._process_next(tree, processor) + return + + try: + await processor(node.node_id, node) + except asyncio.CancelledError: + logger.info(f"Task for node {node.node_id} was cancelled") + raise + except Exception as e: + d = get_settings().log_messaging_error_details + logger.error( + "Error processing node {}: {}", + node.node_id, + format_exception_for_log(e, log_full_message=d), + ) + await tree.update_state( + node.node_id, + MessageState.ERROR, + error_message=get_user_facing_error_message(e), + ) + finally: + async with tree.with_lock(): + tree.clear_current_node() + await self._process_next(tree, processor) + + async def _process_next( + self, + tree: MessageTree, + processor: Callable[[str, MessageNode], Awaitable[None]], + ) -> None: + """Process the next message in queue, if any.""" + next_node_id = None + async with tree.with_lock(): + next_node_id = await tree.dequeue() + + if not next_node_id: + tree.set_processing_state(None, False) + logger.debug(f"Tree {tree.root_id} queue empty, marking as free") + return + + tree.set_processing_state(next_node_id, True) + logger.info(f"Processing next queued node {next_node_id}") + + node = tree.get_node(next_node_id) + if node: + tree.set_current_task( + asyncio.create_task(self.process_node(tree, node, processor)) + ) + + if next_node_id: + await self._notify_node_started(tree, next_node_id) + await self._notify_queue_updated(tree) + + async def enqueue_and_start( + self, + tree: MessageTree, + node_id: str, + processor: Callable[[str, MessageNode], Awaitable[None]], + ) -> bool: + """ + Enqueue a node or start processing immediately. + + Returns: + True if queued, False if processing immediately + """ + async with tree.with_lock(): + if tree.is_processing: + tree.put_queue_unlocked(node_id) + queue_size = tree.get_queue_size() + logger.info(f"Queued node {node_id}, position {queue_size}") + return True + else: + tree.set_processing_state(node_id, True) + + node = tree.get_node(node_id) + if node: + tree.set_current_task( + asyncio.create_task(self.process_node(tree, node, processor)) + ) + return False + + def cancel_current(self, tree: MessageTree) -> bool: + """Cancel the currently running task in a tree.""" + return tree.cancel_current_task() class TreeQueueManager: """ - Manages multiple message trees. Facade that coordinates components. + Manages multiple message trees: index + async processing. Each new conversation creates a new tree. Replies to existing messages add nodes to existing trees. - - Components: - - TreeRepository: Data access layer - - TreeQueueProcessor: Async queue processing """ def __init__( @@ -33,7 +348,7 @@ class TreeQueueManager: node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None = None, _repository: TreeRepository | None = None, - ): + ) -> None: self._repository = _repository or TreeRepository() self._processor = TreeQueueProcessor( queue_update_callback=queue_update_callback, @@ -101,7 +416,6 @@ class TreeQueueManager: if not tree: raise ValueError(f"Parent node {parent_node_id} not found in any tree") - # Add node (tree has its own lock) - outside manager lock to avoid deadlock node = await tree.add_node( node_id=node_id, incoming=incoming, @@ -228,7 +542,6 @@ class TreeQueueManager: cleanup_count = 0 async with tree.with_lock(): - # 1. Cancel running task if tree.cancel_current_task(): current_id = tree.current_node_id if current_id: @@ -240,12 +553,10 @@ class TreeQueueManager: tree.set_node_error_sync(node, "Cancelled by user") cancelled_nodes.append(node) - # 2. Drain queue and mark nodes as cancelled queue_nodes = tree.drain_queue_and_mark_cancelled() cancelled_nodes.extend(queue_nodes) cancelled_ids = {n.node_id for n in cancelled_nodes} - # 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR for node in tree.all_nodes(): if ( node.state in (MessageState.PENDING, MessageState.IN_PROGRESS) @@ -269,10 +580,6 @@ class TreeQueueManager: """ Cancel a single node (queued or in-progress) without affecting other nodes. - - If the node is currently running, cancels the current asyncio task. - - If the node is queued, removes it from the queue. - - Marks the node as ERROR with "Cancelled by user". - Returns: List containing the cancelled node if it was cancellable, else empty list. """ @@ -351,8 +658,6 @@ class TreeQueueManager: async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]: """ Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants). - - Does not call cli_manager.stop_all(). Returns list of cancelled nodes. """ tree = self._repository.get_tree_for_node(branch_root_id) if not tree: @@ -435,3 +740,10 @@ class TreeQueueManager: node_started_callback=node_started_callback, _repository=TreeRepository.from_dict(data), ) + + +__all__ = [ + "TreeQueueManager", + "TreeQueueProcessor", + "TreeRepository", +] diff --git a/messaging/trees/repository.py b/messaging/trees/repository.py deleted file mode 100644 index 2dae1fb..0000000 --- a/messaging/trees/repository.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Repository for message tree data access. - -Provides data access layer for managing trees and node mappings. -""" - -from loguru import logger - -from .data import MessageNode, MessageState, MessageTree - - -class TreeRepository: - """ - Repository for message tree data access. - - Manages the storage and lookup of trees and node-to-tree mappings. - """ - - def __init__(self): - self._trees: dict[str, MessageTree] = {} # root_id -> tree - self._node_to_tree: dict[str, str] = {} # node_id -> root_id - - def get_tree(self, root_id: str) -> MessageTree | None: - """Get a tree by its root ID.""" - return self._trees.get(root_id) - - def get_tree_for_node(self, node_id: str) -> MessageTree | None: - """Get the tree containing a given node.""" - root_id = self._node_to_tree.get(node_id) - if not root_id: - return None - return self._trees.get(root_id) - - def get_node(self, node_id: str) -> MessageNode | None: - """Get a node from any tree.""" - tree = self.get_tree_for_node(node_id) - return tree.get_node(node_id) if tree else None - - def add_tree(self, root_id: str, tree: MessageTree) -> None: - """Add a new tree to the repository.""" - self._trees[root_id] = tree - self._node_to_tree[root_id] = root_id - logger.debug("TREE_REPO: add_tree root_id={}", root_id) - - def register_node(self, node_id: str, root_id: str) -> None: - """Register a node ID to a tree.""" - self._node_to_tree[node_id] = root_id - logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id) - - def has_node(self, node_id: str) -> bool: - """Check if a node is registered in any tree.""" - return node_id in self._node_to_tree - - def tree_count(self) -> int: - """Get the number of trees in the repository.""" - return len(self._trees) - - def is_tree_busy(self, root_id: str) -> bool: - """Check if a tree is currently processing.""" - tree = self._trees.get(root_id) - return tree.is_processing if tree else False - - def is_node_tree_busy(self, node_id: str) -> bool: - """Check if the tree containing a node is busy.""" - tree = self.get_tree_for_node(node_id) - return tree.is_processing if tree else False - - def get_queue_size(self, node_id: str) -> int: - """Get queue size for the tree containing a node.""" - tree = self.get_tree_for_node(node_id) - return tree.get_queue_size() if tree else 0 - - def resolve_parent_node_id(self, msg_id: str) -> str | None: - """ - Resolve a message ID to the actual parent node ID. - - Handles the case where msg_id is a status message ID - (which maps to the tree but isn't an actual node). - - Returns: - The node_id to use as parent, or None if not found - """ - tree = self.get_tree_for_node(msg_id) - if not tree: - return None - - # Check if msg_id is an actual node - if tree.has_node(msg_id): - return msg_id - - # Otherwise, it might be a status message - find the owning node - node = tree.find_node_by_status_message(msg_id) - if node: - return node.node_id - - return None - - def get_pending_children(self, node_id: str) -> list[MessageNode]: - """ - Get all pending child nodes (recursively) of a given node. - - Used for error propagation - when a node fails, its pending - children should also be marked as failed. - """ - tree = self.get_tree_for_node(node_id) - if not tree: - return [] - - pending: list[MessageNode] = [] - stack = [node_id] - - while stack: - current_id = stack.pop() - node = tree.get_node(current_id) - if not node: - continue - for child_id in node.children_ids: - child = tree.get_node(child_id) - if child and child.state == MessageState.PENDING: - pending.append(child) - stack.append(child_id) - - return pending - - def all_trees(self) -> list[MessageTree]: - """Get all trees in the repository.""" - return list(self._trees.values()) - - def tree_ids(self) -> list[str]: - """Get all tree root IDs.""" - return list(self._trees.keys()) - - def unregister_nodes(self, node_ids: list[str]) -> None: - """Remove node IDs from the node-to-tree mapping.""" - for nid in node_ids: - self._node_to_tree.pop(nid, None) - - def remove_tree(self, root_id: str) -> MessageTree | None: - """ - Remove a tree and all its node mappings from the repository. - - Returns: - The removed tree, or None if not found. - """ - tree = self._trees.pop(root_id, None) - if not tree: - return None - for node in tree.all_nodes(): - self._node_to_tree.pop(node.node_id, None) - logger.debug("TREE_REPO: remove_tree root_id={}", root_id) - return tree - - def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: - """Get all message IDs (incoming + status) for a given platform/chat. - - Note: O(total_nodes) scan. Acceptable because this is only called - from /clear (user-initiated, infrequent). - """ - msg_ids: set[str] = set() - for tree in self._trees.values(): - for node in tree.all_nodes(): - if str(node.incoming.platform) == str(platform) and str( - node.incoming.chat_id - ) == str(chat_id): - if node.incoming.message_id is not None: - msg_ids.add(str(node.incoming.message_id)) - if node.status_message_id: - msg_ids.add(str(node.status_message_id)) - return msg_ids - - def to_dict(self) -> dict: - """Serialize all trees.""" - return { - "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()}, - "node_to_tree": self._node_to_tree.copy(), - } - - @classmethod - def from_dict(cls, data: dict) -> TreeRepository: - """Deserialize from dictionary.""" - from .data import MessageTree - - repo = cls() - for root_id, tree_data in data.get("trees", {}).items(): - repo._trees[root_id] = MessageTree.from_dict(tree_data) - repo._node_to_tree = data.get("node_to_tree", {}) - return repo diff --git a/messaging/ui_updates.py b/messaging/ui_updates.py new file mode 100644 index 0000000..929202e --- /dev/null +++ b/messaging/ui_updates.py @@ -0,0 +1,101 @@ +"""Throttled platform UI updates driven by transcript rendering.""" + +from __future__ import annotations + +import time +from collections.abc import Callable + +from loguru import logger + +from .platforms.base import MessagingPlatform +from .safe_diagnostics import format_exception_for_log +from .transcript import RenderCtx, TranscriptBuffer + + +class ThrottledTranscriptEditor: + """Rate-limited status message edits from a growing transcript.""" + + def __init__( + self, + *, + platform: MessagingPlatform, + parse_mode: str | None, + get_limit_chars: Callable[[], int], + transcript: TranscriptBuffer, + render_ctx: RenderCtx, + node_id: str, + chat_id: str, + status_msg_id: str, + debug_platform_edits: bool, + log_messaging_error_details: bool = False, + ) -> None: + self._platform = platform + self._parse_mode = parse_mode + self._get_limit_chars = get_limit_chars + self._transcript = transcript + self._render_ctx = render_ctx + self._node_id = node_id + self._chat_id = chat_id + self._status_msg_id = status_msg_id + self._debug_platform_edits = debug_platform_edits + self._log_messaging_error_details = log_messaging_error_details + self._last_ui_update = 0.0 + self._last_displayed_text: str | None = None + self._last_status: str | None = None + + @property + def last_status(self) -> str | None: + return self._last_status + + async def update(self, status: str | None = None, *, force: bool = False) -> None: + """Render transcript + optional status line and edit the platform message.""" + now = time.time() + if not force and now - self._last_ui_update < 1.0: + return + + self._last_ui_update = now + if status is not None: + self._last_status = status + try: + display = self._transcript.render( + self._render_ctx, + limit_chars=self._get_limit_chars(), + status=status, + ) + except Exception as e: + logger.warning( + "Transcript render failed for node {}: {}", + self._node_id, + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) + return + if display and display != self._last_displayed_text: + logger.debug( + "PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}", + self._node_id, + self._chat_id, + self._status_msg_id, + bool(force), + status, + len(display), + ) + if self._debug_platform_edits: + logger.debug("PLATFORM_EDIT_TEXT:\n{}", display) + self._last_displayed_text = display + try: + await self._platform.queue_edit_message( + self._chat_id, + self._status_msg_id, + display, + parse_mode=self._parse_mode, + ) + except Exception as e: + logger.warning( + "Failed to update platform for node {}: {}", + self._node_id, + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) diff --git a/providers/anthropic_messages.py b/providers/anthropic_messages.py index c35c1d5..b501aac 100644 --- a/providers/anthropic_messages.py +++ b/providers/anthropic_messages.py @@ -2,19 +2,32 @@ from __future__ import annotations -import json from collections.abc import AsyncIterator, Iterator from typing import Any, Literal import httpx from loguru import logger -from core.anthropic import get_user_facing_error_message +from config.constants import ( + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES, +) +from core.anthropic import iter_provider_stream_error_sse_events +from core.anthropic.emitted_sse_tracker import EmittedNativeSseTracker +from core.anthropic.native_messages_request import ( + build_base_native_anthropic_request_body, +) +from core.anthropic.native_sse_block_policy import ( + NativeSseBlockPolicyState, + transform_native_sse_block_event, +) from providers.base import BaseProvider, ProviderConfig -from providers.error_mapping import map_error +from providers.error_mapping import ( + map_error, + user_visible_message_for_mapped_provider_error, +) from providers.rate_limit import GlobalRateLimiter -ANTHROPIC_DEFAULT_MAX_TOKENS = 81920 StreamChunkMode = Literal["line", "event"] @@ -64,25 +77,11 @@ class AnthropicMessagesTransport(BaseProvider): ) -> dict: """Build a native Anthropic request body.""" thinking_enabled = self._is_thinking_enabled(request, thinking_enabled) - body = request.model_dump(exclude_none=True) - - body.pop("extra_body", None) - body.pop("original_model", None) - body.pop("resolved_provider_model", None) - - if "thinking" in body: - thinking_cfg = body.pop("thinking") - if thinking_enabled and isinstance(thinking_cfg, dict): - thinking_payload = {"type": "enabled"} - budget_tokens = thinking_cfg.get("budget_tokens") - if isinstance(budget_tokens, int): - thinking_payload["budget_tokens"] = budget_tokens - body["thinking"] = thinking_payload - - if "max_tokens" not in body: - body["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS - - return body + return build_base_native_anthropic_request_body( + request, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=thinking_enabled, + ) async def _send_stream_request(self, body: dict) -> httpx.Response: """Create a streaming messages response.""" @@ -97,30 +96,68 @@ class AnthropicMessagesTransport(BaseProvider): async def _raise_for_status( self, response: httpx.Response, *, req_tag: str ) -> None: - """Raise for non-200 responses after logging the upstream body.""" + """Raise for non-200 responses after logging safe metadata (or capped body if opted in).""" try: response.raise_for_status() except httpx.HTTPStatusError as error: - response_text = await self._read_error_body(response) - if response_text: + if self._config.log_api_error_tracebacks: + preview, truncated = await self._read_error_body_preview( + response, NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES + ) + if preview: + text = preview.decode("utf-8", errors="replace") + logger.error( + "{}_ERROR:{} HTTP {} body_preview_bytes={} truncated={}: {}", + self._provider_name, + req_tag, + response.status_code, + len(preview), + truncated, + text, + ) + else: + logger.error( + "{}_ERROR:{} HTTP {} (empty error body)", + self._provider_name, + req_tag, + response.status_code, + ) + else: + cl = response.headers.get("content-length", "").strip() + extra = f" content_length_declared={cl}" if cl.isdigit() else "" logger.error( - "{}_ERROR:{} HTTP {}: {}", + "{}_ERROR:{} HTTP {}{}", self._provider_name, req_tag, response.status_code, - response_text, + extra, ) raise error - async def _read_error_body(self, response: httpx.Response) -> str: - """Read a response body for diagnostics.""" - aread = getattr(response, "aread", None) - if aread is None: - return "" - body = await aread() - if isinstance(body, bytes): - return body.decode("utf-8", errors="replace") - return str(body) + async def _read_error_body_preview( + self, response: httpx.Response, max_bytes: int + ) -> tuple[bytes, bool]: + """Read at most ``max_bytes`` from the error body for logging. Returns (preview, truncated).""" + if max_bytes <= 0: + return b"", False + received = 0 + parts: list[bytes] = [] + truncated = False + async for chunk in response.aiter_bytes(chunk_size=65_536): + if received >= max_bytes: + truncated = True + break + remaining = max_bytes - received + take = chunk if len(chunk) <= remaining else chunk[:remaining] + if take: + parts.append(take) + received += len(take) + if len(chunk) > len(take): + truncated = True + break + if received >= max_bytes: + break + return (b"".join(parts), truncated) async def _iter_sse_lines(self, response: httpx.Response) -> AsyncIterator[str]: """Yield raw SSE line chunks preserving local provider behavior.""" @@ -145,6 +182,8 @@ class AnthropicMessagesTransport(BaseProvider): def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any: """Return per-stream provider state for event transformation.""" + if self.stream_chunk_mode == "line": + return NativeSseBlockPolicyState() return None def _transform_stream_event( @@ -155,6 +194,10 @@ class AnthropicMessagesTransport(BaseProvider): thinking_enabled: bool, ) -> str | None: """Transform or drop a grouped SSE event before yielding it downstream.""" + if isinstance(state, NativeSseBlockPolicyState): + return transform_native_sse_block_event( + event, state, thinking_enabled=thinking_enabled + ) return event def _format_error_message(self, base_message: str, request_id: str | None) -> str: @@ -166,15 +209,11 @@ class AnthropicMessagesTransport(BaseProvider): def _get_error_message(self, error: Exception, request_id: str | None) -> str: """Map an exception into a user-facing provider error message.""" mapped_error = map_error(error, rate_limiter=self._global_rate_limiter) - if getattr(mapped_error, "status_code", None) == 405: - base_message = ( - f"Upstream provider {self._provider_name} rejected the request method " - "or endpoint (HTTP 405)." - ) - else: - base_message = get_user_facing_error_message( - mapped_error, read_timeout_s=self._config.http_read_timeout - ) + base_message = user_visible_message_for_mapped_provider_error( + mapped_error, + provider_name=self._provider_name, + read_timeout_s=self._config.http_read_timeout, + ) return self._format_error_message(base_message, request_id) def _emit_error_events( @@ -185,12 +224,14 @@ class AnthropicMessagesTransport(BaseProvider): error_message: str, sent_any_event: bool, ) -> Iterator[str]: - """Emit a native Anthropic error event.""" - error_event = { - "type": "error", - "error": {"type": "api_error", "message": error_message}, - } - yield f"event: error\ndata: {json.dumps(error_event)}\n\n" + """Emit the same Anthropic message lifecycle used by OpenAI-compat providers.""" + yield from iter_provider_stream_error_sse_events( + request=request, + input_tokens=input_tokens, + error_message=error_message, + sent_any_event=sent_any_event, + log_raw_sse_events=self._config.log_raw_sse_events, + ) async def _iter_stream_chunks( self, @@ -200,6 +241,21 @@ class AnthropicMessagesTransport(BaseProvider): thinking_enabled: bool, ) -> AsyncIterator[str]: """Yield stream chunks according to the provider's observable chunk shape.""" + if self.stream_chunk_mode == "line" and isinstance( + state, NativeSseBlockPolicyState + ): + async for event in self._iter_sse_events(response): + output_event = self._transform_stream_event( + event, + state, + thinking_enabled=thinking_enabled, + ) + if output_event is None: + continue + for line in output_event.splitlines(keepends=True): + yield line + return + if self.stream_chunk_mode == "line": async for chunk in self._iter_sse_lines(response): yield chunk @@ -240,15 +296,28 @@ class AnthropicMessagesTransport(BaseProvider): response: httpx.Response | None = None sent_any_event = False state = self._new_stream_state(request, thinking_enabled=thinking_enabled) + emitted_tracker = EmittedNativeSseTracker() async with self._global_rate_limiter.concurrency_slot(): try: - response = await self._global_rate_limiter.execute_with_retry( - self._send_stream_request, body - ) - if response.status_code != 200: - await self._raise_for_status(response, req_tag=req_tag) + async def _validated_stream_send() -> httpx.Response: + """Send request; raise inside retry loop on 429 so rate limiter can backoff.""" + send_response = await self._send_stream_request(body) + if send_response.status_code == 429: + await send_response.aclose() + send_response.raise_for_status() + if send_response.status_code != 200: + try: + await self._raise_for_status(send_response, req_tag=req_tag) + finally: + if not send_response.is_closed: + await send_response.aclose() + return send_response + + response = await self._global_rate_limiter.execute_with_retry( + _validated_stream_send + ) async for chunk in self._iter_stream_chunks( response, @@ -256,12 +325,12 @@ class AnthropicMessagesTransport(BaseProvider): thinking_enabled=thinking_enabled, ): sent_any_event = True + emitted_tracker.feed(chunk) yield chunk except Exception as error: - logger.error( - "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error - ) + if not isinstance(error, httpx.HTTPStatusError): + self._log_stream_transport_error(tag, req_tag, error) error_message = self._get_error_message(error, request_id) if response is not None and not response.is_closed: @@ -273,13 +342,24 @@ class AnthropicMessagesTransport(BaseProvider): type(error).__name__, req_tag, ) - for event in self._emit_error_events( - request=request, - input_tokens=input_tokens, - error_message=error_message, - sent_any_event=sent_any_event, - ): - yield event + if sent_any_event: + for event in emitted_tracker.iter_close_unclosed_blocks(): + yield event + for event in emitted_tracker.iter_midstream_error_tail( + error_message, + request=request, + input_tokens=input_tokens, + log_raw_sse_events=self._config.log_raw_sse_events, + ): + yield event + else: + for event in self._emit_error_events( + request=request, + input_tokens=input_tokens, + error_message=error_message, + sent_any_event=False, + ): + yield event return finally: if response is not None and not response.is_closed: diff --git a/providers/base.py b/providers/base.py index f8ff193..4f1a74b 100644 --- a/providers/base.py +++ b/providers/base.py @@ -6,6 +6,8 @@ from typing import Any from pydantic import BaseModel +from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT + class ProviderConfig(BaseModel): """Configuration for a provider. @@ -21,9 +23,11 @@ class ProviderConfig(BaseModel): max_concurrency: int = 5 http_read_timeout: float = 300.0 http_write_timeout: float = 10.0 - http_connect_timeout: float = 2.0 + http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT enable_thinking: bool = True proxy: str = "" + log_raw_sse_events: bool = False + log_api_error_tracebacks: bool = False class BaseProvider(ABC): @@ -61,6 +65,42 @@ class BaseProvider(ABC): request_enabled = bool(enabled) return config_enabled and request_enabled + def preflight_stream( + self, request: Any, *, thinking_enabled: bool | None = None + ) -> None: + """Eagerly validate/build the upstream request before opening an SSE stream. + + Subclasses with ``_build_request_body`` (OpenAI and native) raise + :class:`providers.exceptions.InvalidRequestError` on conversion failures. + """ + build = getattr(self, "_build_request_body", None) + if build is None: + return + build(request, thinking_enabled=thinking_enabled) + + def _log_stream_transport_error( + self, tag: str, req_tag: str, error: Exception + ) -> None: + """Log streaming transport failures (metadata-only unless verbose is enabled).""" + from loguru import logger + + if self._config.log_api_error_tracebacks: + logger.error( + "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error + ) + return + response = getattr(error, "response", None) + status_code = ( + getattr(response, "status_code", None) if response is not None else None + ) + logger.error( + "{}_ERROR:{} exc_type={} http_status={}", + tag, + req_tag, + type(error).__name__, + status_code, + ) + @abstractmethod async def cleanup(self) -> None: """Release any resources held by this provider.""" @@ -75,5 +115,7 @@ class BaseProvider(ABC): thinking_enabled: bool | None = None, ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" + # Typing: abstract async generators need a yield for AsyncIterator[str] + # inference; this branch is never executed. if False: - yield "" # Required for ty/mypy to accept abstract async generator + yield "" diff --git a/providers/deepseek/__init__.py b/providers/deepseek/__init__.py index 0ecf7d0..5ab4c79 100644 --- a/providers/deepseek/__init__.py +++ b/providers/deepseek/__init__.py @@ -1,5 +1,7 @@ """DeepSeek provider exports.""" -from .client import DEEPSEEK_BASE_URL, DeepSeekProvider +from providers.defaults import DEEPSEEK_DEFAULT_BASE -__all__ = ["DEEPSEEK_BASE_URL", "DeepSeekProvider"] +from .client import DeepSeekProvider + +__all__ = ["DEEPSEEK_DEFAULT_BASE", "DeepSeekProvider"] diff --git a/providers/deepseek/client.py b/providers/deepseek/client.py index 210f223..3996b61 100644 --- a/providers/deepseek/client.py +++ b/providers/deepseek/client.py @@ -3,7 +3,7 @@ from typing import Any from providers.base import ProviderConfig -from providers.defaults import DEEPSEEK_BASE_URL +from providers.defaults import DEEPSEEK_DEFAULT_BASE from providers.openai_compat import OpenAIChatTransport from .request import build_request_body @@ -16,7 +16,7 @@ class DeepSeekProvider(OpenAIChatTransport): super().__init__( config, provider_name="DEEPSEEK", - base_url=config.base_url or DEEPSEEK_BASE_URL, + base_url=config.base_url or DEEPSEEK_DEFAULT_BASE, api_key=config.api_key, ) diff --git a/providers/deepseek/request.py b/providers/deepseek/request.py index 67cabf9..f22a9f8 100644 --- a/providers/deepseek/request.py +++ b/providers/deepseek/request.py @@ -5,6 +5,8 @@ from typing import Any from loguru import logger from core.anthropic import build_base_request_body +from core.anthropic.conversion import OpenAIConversionError +from providers.exceptions import InvalidRequestError def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: @@ -14,10 +16,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - body = build_base_request_body( - request_data, - include_reasoning_content=True, - ) + try: + body = build_base_request_body( + request_data, + include_thinking=thinking_enabled, + include_reasoning_content=thinking_enabled, + ) + except OpenAIConversionError as exc: + raise InvalidRequestError(str(exc)) from exc extra_body: dict[str, Any] = {} request_extra = getattr(request_data, "extra_body", None) diff --git a/providers/defaults.py b/providers/defaults.py index 5f7fcfe..80109dc 100644 --- a/providers/defaults.py +++ b/providers/defaults.py @@ -1,21 +1,19 @@ -"""Default upstream base URLs and shared provider constants. +"""Re-exports default upstream base URLs from the config provider catalog.""" -Adapters and :mod:`providers.registry` import from here to avoid duplicating -literals and to keep ``providers.registry`` free of per-adapter eager imports. -""" +from config.provider_catalog import ( + DEEPSEEK_DEFAULT_BASE, + LLAMACPP_DEFAULT_BASE, + LMSTUDIO_DEFAULT_BASE, + NVIDIA_NIM_DEFAULT_BASE, + OLLAMA_DEFAULT_BASE, + OPENROUTER_DEFAULT_BASE, +) -# OpenAI-compatible chat (NIM, DeepSeek) and local/native provider endpoints -NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1" -DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com" -OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1" -LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1" -LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" -OLLAMA_DEFAULT_BASE = "http://localhost:11434" - -# Backward-compatible names used by existing adapter modules -NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE -DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE -OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE -LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE -LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE -OLLAMA_DEFAULT_BASE_URL = OLLAMA_DEFAULT_BASE +__all__ = ( + "DEEPSEEK_DEFAULT_BASE", + "LLAMACPP_DEFAULT_BASE", + "LMSTUDIO_DEFAULT_BASE", + "NVIDIA_NIM_DEFAULT_BASE", + "OLLAMA_DEFAULT_BASE", + "OPENROUTER_DEFAULT_BASE", +) diff --git a/providers/error_mapping.py b/providers/error_mapping.py index 3826a84..44c33a5 100644 --- a/providers/error_mapping.py +++ b/providers/error_mapping.py @@ -14,10 +14,30 @@ from providers.exceptions import ( from providers.rate_limit import GlobalRateLimiter +def user_visible_message_for_mapped_provider_error( + mapped: Exception, + *, + provider_name: str, + read_timeout_s: float | None, +) -> str: + """Return the user-visible string after :func:`map_error` (405 + mapped types).""" + if getattr(mapped, "status_code", None) == 405: + return ( + f"Upstream provider {provider_name} rejected the request method " + "or endpoint (HTTP 405)." + ) + return get_user_facing_error_message(mapped, read_timeout_s=read_timeout_s) + + def map_error( e: Exception, *, rate_limiter: GlobalRateLimiter | None = None ) -> Exception: - """Map OpenAI or HTTPX exception to specific ProviderError.""" + """Map OpenAI or HTTPX exception to specific ProviderError. + + Streaming transports should pass their scoped limiter (``self._global_rate_limiter``) + so reactive 429 handling applies to the correct provider. Tests may omit + ``rate_limiter`` to use the process-wide singleton. + """ message = get_user_facing_error_message(e) limiter = rate_limiter or GlobalRateLimiter.get_instance() diff --git a/providers/exceptions.py b/providers/exceptions.py index 31c6781..6901b9c 100644 --- a/providers/exceptions.py +++ b/providers/exceptions.py @@ -90,8 +90,20 @@ class APIError(ProviderError): ) -class UnknownProviderTypeError(ValueError): +class UnknownProviderTypeError(InvalidRequestError): """Raised when ``provider_id`` is not registered in the provider map.""" def __init__(self, message: str) -> None: super().__init__(message) + + +class ServiceUnavailableError(ProviderError): + """Raised when the server is not ready (e.g. app lifespan did not wire state).""" + + def __init__(self, message: str, raw_error: Any = None): + super().__init__( + message, + status_code=503, + error_type="api_error", + raw_error=raw_error, + ) diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py index de0c268..891022d 100644 --- a/providers/llamacpp/client.py +++ b/providers/llamacpp/client.py @@ -2,7 +2,7 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig -from providers.defaults import LLAMACPP_DEFAULT_BASE_URL +from providers.defaults import LLAMACPP_DEFAULT_BASE class LlamaCppProvider(AnthropicMessagesTransport): @@ -12,5 +12,5 @@ class LlamaCppProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="LLAMACPP", - default_base_url=LLAMACPP_DEFAULT_BASE_URL, + default_base_url=LLAMACPP_DEFAULT_BASE, ) diff --git a/providers/lmstudio/client.py b/providers/lmstudio/client.py index 2961993..e32b835 100644 --- a/providers/lmstudio/client.py +++ b/providers/lmstudio/client.py @@ -2,7 +2,7 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig -from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL +from providers.defaults import LMSTUDIO_DEFAULT_BASE class LMStudioProvider(AnthropicMessagesTransport): @@ -12,5 +12,5 @@ class LMStudioProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="LMSTUDIO", - default_base_url=LMSTUDIO_DEFAULT_BASE_URL, + default_base_url=LMSTUDIO_DEFAULT_BASE, ) diff --git a/providers/nvidia_nim/__init__.py b/providers/nvidia_nim/__init__.py index b0a410f..253acd1 100644 --- a/providers/nvidia_nim/__init__.py +++ b/providers/nvidia_nim/__init__.py @@ -1,5 +1,7 @@ """NVIDIA NIM provider package.""" -from .client import NVIDIA_NIM_BASE_URL, NvidiaNimProvider +from providers.defaults import NVIDIA_NIM_DEFAULT_BASE -__all__ = ["NVIDIA_NIM_BASE_URL", "NvidiaNimProvider"] +from .client import NvidiaNimProvider + +__all__ = ["NVIDIA_NIM_DEFAULT_BASE", "NvidiaNimProvider"] diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index 5cc2f31..2661203 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -8,7 +8,7 @@ from loguru import logger from config.nim import NimSettings from providers.base import ProviderConfig -from providers.defaults import NVIDIA_NIM_BASE_URL +from providers.defaults import NVIDIA_NIM_DEFAULT_BASE from providers.openai_compat import OpenAIChatTransport from .request import ( @@ -25,7 +25,7 @@ class NvidiaNimProvider(OpenAIChatTransport): super().__init__( config, provider_name="NIM", - base_url=config.base_url or NVIDIA_NIM_BASE_URL, + base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE, api_key=config.api_key, ) self._nim_settings = nim_settings diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index c0a83f3..a4e96de 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -1,5 +1,6 @@ """Request builder for NVIDIA NIM provider.""" +from collections.abc import Callable from copy import deepcopy from typing import Any @@ -7,6 +8,42 @@ from loguru import logger from config.nim import NimSettings from core.anthropic import build_base_request_body, set_if_not_none +from core.anthropic.conversion import OpenAIConversionError +from providers.exceptions import InvalidRequestError + + +def _clone_strip_extra_body( + body: dict[str, Any], + strip: Callable[[dict[str, Any]], bool], +) -> dict[str, Any] | None: + """Deep-clone ``body`` and remove fields via ``strip`` on ``extra_body`` only. + + Returns ``None`` when there is no ``extra_body`` dict or ``strip`` reports no change. + """ + cloned_body = deepcopy(body) + extra_body = cloned_body.get("extra_body") + if not isinstance(extra_body, dict): + return None + if not strip(extra_body): + return None + if not extra_body: + cloned_body.pop("extra_body", None) + return cloned_body + + +def _strip_reasoning_budget_fields(extra_body: dict[str, Any]) -> bool: + removed = extra_body.pop("reasoning_budget", None) is not None + chat_template_kwargs = extra_body.get("chat_template_kwargs") + if ( + isinstance(chat_template_kwargs, dict) + and chat_template_kwargs.pop("reasoning_budget", None) is not None + ): + removed = True + return removed + + +def _strip_chat_template_field(extra_body: dict[str, Any]) -> bool: + return extra_body.pop("chat_template", None) is not None def _set_extra( @@ -23,43 +60,12 @@ def _set_extra( def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None: """Clone a request body and strip only reasoning_budget fields.""" - cloned_body = deepcopy(body) - extra_body = cloned_body.get("extra_body") - if not isinstance(extra_body, dict): - return None - - removed = extra_body.pop("reasoning_budget", None) is not None - - chat_template_kwargs = extra_body.get("chat_template_kwargs") - if ( - isinstance(chat_template_kwargs, dict) - and chat_template_kwargs.pop("reasoning_budget", None) is not None - ): - removed = True - - if not extra_body: - cloned_body.pop("extra_body", None) - - if not removed: - return None - - return cloned_body + return _clone_strip_extra_body(body, _strip_reasoning_budget_fields) def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None: """Clone a request body and strip only chat_template.""" - cloned_body = deepcopy(body) - extra_body = cloned_body.get("extra_body") - if not isinstance(extra_body, dict): - return None - - if extra_body.pop("chat_template", None) is None: - return None - - if not extra_body: - cloned_body.pop("extra_body", None) - - return cloned_body + return _clone_strip_extra_body(body, _strip_chat_template_field) def build_request_body( @@ -71,10 +77,13 @@ def build_request_body( getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - body = build_base_request_body( - request_data, - include_thinking=thinking_enabled, - ) + try: + body = build_base_request_body( + request_data, + include_thinking=thinking_enabled, + ) + except OpenAIConversionError as exc: + raise InvalidRequestError(str(exc)) from exc # NIM-specific max_tokens: cap against nim.max_tokens max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None) diff --git a/providers/nvidia_nim/voice.py b/providers/nvidia_nim/voice.py new file mode 100644 index 0000000..f26255a --- /dev/null +++ b/providers/nvidia_nim/voice.py @@ -0,0 +1,95 @@ +"""NVIDIA NIM / Riva offline ASR for voice notes (provider-owned transport).""" + +from __future__ import annotations + +from pathlib import Path + +from loguru import logger + +# NVIDIA NIM Whisper model mapping: (function_id, language_code) +_NIM_ASR_MODEL_MAP: dict[str, tuple[str, str]] = { + "nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"), + "nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"), + # function-id from NVIDIA NIM API docs (parakeet-ctc-0.6b-es). + "nvidia/parakeet-ctc-0.6b-es": ("a9eeee8f-b509-4712-b19d-194361fa5f31", "es-US"), + "nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"), + "nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"), + "nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"), + "nvidia/parakeet-1.1b-rnnt-multilingual-asr": ( + "71203149-d3b7-4460-8231-1be2543a1fca", + "", + ), + "openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"), +} + +_RIVA_SERVER = "grpc.nvcf.nvidia.com:443" + + +def transcribe_audio_file( + file_path: Path, + model: str, + *, + api_key: str, +) -> str: + """Transcribe audio using NVIDIA NIM / Riva gRPC (offline recognition). + + Args: + file_path: Path to encoded audio bytes readable by Riva. + model: Hugging Face-style NIM model id (see ``_NIM_ASR_MODEL_MAP``). + api_key: NVIDIA API key (Bearer token); must be non-empty. + + Returns: + Transcript text, or ``(no speech detected)`` when empty. + """ + key = (api_key or "").strip() + if not key: + raise ValueError( + "NVIDIA NIM transcription requires a non-empty nvidia_nim_api_key " + "(configure NVIDIA_NIM_API_KEY or pass api_key explicitly)." + ) + + try: + import riva.client + except ImportError as e: + raise ImportError( + "NVIDIA NIM transcription requires the voice extra. " + "Install with: uv sync --extra voice" + ) from e + + model_config = _NIM_ASR_MODEL_MAP.get(model) + if not model_config: + raise ValueError( + f"No NVIDIA NIM config found for model: {model}. " + f"Supported models: {', '.join(_NIM_ASR_MODEL_MAP.keys())}" + ) + function_id, language_code = model_config + + auth = riva.client.Auth( + use_ssl=True, + uri=_RIVA_SERVER, + metadata_args=[ + ["function-id", function_id], + ["authorization", f"Bearer {key}"], + ], + ) + + asr_service = riva.client.ASRService(auth) + + config = riva.client.RecognitionConfig( + language_code=language_code, + max_alternatives=1, + verbatim_transcripts=True, + ) + + with open(file_path, "rb") as f: + data = f.read() + + response = asr_service.offline_recognize(data, config) + + transcript = "" + results = getattr(response, "results", None) + if results and results[0].alternatives: + transcript = results[0].alternatives[0].transcript + + logger.debug(f"NIM transcription: {len(transcript)} chars") + return transcript or "(no speech detected)" diff --git a/providers/ollama/__init__.py b/providers/ollama/__init__.py index 365677c..8934ab9 100644 --- a/providers/ollama/__init__.py +++ b/providers/ollama/__init__.py @@ -1,5 +1,7 @@ """Ollama provider package.""" -from .client import OLLAMA_BASE_URL, OllamaProvider +from providers.defaults import OLLAMA_DEFAULT_BASE -__all__ = ["OLLAMA_BASE_URL", "OllamaProvider"] +from .client import OllamaProvider + +__all__ = ["OLLAMA_DEFAULT_BASE", "OllamaProvider"] diff --git a/providers/ollama/client.py b/providers/ollama/client.py index ebff031..6696bd8 100644 --- a/providers/ollama/client.py +++ b/providers/ollama/client.py @@ -6,8 +6,6 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig from providers.defaults import OLLAMA_DEFAULT_BASE -OLLAMA_BASE_URL = OLLAMA_DEFAULT_BASE - class OllamaProvider(AnthropicMessagesTransport): """Ollama provider using native Anthropic Messages API.""" @@ -16,7 +14,7 @@ class OllamaProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="OLLAMA", - default_base_url=OLLAMA_BASE_URL, + default_base_url=OLLAMA_DEFAULT_BASE, ) self._api_key = config.api_key or "ollama" diff --git a/providers/open_router/__init__.py b/providers/open_router/__init__.py index 7a0cdad..df12496 100644 --- a/providers/open_router/__init__.py +++ b/providers/open_router/__init__.py @@ -1,5 +1,7 @@ """OpenRouter provider - Anthropic-compatible native transport.""" -from .client import OPENROUTER_BASE_URL, OpenRouterProvider +from providers.defaults import OPENROUTER_DEFAULT_BASE -__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"] +from .client import OpenRouterProvider + +__all__ = ["OPENROUTER_DEFAULT_BASE", "OpenRouterProvider"] diff --git a/providers/open_router/client.py b/providers/open_router/client.py index d5c2081..032df60 100644 --- a/providers/open_router/client.py +++ b/providers/open_router/client.py @@ -2,34 +2,25 @@ from __future__ import annotations -import json -import uuid from collections.abc import Iterator -from dataclasses import dataclass, field from typing import Any -from core.anthropic import SSEBuilder, append_request_id +from core.anthropic import append_request_id, iter_provider_stream_error_sse_events +from core.anthropic.native_sse_block_policy import ( + NativeSseBlockPolicyState, + is_terminal_openrouter_done_event, + parse_native_sse_event, + transform_native_sse_block_event, +) from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode from providers.base import ProviderConfig -from providers.defaults import OPENROUTER_BASE_URL +from providers.defaults import OPENROUTER_DEFAULT_BASE from .request import build_request_body _ANTHROPIC_VERSION = "2023-06-01" -@dataclass -class _SSEFilterState: - """Track Anthropic content block index remapping while filtering thinking.""" - - next_index: int = 0 - index_map: dict[int, int] = field(default_factory=dict) - dropped_indexes: set[int] = field(default_factory=set) - open_block_types: dict[int, str] = field(default_factory=dict) - closed_indexes: set[int] = field(default_factory=set) - message_stopped: bool = False - - class OpenRouterProvider(AnthropicMessagesTransport): """OpenRouter provider using the native Anthropic-compatible messages API.""" @@ -39,7 +30,7 @@ class OpenRouterProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="OPENROUTER", - default_base_url=OPENROUTER_BASE_URL, + default_base_url=OPENROUTER_DEFAULT_BASE, ) def _build_request_body( @@ -60,163 +51,9 @@ class OpenRouterProvider(AnthropicMessagesTransport): "anthropic-version": _ANTHROPIC_VERSION, } - @staticmethod - def _format_sse_event(event_name: str | None, data_text: str) -> str: - """Format an SSE event from its event name and data payload.""" - lines: list[str] = [] - if event_name: - lines.append(f"event: {event_name}") - lines.extend(f"data: {line}" for line in data_text.splitlines()) - return "\n".join(lines) + "\n\n" - - @staticmethod - def _parse_sse_event(event: str) -> tuple[str | None, str]: - """Extract the event name and raw data payload from an SSE event.""" - event_name = None - data_lines: list[str] = [] - for line in event.strip().splitlines(): - if line.startswith("event:"): - event_name = line[6:].strip() - elif line.startswith("data:"): - data_lines.append(line[5:].lstrip()) - return event_name, "\n".join(data_lines) - - @staticmethod - def _is_terminal_done_event(event_name: str | None, data_text: str) -> bool: - """Return whether an event is OpenAI-style terminal noise.""" - return (event_name is None or event_name in {"data", "done"}) and ( - data_text.strip().upper() == "[DONE]" - ) - - @staticmethod - def _remap_index( - payload: dict[str, Any], state: _SSEFilterState, *, create: bool - ) -> int | None: - """Return the downstream index for a content block event.""" - upstream_index = payload.get("index") - if not isinstance(upstream_index, int): - return None - if upstream_index in state.dropped_indexes: - return None - mapped_index = state.index_map.get(upstream_index) - if mapped_index is None and create: - mapped_index = state.next_index - state.index_map[upstream_index] = mapped_index - state.next_index += 1 - return mapped_index - - def _close_open_blocks_before( - self, state: _SSEFilterState, upstream_index: int - ) -> str: - """Close overlapping upstream blocks before starting a new block.""" - events: list[str] = [] - for open_upstream_index in list(state.open_block_types): - if open_upstream_index == upstream_index: - continue - mapped_index = state.index_map.get(open_upstream_index) - if mapped_index is None: - continue - payload = {"type": "content_block_stop", "index": mapped_index} - events.append( - self._format_sse_event("content_block_stop", json.dumps(payload)) - ) - state.closed_indexes.add(open_upstream_index) - state.open_block_types.pop(open_upstream_index, None) - return "".join(events) - - @staticmethod - def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool: - if not isinstance(block_type, str): - return False - if block_type.startswith("redacted_thinking"): - return True - return not thinking_enabled and "thinking" in block_type - - def _transform_sse_payload( - self, - event: str, - state: _SSEFilterState, - *, - thinking_enabled: bool, - ) -> str | None: - """Normalize OpenRouter SSE events and enforce local thinking policy.""" - event_name, data_text = self._parse_sse_event(event) - if not event_name or not data_text: - return event - - try: - payload = json.loads(data_text) - except json.JSONDecodeError: - return event - - if event_name == "content_block_start": - block = payload.get("content_block") - if not isinstance(block, dict): - return event - block_type = block.get("type") - upstream_index = payload.get("index") - if self._should_drop_block_type( - block_type, thinking_enabled=thinking_enabled - ): - if isinstance(upstream_index, int): - state.dropped_indexes.add(upstream_index) - return None - - mapped_index = self._remap_index(payload, state, create=True) - if mapped_index is not None: - payload["index"] = mapped_index - if isinstance(upstream_index, int) and isinstance(block_type, str): - prefix = self._close_open_blocks_before(state, upstream_index) - state.open_block_types[upstream_index] = block_type - return prefix + self._format_sse_event( - event_name, json.dumps(payload) - ) - return self._format_sse_event(event_name, json.dumps(payload)) - return None if not thinking_enabled else event - - if event_name == "content_block_delta": - delta = payload.get("delta") - if not isinstance(delta, dict): - return event - delta_type = delta.get("type") - if self._should_drop_block_type( - delta_type, thinking_enabled=thinking_enabled - ): - return None - - mapped_index = self._remap_index(payload, state, create=False) - if mapped_index is not None: - payload["index"] = mapped_index - return self._format_sse_event(event_name, json.dumps(payload)) - if payload.get("index") in state.dropped_indexes: - return None - if not thinking_enabled: - return None - - if event_name == "content_block_stop": - upstream_index = payload.get("index") - if ( - isinstance(upstream_index, int) - and upstream_index in state.closed_indexes - ): - state.closed_indexes.discard(upstream_index) - return None - mapped_index = self._remap_index(payload, state, create=False) - if mapped_index is not None: - payload["index"] = mapped_index - if isinstance(upstream_index, int): - state.open_block_types.pop(upstream_index, None) - return self._format_sse_event(event_name, json.dumps(payload)) - if payload.get("index") in state.dropped_indexes: - return None - if not thinking_enabled: - return None - - return event - def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any: """Create per-stream state for thinking block filtering.""" - return _SSEFilterState() + return NativeSseBlockPolicyState() def _transform_stream_event( self, @@ -226,23 +63,17 @@ class OpenRouterProvider(AnthropicMessagesTransport): thinking_enabled: bool, ) -> str | None: """Drop provider-specific terminal noise and hidden thinking events.""" - if isinstance(state, _SSEFilterState): - event_name, data_text = self._parse_sse_event(event) - if state.message_stopped or self._is_terminal_done_event( + if isinstance(state, NativeSseBlockPolicyState): + event_name, data_text = parse_native_sse_event(event) + if state.message_stopped or is_terminal_openrouter_done_event( event_name, data_text ): return None if event_name == "message_stop": state.message_stopped = True - if thinking_enabled: - if isinstance(state, _SSEFilterState): - return self._transform_sse_payload( - event, state, thinking_enabled=thinking_enabled - ) - return event - if isinstance(state, _SSEFilterState): - return self._transform_sse_payload( + if isinstance(state, NativeSseBlockPolicyState): + return transform_native_sse_block_event( event, state, thinking_enabled=thinking_enabled ) return event @@ -260,9 +91,10 @@ class OpenRouterProvider(AnthropicMessagesTransport): sent_any_event: bool, ) -> Iterator[str]: """Emit the Anthropic SSE error shape expected by Claude clients.""" - sse = SSEBuilder(f"msg_{uuid.uuid4()}", request.model, input_tokens) - if not sent_any_event: - yield sse.message_start() - yield from sse.emit_error(error_message) - yield sse.message_delta("end_turn", 1) - yield sse.message_stop() + yield from iter_provider_stream_error_sse_events( + request=request, + input_tokens=input_tokens, + error_message=error_message, + sent_any_event=sent_any_event, + log_raw_sse_events=self._config.log_raw_sse_events, + ) diff --git a/providers/open_router/request.py b/providers/open_router/request.py index 89759ab..abd93df 100644 --- a/providers/open_router/request.py +++ b/providers/open_router/request.py @@ -2,136 +2,18 @@ from __future__ import annotations -from collections.abc import Sequence from typing import Any from loguru import logger -from pydantic import BaseModel -OPENROUTER_DEFAULT_MAX_TOKENS = 81920 - -_REQUEST_FIELDS = ( - "model", - "messages", - "system", - "max_tokens", - "stop_sequences", - "stream", - "temperature", - "top_p", - "top_k", - "metadata", - "tools", - "tool_choice", - "thinking", - "extra_body", - "original_model", - "resolved_provider_model", +from config.constants import ( + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS as OPENROUTER_DEFAULT_MAX_TOKENS, ) - -_INTERNAL_FIELDS = { - "thinking", - "extra_body", - "original_model", - "resolved_provider_model", -} -_THINKING_HISTORY_BLOCK_TYPES = {"thinking", "redacted_thinking"} - - -def _serialize_value(value: Any) -> Any: - """Convert Pydantic models and lightweight objects into JSON-ready values.""" - if isinstance(value, BaseModel): - return value.model_dump(exclude_none=True) - if isinstance(value, dict): - return { - key: _serialize_value(item) - for key, item in value.items() - if item is not None - } - if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray): - return [_serialize_value(item) for item in value] - if value is None or isinstance(value, str | int | float | bool): - return value - if hasattr(value, "__dict__"): - return { - key: _serialize_value(item) - for key, item in vars(value).items() - if not key.startswith("_") and item is not None - } - return value - - -def _dump_request_fields(request_data: Any) -> dict[str, Any]: - """Extract the public request fields forwarded to OpenRouter.""" - if isinstance(request_data, BaseModel): - return request_data.model_dump(exclude_none=True) - - dumped: dict[str, Any] = {} - for field in _REQUEST_FIELDS: - value = getattr(request_data, field, None) - if value is not None: - dumped[field] = _serialize_value(value) - return dumped - - -def _strip_unsigned_thinking_history(messages: Any) -> Any: - """Remove assistant thinking history blocks that OpenRouter cannot replay.""" - if not isinstance(messages, list): - return messages - - sanitized_messages: list[Any] = [] - for message in messages: - if not isinstance(message, dict): - sanitized_messages.append(message) - continue - - content = message.get("content") - if not isinstance(content, list): - sanitized_messages.append(message) - continue - - sanitized_content = [ - block - for block in content - if not ( - isinstance(block, dict) - and block.get("type") in _THINKING_HISTORY_BLOCK_TYPES - and not isinstance(block.get("signature"), str) - ) - ] - - sanitized_message = dict(message) - sanitized_message["content"] = sanitized_content or "" - sanitized_messages.append(sanitized_message) - - return sanitized_messages - - -def _normalize_system_prompt(system: Any) -> Any: - """Flatten Claude SDK system blocks for OpenRouter's native endpoint.""" - if not isinstance(system, list): - return system - - text_parts: list[str] = [] - for block in system: - if not isinstance(block, dict): - continue - if block.get("type") == "text" and isinstance(block.get("text"), str): - text_parts.append(block["text"]) - return "\n\n".join(text_parts).strip() if text_parts else system - - -def _apply_reasoning(body: dict[str, Any], thinking_cfg: Any) -> None: - """Map Anthropic thinking controls onto OpenRouter reasoning controls.""" - reasoning = body.setdefault("reasoning", {"enabled": True}) - if not isinstance(reasoning, dict): - return - reasoning.setdefault("enabled", True) - if not isinstance(thinking_cfg, dict): - return - budget_tokens = thinking_cfg.get("budget_tokens") - if isinstance(budget_tokens, int): - reasoning.setdefault("max_tokens", budget_tokens) +from core.anthropic.native_messages_request import ( + OpenRouterExtraBodyError, + build_openrouter_native_request_body, +) +from providers.exceptions import InvalidRequestError def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: @@ -142,27 +24,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: len(getattr(request_data, "messages", [])), ) - dumped_request = _dump_request_fields(request_data) - request_extra = dumped_request.pop("extra_body", None) - thinking_cfg = dumped_request.get("thinking") - body = { - key: value - for key, value in dumped_request.items() - if key not in _INTERNAL_FIELDS - } - - if isinstance(request_extra, dict): - body.update(request_extra) - - body["messages"] = _strip_unsigned_thinking_history(body.get("messages")) - if "system" in body: - body["system"] = _normalize_system_prompt(body["system"]) - body["stream"] = True - if body.get("max_tokens") is None: - body["max_tokens"] = OPENROUTER_DEFAULT_MAX_TOKENS - - if thinking_enabled: - _apply_reasoning(body, thinking_cfg) + try: + body = build_openrouter_native_request_body( + request_data, + thinking_enabled=thinking_enabled, + default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS, + ) + except OpenRouterExtraBodyError as exc: + raise InvalidRequestError(str(exc)) from exc logger.debug( "OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}", diff --git a/providers/openai_compat.py b/providers/openai_compat.py index d729896..31c1d7e 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -21,14 +21,40 @@ from core.anthropic import ( SSEBuilder, ThinkTagParser, append_request_id, - get_user_facing_error_message, map_stop_reason, ) from providers.base import BaseProvider, ProviderConfig -from providers.error_mapping import map_error +from providers.error_mapping import ( + map_error, + user_visible_message_for_mapped_provider_error, +) from providers.rate_limit import GlobalRateLimiter +def _iter_heuristic_tool_use_sse( + sse: SSEBuilder, tool_use: dict[str, Any] +) -> Iterator[str]: + """Emit SSE for one heuristic tool_use block (closes open text/thinking first).""" + if tool_use.get("name") == "Task" and isinstance(tool_use.get("input"), dict): + task_input = tool_use["input"] + if task_input.get("run_in_background") is not False: + task_input["run_in_background"] = False + yield from sse.close_content_blocks() + block_idx = sse.blocks.allocate_index() + yield sse.content_block_start( + block_idx, + "tool_use", + id=tool_use["id"], + name=tool_use["name"], + ) + yield sse.content_block_delta( + block_idx, + "input_json_delta", + json.dumps(tool_use["input"]), + ) + yield sse.content_block_stop(block_idx) + + class OpenAIChatTransport(BaseProvider): """Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …).""" @@ -113,6 +139,22 @@ class OpenAIChatTransport(BaseProvider): ) return stream, retry_body + def _emit_tool_arg_delta( + self, sse: SSEBuilder, tc_index: int, args: str + ) -> Iterator[str]: + """Emit one argument fragment for a started tool block (Task buffer or raw JSON).""" + if not args: + return + state = sse.blocks.tool_states.get(tc_index) + if state is None: + return + if state.name == "Task": + parsed = sse.blocks.buffer_task_args(tc_index, args) + if parsed is not None: + yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) + return + yield sse.emit_tool_delta(tc_index, args) + def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]: """Process a single tool call delta and yield SSE events.""" tc_index = tc.get("index", 0) @@ -121,34 +163,42 @@ class OpenAIChatTransport(BaseProvider): fn_delta = tc.get("function", {}) incoming_name = fn_delta.get("name") - arguments = fn_delta.get("arguments", "") + arguments = fn_delta.get("arguments", "") or "" + + if tc.get("id") is not None: + sse.blocks.set_stream_tool_id(tc_index, tc.get("id")) + if incoming_name is not None: sse.blocks.register_tool_name(tc_index, incoming_name) state = sse.blocks.tool_states.get(tc_index) + resolved_id = (state.tool_id if state and state.tool_id else None) or tc.get( + "id" + ) + resolved_name = (state.name if state else "") or "" + + if not state or not state.started: + name_ok = bool((resolved_name or "").strip()) + if name_ok: + tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}" + display_name = (resolved_name or "").strip() or "tool_call" + yield sse.start_tool_block(tc_index, tool_id, display_name) + state = sse.blocks.tool_states[tc_index] + if state.pre_start_args: + pre = state.pre_start_args + state.pre_start_args = "" + yield from self._emit_tool_arg_delta(sse, tc_index, pre) + + state = sse.blocks.tool_states.get(tc_index) + if not arguments: + return if state is None or not state.started: - name = state.name if state else "" - if name or tc.get("id"): - tool_id = tc.get("id") or f"tool_{uuid.uuid4()}" - yield sse.start_tool_block(tc_index, tool_id, name) - - args = arguments - if args: - state = sse.blocks.tool_states.get(tc_index) - if state is None or not state.started: - tool_id = tc.get("id") or f"tool_{uuid.uuid4()}" - name = (state.name if state else None) or "tool_call" - yield sse.start_tool_block(tc_index, tool_id, name) - state = sse.blocks.tool_states.get(tc_index) - - current_name = state.name if state else "" - if current_name == "Task": - parsed = sse.blocks.buffer_task_args(tc_index, args) - if parsed is not None: - yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) + state = sse.blocks.ensure_tool_state(tc_index) + if not (resolved_name or "").strip(): + state.pre_start_args += arguments return - yield sse.emit_tool_delta(tc_index, args) + yield from self._emit_tool_arg_delta(sse, tc_index, arguments) def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]: """Emit buffered Task args as a single JSON delta (best-effort).""" @@ -181,7 +231,12 @@ class OpenAIChatTransport(BaseProvider): """Shared streaming implementation.""" tag = self._provider_name message_id = f"msg_{uuid.uuid4()}" - sse = SSEBuilder(message_id, request.model, input_tokens) + sse = SSEBuilder( + message_id, + request.model, + input_tokens, + log_raw_events=self._config.log_raw_sse_events, + ) body = self._build_request_body(request, thinking_enabled=thinking_enabled) thinking_enabled = self._is_thinking_enabled(request, thinking_enabled) @@ -201,8 +256,6 @@ class OpenAIChatTransport(BaseProvider): heuristic_parser = HeuristicToolParser() finish_reason = None usage_info = None - error_occurred = False - error_message = "" async with self._global_rate_limiter.concurrency_slot(): try: @@ -258,26 +311,10 @@ class OpenAIChatTransport(BaseProvider): yield sse.emit_text_delta(filtered_text) for tool_use in detected_tools: - for event in sse.close_content_blocks(): - yield event - - block_idx = sse.blocks.allocate_index() - if tool_use.get("name") == "Task" and isinstance( - tool_use.get("input"), dict + for event in _iter_heuristic_tool_use_sse( + sse, tool_use ): - tool_use["input"]["run_in_background"] = False - yield sse.content_block_start( - block_idx, - "tool_use", - id=tool_use["id"], - name=tool_use["name"], - ) - yield sse.content_block_delta( - block_idx, - "input_json_delta", - json.dumps(tool_use["input"]), - ) - yield sse.content_block_stop(block_idx) + yield event # Handle native tool calls if delta.tool_calls: @@ -298,18 +335,13 @@ class OpenAIChatTransport(BaseProvider): except asyncio.CancelledError, GeneratorExit: raise except Exception as e: - logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e) + self._log_stream_transport_error(tag, req_tag, e) mapped_e = map_error(e, rate_limiter=self._global_rate_limiter) - error_occurred = True - if getattr(mapped_e, "status_code", None) == 405: - base_message = ( - f"Upstream provider {tag} rejected the request method " - "or endpoint (HTTP 405)." - ) - else: - base_message = get_user_facing_error_message( - mapped_e, read_timeout_s=self._config.http_read_timeout - ) + base_message = user_visible_message_for_mapped_provider_error( + mapped_e, + provider_name=tag, + read_timeout_s=self._config.http_read_timeout, + ) error_message = append_request_id(base_message, request_id) logger.info( "{}_STREAM: Emitting SSE error event for {}{}", @@ -317,10 +349,13 @@ class OpenAIChatTransport(BaseProvider): type(e).__name__, req_tag, ) - for event in sse.close_content_blocks(): + for event in sse.close_all_blocks(): yield event for event in sse.emit_error(error_message): yield event + yield sse.message_delta("end_turn", 1) + yield sse.message_stop() + return # Flush remaining content remaining = think_parser.flush() @@ -338,32 +373,16 @@ class OpenAIChatTransport(BaseProvider): yield sse.emit_text_delta(remaining.content) for tool_use in heuristic_parser.flush(): - for event in sse.close_content_blocks(): + for event in _iter_heuristic_tool_use_sse(sse, tool_use): yield event - block_idx = sse.blocks.allocate_index() - yield sse.content_block_start( - block_idx, - "tool_use", - id=tool_use["id"], - name=tool_use["name"], - ) - if tool_use.get("name") == "Task" and isinstance( - tool_use.get("input"), dict - ): - tool_use["input"]["run_in_background"] = False - yield sse.content_block_delta( - block_idx, - "input_json_delta", - json.dumps(tool_use["input"]), - ) - yield sse.content_block_stop(block_idx) - - if ( - not error_occurred - and sse.blocks.text_index == -1 - and not sse.blocks.tool_states - ): + has_started_tool = any(s.started for s in sse.blocks.tool_states.values()) + has_content_blocks = ( + sse.blocks.text_index != -1 + or sse.blocks.thinking_index != -1 + or has_started_tool + ) + if not has_content_blocks: for event in sse.ensure_text_block(): yield event yield sse.emit_text_delta(" ") diff --git a/providers/rate_limit.py b/providers/rate_limit.py index 0e427e9..59e4291 100644 --- a/providers/rate_limit.py +++ b/providers/rate_limit.py @@ -3,14 +3,16 @@ import asyncio import random import time -from collections import deque from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from typing import Any, ClassVar, TypeVar +import httpx import openai from loguru import logger +from core.rate_limit import StrictSlidingWindowLimiter + T = TypeVar("T") @@ -51,10 +53,10 @@ class GlobalRateLimiter: self._rate_limit = rate_limit self._rate_window = float(rate_window) self._max_concurrency = max_concurrency - # Monotonic timestamps of the last granted slots. - self._request_times: deque[float] = deque() + self._proactive_limiter = StrictSlidingWindowLimiter( + self._rate_limit, self._rate_window + ) self._blocked_until: float = 0 - self._lock = asyncio.Lock() self._concurrency_sem = asyncio.Semaphore(max_concurrency) self._initialized = True @@ -107,12 +109,11 @@ class GlobalRateLimiter: logger.info( "Rebuilding provider rate limiter for updated scope '{}'", scope ) - if scope not in cls._scoped_instances or existing: - cls._scoped_instances[scope] = cls( - rate_limit=desired_rate_limit, - rate_window=desired_rate_window, - max_concurrency=max_concurrency, - ) + cls._scoped_instances[scope] = cls( + rate_limit=desired_rate_limit, + rate_window=desired_rate_window, + max_concurrency=max_concurrency, + ) return cls._scoped_instances[scope] @classmethod @@ -150,27 +151,7 @@ class GlobalRateLimiter: Guarantees: at most `self._rate_limit` acquisitions in any interval of length `self._rate_window` (seconds). """ - while True: - wait_time = 0.0 - async with self._lock: - now = time.monotonic() - cutoff = now - self._rate_window - - while self._request_times and self._request_times[0] <= cutoff: - self._request_times.popleft() - - if len(self._request_times) < self._rate_limit: - self._request_times.append(now) - return - - oldest = self._request_times[0] - wait_time = max(0.0, (oldest + self._rate_window) - now) - - # Sleep outside the lock so other tasks can continue to queue. - if wait_time > 0: - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0) + await self._proactive_limiter.acquire() def set_blocked(self, seconds: float = 60) -> None: """ @@ -263,6 +244,24 @@ class GlobalRateLimiter: ) self.set_blocked(delay) await asyncio.sleep(delay) + except httpx.HTTPStatusError as e: + if e.response.status_code != 429: + raise + last_exc = e + if attempt >= max_retries: + logger.warning( + f"HTTP 429 retry exhausted after {max_retries} retries" + ) + break + + delay = min(base_delay * (2**attempt), max_delay) + delay += random.uniform(0, jitter) + logger.warning( + f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. " + f"Retrying in {delay:.1f}s..." + ) + self.set_blocked(delay) + await asyncio.sleep(delay) assert last_exc is not None raise last_exc diff --git a/providers/registry.py b/providers/registry.py index 75a6d8c..20ef4a0 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -3,108 +3,20 @@ from __future__ import annotations from collections.abc import Callable, MutableMapping -from dataclasses import dataclass -from typing import Literal -from config.provider_ids import SUPPORTED_PROVIDER_IDS +from config.provider_catalog import ( + PROVIDER_CATALOG, + SUPPORTED_PROVIDER_IDS, + ProviderDescriptor, +) from config.settings import Settings from providers.base import BaseProvider, ProviderConfig -from providers.defaults import ( - DEEPSEEK_DEFAULT_BASE, - LLAMACPP_DEFAULT_BASE, - LMSTUDIO_DEFAULT_BASE, - NVIDIA_NIM_DEFAULT_BASE, - OLLAMA_DEFAULT_BASE, - OPENROUTER_DEFAULT_BASE, -) from providers.exceptions import AuthenticationError, UnknownProviderTypeError -TransportType = Literal["openai_chat", "anthropic_messages"] ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider] - -@dataclass(frozen=True, slots=True) -class ProviderDescriptor: - """Metadata for building :class:`ProviderConfig` and factory wiring.""" - - provider_id: str - transport_type: TransportType - capabilities: tuple[str, ...] - credential_env: str | None = None - credential_url: str | None = None - # If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key). - credential_attr: str | None = None - # If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp). - static_credential: str | None = None - default_base_url: str | None = None - base_url_attr: str | None = None - proxy_attr: str | None = None - - -PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { - "nvidia_nim": ProviderDescriptor( - provider_id="nvidia_nim", - transport_type="openai_chat", - credential_env="NVIDIA_NIM_API_KEY", - credential_url="https://build.nvidia.com/settings/api-keys", - credential_attr="nvidia_nim_api_key", - default_base_url=NVIDIA_NIM_DEFAULT_BASE, - proxy_attr="nvidia_nim_proxy", - capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), - ), - "open_router": ProviderDescriptor( - provider_id="open_router", - transport_type="anthropic_messages", - credential_env="OPENROUTER_API_KEY", - credential_url="https://openrouter.ai/keys", - credential_attr="open_router_api_key", - default_base_url=OPENROUTER_DEFAULT_BASE, - proxy_attr="open_router_proxy", - capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"), - ), - "deepseek": ProviderDescriptor( - provider_id="deepseek", - transport_type="openai_chat", - credential_env="DEEPSEEK_API_KEY", - credential_url="https://platform.deepseek.com/api_keys", - credential_attr="deepseek_api_key", - default_base_url=DEEPSEEK_DEFAULT_BASE, - capabilities=("chat", "streaming", "thinking"), - ), - "lmstudio": ProviderDescriptor( - provider_id="lmstudio", - transport_type="anthropic_messages", - static_credential="lm-studio", - default_base_url=LMSTUDIO_DEFAULT_BASE, - base_url_attr="lm_studio_base_url", - proxy_attr="lmstudio_proxy", - capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), - ), - "llamacpp": ProviderDescriptor( - provider_id="llamacpp", - transport_type="anthropic_messages", - static_credential="llamacpp", - default_base_url=LLAMACPP_DEFAULT_BASE, - base_url_attr="llamacpp_base_url", - proxy_attr="llamacpp_proxy", - capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), - ), - "ollama": ProviderDescriptor( - provider_id="ollama", - transport_type="anthropic_messages", - static_credential="ollama", - default_base_url=OLLAMA_DEFAULT_BASE, - base_url_attr="ollama_base_url", - capabilities=( - "chat", - "streaming", - "tools", - "thinking", - "native_anthropic", - "local", - ), - ), -} +# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``). +PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider: @@ -119,25 +31,25 @@ def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProv return OpenRouterProvider(config) -def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.deepseek import DeepSeekProvider return DeepSeekProvider(config) -def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.lmstudio import LMStudioProvider return LMStudioProvider(config) -def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.llamacpp import LlamaCppProvider return LlamaCppProvider(config) -def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.ollama import OllamaProvider return OllamaProvider(config) @@ -208,6 +120,8 @@ def build_provider_config( http_connect_timeout=settings.http_connect_timeout, enable_thinking=settings.enable_model_thinking, proxy=proxy, + log_raw_sse_events=settings.log_raw_sse_events, + log_api_error_tracebacks=settings.log_api_error_tracebacks, ) diff --git a/server.py b/server.py index b025be0..958e730 100644 --- a/server.py +++ b/server.py @@ -1,11 +1,13 @@ """ Claude Code Proxy - Entry Point -Minimal entry point that imports the app from the api module. +Minimal entry point that builds the ASGI app via :func:`api.app.create_app`. Run with: uv run uvicorn server:app --host 0.0.0.0 --port 8082 --timeout-graceful-shutdown 5 """ -from api.app import app, create_app +from api.app import create_app + +app = create_app() __all__ = ["app", "create_app"] diff --git a/smoke/capabilities.py b/smoke/capabilities.py index 5f29fff..eafae44 100644 --- a/smoke/capabilities.py +++ b/smoke/capabilities.py @@ -98,7 +98,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = ( "providers.registry.ProviderRegistry", "provider id and Settings", "configured BaseProvider instance", - "503 for missing credentials; ValueError for unknown provider", + "503 for missing credentials; invalid_request_error for unknown provider", ("tests/api/test_dependencies.py", "tests/providers/test_registry.py"), ("test_configured_provider_models_stream_successfully",), ), diff --git a/smoke/lib/e2e.py b/smoke/lib/e2e.py index e669b51..8b20e80 100644 --- a/smoke/lib/e2e.py +++ b/smoke/lib/e2e.py @@ -19,6 +19,14 @@ import httpx import pytest from config.provider_ids import SUPPORTED_PROVIDER_IDS +from core.anthropic.stream_contracts import ( + SSEEvent, + assert_anthropic_stream_contract, + event_index, + has_tool_use, + parse_sse_lines, + text_content, +) from messaging.handler import ClaudeMessageHandler from messaging.models import IncomingMessage from messaging.platforms.base import MessagingPlatform @@ -26,13 +34,6 @@ from messaging.session import SessionStore from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers from smoke.lib.server import RunningServer, start_server from smoke.lib.skips import fail_missing_env -from smoke.lib.sse import ( - SSEEvent, - assert_anthropic_stream_contract, - has_tool_use, - parse_sse_lines, - text_content, -) @dataclass(slots=True) @@ -590,14 +591,14 @@ def assistant_content_from_events(events: list[SSEEvent]) -> list[dict[str, Any] block_order: list[int] = [] for event in events: if event.event == "content_block_start": - index = _event_index(event) + index = event_index(event) block = event.data.get("content_block", {}) if isinstance(block, dict): blocks[index] = dict(block) block_order.append(index) continue if event.event == "content_block_delta": - index = _event_index(event) + index = event_index(event) block = blocks.get(index) delta = event.data.get("delta", {}) if not isinstance(block, dict) or not isinstance(delta, dict): @@ -663,12 +664,6 @@ def default_cli_events(session_id: str) -> list[dict[str, Any]]: ] -def _event_index(event: SSEEvent) -> int: - value = event.data.get("index") - assert isinstance(value, int), event.data - return value - - def assert_product_stream(events: list[SSEEvent]) -> None: assert_anthropic_stream_contract(events) assert text_content(events).strip() or has_tool_use(events), ( diff --git a/smoke/lib/http.py b/smoke/lib/http.py index f019473..71e2cb0 100644 --- a/smoke/lib/http.py +++ b/smoke/lib/http.py @@ -6,9 +6,10 @@ from typing import Any import httpx +from core.anthropic.stream_contracts import SSEEvent, parse_sse_lines + from .config import SmokeConfig, auth_headers, redacted from .server import RunningServer -from .sse import SSEEvent, parse_sse_lines def message_payload( diff --git a/smoke/lib/skips.py b/smoke/lib/skips.py index ffadb85..1cb1e29 100644 --- a/smoke/lib/skips.py +++ b/smoke/lib/skips.py @@ -5,7 +5,7 @@ from __future__ import annotations import httpx import pytest -from smoke.lib.sse import SSEEvent +from core.anthropic.stream_contracts import SSEEvent UPSTREAM_UNAVAILABLE_MARKERS = ( "connection refused", diff --git a/smoke/lib/sse.py b/smoke/lib/sse.py deleted file mode 100644 index a294ce5..0000000 --- a/smoke/lib/sse.py +++ /dev/null @@ -1,29 +0,0 @@ -"""SSE parsing and Anthropic stream assertions for smoke tests. - -Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this -module re-exports it for smoke and historical import paths. -""" - -from __future__ import annotations - -from core.anthropic.stream_contracts import ( - SSEEvent, - assert_anthropic_stream_contract, - event_names, - has_tool_use, - parse_sse_lines, - parse_sse_text, - text_content, - thinking_content, -) - -__all__ = [ - "SSEEvent", - "assert_anthropic_stream_contract", - "event_names", - "has_tool_use", - "parse_sse_lines", - "parse_sse_text", - "text_content", - "thinking_content", -] diff --git a/smoke/prereq/test_provider_prereq_live.py b/smoke/prereq/test_provider_prereq_live.py index 8731278..0d85356 100644 --- a/smoke/prereq/test_provider_prereq_live.py +++ b/smoke/prereq/test_provider_prereq_live.py @@ -5,6 +5,10 @@ import time import httpx import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + text_content, +) from smoke.lib.config import SmokeConfig, auth_headers from smoke.lib.http import collect_message_stream, message_payload from smoke.lib.server import start_server @@ -12,7 +16,6 @@ from smoke.lib.skips import ( skip_if_upstream_unavailable_events, skip_if_upstream_unavailable_exception, ) -from smoke.lib.sse import assert_anthropic_stream_contract, text_content pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")] diff --git a/smoke/prereq/test_tools_prereq_live.py b/smoke/prereq/test_tools_prereq_live.py index 824197b..716678c 100644 --- a/smoke/prereq/test_tools_prereq_live.py +++ b/smoke/prereq/test_tools_prereq_live.py @@ -2,11 +2,14 @@ from __future__ import annotations import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + has_tool_use, +) from smoke.lib.config import SmokeConfig from smoke.lib.http import collect_message_stream, message_payload from smoke.lib.server import start_server from smoke.lib.skips import skip_if_upstream_unavailable_events -from smoke.lib.sse import assert_anthropic_stream_contract, has_tool_use pytestmark = [pytest.mark.live, pytest.mark.smoke_target("tools")] diff --git a/smoke/prereq/test_voice_prereq_live.py b/smoke/prereq/test_voice_prereq_live.py index 21e4dcb..50ce9e5 100644 --- a/smoke/prereq/test_voice_prereq_live.py +++ b/smoke/prereq/test_voice_prereq_live.py @@ -24,12 +24,13 @@ def test_voice_transcription_backend_when_explicitly_enabled( wav_path = tmp_path / "smoke-tone.wav" _write_tone_wav(wav_path) try: - text = transcribe_audio( - wav_path, - "audio/wav", - whisper_model=smoke_config.settings.whisper_model, - whisper_device=smoke_config.settings.whisper_device, - ) + t_kw: dict[str, str] = { + "whisper_model": smoke_config.settings.whisper_model, + "whisper_device": smoke_config.settings.whisper_device, + } + if smoke_config.settings.whisper_device == "nvidia_nim": + t_kw["nvidia_nim_api_key"] = smoke_config.settings.nvidia_nim_api_key + text = transcribe_audio(wav_path, "audio/wav", **t_kw) except ImportError as exc: pytest.skip(str(exc)) assert isinstance(text, str) diff --git a/smoke/product/test_provider_product_live.py b/smoke/product/test_provider_product_live.py index ac61839..be5d45e 100644 --- a/smoke/product/test_provider_product_live.py +++ b/smoke/product/test_provider_product_live.py @@ -3,6 +3,11 @@ from __future__ import annotations import httpx import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_lines, + text_content, +) from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers from smoke.lib.e2e import ( ConversationDriver, @@ -16,11 +21,6 @@ from smoke.lib.skips import ( skip_if_upstream_unavailable_events, skip_if_upstream_unavailable_exception, ) -from smoke.lib.sse import ( - assert_anthropic_stream_contract, - parse_sse_lines, - text_content, -) pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")] diff --git a/smoke/product/test_voice_product_live.py b/smoke/product/test_voice_product_live.py index d1c27f5..0327168 100644 --- a/smoke/product/test_voice_product_live.py +++ b/smoke/product/test_voice_product_live.py @@ -55,6 +55,7 @@ def test_voice_nim_backend_e2e(smoke_config: SmokeConfig, tmp_path: Path) -> Non "audio/wav", whisper_model=smoke_config.settings.whisper_model, whisper_device="nvidia_nim", + nvidia_nim_api_key=smoke_config.settings.nvidia_nim_api_key, ) assert isinstance(text, str) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..10436b4 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite package.""" diff --git a/tests/api/test_anthropic_request_passthrough.py b/tests/api/test_anthropic_request_passthrough.py new file mode 100644 index 0000000..15ee3db --- /dev/null +++ b/tests/api/test_anthropic_request_passthrough.py @@ -0,0 +1,184 @@ +"""Pydantic passthrough of Anthropic protocol fields (e.g. ``cache_control``).""" + +from __future__ import annotations + +from api.models.anthropic import ( + ContentBlockServerToolUse, + ContentBlockText, + ContentBlockWebSearchToolResult, + Message, + MessagesRequest, + SystemContent, + Tool, +) +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS +from core.anthropic.native_messages_request import ( + build_base_native_anthropic_request_body, +) + + +def test_cache_control_preserved_on_parsed_user_text_system_and_tool() -> None: + raw = { + "model": "m", + "max_tokens": 20, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hi", + "cache_control": {"type": "ephemeral"}, + } + ], + } + ], + "system": [ + { + "type": "text", + "text": "be brief", + "cache_control": {"type": "ephemeral"}, + } + ], + "tools": [ + { + "name": "alpha", + "input_schema": {"type": "object"}, + "cache_control": {"type": "ephemeral"}, + } + ], + } + req = MessagesRequest.model_validate(raw) + block = req.messages[0].content[0] + assert isinstance(block, ContentBlockText) + assert block.model_dump()["cache_control"] == {"type": "ephemeral"} + s0 = req.system[0] if isinstance(req.system, list) else None + assert isinstance(s0, SystemContent) + assert s0.model_dump()["cache_control"] == {"type": "ephemeral"} + t0 = req.tools[0] if req.tools else None + assert isinstance(t0, Tool) + assert t0.model_dump()["cache_control"] == {"type": "ephemeral"} + + +def test_build_base_native_body_includes_cache_control() -> None: + req = MessagesRequest( + model="m", + max_tokens=20, + messages=[ + Message( + role="user", + content=[ + ContentBlockText.model_validate( + { + "type": "text", + "text": "x", + "cache_control": {"type": "ephemeral"}, + } + ) + ], + ) + ], + system=[ + SystemContent.model_validate( + { + "type": "text", + "text": "s", + "cache_control": {"type": "ephemeral"}, + } + ) + ], + tools=[ + Tool.model_validate( + { + "name": "n", + "input_schema": {"type": "object"}, + "cache_control": {"type": "ephemeral"}, + } + ) + ], + ) + body = build_base_native_anthropic_request_body( + req, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=False, + ) + assert body["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"} + assert body["system"][0]["cache_control"] == {"type": "ephemeral"} + assert body["tools"][0]["cache_control"] == {"type": "ephemeral"} + + +def test_pydantic_discriminator_still_distinguishes_blocks() -> None: + m = Message.model_validate( + { + "role": "user", + "content": [{"type": "text", "text": "a", "z": 1}], + } + ) + b = m.content[0] + assert isinstance(b, ContentBlockText) + assert b.model_dump()["z"] == 1 + + +def test_server_tool_assistant_blocks_round_trip_in_native_body() -> None: + """Local server-tool responses must parse as valid history for a follow-up request.""" + raw = { + "model": "m", + "max_tokens": 20, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "server_tool_use", + "id": "srvtoolu_1", + "name": "web_search", + "input": {"query": "q"}, + }, + { + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_1", + "content": [ + { + "type": "web_search_result", + "title": "T", + "url": "https://example.com", + } + ], + }, + ], + } + ], + "mcp_servers": [{"type": "url", "url": "https://example.com/mcp"}], + } + req = MessagesRequest.model_validate(raw) + assert len(req.messages) == 1 + blocks = req.messages[0].content + assert isinstance(blocks, list) + assert isinstance(blocks[0], ContentBlockServerToolUse) + assert isinstance(blocks[1], ContentBlockWebSearchToolResult) + body = build_base_native_anthropic_request_body( + req, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=False, + ) + assert body["mcp_servers"][0]["type"] == "url" + assert body["messages"][0]["content"][0]["type"] == "server_tool_use" + assert body["messages"][0]["content"][1]["type"] == "web_search_tool_result" + + +def test_native_body_preserves_context_and_output_config() -> None: + raw = { + "model": "m", + "max_tokens": 20, + "messages": [{"role": "user", "content": "x"}], + "context_management": {"edits": [{"type": "clear"}]}, + "output_config": {"some": "hint"}, + } + req = MessagesRequest.model_validate(raw) + body = build_base_native_anthropic_request_body( + req, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=False, + ) + assert body["context_management"] == raw["context_management"] + assert body["output_config"] == raw["output_config"] diff --git a/tests/api/test_api.py b/tests/api/test_api.py index e4912e7..b387048 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -3,9 +3,11 @@ from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient -from api.app import app +from api.app import create_app from providers.nvidia_nim import NvidiaNimProvider +app = create_app() + # Mock provider mock_provider = MagicMock(spec=NvidiaNimProvider) @@ -171,7 +173,7 @@ def test_generic_exception_returns_500(client: TestClient): def test_generic_exception_with_status_code(client: TestClient): - """Generic exception with status_code attribute uses that status (getattr fallback).""" + """Unexpected errors always map to HTTP 500 (ignore ad-hoc status_code attrs).""" class ExceptionWithStatus(RuntimeError): def __init__(self, msg: str, status_code: int = 500): @@ -191,7 +193,7 @@ def test_generic_exception_with_status_code(client: TestClient): "stream": True, }, ) - assert response.status_code == 502 + assert response.status_code == 500 mock_provider.stream_response = _mock_stream_response diff --git a/tests/api/test_app_lifespan_and_errors.py b/tests/api/test_app_lifespan_and_errors.py index 126a869..adc81c5 100644 --- a/tests/api/test_app_lifespan_and_errors.py +++ b/tests/api/test_app_lifespan_and_errors.py @@ -17,6 +17,15 @@ _RUNTIME_EXTRAS = { "nvidia_nim_api_key": "", "claude_cli_bin": "claude", "uses_process_anthropic_auth_token": lambda: False, + "messaging_rate_limit": 1, + "messaging_rate_window": 1.0, + "max_message_log_entries_per_chat": None, + "debug_platform_edits": False, + "debug_subagent_stack": False, + "log_api_error_tracebacks": False, + "log_raw_messaging_content": False, + "log_raw_cli_diagnostics": False, + "log_messaging_error_details": False, } @@ -81,9 +90,50 @@ def test_create_app_provider_error_handler_returns_anthropic_format(): with TestClient(app) as client: resp = client.get("/raise_provider") assert resp.status_code == 401 - body = resp.json() - assert body["type"] == "error" - assert body["error"]["type"] == "authentication_error" + body = resp.json() + assert body["type"] == "error" + assert body["error"]["type"] == "authentication_error" + + +def test_create_app_provider_error_default_logs_exclude_provider_message(): + """Provider errors must not log exc.message by default.""" + from api.app import create_app + from providers.exceptions import AuthenticationError + + app = create_app() + secret = "provider-upstream-secret-detail" + + @app.get("/raise_provider_secret") + async def _raise(): + raise AuthenticationError(secret) + + api_app_mod = importlib.import_module("api.app") + settings = _app_settings( + messaging_platform="telegram", + telegram_bot_token=None, + allowed_telegram_user_id=None, + discord_bot_token=None, + allowed_discord_channels=None, + allowed_dir="", + claude_workspace="./agent_workspace", + host="127.0.0.1", + port=8082, + log_file="server.log", + log_api_error_tracebacks=False, + ) + with ( + patch.object(api_app_mod, "get_settings", return_value=settings), + patch.object(ProviderRegistry, "cleanup", new=AsyncMock()), + patch.object(api_app_mod.logger, "error") as log_err, + ): + with TestClient(app) as client: + resp = client.get("/raise_provider_secret") + assert resp.status_code == 401 + + blob = " ".join(str(a) for c in log_err.call_args_list for a in c.args) + blob += repr([c.kwargs for c in log_err.call_args_list]) + assert secret not in blob + assert "authentication_error" in blob def test_create_app_general_exception_handler_returns_500(): @@ -120,6 +170,50 @@ def test_create_app_general_exception_handler_returns_500(): assert body["error"]["type"] == "api_error" +def test_create_app_general_exception_default_logs_exclude_exception_message(): + """Unhandled errors must not log exception text by default (may echo user content).""" + from api.app import create_app + + app = create_app() + + secret = "user-provided-secret-token-xyzzy" + + @app.get("/raise_secret") + async def _raise_secret(): + raise ValueError(secret) + + api_app_mod = importlib.import_module("api.app") + settings = _app_settings( + messaging_platform="telegram", + telegram_bot_token=None, + allowed_telegram_user_id=None, + discord_bot_token=None, + allowed_discord_channels=None, + allowed_dir="", + claude_workspace="./agent_workspace", + host="127.0.0.1", + port=8082, + log_file="server.log", + log_api_error_tracebacks=False, + ) + with ( + patch.object(api_app_mod, "get_settings", return_value=settings), + patch.object(ProviderRegistry, "cleanup", new=AsyncMock()), + patch.object(api_app_mod.logger, "error") as log_err, + ): + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/raise_secret") + assert resp.status_code == 500 + + flattened: list[str] = [] + for call in log_err.call_args_list: + flattened.extend(str(arg) for arg in call.args) + flattened.append(repr(call.kwargs)) + blob = " ".join(flattened) + assert secret not in blob + assert "ValueError" in blob + + @pytest.mark.parametrize( "messaging_enabled", [True, False], ids=["with_platform", "no_platform"] ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e661873..99230a3 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -2,10 +2,12 @@ from unittest.mock import patch from fastapi.testclient import TestClient -from api.app import app +from api.app import create_app from api.dependencies import get_settings from config.settings import Settings +app = create_app() + def test_anthropic_auth_token_required_and_accepts_x_api_key(): client = TestClient(app) diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py index 2849667..5be4c26 100644 --- a/tests/api/test_dependencies.py +++ b/tests/api/test_dependencies.py @@ -16,7 +16,7 @@ from api.dependencies import ( ) from config.nim import NimSettings from providers.deepseek import DeepSeekProvider -from providers.exceptions import UnknownProviderTypeError +from providers.exceptions import ServiceUnavailableError, UnknownProviderTypeError from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NvidiaNimProvider from providers.ollama import OllamaProvider @@ -40,7 +40,7 @@ def _make_mock_settings(**overrides): mock.nim = NimSettings() mock.http_read_timeout = 300.0 mock.http_write_timeout = 10.0 - mock.http_connect_timeout = 2.0 + mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True for key, value in overrides.items(): setattr(mock, key, value) @@ -434,18 +434,17 @@ def test_resolve_provider_per_app_uses_separate_registries() -> None: assert p1 is not p2 -def test_resolve_provider_lazily_installs_registry() -> None: - """If app has no provider_registry, one is created on app.state.""" +def test_resolve_provider_missing_registry_raises_service_unavailable() -> None: + """HTTP apps must install app.state.provider_registry (e.g. via AppRuntime).""" with patch("api.dependencies.get_settings") as mock_settings: mock_settings.return_value = _make_mock_settings() settings = _make_mock_settings() app = SimpleNamespace(state=State()) assert getattr(app.state, "provider_registry", None) is None - resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) - reg = app.state.provider_registry - assert reg is not None - p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) - assert p2 is reg.get("nvidia_nim", settings) # same registry instance + with pytest.raises( + ServiceUnavailableError, match="Provider registry is not configured" + ): + resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None: diff --git a/tests/api/test_routes_optimizations.py b/tests/api/test_routes_optimizations.py index e0c058c..eb8d2dd 100644 --- a/tests/api/test_routes_optimizations.py +++ b/tests/api/test_routes_optimizations.py @@ -3,10 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient -from api.app import app +from api.app import create_app from api.dependencies import get_settings from config.settings import Settings +app = create_app() + @pytest.fixture def client(): @@ -103,6 +105,17 @@ def test_create_message_empty_messages_returns_400(client): assert "cannot be empty" in data.get("error", {}).get("message", "") +def test_count_tokens_empty_messages_returns_400(client): + """POST /v1/messages/count_tokens with messages: [] matches messages validation.""" + payload = {"model": "claude-3-sonnet", "messages": []} + response = client.post("/v1/messages/count_tokens", json=payload) + assert response.status_code == 400 + data = response.json() + assert data.get("type") == "error" + assert data.get("error", {}).get("type") == "invalid_request_error" + assert "cannot be empty" in data.get("error", {}).get("message", "") + + def test_count_tokens_endpoint(client): payload = { "model": "claude-3-sonnet", diff --git a/tests/api/test_runtime_safe_logging.py b/tests/api/test_runtime_safe_logging.py new file mode 100644 index 0000000..5c86c15 --- /dev/null +++ b/tests/api/test_runtime_safe_logging.py @@ -0,0 +1,70 @@ +"""Tests for safe default logging in :mod:`api.runtime`.""" + +import importlib +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from tests.api.test_app_lifespan_and_errors import _app_settings + + +@pytest.mark.asyncio +async def test_messaging_start_failure_default_logs_exclude_traceback(caplog): + api_runtime_mod = importlib.import_module("api.runtime") + settings = _app_settings( + messaging_platform="telegram", + telegram_bot_token="t", + allowed_telegram_user_id="1", + discord_bot_token=None, + allowed_discord_channels=None, + allowed_dir="", + claude_workspace="./agent_workspace", + host="127.0.0.1", + port=8082, + log_file="server.log", + log_api_error_tracebacks=False, + ) + runtime = api_runtime_mod.AppRuntime(app=MagicMock(), settings=settings) + + with ( + patch( + "messaging.platforms.factory.create_messaging_platform", + side_effect=RuntimeError("SECRET_RUNTIME_DETAIL"), + ), + caplog.at_level(logging.ERROR), + ): + await runtime._start_messaging_if_configured() + + blob = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_RUNTIME_DETAIL" not in blob + assert "exc_type=RuntimeError" in blob + + +@pytest.mark.asyncio +async def test_best_effort_default_logs_exclude_exception_text(caplog): + api_runtime_mod = importlib.import_module("api.runtime") + + async def boom(): + raise ValueError("SECRET_SHUTDOWN") + + with caplog.at_level(logging.WARNING): + await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=False) + + blob = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_SHUTDOWN" not in blob + assert "exc_type=ValueError" in blob + + +@pytest.mark.asyncio +async def test_best_effort_verbose_includes_exception_text(caplog): + api_runtime_mod = importlib.import_module("api.runtime") + + async def boom(): + raise ValueError("VISIBLE_SHUTDOWN") + + with caplog.at_level(logging.WARNING): + await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=True) + + blob = " | ".join(r.getMessage() for r in caplog.records) + assert "VISIBLE_SHUTDOWN" in blob diff --git a/tests/api/test_safe_logging.py b/tests/api/test_safe_logging.py new file mode 100644 index 0000000..1096516 --- /dev/null +++ b/tests/api/test_safe_logging.py @@ -0,0 +1,208 @@ +"""Tests that API and SSE logging avoid raw sensitive payloads by default.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException + +from api import services as services_mod +from api.models.anthropic import Message, MessagesRequest +from api.services import ClaudeProxyService +from config.settings import Settings +from core.anthropic.sse import SSEBuilder + + +def test_create_message_skips_full_payload_debug_log_by_default(): + settings = Settings() + assert settings.log_raw_api_payloads is False + mock_provider = MagicMock() + + async def fake_stream(*_a, **_kw): + yield "event: ping\ndata: {}\n\n" + + mock_provider.stream_response = fake_stream + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="secret-user-text")], + ) + + with patch.object(services_mod.logger, "debug") as mock_debug: + service.create_message(request) + + full_payload_calls = [ + c + for c in mock_debug.call_args_list + if c.args and str(c.args[0]) == "FULL_PAYLOAD [{}]: {}" + ] + assert not full_payload_calls + + +def test_create_message_logs_full_payload_when_opt_in(): + settings = Settings() + settings.log_raw_api_payloads = True + mock_provider = MagicMock() + + async def fake_stream(*_a, **_kw): + yield "event: ping\ndata: {}\n\n" + + mock_provider.stream_response = fake_stream + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="visible")], + ) + + with patch.object(services_mod.logger, "debug") as mock_debug: + service.create_message(request) + + keys = [c.args[0] for c in mock_debug.call_args_list if c.args] + assert any(k == "FULL_PAYLOAD [{}]: {}" for k in keys) + + +def test_sse_builder_default_debug_has_no_serialized_json_content(): + with patch("core.anthropic.sse.logger.debug") as mock_debug: + sse = SSEBuilder("msg_x", "m", 1, log_raw_events=False) + sse.message_start() + + assert mock_debug.call_count == 1 + message = str(mock_debug.call_args) + assert "serialized_bytes=" in message + assert "role" not in message + assert "assistant" not in message + + +def test_sse_builder_raw_logging_includes_event_body_when_enabled(): + with patch("core.anthropic.sse.logger.debug") as mock_debug: + sse = SSEBuilder("msg_x", "m", 1, log_raw_events=True) + sse.message_start() + + assert mock_debug.call_count == 1 + message = str(mock_debug.call_args) + assert "message_start" in message + assert "role" in message + + +def _flatten_log_calls(mock_log) -> str: + parts: list[str] = [] + for call in mock_log.call_args_list: + parts.extend(str(arg) for arg in call.args) + parts.append(repr(call.kwargs)) + return " ".join(parts) + + +def test_create_message_unexpected_error_default_logs_exclude_exception_text(): + settings = Settings() + assert settings.log_api_error_tracebacks is False + secret = "upstream-secret-token-abc" + + mock_provider = MagicMock() + + def stream_boom(*_a, **_kw): + raise RuntimeError(secret) + + mock_provider.stream_response = stream_boom + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="hi")], + ) + + with ( + patch.object(services_mod.logger, "error") as log_err, + pytest.raises(HTTPException), + ): + service.create_message(request) + + blob = _flatten_log_calls(log_err) + assert secret not in blob + assert "RuntimeError" in blob + + +def test_create_message_unexpected_error_always_returns_500(): + """Non-provider failures must not leak arbitrary status_code attributes.""" + + class WeirdError(Exception): + status_code = 418 + + settings = Settings() + mock_provider = MagicMock() + + def stream_boom(*_a, **_kw): + raise WeirdError("no") + + mock_provider.stream_response = stream_boom + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="hi")], + ) + + with pytest.raises(HTTPException) as excinfo: + service.create_message(request) + + assert excinfo.value.status_code == 500 + + +def test_parse_cli_event_error_logs_metadata_by_default(): + """CLI parser must not log raw error text unless LOG_RAW_CLI_DIAGNOSTICS is on.""" + from messaging.event_parser import parse_cli_event + + secret = "user-secret-parser-leak-xyz" + with patch("messaging.event_parser.logger.info") as log_info: + parse_cli_event( + {"type": "error", "error": {"message": secret}}, log_raw_cli=False + ) + flat = " ".join(str(c) for c in log_info.call_args_list) + assert secret not in flat + assert "message_chars" in flat + + +def test_parse_cli_event_error_logs_text_when_log_raw_cli_enabled(): + from messaging.event_parser import parse_cli_event + + secret = "visible-cli-parser-msg" + with patch("messaging.event_parser.logger.info") as log_info: + parse_cli_event( + {"type": "error", "error": {"message": secret}}, log_raw_cli=True + ) + flat = " ".join(str(c) for c in log_info.call_args_list) + assert secret in flat + + +def test_count_tokens_unexpected_error_default_logs_exclude_exception_text(): + settings = Settings() + assert settings.log_api_error_tracebacks is False + secret = "count-tokens-leak-xyz" + + def boom(*_a, **_kw): + raise ValueError(secret) + + service = ClaudeProxyService( + settings, + provider_getter=lambda _: MagicMock(), + token_counter=boom, + ) + from api.models.anthropic import TokenCountRequest + + req = TokenCountRequest( + model="claude-3-haiku-20240307", + messages=[Message(role="user", content="x")], + ) + + with ( + patch.object(services_mod.logger, "error") as log_err, + pytest.raises(HTTPException), + ): + service.count_tokens(req) + + blob = _flatten_log_calls(log_err) + assert secret not in blob + assert "ValueError" in blob diff --git a/tests/api/test_validation_log.py b/tests/api/test_validation_log.py new file mode 100644 index 0000000..9408745 --- /dev/null +++ b/tests/api/test_validation_log.py @@ -0,0 +1,33 @@ +"""Tests for validation log summaries (metadata only).""" + +from api.validation_log import summarize_request_validation_body + + +def test_summarize_lists_block_metadata_without_echoing_string_content(): + body = { + "messages": [ + { + "role": "user", + "content": "secret user phrase", + } + ], + "tools": [{"name": "web_search", "type": "web_search_20250305"}], + } + summary, tool_names = summarize_request_validation_body(body) + assert summary == [ + { + "role": "user", + "content_kind": "str", + "content_length": 18, + } + ] + assert tool_names == ["web_search"] + blob = repr(summary) + repr(tool_names) + assert "secret" not in blob + + +def test_summarize_handles_non_dict_messages_and_missing_tools(): + body = {"messages": ["not_a_dict"]} + summary, tool_names = summarize_request_validation_body(body) + assert summary == [{"message_kind": "str"}] + assert tool_names == [] diff --git a/tests/api/test_web_server_tools.py b/tests/api/test_web_server_tools.py index c00d538..c46ddf2 100644 --- a/tests/api/test_web_server_tools.py +++ b/tests/api/test_web_server_tools.py @@ -1,20 +1,65 @@ -import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest +import api.web_tools.constants as web_tool_constants +from api.model_router import ModelRouter, ResolvedModel, RoutedMessagesRequest from api.models.anthropic import Message, MessagesRequest, Tool -from api.web_server_tools import ( - is_web_server_tool_request, - stream_web_server_tool_response, +from api.services import ClaudeProxyService +from api.web_tools import egress as web_egress +from api.web_tools.egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + enforce_web_fetch_egress, +) +from api.web_tools.outbound import ( + _drain_response_body_capped, + _read_response_body_capped, + _run_web_fetch, +) +from api.web_tools.request import is_web_server_tool_request +from api.web_tools.streaming import stream_web_server_tool_response +from config.settings import Settings +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, + text_content, +) +from messaging.event_parser import parse_cli_event +from providers.exceptions import InvalidRequestError + +_STRICT_EGRESS = WebFetchEgressPolicy( + allow_private_network_targets=False, + allowed_schemes=frozenset({"http", "https"}), ) -def _event_data(event: str) -> dict: - data_line = next(line for line in event.splitlines() if line.startswith("data: ")) - return json.loads(data_line.removeprefix("data: ")) +class FixedProviderModelRouter(ModelRouter): + """Test double: pin ``provider_id`` for OpenAI vs native routing assertions.""" + + def __init__(self, settings: Settings, provider_id: str) -> None: + super().__init__(settings) + self._fixed_provider_id = provider_id + + def resolve_messages_request( + self, request: MessagesRequest + ) -> RoutedMessagesRequest: + resolved = ResolvedModel( + original_model=request.model, + provider_id=self._fixed_provider_id, + provider_model=request.model, + provider_model_ref=f"{self._fixed_provider_id}/{request.model}", + thinking_enabled=False, + ) + routed = request.model_copy(deep=True) + routed.model = resolved.provider_model + return RoutedMessagesRequest(request=routed, resolved=resolved) -def test_detects_web_search_server_tool_request(): +def test_web_server_tool_not_detected_when_tool_only_listed(): + """Listing web_search without forcing it must not skip the upstream provider.""" request = MessagesRequest( model="claude-haiku-4-5-20251001", max_tokens=100, @@ -22,16 +67,243 @@ def test_detects_web_search_server_tool_request(): tools=[Tool(name="web_search", type="web_search_20250305")], ) + assert not is_web_server_tool_request(request) + + +def test_web_server_tool_detected_when_tool_choice_forces_it(): + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="search")], + tools=[Tool(name="web_search", type="web_search_20250305")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + assert is_web_server_tool_request(request) +def test_web_server_tool_not_detected_when_forced_name_missing_from_tools(): + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="hi")], + tools=[Tool(name="other", type="function")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + + assert not is_web_server_tool_request(request) + + +def test_service_rejects_forced_server_tool_on_openai_when_disabled(): + """OpenAI Chat upstreams cannot run forced server tools without the local handler.""" + settings = Settings() + assert settings.enable_web_server_tools is False + service = ClaudeProxyService( + settings, + provider_getter=lambda _: MagicMock(), + model_router=FixedProviderModelRouter(settings, "nvidia_nim"), + ) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content="Perform a web search for the query: DeepSeek V4 model release 2026", + ) + ], + tools=[Tool(name="web_search", type="web_search_20250305")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + with pytest.raises(InvalidRequestError, match="ENABLE_WEB_SERVER_TOOLS"): + service.create_message(request) + + +@pytest.mark.parametrize( + "url", + [ + "http://127.0.0.1/", + "http://192.168.1.1/", + "http://10.0.0.1/", + "http://[::1]/", + "http://localhost/foo", + "http://mybox.local/", + "file:///etc/passwd", + "http://169.254.169.254/latest/meta-data/", + ], +) +def test_enforce_web_fetch_egress_blocks_internal_or_disallowed(url: str): + with pytest.raises(WebFetchEgressViolation): + enforce_web_fetch_egress(url, _STRICT_EGRESS) + + +def test_enforce_web_fetch_egress_allows_global_literal_ip(): + enforce_web_fetch_egress("http://8.8.8.8/", _STRICT_EGRESS) + + +def test_enforce_web_fetch_egress_skips_private_checks_when_opted_in(): + enforce_web_fetch_egress( + "http://127.0.0.1/", + WebFetchEgressPolicy( + allow_private_network_targets=True, + allowed_schemes=frozenset({"http", "https"}), + ), + ) + + +def _cm(mock_client: MagicMock) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +def _stream_cm(response: httpx.Response) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=response) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +def _aiohttp_response( + status: int, + *, + url: str = "http://8.8.8.8/", + location: str | None = None, + body: bytes = b"hello world", +) -> MagicMock: + r = MagicMock() + r.status = status + r.url = url + hdrs: dict[str, str] = {} + if location is not None: + hdrs["location"] = location + r.headers = hdrs + r.get_encoding = MagicMock(return_value="utf-8") + r.raise_for_status = MagicMock() + r.request_info = MagicMock() + r.history = () + + async def iter_chunked(_n: int) -> Any: + yield body + + r.content.iter_chunked = MagicMock(side_effect=iter_chunked) + return r + + +def _aiohttp_client_session_patch( + *responses: MagicMock, +) -> tuple[MagicMock, MagicMock]: + """Build ``ClientSession`` mock that serves ``responses`` to successive ``get`` calls.""" + queue = list(responses) + n = 0 + + def get_side(*_a: Any, **_k: Any) -> Any: + nonlocal n + resp = queue[n] if n < len(queue) else queue[-1] + n += 1 + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=resp) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + session = MagicMock() + session.get = MagicMock(side_effect=get_side) + + client_cm = MagicMock() + client_cm.__aenter__ = AsyncMock(return_value=session) + client_cm.__aexit__ = AsyncMock(return_value=None) + return client_cm, session + + +def test_enforce_web_fetch_egress_documents_connect_time_pinning(): + assert enforce_web_fetch_egress.__doc__ and "resolved addresses" in ( + enforce_web_fetch_egress.__doc__ or "" + ) + assert ( + web_egress.get_validated_stream_addrinfos_for_egress.__doc__ + and "pinning" + in (web_egress.get_validated_stream_addrinfos_for_egress.__doc__ or "") + ) + assert "DNS-pinned" in (_run_web_fetch.__doc__ or "") + + +@pytest.mark.asyncio +async def test_run_web_fetch_follows_redirect_when_each_hop_is_allowed(): + res_redirect = _aiohttp_response( + 302, url="http://8.8.8.8/start", location="/final", body=b"" + ) + res_ok = _aiohttp_response(200, url="http://8.8.8.8/final", body=b"hello world") + client_cm, session = _aiohttp_client_session_patch(res_redirect, res_ok) + with patch("api.web_tools.outbound.ClientSession", return_value=client_cm): + out = await _run_web_fetch("http://8.8.8.8/start", _STRICT_EGRESS) + + assert out["data"] == "hello world" + assert session.get.call_count == 2 + + +@pytest.mark.asyncio +async def test_run_web_fetch_truncates_large_body_to_byte_cap(monkeypatch): + huge = b"x" * 5000 + res_ok = _aiohttp_response(200, url="http://8.8.8.8/big", body=huge) + client_cm, _ = _aiohttp_client_session_patch(res_ok) + monkeypatch.setattr(web_tool_constants, "_MAX_WEB_FETCH_RESPONSE_BYTES", 100) + with patch("api.web_tools.outbound.ClientSession", return_value=client_cm): + out = await _run_web_fetch("http://8.8.8.8/big", _STRICT_EGRESS) + + assert len(out["data"]) <= 100 + assert out["data"] == "x" * 100 + + +@pytest.mark.asyncio +async def test_run_web_fetch_redirect_to_blocked_host_raises(): + res_redirect = _aiohttp_response( + 302, + url="http://8.8.8.8/start", + location="http://127.0.0.1/secret", + body=b"", + ) + client_cm, session = _aiohttp_client_session_patch(res_redirect) + with ( + patch("api.web_tools.outbound.ClientSession", return_value=client_cm), + pytest.raises(WebFetchEgressViolation), + ): + await _run_web_fetch("http://8.8.8.8/start", _STRICT_EGRESS) + + session.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_web_fetch_redirect_without_location_raises(): + res_bad = _aiohttp_response(302, url="http://8.8.8.8/here", body=b"") + client_cm, _ = _aiohttp_client_session_patch(res_bad) + with ( + patch("api.web_tools.outbound.ClientSession", return_value=client_cm), + pytest.raises(WebFetchEgressViolation, match="missing Location"), + ): + await _run_web_fetch("http://8.8.8.8/here", _STRICT_EGRESS) + + +@pytest.mark.asyncio +async def test_run_web_fetch_excess_redirects_raises(): + res1 = _aiohttp_response(302, url="http://8.8.8.8/a", location="/b", body=b"") + res2 = _aiohttp_response(302, url="http://8.8.8.8/b", location="/c", body=b"") + client_cm, _ = _aiohttp_client_session_patch(res1, res2) + with ( + patch("api.web_tools.constants._MAX_WEB_FETCH_REDIRECTS", 1), + patch("api.web_tools.outbound.ClientSession", return_value=client_cm), + pytest.raises(WebFetchEgressViolation, match="exceeded maximum redirects"), + ): + await _run_web_fetch("http://8.8.8.8/a", _STRICT_EGRESS) + + @pytest.mark.asyncio async def test_streams_web_search_server_tool_result(monkeypatch): async def fake_search(query: str) -> list[dict[str, str]]: assert query == "DeepSeek V4 model release 2026" return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}] - monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search) + monkeypatch.setattr("api.web_tools.outbound._run_web_search", fake_search) request = MessagesRequest( model="claude-haiku-4-5-20251001", max_tokens=100, @@ -47,22 +319,92 @@ async def test_streams_web_search_server_tool_result(monkeypatch): tool_choice={"type": "tool", "name": "web_search"}, ) - events = [ - event - async for event in stream_web_server_tool_response(request, input_tokens=42) + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, input_tokens=42, web_fetch_egress=_STRICT_EGRESS + ) + ] + ) + events = parse_sse_text(raw) + assert_anthropic_stream_contract(events) + starts = [e for e in events if e.event == "content_block_start"] + assert starts[0].data["content_block"]["type"] == "server_tool_use" + assert starts[0].data["content_block"]["name"] == "web_search" + tool_use_id = starts[0].data["content_block"]["id"] + assert starts[1].data["content_block"]["type"] == "web_search_tool_result" + assert starts[1].data["content_block"]["tool_use_id"] == tool_use_id + assert starts[1].data["content_block"]["content"][0]["url"] == ( + "https://example.com/v4" + ) + text_deltas = [ + e + for e in events + if e.event == "content_block_delta" + and e.data.get("delta", {}).get("type") == "text_delta" ] - payloads = [_event_data(event) for event in events] + assert text_deltas, "summary must be streamed as text_delta" + assert "example.com" in text_content(events) + cli_text: list[str] = [] + for ev in events: + cli_text.extend( + str(p.get("text", "")) + for p in parse_cli_event(ev.data) + if p.get("type") == "text_delta" + ) + assert "example.com" in "".join(cli_text) + deltas = [e for e in events if e.event == "message_delta"] + assert deltas[-1].data["usage"]["server_tool_use"] == {"web_search_requests": 1} - assert payloads[1]["content_block"]["type"] == "server_tool_use" - assert payloads[1]["content_block"]["name"] == "web_search" - assert payloads[3]["content_block"]["type"] == "web_search_tool_result" - assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4" - assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1} + +@pytest.mark.asyncio +async def test_forced_web_fetch_ignores_stale_url_from_prior_user_turns(monkeypatch): + """Only the latest user message supplies the URL (not earlier transcript text).""" + target = "https://new-only.example.com/page" + + async def fake_fetch(url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: + assert url == target + return { + "url": url, + "title": "T", + "media_type": "text/plain", + "data": "x", + } + + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", fake_fetch) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content="Earlier turn https://stale.com/old-article ignore this", + ), + Message(role="assistant", content="ok"), + Message( + role="user", + content=f"Please fetch {target} for the summary", + ), + ], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, input_tokens=1, web_fetch_egress=_STRICT_EGRESS + ) + ] + ) + assert target in raw @pytest.mark.asyncio async def test_streams_web_fetch_server_tool_result(monkeypatch): - async def fake_fetch(url: str) -> dict[str, str]: + async def fake_fetch(url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: assert url == "https://example.com/article" return { "url": url, @@ -71,7 +413,7 @@ async def test_streams_web_fetch_server_tool_result(monkeypatch): "data": "Article body", } - monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch) + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", fake_fetch) request = MessagesRequest( model="claude-haiku-4-5-20251001", max_tokens=100, @@ -82,15 +424,198 @@ async def test_streams_web_fetch_server_tool_result(monkeypatch): tool_choice={"type": "tool", "name": "web_fetch"}, ) - events = [ - event - async for event in stream_web_server_tool_response(request, input_tokens=42) - ] - payloads = [_event_data(event) for event in events] - - assert payloads[1]["content_block"]["type"] == "server_tool_use" - assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result" - assert payloads[3]["content_block"]["content"]["content"]["title"] == ( + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, input_tokens=42, web_fetch_egress=_STRICT_EGRESS + ) + ] + ) + events = parse_sse_text(raw) + assert_anthropic_stream_contract(events) + starts = [e for e in events if e.event == "content_block_start"] + assert starts[0].data["content_block"]["type"] == "server_tool_use" + tool_use_id = starts[0].data["content_block"]["id"] + assert starts[1].data["content_block"]["type"] == "web_fetch_tool_result" + assert starts[1].data["content_block"]["tool_use_id"] == tool_use_id + assert starts[1].data["content_block"]["content"]["content"]["title"] == ( "Example Article" ) - assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1} + assert any( + e.event == "content_block_delta" + and e.data.get("delta", {}).get("type") == "text_delta" + for e in events + ) + assert "Article body" in text_content(events) + cli_text: list[str] = [] + for ev in events: + cli_text.extend( + str(p.get("text", "")) + for p in parse_cli_event(ev.data) + if p.get("type") == "text_delta" + ) + assert "Article body" in "".join(cli_text) + deltas = [e for e in events if e.event == "message_delta"] + assert deltas[-1].data["usage"]["server_tool_use"] == {"web_fetch_requests": 1} + + +@pytest.mark.asyncio +async def test_streams_web_fetch_error_summary_generic_by_default(monkeypatch): + secret = "sensitive-upstream-token" + + async def boom(_url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: + raise ValueError(secret) + + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", boom) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content="Fetch https://example.com/sensitive-path?x=1 please", + ) + ], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + with patch("api.web_tools.outbound.logger.warning") as log_warn: + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, + input_tokens=1, + web_fetch_egress=_STRICT_EGRESS, + verbose_client_errors=False, + ) + ] + ) + + assert secret not in raw + assert "ValueError" not in raw + assert "Web tool request failed." in raw + err_events = parse_sse_text(raw) + assert_anthropic_stream_contract(err_events) + assert any( + e.event == "content_block_delta" + and e.data.get("delta", {}).get("type") == "text_delta" + for e in err_events + ) + cli_err_text: list[str] = [] + for ev in err_events: + cli_err_text.extend( + str(p.get("text", "")) + for p in parse_cli_event(ev.data) + if p.get("type") == "text_delta" + ) + assert "Web tool request failed." in "".join(cli_err_text) + log_blob = " ".join(str(a) for c in log_warn.call_args_list for a in c.args) + assert secret not in log_blob + assert "example.com" in log_blob + assert "/sensitive-path" not in log_blob + + +@pytest.mark.asyncio +async def test_streams_web_fetch_error_summary_verbose_includes_exception_class( + monkeypatch, +): + async def boom(_url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: + raise OSError(5, "oops") + + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", boom) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="Fetch https://example.com/x")], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, + input_tokens=1, + web_fetch_egress=_STRICT_EGRESS, + verbose_client_errors=True, + ) + ] + ) + assert "OSError" in raw + + +@pytest.mark.asyncio +async def test_read_response_body_capped_truncates_single_oversized_chunk(): + cap = 500 + + async def aiter_bytes(chunk_size=None): + yield b"z" * (cap * 20) + + response = MagicMock() + response.aiter_bytes = aiter_bytes + + out = await _read_response_body_capped(response, cap) + assert len(out) == cap + assert out == b"z" * cap + + +@pytest.mark.asyncio +async def test_drain_response_body_capped_stops_after_first_chunk_when_oversized(): + cap = 300 + chunk_calls = {"n": 0} + + async def aiter_bytes(chunk_size=None): + chunk_calls["n"] += 1 + yield b"y" * (cap * 10) + + response = MagicMock() + response.aiter_bytes = aiter_bytes + + await _drain_response_body_capped(response, cap) + assert chunk_calls["n"] == 1 + + +def test_service_rejects_listed_server_tools_on_openai_chat() -> None: + settings = Settings() + service = ClaudeProxyService( + settings, + provider_getter=lambda _: MagicMock(), + model_router=FixedProviderModelRouter(settings, "deepseek"), + ) + request = MessagesRequest( + model="m", + max_tokens=20, + messages=[Message(role="user", content="q")], + tools=[Tool(name="web_search", type="web_search_20250305")], + ) + with pytest.raises(InvalidRequestError, match="OpenAI Chat upstreams"): + service.create_message(request) + + +def test_listed_server_tools_routed_on_open_router() -> None: + """Native Anthropic transport may receive listed server tool definitions.""" + settings = Settings() + + async def fake_stream(*_a, **_k): + yield 'event: message_start\ndata: {"type":"message_start"}\n\n' + yield 'event: message_stop\ndata: {"type":"message_stop"}\n\n' + + mock_provider = MagicMock() + mock_provider.stream_response = fake_stream + service = ClaudeProxyService( + settings, + provider_getter=lambda _: mock_provider, + model_router=FixedProviderModelRouter(settings, "open_router"), + ) + request = MessagesRequest( + model="m", + max_tokens=20, + messages=[Message(role="user", content="q")], + tools=[Tool(name="web_search", type="web_search_20250305")], + ) + service.create_message(request) + mock_provider.preflight_stream.assert_called() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index a22f9dd..590f51c 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -3,6 +3,7 @@ import asyncio import json import os +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -300,7 +301,7 @@ class TestCLISession: mock_process = AsyncMock() mock_process.stdout.read.side_effect = [b""] # No stdout - mock_process.stderr.read.return_value = b"Fatal error" + mock_process.stderr.read.side_effect = [b"Fatal error", b""] mock_process.wait.return_value = 1 with patch( @@ -319,6 +320,62 @@ class TestCLISession: assert events[1]["code"] == 1 assert events[1]["stderr"] == "Fatal error" + @pytest.mark.asyncio + async def test_start_task_stderr_while_stdout_streams(self): + """Stderr is drained concurrently so stdout streaming is not blocked.""" + from cli.session import CLISession + + session = CLISession("/tmp", "http://localhost:8082/v1") + + mock_process = AsyncMock() + mock_process.stdout.read.side_effect = [ + b'{"type": "message", "content": "Hi"}\n', + b"", + ] + mock_process.stderr.read.side_effect = [b"warning on stderr\n", b""] + mock_process.wait.return_value = 0 + + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_exec: + mock_exec.return_value = mock_process + + events = [e async for e in session.start_task("Hello")] + + assert mock_process.stderr.read.await_count >= 2 + err_events = [e for e in events if e.get("type") == "error"] + assert len(err_events) == 1 + assert "warning on stderr" in err_events[0]["error"]["message"] + assert events[-1]["type"] == "exit" + assert events[-1]["code"] == 0 + + @pytest.mark.asyncio + async def test_drain_stderr_bounded_retains_cap_but_drains_to_eof(self): + """Oversized stderr is fully drained so the pipe cannot deadlock; capture is bounded.""" + from cli.session import _MAX_STDERR_CAPTURE_BYTES, CLISession + + total_len = _MAX_STDERR_CAPTURE_BYTES + 100_000 + remaining: dict[str, int] = {"n": total_len} + + class _FakeStderr: + async def read(self, n: int = 65536) -> bytes: + left = remaining["n"] + if left <= 0: + return b"" + take = min(n, left) + remaining["n"] = left - take + return b"y" * take + + class _FakeProcess: + stderr = _FakeStderr() + + out = await CLISession._drain_stderr_bounded( + cast(asyncio.subprocess.Process, _FakeProcess()) + ) + assert len(out) == _MAX_STDERR_CAPTURE_BYTES + assert out == b"y" * _MAX_STDERR_CAPTURE_BYTES + assert remaining["n"] == 0 + @pytest.mark.asyncio async def test_stop_session(self): """Test stopping the session process.""" diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 583dc75..e62f583 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -3,6 +3,10 @@ import pytest from pydantic import ValidationError +from config.constants import ( + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + HTTP_CONNECT_TIMEOUT_DEFAULT, +) from config.nim import NimSettings @@ -22,6 +26,7 @@ class TestSettings: monkeypatch.delenv("MODEL", raising=False) monkeypatch.delenv("HTTP_READ_TIMEOUT", raising=False) + monkeypatch.delenv("HTTP_CONNECT_TIMEOUT", raising=False) monkeypatch.setitem(Settings.model_config, "env_file", ()) settings = Settings() assert settings.model == "nvidia_nim/z-ai/glm4.7" @@ -31,6 +36,12 @@ class TestSettings: assert isinstance(settings.fast_prefix_detection, bool) assert isinstance(settings.enable_model_thinking, bool) assert settings.http_read_timeout == 120.0 + assert settings.http_connect_timeout == HTTP_CONNECT_TIMEOUT_DEFAULT + assert settings.enable_web_server_tools is False + assert settings.log_raw_api_payloads is False + assert settings.log_raw_sse_events is False + assert settings.debug_platform_edits is False + assert settings.debug_subagent_stack is False def test_get_settings_cached(self): """Test get_settings returns cached instance.""" @@ -57,10 +68,10 @@ class TestSettings: assert len(settings.model) > 0 def test_base_url_constant(self): - """Test NVIDIA_NIM_BASE_URL is a constant.""" - from providers.nvidia_nim import NVIDIA_NIM_BASE_URL + """Test NVIDIA_NIM_DEFAULT_BASE is a constant.""" + from providers.nvidia_nim import NVIDIA_NIM_DEFAULT_BASE - assert NVIDIA_NIM_BASE_URL == "https://integrate.api.nvidia.com/v1" + assert NVIDIA_NIM_DEFAULT_BASE == "https://integrate.api.nvidia.com/v1" def test_lm_studio_base_url_from_env(self, monkeypatch): """LM_STUDIO_BASE_URL env var is loaded into settings.""" @@ -127,6 +138,18 @@ class TestSettings: settings = Settings() assert settings.http_connect_timeout == 5.0 + def test_http_connect_timeout_default_matches_shared_constant( + self, monkeypatch + ) -> None: + """Default must match config.constants (and README / .env.example).""" + from config.settings import Settings + + monkeypatch.delenv("HTTP_CONNECT_TIMEOUT", raising=False) + monkeypatch.setitem(Settings.model_config, "env_file", ()) + settings = Settings() + assert settings.http_connect_timeout == HTTP_CONNECT_TIMEOUT_DEFAULT + assert HTTP_CONNECT_TIMEOUT_DEFAULT == 10.0 + def test_enable_model_thinking_from_env(self, monkeypatch): """ENABLE_MODEL_THINKING env var is loaded into settings.""" from config.settings import Settings @@ -319,6 +342,9 @@ class TestNimSettingsInvalidBounds: class TestNimSettingsValidators: """Test custom field validators in NimSettings.""" + def test_default_max_tokens_matches_shared_constant(self): + assert NimSettings().max_tokens == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS + @pytest.mark.parametrize( "seed_val,expected", [("", None), (None, None), ("42", 42), (42, 42)], diff --git a/tests/config/test_logging_config.py b/tests/config/test_logging_config.py index 11a818c..9f54df0 100644 --- a/tests/config/test_logging_config.py +++ b/tests/config/test_logging_config.py @@ -4,6 +4,8 @@ import json import logging from pathlib import Path +from loguru import logger + from config.logging_config import configure_logging @@ -57,3 +59,38 @@ def test_configure_logging_skips_when_already_configured(tmp_path): assert "Still goes to first file" in (tmp_path / "test.log").read_text( encoding="utf-8" ) + + +def test_telegram_bot_token_redacted_in_message_field(tmp_path) -> None: + log_file = str(tmp_path / "redact.log") + configure_logging(log_file, force=True, verbose_third_party=False) + token = "123456:ABCDEF-ghij-klm" + logger.info("Calling {}", f"https://api.telegram.org/bot{token}/getMe") + logger.complete() + text = Path(log_file).read_text(encoding="utf-8") + assert token not in text + assert "bot/" in text or "redacted" in text + + +def test_bearer_substring_redacted_in_log_file(tmp_path) -> None: + log_file = str(tmp_path / "bearer.log") + configure_logging(log_file, force=True, verbose_third_party=False) + secret = "ya29.secret-token-abc" + logger.info("Request headers: Authorization: Bearer {}", secret) + logger.complete() + text = Path(log_file).read_text(encoding="utf-8") + assert secret not in text + assert "Bearer" in text + + +def test_httpx_logger_quieted_when_not_verbose_third_party(tmp_path) -> None: + log_file = str(tmp_path / "quiet.log") + configure_logging(log_file, force=True, verbose_third_party=False) + assert logging.getLogger("httpx").level >= logging.WARNING + assert logging.getLogger("httpcore").level >= logging.WARNING + + +def test_httpx_resets_to_notset_when_verbose_third_party(tmp_path) -> None: + log_file = str(tmp_path / "verbose.log") + configure_logging(log_file, force=True, verbose_third_party=True) + assert logging.getLogger("httpx").level == logging.NOTSET diff --git a/tests/contracts/test_architecture_contracts.py b/tests/contracts/test_architecture_contracts.py index 344122e..25be850 100644 --- a/tests/contracts/test_architecture_contracts.py +++ b/tests/contracts/test_architecture_contracts.py @@ -13,6 +13,25 @@ def test_architecture_plan_exists() -> None: text = plan.read_text(encoding="utf-8") assert "Intended Dependency Direction" in text assert "Smoke Coverage Policy" in text + assert "providers.nvidia_nim.voice" in text + assert "no dedicated smoke SSE shim" in text + + +def test_smoke_lib_has_no_sse_shim_module() -> None: + repo_root = Path(__file__).resolve().parents[2] + assert not (repo_root / "smoke" / "lib" / "sse.py").exists() + + +def test_api_package_exports_match_plan() -> None: + import api + + assert set(api.__all__) == { + "MessagesRequest", + "MessagesResponse", + "TokenCountRequest", + "TokenCountResponse", + "create_app", + } def test_root_env_example_is_the_single_template_source() -> None: diff --git a/tests/contracts/test_import_boundaries.py b/tests/contracts/test_import_boundaries.py index 918d7d1..3bde11d 100644 --- a/tests/contracts/test_import_boundaries.py +++ b/tests/contracts/test_import_boundaries.py @@ -54,6 +54,15 @@ def test_core_does_not_import_product_packages() -> None: assert offenders == [] +def test_provider_catalog_is_single_source_for_supported_ids() -> None: + from config.provider_catalog import PROVIDER_CATALOG, SUPPORTED_PROVIDER_IDS + from providers.registry import PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES + + assert tuple(PROVIDER_CATALOG.keys()) == SUPPORTED_PROVIDER_IDS + assert PROVIDER_DESCRIPTORS is PROVIDER_CATALOG + assert set(SUPPORTED_PROVIDER_IDS) == set(PROVIDER_FACTORIES) + + def test_config_does_not_import_non_config_packages() -> None: """Settings and env handling must not depend on transport or protocol layers.""" repo_root = Path(__file__).resolve().parents[2] @@ -71,14 +80,34 @@ def test_config_does_not_import_non_config_packages() -> None: assert offenders == [] -def test_messaging_does_not_import_api_or_cli_or_providers() -> None: - """Messaging is wired by ``api.runtime``; must not import server or provider adapters.""" +_MESSAGING_ALLOWED_PROVIDER_MODULES = frozenset({"providers.nvidia_nim.voice"}) + + +def test_messaging_does_not_import_disallowed_modules() -> None: + """Messaging is wired by ``api.runtime``; narrow provider imports only for NIM voice ASR.""" repo_root = Path(__file__).resolve().parents[2] - offenders = _imports_matching( - [repo_root / "messaging"], - forbidden_prefixes=("api.", "cli.", "providers.", "smoke."), - ) - assert offenders == [] + offenders: list[str] = [] + for path in (repo_root / "messaging").rglob("*.py"): + for imported in _imports_from(path, repo_root): + if imported is None: + continue + if ( + imported == "api" + or imported.startswith("api.") + or imported == "cli" + or imported.startswith("cli.") + or imported == "smoke" + or imported.startswith("smoke.") + ): + rel = path.relative_to(repo_root) + offenders.append(f"{rel}: {imported}") + elif imported.startswith("providers."): + if imported in _MESSAGING_ALLOWED_PROVIDER_MODULES: + continue + rel = path.relative_to(repo_root) + offenders.append(f"{rel}: {imported}") + + assert sorted(offenders) == [] def test_api_may_only_import_narrow_provider_facade() -> None: diff --git a/tests/contracts/test_smoke_sse_reexport.py b/tests/contracts/test_smoke_sse_reexport.py deleted file mode 100644 index 4190312..0000000 --- a/tests/contracts/test_smoke_sse_reexport.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`.""" - -from __future__ import annotations - -import core.anthropic.stream_contracts as core_sc -import smoke.lib.sse as smoke_sse - - -def test_smoke_lib_sse_reexports_core_stream_contracts() -> None: - for name in smoke_sse.__all__: - assert getattr(smoke_sse, name) is getattr(core_sc, name) diff --git a/tests/contracts/test_stream_contracts.py b/tests/contracts/test_stream_contracts.py index 0125459..6898cc3 100644 --- a/tests/contracts/test_stream_contracts.py +++ b/tests/contracts/test_stream_contracts.py @@ -8,16 +8,14 @@ from __future__ import annotations from collections.abc import Iterable from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser +from core.anthropic.sse import format_sse_event from core.anthropic.stream_contracts import ( assert_anthropic_stream_contract, event_names, - has_tool_use, parse_sse_text, text_content, thinking_content, ) -from messaging.event_parser import parse_cli_event -from messaging.transcript import RenderCtx, TranscriptBuffer def test_interleaved_thinking_text_blocks_are_valid() -> None: @@ -59,44 +57,49 @@ def test_mixed_reasoning_content_and_think_tags_keep_order() -> None: assert text_content(events) == " visible done" -def test_thinking_tool_text_and_transcript_order_contract() -> None: - builder = SSEBuilder("msg_contract", "contract-model") - chunks = [builder.message_start()] - chunks.extend(builder.ensure_thinking_block()) - chunks.append(builder.emit_thinking_delta("inspect first")) - chunks.extend(builder.close_content_blocks()) - tool_block_index = builder.blocks.allocate_index() - chunks.append( - builder.content_block_start( - tool_block_index, "tool_use", id="toolu_1", name="Read" - ) - ) - chunks.append( - builder.content_block_delta( - tool_block_index, "input_json_delta", '{"file":"README.md"}' - ) - ) - chunks.append(builder.content_block_stop(tool_block_index)) - chunks.extend(builder.ensure_text_block()) - chunks.append(builder.emit_text_delta("done")) - chunks.extend(builder.close_all_blocks()) - chunks.append(builder.message_delta("end_turn", 20)) - chunks.append(builder.message_stop()) - +def test_redacted_thinking_block_start_stop_is_valid() -> None: + """Native redacted_thinking uses start/stop only (no deltas).""" + chunks = [ + format_sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": "msg_r", + "type": "message", + "role": "assistant", + "content": [], + "model": "m", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + }, + ), + format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "redacted_thinking", "data": "opaque"}, + }, + ), + format_sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": 0}, + ), + format_sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": None}, + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, + ), + format_sse_event("message_stop", {"type": "message_stop"}), + ] events = parse_sse_text("".join(chunks)) assert_anthropic_stream_contract(events) - assert has_tool_use(events) - - transcript = TranscriptBuffer() - for event in events: - for parsed in parse_cli_event(event.data): - transcript.apply(parsed) - rendered = transcript.render(_render_ctx(), limit_chars=3900, status=None) - assert ( - rendered.find("inspect first") - < rendered.find("Tool call:") - < rendered.find("done") - ) def test_enable_thinking_false_suppresses_reasoning_only() -> None: @@ -186,13 +189,3 @@ def _emit_parser_parts( def _parse_builder_events(chunks: Iterable[str]): return parse_sse_text("".join(chunks)) - - -def _render_ctx() -> RenderCtx: - return RenderCtx( - bold=lambda text: f"*{text}*", - code_inline=lambda text: f"`{text}`", - escape_code=lambda text: text, - escape_text=lambda text: text, - render_markdown=lambda text: text, - ) diff --git a/tests/core/anthropic/test_native_sse_block_policy.py b/tests/core/anthropic/test_native_sse_block_policy.py new file mode 100644 index 0000000..f1b6658 --- /dev/null +++ b/tests/core/anthropic/test_native_sse_block_policy.py @@ -0,0 +1,133 @@ +"""Unit tests for shared native Anthropic SSE thinking policy / block remapping.""" + +from __future__ import annotations + +import json + +from core.anthropic.native_sse_block_policy import ( + NativeSseBlockPolicyState, + format_native_sse_event, + transform_native_sse_block_event, +) + + +def test_thinking_start_dropped_when_disabled() -> None: + st = NativeSseBlockPolicyState() + payload = { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "thinking", "thinking": ""}, + } + ev = format_native_sse_event( + "content_block_start", + json.dumps(payload), + ) + assert transform_native_sse_block_event(ev, st, thinking_enabled=False) is None + + +def test_thinking_delta_dropped_when_disabled() -> None: + st = NativeSseBlockPolicyState() + # No prior start in stream (OpenRouter-style: returns None when thinking off) + payload = { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "thinking_delta", "thinking": "secret"}, + } + ev = format_native_sse_event("content_block_delta", json.dumps(payload)) + assert transform_native_sse_block_event(ev, st, thinking_enabled=False) is None + + +def test_text_block_passthrough_when_thinking_disabled() -> None: + st = NativeSseBlockPolicyState() + payload = { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + } + ev = format_native_sse_event("content_block_start", json.dumps(payload)) + out = transform_native_sse_block_event(ev, st, thinking_enabled=False) + assert out is not None + assert '"index": 0' in (out or "") + + +def test_interleaved_thinking_signature_delta_remaps_to_reopened_block_index() -> None: + """After text interrupts thinking, signature_delta must follow the reopened segment index.""" + st = NativeSseBlockPolicyState() + + def run(ev: str) -> str | None: + return transform_native_sse_block_event(ev, st, thinking_enabled=True) + + out1 = run( + format_native_sse_event( + "content_block_start", + json.dumps( + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "thinking", "thinking": ""}, + } + ), + ) + ) + assert out1 is not None and '"index": 0' in out1 + + out2 = run( + format_native_sse_event( + "content_block_start", + json.dumps( + { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "text", "text": ""}, + } + ), + ) + ) + assert out2 is not None + + out3 = run( + format_native_sse_event( + "content_block_delta", + json.dumps( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "thinking_delta", "thinking": "plan"}, + } + ), + ) + ) + assert out3 is not None + assert "content_block_start" in out3 + + out4 = run( + format_native_sse_event( + "content_block_delta", + json.dumps( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "signature_delta", "signature": "sig"}, + } + ), + ) + ) + assert out4 is not None + assert '"index": 2' in out4 + assert "signature_delta" in out4 + + +def test_startless_text_delta_synthesizes_start_when_thinking_disabled() -> None: + """Startless text deltas must not be dropped when thinking is disabled (OpenRouter).""" + st = NativeSseBlockPolicyState() + payload = { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Hello"}, + } + ev = format_native_sse_event("content_block_delta", json.dumps(payload)) + out = transform_native_sse_block_event(ev, st, thinking_enabled=False) + assert out is not None + assert "content_block_start" in (out or "") + assert "Hello" in (out or "") + assert "text_delta" in (out or "") diff --git a/tests/core/test_strict_sliding_window.py b/tests/core/test_strict_sliding_window.py new file mode 100644 index 0000000..1607521 --- /dev/null +++ b/tests/core/test_strict_sliding_window.py @@ -0,0 +1,38 @@ +"""Direct tests for :class:`core.rate_limit.StrictSlidingWindowLimiter`.""" + +import time + +import pytest + +from core.rate_limit import StrictSlidingWindowLimiter + + +@pytest.mark.asyncio +async def test_strict_window_allows_burst_then_blocks(): + lim = StrictSlidingWindowLimiter(rate_limit=2, rate_window=0.2) + await lim.acquire() + await lim.acquire() + start = time.monotonic() + await lim.acquire() + assert time.monotonic() - start >= 0.15 + + +@pytest.mark.asyncio +async def test_strict_window_async_context_manager(): + lim = StrictSlidingWindowLimiter(rate_limit=1, rate_window=0.15) + + async def run(): + async with lim: + pass + + await run() + start = time.monotonic() + await run() + assert time.monotonic() - start >= 0.1 + + +def test_strict_window_rejects_invalid_config(): + with pytest.raises(ValueError): + StrictSlidingWindowLimiter(rate_limit=0, rate_window=1.0) + with pytest.raises(ValueError): + StrictSlidingWindowLimiter(rate_limit=1, rate_window=0.0) diff --git a/tests/messaging/test_handler.py b/tests/messaging/test_handler.py index 7695dd0..013b096 100644 --- a/tests/messaging/test_handler.py +++ b/tests/messaging/test_handler.py @@ -13,6 +13,49 @@ def handler(mock_platform, mock_cli_manager, mock_session_store): return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store) +@pytest.mark.asyncio +async def test_handle_message_default_logs_text_len_not_content( + mock_platform, mock_cli_manager, mock_session_store, incoming_message_factory +): + secret = "user-secret-content-never-log-default" + handler = ClaudeMessageHandler( + mock_platform, + mock_cli_manager, + mock_session_store, + log_raw_messaging_content=False, + ) + incoming = incoming_message_factory(text=secret) + with ( + patch.object(handler, "_handle_message_impl", new_callable=AsyncMock), + patch("messaging.handler.logger.info") as log_info, + ): + await handler.handle_message(incoming) + blob = " ".join(str(c) for c in log_info.call_args_list) + assert secret not in blob + assert "text_len=" in blob + + +@pytest.mark.asyncio +async def test_handle_message_raw_content_logging_includes_preview( + mock_platform, mock_cli_manager, mock_session_store, incoming_message_factory +): + secret = "visible-preview-xyz" + handler = ClaudeMessageHandler( + mock_platform, + mock_cli_manager, + mock_session_store, + log_raw_messaging_content=True, + ) + incoming = incoming_message_factory(text=secret) + with ( + patch.object(handler, "_handle_message_impl", new_callable=AsyncMock), + patch("messaging.handler.logger.info") as log_info, + ): + await handler.handle_message(incoming) + blob = " ".join(str(c) for c in log_info.call_args_list) + assert secret in blob + + def test_get_initial_status_new_conversation(handler): """New conversation always returns launching message.""" result = handler._get_initial_status(None, None) diff --git a/tests/messaging/test_handler_format.py b/tests/messaging/test_handler_format.py index 2cf60ee..cede9fb 100644 --- a/tests/messaging/test_handler_format.py +++ b/tests/messaging/test_handler_format.py @@ -17,7 +17,6 @@ def handler(): platform = MagicMock() cli = MagicMock() store = MagicMock() - # Kept for backwards test structure; transcript rendering is now separate. return (platform, cli, store) diff --git a/tests/messaging/test_handler_markdown_and_status_edges.py b/tests/messaging/test_handler_markdown_and_status_edges.py index 185e8b1..662c6ff 100644 --- a/tests/messaging/test_handler_markdown_and_status_edges.py +++ b/tests/messaging/test_handler_markdown_and_status_edges.py @@ -4,6 +4,7 @@ import pytest from messaging.handler import ClaudeMessageHandler from messaging.models import IncomingMessage +from messaging.node_event_pipeline import process_parsed_cli_event from messaging.rendering.telegram_markdown import render_markdown_to_mdv2 from messaging.trees.data import MessageNode, MessageState @@ -333,7 +334,7 @@ async def test_handle_message_incoming_text_none_safe(): @pytest.mark.asyncio async def test_process_parsed_event_malformed_content_continues(): - """Malformed/unknown parsed event does not crash _process_parsed_event.""" + """Malformed/unknown parsed event does not crash process_parsed_cli_event.""" platform = MagicMock() platform.queue_edit_message = AsyncMock() @@ -344,7 +345,7 @@ async def test_process_parsed_event_malformed_content_continues(): transcript = MagicMock() update_ui = AsyncMock() - last_status, had = await handler._process_parsed_event( + last_status, had = await process_parsed_cli_event( parsed={"type": "unknown_type"}, transcript=transcript, update_ui=update_ui, @@ -353,6 +354,9 @@ async def test_process_parsed_event_malformed_content_continues(): tree=None, node_id="n1", captured_session_id=None, + session_store=session_store, + format_status=handler.format_status, + propagate_error_to_children=AsyncMock(), ) assert last_status is None assert had is False diff --git a/tests/messaging/test_limiter.py b/tests/messaging/test_limiter.py index 565e80f..0882c03 100644 --- a/tests/messaging/test_limiter.py +++ b/tests/messaging/test_limiter.py @@ -1,16 +1,10 @@ import asyncio -import os +import contextlib import time import pytest import pytest_asyncio -# Set environment variables relative to test execution -os.environ["MESSAGING_RATE_LIMIT"] = "1" -os.environ["MESSAGING_RATE_WINDOW"] = "0.5" - -import contextlib - from messaging.limiter import MessagingRateLimiter @@ -19,11 +13,8 @@ class TestMessagingRateLimiter: @pytest_asyncio.fixture(autouse=True) async def reset_limiter(self): - """Reset singleton and environment before each test.""" - # Ensure the singleton worker is stopped between tests to avoid dangling tasks. + """Reset singleton before each test.""" await MessagingRateLimiter.shutdown_instance(timeout=0.1) - os.environ["MESSAGING_RATE_LIMIT"] = "1" - os.environ["MESSAGING_RATE_WINDOW"] = "0.5" yield @@ -32,9 +23,16 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_singleton_pattern(self): """Test that get_instance returns the same object.""" - limiter1 = await MessagingRateLimiter.get_instance() - limiter2 = await MessagingRateLimiter.get_instance() + limiter1 = await MessagingRateLimiter.get_instance( + rate_limit=1, rate_window=0.5 + ) + limiter2 = await MessagingRateLimiter.get_instance( + rate_limit=99, rate_window=99.0 + ) assert limiter1 is limiter2 + # First-construction wins for rate parameters + assert limiter1.limiter._rate_limit == 1 + assert limiter1.limiter._rate_window == 0.5 @pytest.mark.asyncio async def test_compaction(self): @@ -42,13 +40,8 @@ class TestMessagingRateLimiter: Verify multiple rapid requests with same dedup_key are compacted. Logic ported from verify_limiter.py """ - # Set slow rate for testing compaction - os.environ["MESSAGING_RATE_LIMIT"] = "1" - os.environ["MESSAGING_RATE_WINDOW"] = "1.0" - - # Must reset instance to pick up new env vars - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) call_counts = {} @@ -78,10 +71,8 @@ class TestMessagingRateLimiter: Verify that even when compacted, all futures resolve to the result of the LAST execution. Logic ported from verify_limiter_v2.py """ - os.environ["MESSAGING_RATE_LIMIT"] = "1" - os.environ["MESSAGING_RATE_WINDOW"] = "0.5" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=0.5) call_counts = {} msg_id = "test_msg_hang" @@ -116,8 +107,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_flood_wait_handling(self): """Test that FloodWait exceptions pause the worker.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) # Mock exception with .seconds attribute class FloodWait(Exception): @@ -157,8 +148,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_flood_wait_retry_after_parsing(self): """Error message with 'retry after N' parses the wait seconds.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def mock_flood(): raise Exception("Flood wait: retry after 2 seconds") @@ -172,8 +163,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_non_flood_exception_no_pause(self): """Non-flood exception doesn't trigger pause.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def mock_error(): raise ValueError("some regular error") @@ -187,8 +178,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_flood_with_seconds_attribute(self): """Exception with .seconds attribute uses that value for pause.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) class FloodWaitCustom(Exception): def __init__(self): @@ -209,10 +200,8 @@ class TestMessagingRateLimiter: Proactive limiter should enforce a strict sliding window: for any i, t[i+rate_limit] - t[i] >= rate_window (within tolerance). """ - os.environ["MESSAGING_RATE_LIMIT"] = "2" - os.environ["MESSAGING_RATE_WINDOW"] = "0.5" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=2, rate_window=0.5) async def acquire(i: int) -> float: async def _do() -> float: @@ -235,8 +224,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_compaction_last_task_fails_all_futures_get_exception(self): """When compacted task's last func fails, all futures get the exception.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def ok_task(): return "ok" @@ -255,8 +244,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_fire_and_forget_failure_logged(self, caplog): """fire_and_forget with failing task logs error and does not re-raise.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def fail_task(): raise ValueError("fire_and_forget failed") @@ -264,4 +253,6 @@ class TestMessagingRateLimiter: limiter.fire_and_forget(fail_task, dedup_key="fire_fail") await asyncio.sleep(1.5) - assert any("fire_and_forget failed" in str(r) for r in caplog.records) + joined = " ".join(str(r.message) for r in caplog.records) + assert "ValueError" in joined + assert "fire_and_forget failed" not in joined diff --git a/tests/messaging/test_messaging_factory.py b/tests/messaging/test_messaging_factory.py index 5d3c0a3..a2ed0cd 100644 --- a/tests/messaging/test_messaging_factory.py +++ b/tests/messaging/test_messaging_factory.py @@ -41,6 +41,10 @@ class TestCreateMessagingPlatform: whisper_device="cuda", hf_token="", nvidia_nim_api_key="", + messaging_rate_limit=1, + messaging_rate_window=1.0, + log_raw_messaging_content=False, + log_api_error_tracebacks=False, ) def test_telegram_without_token(self): @@ -85,6 +89,10 @@ class TestCreateMessagingPlatform: whisper_device="nvidia_nim", hf_token="", nvidia_nim_api_key="", + messaging_rate_limit=1, + messaging_rate_window=1.0, + log_raw_messaging_content=False, + log_api_error_tracebacks=False, ) def test_discord_without_token(self): diff --git a/tests/messaging/test_session_store_edge_cases.py b/tests/messaging/test_session_store_edge_cases.py index 31683d4..8656c36 100644 --- a/tests/messaging/test_session_store_edge_cases.py +++ b/tests/messaging/test_session_store_edge_cases.py @@ -81,15 +81,40 @@ class TestSessionStoreSaveEdgeCases: """Tests for save failure handling.""" def test_save_io_error_handled(self, tmp_store): - """Write failure in _write_data() raises (callers handle the error).""" + """Write failure during atomic replace is surfaced to callers.""" tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) with ( - patch("builtins.open", side_effect=OSError("disk full")), + patch("messaging.session.os.replace", side_effect=OSError("disk full")), pytest.raises(OSError), ): tmp_store._write_data(tmp_store._snapshot()) +class TestSessionStoreAtomicWrites: + """Atomic persistence: failed replace must not truncate the prior file.""" + + def test_failed_replace_keeps_prior_bytes_and_marks_dirty(self, tmp_path): + path = str(tmp_path / "sessions.json") + store = SessionStore(storage_path=path) + store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) + store.flush_pending_save() + with open(path, encoding="utf-8") as f: + disk_after_first = f.read() + + store.save_tree("r2", {"root_id": "r2", "nodes": {"r2": {}}}) + + with patch( + "messaging.session.os.replace", side_effect=OSError("replace failed") + ): + store.flush_pending_save() + + with open(path, encoding="utf-8") as f: + disk_after_failed = f.read() + assert disk_after_failed == disk_after_first + assert store._dirty is True + assert store.get_tree("r2") is not None + + class TestSessionStoreClearAll: def test_clear_all_wipes_state_and_persists(self, tmp_path): path = str(tmp_path / "sessions.json") diff --git a/tests/messaging/test_stream_transcript_contract.py b/tests/messaging/test_stream_transcript_contract.py new file mode 100644 index 0000000..2ef2b60 --- /dev/null +++ b/tests/messaging/test_stream_transcript_contract.py @@ -0,0 +1,62 @@ +"""Messaging-specific assertions built on neutral Anthropic stream contracts.""" + +from __future__ import annotations + +from core.anthropic import SSEBuilder +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + has_tool_use, + parse_sse_text, +) +from messaging.event_parser import parse_cli_event +from messaging.transcript import RenderCtx, TranscriptBuffer + + +def test_thinking_tool_text_and_transcript_order_contract() -> None: + builder = SSEBuilder("msg_contract", "contract-model") + chunks = [builder.message_start()] + chunks.extend(builder.ensure_thinking_block()) + chunks.append(builder.emit_thinking_delta("inspect first")) + chunks.extend(builder.close_content_blocks()) + tool_block_index = builder.blocks.allocate_index() + chunks.append( + builder.content_block_start( + tool_block_index, "tool_use", id="toolu_1", name="Read" + ) + ) + chunks.append( + builder.content_block_delta( + tool_block_index, "input_json_delta", '{"file":"README.md"}' + ) + ) + chunks.append(builder.content_block_stop(tool_block_index)) + chunks.extend(builder.ensure_text_block()) + chunks.append(builder.emit_text_delta("done")) + chunks.extend(builder.close_all_blocks()) + chunks.append(builder.message_delta("end_turn", 20)) + chunks.append(builder.message_stop()) + + events = parse_sse_text("".join(chunks)) + assert_anthropic_stream_contract(events) + assert has_tool_use(events) + + transcript = TranscriptBuffer() + for event in events: + for parsed in parse_cli_event(event.data): + transcript.apply(parsed) + rendered = transcript.render(_render_ctx(), limit_chars=3900, status=None) + assert ( + rendered.find("inspect first") + < rendered.find("Tool call:") + < rendered.find("done") + ) + + +def _render_ctx() -> RenderCtx: + return RenderCtx( + bold=lambda s: f"*{s}*", + code_inline=lambda s: f"`{s}`", + escape_code=lambda s: s, + escape_text=lambda s: s, + render_markdown=lambda s: s, + ) diff --git a/tests/messaging/test_transcription.py b/tests/messaging/test_transcription.py index 6aaaca4..023f723 100644 --- a/tests/messaging/test_transcription.py +++ b/tests/messaging/test_transcription.py @@ -25,7 +25,7 @@ def test_transcribe_file_too_large_raises(): path = Path(f.name) try: with pytest.raises(ValueError, match="too large"): - transcribe_audio(path, "audio/ogg", whisper_device="auto") + transcribe_audio(path, "audio/ogg", whisper_device="cpu") finally: path.unlink(missing_ok=True) @@ -87,15 +87,9 @@ def test_transcribe_invalid_device_raises(): f.write(b"fake ogg") path = Path(f.name) try: - # Mock settings to return invalid device "auto" - mock_settings = MagicMock() - mock_settings.whisper_device = "auto" - mock_settings.whisper_model = "base" - # Patch _load_audio to avoid ImportError from missing librosa # Device validation happens in _get_pipeline before torch import with ( - patch("messaging.transcription.get_settings", return_value=mock_settings), patch("messaging.transcription._load_audio"), pytest.raises(ValueError, match="whisper_device must be 'cpu' or 'cuda'"), ): @@ -104,6 +98,24 @@ def test_transcribe_invalid_device_raises(): path.unlink(missing_ok=True) +def test_transcribe_nim_requires_api_key(): + """NIM path rejects empty API key without reading global settings.""" + with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f: + f.write(b"fake ogg") + path = Path(f.name) + try: + with pytest.raises(ValueError, match="non-empty"): + transcribe_audio( + path, + "audio/ogg", + whisper_device="nvidia_nim", + whisper_model="openai/whisper-large-v3", + nvidia_nim_api_key="", + ) + finally: + path.unlink(missing_ok=True) + + def test_transcribe_local_import_error_raises(): """Local backend when voice_local extra not installed raises ImportError.""" with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f: @@ -120,6 +132,6 @@ def test_transcribe_local_import_error_raises(): ), pytest.raises(ImportError, match="voice_local extra"), ): - transcribe_audio(path, "audio/ogg", whisper_device="auto") + transcribe_audio(path, "audio/ogg", whisper_device="cpu") finally: path.unlink(missing_ok=True) diff --git a/tests/messaging/test_transcription_nim.py b/tests/messaging/test_transcription_nim.py new file mode 100644 index 0000000..b15f89b --- /dev/null +++ b/tests/messaging/test_transcription_nim.py @@ -0,0 +1,39 @@ +"""Tests for NVIDIA NIM voice transcription wiring.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from messaging.transcription import transcribe_audio + + +def test_transcribe_audio_nvidia_nim_forwards_api_key(tmp_path: Path) -> None: + wav = tmp_path / "stub.wav" + wav.write_bytes(b"\x00" * 128) + with patch("messaging.transcription.transcribe_nvidia_nim_audio") as nim_fn: + nim_fn.return_value = "ok" + out = transcribe_audio( + wav, + "audio/wav", + whisper_model="openai/whisper-large-v3", + whisper_device="nvidia_nim", + nvidia_nim_api_key="test-nim-key", + ) + nim_fn.assert_called_once_with( + wav, "openai/whisper-large-v3", api_key="test-nim-key" + ) + assert out == "ok" + + +def test_nim_asr_model_map_entries_are_real_function_ids() -> None: + from providers.nvidia_nim.voice import _NIM_ASR_MODEL_MAP + + for function_id, language_code in _NIM_ASR_MODEL_MAP.values(): + assert function_id + assert function_id.strip().lower() != "none" + # Hosted NIM function-id is a lowercase UUID string. + parts = function_id.split("-") + assert len(parts) == 5 + assert all(p for p in parts) + assert language_code is not None diff --git a/tests/messaging/test_tree_processor.py b/tests/messaging/test_tree_processor.py index e6a3a03..664748d 100644 --- a/tests/messaging/test_tree_processor.py +++ b/tests/messaging/test_tree_processor.py @@ -6,7 +6,7 @@ import pytest from messaging.models import IncomingMessage from messaging.trees.data import MessageNode, MessageState, MessageTree -from messaging.trees.processor import TreeQueueProcessor +from messaging.trees.queue_manager import TreeQueueProcessor @pytest.fixture diff --git a/tests/messaging/test_tree_repository.py b/tests/messaging/test_tree_repository.py index 2860ae6..e5bf4ef 100644 --- a/tests/messaging/test_tree_repository.py +++ b/tests/messaging/test_tree_repository.py @@ -4,7 +4,7 @@ import pytest from messaging.models import IncomingMessage from messaging.trees.data import MessageNode, MessageState, MessageTree -from messaging.trees.repository import TreeRepository +from messaging.trees.queue_manager import TreeRepository @pytest.fixture diff --git a/tests/provider_request_mocks.py b/tests/provider_request_mocks.py new file mode 100644 index 0000000..ef2a40d --- /dev/null +++ b/tests/provider_request_mocks.py @@ -0,0 +1,25 @@ +"""Shared MagicMock request objects for OpenAI-compatible provider tests.""" + +from unittest.mock import MagicMock + + +def make_openai_compat_stream_request( + *, model: str = "test-model", stream: bool = True +) -> MagicMock: + """Minimal request stub matching :meth:`OpenAIChatTransport._build_request_body` needs.""" + req = MagicMock() + req.model = model + req.stream = stream + req.messages = [] + req.system = None + req.tools = None + req.tool_choice = None + req.metadata = None + req.max_tokens = 4096 + req.temperature = None + req.top_p = None + req.top_k = None + req.stop_sequences = None + req.extra_body = None + req.thinking = None + return req diff --git a/tests/providers/test_anthropic_messages.py b/tests/providers/test_anthropic_messages.py index 2d9498d..1cb57cb 100644 --- a/tests/providers/test_anthropic_messages.py +++ b/tests/providers/test_anthropic_messages.py @@ -6,8 +6,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS +from core.anthropic.sse import format_sse_event +from core.anthropic.stream_contracts import event_index, parse_sse_text from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig +from tests.stream_contract import assert_canonical_stream_error_envelope class NativeProvider(AnthropicMessagesTransport): @@ -32,8 +36,6 @@ class MockRequest: "model": self.model, "messages": [{"role": "user", "content": "Hello"}], "extra_body": {"ignored": True}, - "original_model": "claude", - "resolved_provider_model": "native/test-model", "thinking": {"enabled": thinking_enabled}, } @@ -42,16 +44,30 @@ class MockRequest: class FakeResponse: - def __init__(self, *, status_code=200, lines=None, text=""): + def __init__( + self, + *, + status_code=200, + lines=None, + text="", + raise_after_line_index: int | None = None, + ): self.status_code = status_code self._lines = lines or [] self._text = text + self._raise_after_line_index = raise_after_line_index self.is_closed = False self.request = httpx.Request("POST", "https://example.test/v1/messages") + self.headers = httpx.Headers() async def aiter_lines(self): - for line in self._lines: + for i, line in enumerate(self._lines): yield line + if ( + self._raise_after_line_index is not None + and i >= self._raise_after_line_index + ): + raise RuntimeError("mid-stream failure") async def aread(self): return self._text.encode() @@ -67,6 +83,11 @@ class FakeResponse: async def aclose(self): self.is_closed = True + async def aiter_bytes(self, chunk_size: int = 65_536): + data = self._text.encode("utf-8") + for offset in range(0, len(data), chunk_size): + yield data[offset : offset + chunk_size] + @pytest.fixture def provider_config(): @@ -122,10 +143,8 @@ def test_default_request_body_strips_internal_fields(provider_config): assert body["model"] == "test-model" assert body["thinking"] == {"type": "enabled"} - assert body["max_tokens"] == 81920 + assert body["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS assert "extra_body" not in body - assert "original_model" not in body - assert "resolved_provider_model" not in body def test_default_request_body_preserves_thinking_budget(provider_config): @@ -210,7 +229,66 @@ async def test_stream_maps_non_200_to_error_event_and_closes_response( ] assert response.is_closed - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "REQ_123" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + blob = "".join(events) + assert "REQ_123" in blob + + +@pytest.mark.asyncio +async def test_midstream_error_closes_open_block_and_uses_fresh_content_index( + provider_config, +): + """After upstream message_start + content_block_start, synthetic errors must not reuse index 0.""" + provider = NativeProvider(provider_config) + req = MockRequest() + mid = "msg_midstream_err" + msg_start = format_sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": mid, + "type": "message", + "role": "assistant", + "content": [], + "model": "test-model", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + }, + ) + block_start = format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + ) + lines: list[str] = [] + for blob in (msg_start, block_start): + lines.extend(blob.splitlines()) + response = FakeResponse(lines=lines, raise_after_line_index=len(lines) - 1) + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in provider.stream_response(req)] + + assert_canonical_stream_error_envelope( + events, user_message_substr="mid-stream failure" + ) + parsed = parse_sse_text("".join(events)) + starts = [e for e in parsed if e.event == "content_block_start"] + assert event_index(starts[0]) == 0 + assert event_index(starts[-1]) == 1 + assert {event_index(e) for e in parsed if e.event == "content_block_stop"} == {0, 1} diff --git a/tests/providers/test_anthropic_messages_429_retry.py b/tests/providers/test_anthropic_messages_429_retry.py new file mode 100644 index 0000000..e64bbb7 --- /dev/null +++ b/tests/providers/test_anthropic_messages_429_retry.py @@ -0,0 +1,125 @@ +"""Native Anthropic transport: HTTP 429 is retried inside execute_with_retry.""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from providers.base import ProviderConfig +from providers.rate_limit import GlobalRateLimiter +from tests.providers.test_anthropic_messages import ( + FakeResponse, + MockRequest, + NativeProvider, +) +from tests.stream_contract import assert_canonical_stream_error_envelope + + +@pytest.fixture +def provider_config(): + return ProviderConfig( + api_key="test-key", + base_url="https://custom.test/v1/", + rate_limit=100, + rate_window=60, + http_read_timeout=600.0, + http_write_timeout=15.0, + http_connect_timeout=5.0, + ) + + +@pytest.mark.asyncio +async def test_native_stream_retries_on_http_429_then_streams(provider_config): + """First response 429 (closed), second 200 streams; send is called twice.""" + GlobalRateLimiter.reset_instance() + try: + provider = NativeProvider(provider_config) + req = MockRequest() + request_obj = httpx.Request("POST", "https://custom.test/v1/messages") + ok_lines = [ + "event: message_start", + 'data: {"type":"message_start"}', + "", + ] + ok_response = FakeResponse(lines=ok_lines) + too_many = FakeResponse(status_code=429, text="rate limited") + + send_calls = {"n": 0} + + async def send_side_effect(*_a, **_kw): + send_calls["n"] += 1 + if send_calls["n"] == 1: + return too_many + return ok_response + + with ( + patch.object(provider._client, "build_request", return_value=request_obj), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + side_effect=send_side_effect, + ), + patch( + "asyncio.sleep", + new_callable=AsyncMock, + ), + ): + events = [e async for e in provider.stream_response(req)] + + assert send_calls["n"] == 2 + assert too_many.is_closed + assert ok_response.is_closed + assert events == [ + "event: message_start\n", + 'data: {"type":"message_start"}\n', + "\n", + ] + finally: + GlobalRateLimiter.reset_instance() + + +@pytest.mark.asyncio +async def test_non_429_http_error_not_retried(provider_config): + """HTTP 500 from upstream is not retried; single send.""" + GlobalRateLimiter.reset_instance() + try: + + @asynccontextmanager + async def _slot(): + yield + + with patch("providers.anthropic_messages.GlobalRateLimiter") as mock_gl: + instance = mock_gl.get_scoped_instance.return_value + + async def _passthrough(fn, *args, **kwargs): + return await fn(*args, **kwargs) + + instance.execute_with_retry = AsyncMock(side_effect=_passthrough) + instance.concurrency_slot.side_effect = _slot + + provider = NativeProvider(provider_config) + req = MockRequest() + err = FakeResponse(status_code=500, text="Internal Server Error") + + with ( + patch.object( + provider._client, "build_request", return_value=MagicMock() + ), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=err, + ) as mock_send, + ): + events = [e async for e in provider.stream_response(req)] + + mock_send.assert_awaited_once() + assert err.is_closed + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + finally: + GlobalRateLimiter.reset_instance() diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py index 7312fbd..99aeba6 100644 --- a/tests/providers/test_converter.py +++ b/tests/providers/test_converter.py @@ -2,7 +2,12 @@ import json import pytest -from core.anthropic import AnthropicToOpenAIConverter +from api.models.anthropic import MessagesRequest +from core.anthropic import ( + AnthropicToOpenAIConverter, + OpenAIConversionError, + build_base_request_body, +) # --- Mock Classes --- @@ -468,3 +473,102 @@ def test_convert_multiple_tool_results(): assert len(result) == 2 assert result[0]["tool_call_id"] == "t1" assert result[1]["tool_call_id"] == "t2" + + +def test_convert_user_message_tool_result_dict_as_json(): + content = [ + MockBlock( + type="tool_result", + tool_use_id="t_dict", + content={"ok": True, "count": 2}, + ), + ] + messages = [MockMessage("user", content)] + result = AnthropicToOpenAIConverter.convert_messages(messages) + assert result[0]["role"] == "tool" + assert result[0]["content"] == '{"ok": true, "count": 2}' + + +def test_assistant_redacted_thinking_omitted_from_openai_chat(): + """Opaque redacted_thinking is not materialized as content or reasoning_content.""" + content = [ + MockBlock(type="redacted_thinking", data="secret-opaque"), + MockBlock(type="text", text="Visible."), + ] + messages = [MockMessage("assistant", content)] + result = AnthropicToOpenAIConverter.convert_messages( + messages, include_thinking=True, include_reasoning_content=True + ) + assert result[0]["content"] == "Visible." + assert "secret-opaque" not in result[0]["content"] + assert "reasoning_content" not in result[0] + + +def test_convert_user_message_image_raises(): + content = [ + MockBlock(type="image", source={"type": "url", "url": "https://example.com/x"}) + ] + messages = [MockMessage("user", content)] + with pytest.raises(OpenAIConversionError): + AnthropicToOpenAIConverter.convert_messages(messages) + + +def test_convert_assistant_text_after_tool_use_raises(): + content = [ + MockBlock(type="tool_use", id="call_z", name="Read", input={}), + MockBlock(type="text", text="Illegal after tool"), + ] + messages = [MockMessage("assistant", content)] + with pytest.raises(OpenAIConversionError): + AnthropicToOpenAIConverter.convert_messages(messages) + + +def test_openai_build_accepts_declared_native_top_level_hints() -> None: + """OpenAI conversion ignores known non-OpenAI hints (e.g. context_management) without 400.""" + req = MessagesRequest.model_validate( + { + "model": "m", + "messages": [{"role": "user", "content": "h"}], + "context_management": {"edits": []}, + "output_config": {"foo": 1}, + "mcp_servers": [{"type": "url", "url": "https://x.com"}], + } + ) + body = build_base_request_body(req, default_max_tokens=100) + assert "context_management" not in body + assert "output_config" not in body + assert "mcp_servers" not in body + assert body["model"] == "m" + + +def test_openai_build_rejects_unknown_top_level_extras() -> None: + """Truly unknown keys must still be rejected (not dropped silently).""" + req = MessagesRequest.model_validate( + { + "model": "m", + "messages": [{"role": "user", "content": "h"}], + "experimental_client_only_passthrough": True, + } + ) + with pytest.raises(OpenAIConversionError, match="top-level request fields"): + build_base_request_body(req) + + +@pytest.mark.parametrize( + "content", + [ + [MockBlock(type="server_tool_use", id="1", name="web_search", input={})], + [MockBlock(type="web_search_tool_result", tool_use_id="1", content=[])], + [ + MockBlock( + type="web_fetch_tool_result", + tool_use_id="1", + content={"type": "web_fetch_result", "url": "https://a.com/x"}, + ) + ], + ], +) +def test_convert_assistant_server_tool_blocks_raise(content) -> None: + messages = [MockMessage("assistant", content)] + with pytest.raises(OpenAIConversionError, match="server tool"): + AnthropicToOpenAIConverter.convert_messages(messages) diff --git a/tests/providers/test_deepseek.py b/tests/providers/test_deepseek.py index bb00e0a..7a4d997 100644 --- a/tests/providers/test_deepseek.py +++ b/tests/providers/test_deepseek.py @@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from api.models.anthropic import ContentBlockImage, Message, MessagesRequest from providers.base import ProviderConfig -from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider +from providers.deepseek import DEEPSEEK_DEFAULT_BASE, DeepSeekProvider +from providers.exceptions import InvalidRequestError class MockMessage: @@ -41,7 +43,7 @@ class MockRequest: def deepseek_config(): return ProviderConfig( api_key="test_deepseek_key", - base_url=DEEPSEEK_BASE_URL, + base_url=DEEPSEEK_DEFAULT_BASE, rate_limit=10, rate_window=60, enable_thinking=True, @@ -72,7 +74,7 @@ def test_init(deepseek_config): with patch("providers.openai_compat.AsyncOpenAI") as mock_openai: provider = DeepSeekProvider(deepseek_config) assert provider._api_key == "test_deepseek_key" - assert provider._base_url == DEEPSEEK_BASE_URL + assert provider._base_url == DEEPSEEK_DEFAULT_BASE mock_openai.assert_called_once() @@ -91,7 +93,7 @@ def test_build_request_body_global_disable_blocks_request_thinking(): provider = DeepSeekProvider( ProviderConfig( api_key="test_deepseek_key", - base_url=DEEPSEEK_BASE_URL, + base_url=DEEPSEEK_DEFAULT_BASE, rate_limit=10, rate_window=60, enable_thinking=False, @@ -152,6 +154,67 @@ def test_build_request_body_preserves_reasoning_content(deepseek_provider): assert body["messages"][0]["reasoning_content"] == "First think" +def test_build_request_body_disabled_thinking_omits_reasoning_and_thinking_tags(): + """Resolved thinking policy must strip assistant thinking from OpenAI history.""" + provider = DeepSeekProvider( + ProviderConfig( + api_key="test_deepseek_key", + base_url=DEEPSEEK_DEFAULT_BASE, + rate_limit=10, + rate_window=60, + enable_thinking=False, + ) + ) + req = MockRequest( + system=None, + model="deepseek-chat", + messages=[ + MockMessage( + "assistant", + [ + MockBlock(type="thinking", thinking="secret"), + MockBlock(type="text", text="hi"), + ], + ) + ], + ) + + body = provider._build_request_body(req) + + assistant = body["messages"][0] + assert "reasoning_content" not in assistant + assert "secret" not in assistant["content"] + assert assistant["content"] == "hi" + + +def test_build_request_body_disabled_thinking_omits_redacted_blocks(): + """redacted_thinking is not sent as OpenAI text when thinking is disabled.""" + provider = DeepSeekProvider( + ProviderConfig( + api_key="test_deepseek_key", + base_url=DEEPSEEK_DEFAULT_BASE, + rate_limit=10, + rate_window=60, + enable_thinking=False, + ) + ) + req = MockRequest( + system=None, + model="deepseek-chat", + messages=[ + MockMessage( + "assistant", + [ + MockBlock(type="redacted_thinking", data="opaque-xyz"), + MockBlock(type="text", text="hi"), + ], + ) + ], + ) + body = provider._build_request_body(req) + assert "opaque-xyz" not in body["messages"][0]["content"] + + @pytest.mark.asyncio async def test_stream_response_reasoning_content(deepseek_provider): """reasoning_content deltas are emitted as thinking blocks.""" @@ -179,3 +242,37 @@ async def test_stream_response_reasoning_content(deepseek_provider): assert any( '"thinking_delta"' in event and "Thinking..." in event for event in events ) + + +def test_preflight_stream_rejects_unsupported_user_image_for_openai_conversion(): + """Eager preflight: image block fails before a stream would be opened.""" + request = MessagesRequest( + model="deepseek/deepseek-chat", + max_tokens=100, + messages=[ + Message( + role="user", + content=[ + ContentBlockImage( + type="image", + source={ + "type": "base64", + "media_type": "image/png", + "data": "YQ==", + }, + ) + ], + ) + ], + ) + provider = DeepSeekProvider( + ProviderConfig( + api_key="k", + base_url=DEEPSEEK_DEFAULT_BASE, + rate_limit=10, + rate_window=60, + ) + ) + with pytest.raises(InvalidRequestError) as exc: + provider.preflight_stream(request, thinking_enabled=True) + assert "image" in str(exc.value).lower() diff --git a/tests/providers/test_error_mapping.py b/tests/providers/test_error_mapping.py index 162b05d..2a180e7 100644 --- a/tests/providers/test_error_mapping.py +++ b/tests/providers/test_error_mapping.py @@ -1,13 +1,21 @@ """Tests for provider error mapping and core error formatting.""" +from pathlib import Path from unittest.mock import MagicMock, patch import openai import pytest from httpx import ReadTimeout, Request, Response -from core.anthropic import append_request_id, get_user_facing_error_message -from providers.error_mapping import map_error +from core.anthropic import ( + append_request_id, + format_user_error_preview, + get_user_facing_error_message, +) +from providers.error_mapping import ( + map_error, + user_visible_message_for_mapped_provider_error, +) from providers.exceptions import ( APIError, AuthenticationError, @@ -101,27 +109,6 @@ class TestMapError: result = map_error(exc) assert result is exc - @pytest.mark.parametrize( - "exc_cls,expected_cls", - [ - (openai.AuthenticationError, AuthenticationError), - (openai.RateLimitError, RateLimitError), - (openai.BadRequestError, InvalidRequestError), - ], - ids=["auth", "rate_limit", "bad_request"], - ) - def test_mapping_parametrized(self, exc_cls, expected_cls): - """Parametrized check of openai -> provider error mapping.""" - status_map = { - openai.AuthenticationError: 401, - openai.RateLimitError: 429, - openai.BadRequestError: 400, - } - exc = _make_openai_error(exc_cls, status_code=status_map[exc_cls]) - with patch("providers.error_mapping.GlobalRateLimiter"): - result = map_error(exc) - assert isinstance(result, expected_cls) - def test_user_facing_message_read_timeout_empty_string(): """ReadTimeout wrapping TimeoutError should still produce readable text.""" @@ -134,3 +121,35 @@ def test_append_request_id_suffix(): """Request id suffix should be appended deterministically.""" message = append_request_id("Provider request failed.", "req_abc123") assert message == "Provider request failed. (request_id=req_abc123)" + + +def test_user_facing_message_bad_request_prefers_mapped_text_over_sdk_string(): + """BadRequestError should map to stable wording even when str(exc) is non-empty.""" + exc = _make_openai_error( + openai.BadRequestError, message="leaky-upstream-detail", status_code=400 + ) + assert get_user_facing_error_message(exc) == "Invalid request sent to provider." + + +def test_format_user_error_preview_truncates(): + exc = ValueError("x" * 500) + preview = format_user_error_preview(exc, max_len=20) + assert len(preview) == 20 + assert preview == "x" * 20 + + +def test_user_visible_message_for_mapped_provider_error_405(): + mapped = APIError("ignored", status_code=405, raw_error="") + msg = user_visible_message_for_mapped_provider_error( + mapped, provider_name="ACME", read_timeout_s=30.0 + ) + assert "ACME" in msg and "405" in msg + + +def test_streaming_transports_pass_scoped_rate_limiter_to_map_error(): + """Guardrail: streaming adapters must scope reactive 429 handling per provider.""" + root = Path(__file__).resolve().parents[2] + for name in ("anthropic_messages.py", "openai_compat.py"): + text = (root / "providers" / name).read_text(encoding="utf-8") + assert "map_error(" in text, name + assert "rate_limiter=self._global_rate_limiter" in text, name diff --git a/tests/providers/test_llamacpp.py b/tests/providers/test_llamacpp.py index 04e0874..df87069 100644 --- a/tests/providers/test_llamacpp.py +++ b/tests/providers/test_llamacpp.py @@ -5,8 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS from providers.base import ProviderConfig from providers.llamacpp import LlamaCppProvider +from tests.stream_contract import assert_canonical_stream_error_envelope class MockMessage: @@ -221,7 +223,7 @@ async def test_stream_response_adds_max_tokens_if_missing(llamacpp_provider): [e async for e in llamacpp_provider.stream_response(req)] _, kwargs = mock_build.call_args - assert kwargs["json"]["max_tokens"] == 81920 + assert kwargs["json"]["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS @pytest.mark.asyncio @@ -254,10 +256,10 @@ async def test_stream_error_status_code(llamacpp_provider): async for e in llamacpp_provider.stream_response(req, request_id="TEST_ID") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "TEST_ID" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + assert "TEST_ID" in "".join(events) @pytest.mark.asyncio @@ -281,10 +283,11 @@ async def test_stream_network_error(llamacpp_provider): async for e in llamacpp_provider.stream_response(req, request_id="TEST_ID2") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Connection refused" in events[0] - assert "TEST_ID2" in events[0] + blob = "".join(events) + assert_canonical_stream_error_envelope( + events, user_message_substr="Connection refused" + ) + assert "TEST_ID2" in blob @pytest.mark.asyncio @@ -315,8 +318,36 @@ async def test_stream_error_405_mentions_upstream_provider(llamacpp_provider): e async for e in llamacpp_provider.stream_response(req, request_id="REQ405") ] + blob = "".join(events) assert ( "Upstream provider LLAMACPP rejected the request method or endpoint (HTTP 405)." - in events[0] + in blob ) - assert "REQ405" in events[0] + assert "REQ405" in blob + + +def test_build_request_body_disabled_thinking_strips_native_thinking_history( + llamacpp_config, +): + """With thinking disabled, prior assistant thinking/redacted blocks are omitted.""" + config = llamacpp_config.model_copy(update={"enable_thinking": False}) + provider = LlamaCppProvider(config) + messages = [ + MockMessage("user", "Hi"), + MockMessage( + "assistant", + [ + {"type": "thinking", "thinking": "p"}, + {"type": "redacted_thinking", "data": "ZGF0YQ=="}, + ], + ), + ] + req = MockRequest( + system=None, + messages=messages, + ) + body = provider._build_request_body(req, thinking_enabled=False) + asst = body["messages"][1] + assert asst["content"] == "" + assert "thinking" not in str(body) + assert "redacted_thinking" not in str(body) diff --git a/tests/providers/test_lmstudio.py b/tests/providers/test_lmstudio.py index f8c4dee..06aca7a 100644 --- a/tests/providers/test_lmstudio.py +++ b/tests/providers/test_lmstudio.py @@ -5,8 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS from providers.base import ProviderConfig from providers.lmstudio import LMStudioProvider +from tests.stream_contract import assert_canonical_stream_error_envelope class MockMessage: @@ -252,7 +254,7 @@ async def test_stream_response_adds_max_tokens_if_missing(lmstudio_provider): [e async for e in lmstudio_provider.stream_response(req)] _, kwargs = mock_build.call_args - assert kwargs["json"]["max_tokens"] == 81920 + assert kwargs["json"]["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS @pytest.mark.asyncio @@ -285,10 +287,10 @@ async def test_stream_error_status_code(lmstudio_provider): async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "TEST_ID" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + assert "TEST_ID" in "".join(events) @pytest.mark.asyncio @@ -312,10 +314,11 @@ async def test_stream_network_error(lmstudio_provider): async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID2") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Connection refused" in events[0] - assert "TEST_ID2" in events[0] + blob = "".join(events) + assert_canonical_stream_error_envelope( + events, user_message_substr="Connection refused" + ) + assert "TEST_ID2" in blob @pytest.mark.asyncio @@ -346,8 +349,61 @@ async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider): e async for e in lmstudio_provider.stream_response(req, request_id="REQ405") ] + blob = "".join(events) assert ( "Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)." - in events[0] + in blob ) - assert "REQ405" in events[0] + assert "REQ405" in blob + + +def test_build_request_body_disabled_thinking_strips_native_thinking_history( + lmstudio_config, +): + """Disabled thinking must omit prior assistant thinking/redacted blocks in JSON.""" + provider = LMStudioProvider( + lmstudio_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest( + system=None, + messages=[ + MockMessage("user", "hi"), + MockMessage( + "assistant", + [ + {"type": "thinking", "thinking": "t"}, + {"type": "redacted_thinking", "data": "opaque"}, + ], + ), + ], + ) + body = provider._build_request_body(req, thinking_enabled=False) + assert body["messages"][1]["content"] == "" + assert "redacted_thinking" not in str(body) + + +def test_build_request_body_preserves_signed_thinking_native_history(lmstudio_config): + """When thinking is enabled, signed thinking blocks are kept; unsigned stripped.""" + provider = LMStudioProvider(lmstudio_config) + req = MockRequest( + system=None, + messages=[ + MockMessage("user", "hi"), + MockMessage( + "assistant", + [ + { + "type": "thinking", + "thinking": "signed", + "signature": "sig", + }, + {"type": "redacted_thinking", "data": "opaque"}, + ], + ), + ], + ) + body = provider._build_request_body(req, thinking_enabled=True) + c = body["messages"][1]["content"] + assert isinstance(c, list) + assert any(isinstance(b, dict) and b.get("type") == "thinking" for b in c) + assert any(isinstance(b, dict) and b.get("type") == "redacted_thinking" for b in c) diff --git a/tests/providers/test_nim_request_clone.py b/tests/providers/test_nim_request_clone.py new file mode 100644 index 0000000..30260c1 --- /dev/null +++ b/tests/providers/test_nim_request_clone.py @@ -0,0 +1,40 @@ +"""Tests for NVIDIA NIM request body cloning helpers.""" + +from copy import deepcopy + +from providers.nvidia_nim.request import clone_body_without_reasoning_budget + + +def test_clone_body_without_reasoning_budget_strips_top_level_and_nested(): + body: dict = { + "model": "x", + "extra_body": { + "reasoning_budget": 99, + "chat_template_kwargs": {"reasoning_budget": 42, "thinking": True}, + "top_k": 1, + }, + } + original_extra = deepcopy(body["extra_body"]) + out = clone_body_without_reasoning_budget(body) + + assert out is not None + assert out["extra_body"]["chat_template_kwargs"] == {"thinking": True} + assert "reasoning_budget" not in out["extra_body"] + assert body["extra_body"] == original_extra + + +def test_clone_body_without_reasoning_budget_returns_none_when_unchanged(): + body = {"model": "x", "extra_body": {"top_k": 3}} + assert clone_body_without_reasoning_budget(body) is None + + +def test_clone_body_without_reasoning_budget_returns_none_without_extra_body(): + assert clone_body_without_reasoning_budget({"model": "y"}) is None + + +def test_clone_body_drops_empty_extra_body_after_strip(): + body = {"model": "z", "extra_body": {"reasoning_budget": 7}} + out = clone_body_without_reasoning_budget(body) + assert out is not None + assert "extra_body" not in out + assert "extra_body" in body diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py index dcacbfd..4c190ed 100644 --- a/tests/providers/test_nvidia_nim.py +++ b/tests/providers/test_nvidia_nim.py @@ -5,6 +5,8 @@ import openai import pytest from httpx import Request, Response +from config.nim import NimSettings +from providers.defaults import NVIDIA_NIM_DEFAULT_BASE from providers.nvidia_nim import NvidiaNimProvider @@ -42,7 +44,7 @@ class MockRequest: def _make_bad_request_error(message: str) -> openai.BadRequestError: response = Response( status_code=400, - request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"), + request=Request("POST", f"{NVIDIA_NIM_DEFAULT_BASE}/chat/completions"), ) body = {"error": {"message": message, "type": "BadRequestError", "code": 400}} return openai.BadRequestError(message, response=response, body=body) @@ -67,8 +69,6 @@ def mock_rate_limiter(): async def test_init(provider_config): """Test provider initialization.""" with patch("providers.openai_compat.AsyncOpenAI") as mock_openai: - from config.nim import NimSettings - provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) assert provider._api_key == "test_key" assert provider._base_url == "https://test.api.nvidia.com/v1" @@ -78,7 +78,6 @@ async def test_init(provider_config): @pytest.mark.asyncio async def test_init_uses_configurable_timeouts(): """Test that provider passes configurable read/write/connect timeouts to client.""" - from config.nim import NimSettings from providers.base import ProviderConfig config = ProviderConfig( @@ -100,8 +99,6 @@ async def test_init_uses_configurable_timeouts(): @pytest.mark.asyncio async def test_build_request_body(provider_config): """Test request body construction.""" - from config.nim import NimSettings - provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) req = MockRequest() body = provider._build_request_body(req) @@ -124,8 +121,6 @@ async def test_build_request_body(provider_config): async def test_build_request_body_omits_reasoning_when_globally_disabled( provider_config, ): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config.model_copy(update={"enable_thinking": False}), nim_settings=NimSettings(), @@ -142,8 +137,6 @@ async def test_build_request_body_omits_reasoning_when_globally_disabled( async def test_build_request_body_omits_reasoning_when_request_disables_thinking( provider_config, ): - from config.nim import NimSettings - provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) req = MockRequest() req.thinking.enabled = False @@ -241,8 +234,6 @@ async def test_stream_response_thinking_reasoning_content(nim_provider): @pytest.mark.asyncio async def test_stream_response_suppresses_thinking_when_disabled(provider_config): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config.model_copy(update={"enable_thinking": False}), nim_settings=NimSettings(), @@ -285,8 +276,6 @@ def _make_bad_request_error(message: str) -> openai.BadRequestError: @pytest.mark.asyncio async def test_stream_response_retries_without_chat_template(provider_config): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config, nim_settings=NimSettings(chat_template="custom_template"), @@ -344,8 +333,6 @@ async def test_stream_response_retries_without_chat_template(provider_config): @pytest.mark.asyncio async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config, nim_settings=NimSettings(chat_template="custom_template"), @@ -361,7 +348,7 @@ async def test_stream_response_does_not_retry_unrelated_bad_request(provider_con assert mock_create.await_count == 1 event_text = "".join(events) - assert "unrelated bad request" in event_text + assert "Invalid request sent to provider" in event_text assert "event: message_stop" in event_text @@ -457,5 +444,5 @@ async def test_stream_response_bad_request_without_reasoning_budget_does_not_ret events = [e async for e in nim_provider.stream_response(req)] assert mock_create.await_count == 1 - assert any("Unsupported field: top_k" in event for event in events) + assert any("Invalid request sent to provider" in event for event in events) assert any("message_stop" in event for event in events) diff --git a/tests/providers/test_ollama.py b/tests/providers/test_ollama.py index 1edeb3b..ffb90a7 100644 --- a/tests/providers/test_ollama.py +++ b/tests/providers/test_ollama.py @@ -6,7 +6,8 @@ import httpx import pytest from providers.base import ProviderConfig -from providers.ollama import OLLAMA_BASE_URL, OllamaProvider +from providers.ollama import OLLAMA_DEFAULT_BASE, OllamaProvider +from tests.stream_contract import assert_canonical_stream_error_envelope class MockMessage: @@ -92,7 +93,7 @@ def test_init_uses_default_base_url(): config = ProviderConfig(api_key="ollama", base_url=None) with patch("httpx.AsyncClient"): provider = OllamaProvider(config) - assert provider._base_url == OLLAMA_BASE_URL + assert provider._base_url == OLLAMA_DEFAULT_BASE def test_init_uses_configurable_timeouts(): @@ -199,6 +200,30 @@ async def test_build_request_body_omits_thinking_when_disabled(ollama_config): assert body["model"] == "llama3.1:8b" +def test_build_request_body_disabled_thinking_strips_assistant_thinking_blocks( + ollama_config, +): + """Prior assistant thinking/redacted blocks are removed when policy is off.""" + provider = OllamaProvider( + ollama_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest( + system=None, + messages=[ + MockMessage("user", "hi"), + MockMessage( + "assistant", + [ + {"type": "thinking", "thinking": "t"}, + {"type": "redacted_thinking", "data": "opaque"}, + ], + ), + ], + ) + body = provider._build_request_body(req, thinking_enabled=False) + assert body["messages"][1]["content"] == "" + + @pytest.mark.asyncio async def test_stream_error_status_code(ollama_provider): """Non-200 status code is yielded as an SSE API error.""" @@ -228,10 +253,10 @@ async def test_stream_error_status_code(ollama_provider): async for event in ollama_provider.stream_response(req, request_id="REQ") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "REQ" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + assert "REQ" in "".join(events) @pytest.mark.asyncio diff --git a/tests/providers/test_open_router.py b/tests/providers/test_open_router.py index d88b2c6..7a6da14 100644 --- a/tests/providers/test_open_router.py +++ b/tests/providers/test_open_router.py @@ -6,6 +6,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, + text_content, + thinking_content, +) from providers.base import ProviderConfig from providers.open_router import OpenRouterProvider from providers.open_router.request import OPENROUTER_DEFAULT_MAX_TOKENS @@ -30,8 +36,6 @@ class MockRequest: self.tool_choice = None self.metadata = None self.extra_body = {} - self.original_model = "claude-3-sonnet" - self.resolved_provider_model = "open_router/stepfun/step-3.5-flash:free" self.thinking = MagicMock() self.thinking.enabled = True for k, v in kwargs.items(): @@ -135,8 +139,26 @@ def test_build_request_body_is_native_anthropic(open_router_provider): assert body["system"] == "System prompt" assert body["reasoning"] == {"enabled": True} assert "extra_body" not in body - assert "original_model" not in body - assert "resolved_provider_model" not in body + + +def test_openrouter_extra_body_rejects_overriding_reserved_fields() -> None: + from providers.exceptions import InvalidRequestError + from providers.open_router.request import build_request_body + + r = MockRequest() + r.extra_body = {"model": "hijack"} + with pytest.raises(InvalidRequestError, match="model"): + build_request_body(r, thinking_enabled=True) + + +def test_openrouter_extra_body_allows_openrouter_only_keys() -> None: + from providers.open_router.request import build_request_body + + r = MockRequest() + r.extra_body = {"transforms": ["no-web"], "plugins": []} + body = build_request_body(r, thinking_enabled=False) + assert body["transforms"] == ["no-web"] + assert body["plugins"] == [] def test_build_request_body_omits_reasoning_when_globally_disabled( @@ -188,7 +210,6 @@ def test_build_request_body_default_max_tokens(open_router_provider): body = open_router_provider._build_request_body(req) assert body["max_tokens"] == OPENROUTER_DEFAULT_MAX_TOKENS - assert body["max_tokens"] == 81920 def test_build_request_body_strips_unsigned_thinking_history(open_router_provider): @@ -209,7 +230,32 @@ def test_build_request_body_strips_unsigned_thinking_history(open_router_provide body = open_router_provider._build_request_body(req) - assert body["messages"][1]["content"] == [{"type": "text", "text": "Hello"}] + assert body["messages"][1]["content"] == [ + {"type": "redacted_thinking", "data": "opaque"}, + {"type": "text", "text": "Hello"}, + ] + + +def test_build_request_body_strips_redacted_when_thinking_disabled( + open_router_config, +): + """Disabled thinking must remove all assistant thinking history including redacted.""" + provider = OpenRouterProvider( + open_router_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest( + messages=[ + MockMessage( + "assistant", + [ + {"type": "redacted_thinking", "data": "opaque"}, + {"type": "text", "text": "Hi"}, + ], + ) + ] + ) + body = provider._build_request_body(req) + assert body["messages"][0]["content"] == [{"type": "text", "text": "Hi"}] def test_build_request_body_preserves_signed_thinking_history(open_router_provider): @@ -336,12 +382,12 @@ async def test_stream_response_suppresses_native_thinking_when_disabled( assert "Answer" in event_text text_start = next(event for event in events if "content_block_start" in event) - payload = json.loads(text_start.split("data: ", 1)[1]) + payload = parse_sse_text(text_start)[0].data assert payload["index"] == 0 @pytest.mark.asyncio -async def test_stream_response_drops_redacted_thinking_when_enabled( +async def test_stream_response_preserves_redacted_thinking_when_enabled( open_router_provider, ): response = FakeResponse( @@ -378,16 +424,216 @@ async def test_stream_response_drops_redacted_thinking_when_enabled( ): events = [e async for e in open_router_provider.stream_response(MockRequest())] + event_text = "".join(events) + assert "redacted_thinking" in event_text + assert "opaque" in event_text + assert "Answer" in event_text + + parsed = parse_sse_text(event_text) + first_start = next( + p + for p in parsed + if p.event == "content_block_start" + and p.data.get("content_block", {}).get("type") == "redacted_thinking" + ) + assert first_start.data["index"] == 0 + + +@pytest.mark.asyncio +async def test_stream_response_drops_redacted_thinking_when_disabled( + open_router_config, +): + provider = OpenRouterProvider( + open_router_config.model_copy(update={"enable_thinking": False}) + ) + response = FakeResponse( + lines=[ + "event: content_block_start", + 'data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"opaque"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":0}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"Answer"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":1}', + "", + ] + ) + + with ( + patch.object(provider._client, "build_request"), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in provider.stream_response(MockRequest())] + event_text = "".join(events) assert "redacted_thinking" not in event_text + assert "opaque" not in event_text assert "Answer" in event_text start_event = next(event for event in events if "content_block_start" in event) - payload = json.loads(start_event.split("data: ", 1)[1]) + payload = parse_sse_text(start_event)[0].data assert payload["index"] == 0 assert payload["content_block"]["type"] == "text" +@pytest.mark.asyncio +async def test_stream_response_reopens_interleaved_thinking_after_text( + open_router_provider, +): + """Overthinking+text+more thinking: downstream indices must not reuse closed blocks.""" + response = FakeResponse( + lines=[ + "event: message_start", + 'data: {"type":"message_start","message":{}}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":0,' + '"content_block":{"type":"thinking","thinking":"","signature":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":0,' + '"delta":{"type":"thinking_delta","thinking":"first"}}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":1,' + '"content_block":{"type":"text","text":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":0,' + '"delta":{"type":"thinking_delta","thinking":" second"}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":1,' + '"delta":{"type":"text_delta","text":"Answer"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":1}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":0}', + "", + "event: message_stop", + 'data: {"type":"message_stop"}', + "", + ] + ) + + with ( + patch.object(open_router_provider._client, "build_request"), + patch.object( + open_router_provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in open_router_provider.stream_response(MockRequest())] + + parsed = parse_sse_text("".join(events)) + assert_anthropic_stream_contract(parsed) + assert thinking_content(parsed) == "first second" + assert "Answer" in text_content(parsed) + stop_payloads = [ + p.data + for p in parsed + if p.event == "content_block_stop" + and p.data.get("type") == "content_block_stop" + ] + seen_stop_indices: set[int] = set() + for s in stop_payloads: + idx = s.get("index") + assert isinstance(idx, int) + assert idx not in seen_stop_indices, "stop reused or duplicated index" + seen_stop_indices.add(idx) + # Two distinct thinking block indices: initial + reopened segment + think_starts = [ + p + for p in parsed + if p.event == "content_block_start" + and p.data.get("content_block", {}).get("type") == "thinking" + ] + assert len(think_starts) == 2, ( + "reopened thinking must have its own `content_block_start`" + ) + + +@pytest.mark.asyncio +async def test_stream_response_reopened_tool_use_preserves_tool_identity( + open_router_provider, +): + """After overlapping close, resumed input_json_delta must keep original tool id/name.""" + lines: list[str] = [] + for payload in ( + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "tool_use", + "id": "toolu_real_1", + "name": "Read", + "input": {}, + }, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "input_json_delta", "partial_json": '{"path'}, + }, + { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "text", "text": ""}, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "input_json_delta", "partial_json": '":"/tmp"}'}, + }, + {"type": "content_block_stop", "index": 1}, + {"type": "content_block_stop", "index": 0}, + ): + event_name = payload["type"] + lines.extend((f"event: {event_name}", f"data: {json.dumps(payload)}", "")) + + response = FakeResponse(lines=lines) + + with ( + patch.object(open_router_provider._client, "build_request"), + patch.object( + open_router_provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in open_router_provider.stream_response(MockRequest())] + + parsed = parse_sse_text("".join(events)) + tool_starts = [ + p + for p in parsed + if p.event == "content_block_start" + and p.data.get("content_block", {}).get("type") == "tool_use" + ] + assert len(tool_starts) == 2 + for start in tool_starts: + block = start.data["content_block"] + assert block["id"] == "toolu_real_1" + assert block["name"] == "Read" + + @pytest.mark.asyncio async def test_stream_response_closes_overlapping_thinking_before_text( open_router_provider, diff --git a/tests/providers/test_provider_rate_limit.py b/tests/providers/test_provider_rate_limit.py index 7cd00d9..7d59280 100644 --- a/tests/providers/test_provider_rate_limit.py +++ b/tests/providers/test_provider_rate_limit.py @@ -27,11 +27,11 @@ class TestProviderRateLimiter: GlobalRateLimiter.reset_instance() limiter = GlobalRateLimiter.get_instance(rate_limit=1, rate_window=0.25) - start_time = time.time() + start_time = time.monotonic() async def call_limiter(): await limiter.wait_if_blocked() - return time.time() + return time.monotonic() # 5 requests. # R0 -> 0s @@ -41,7 +41,7 @@ class TestProviderRateLimiter: # R4 -> 1.00s results = [await call_limiter() for _ in range(5)] - total_time = time.time() - start_time + total_time = time.monotonic() - start_time assert len(results) == 5 # Should take at least ~1.0s @@ -56,7 +56,7 @@ class TestProviderRateLimiter: GlobalRateLimiter.reset_instance() limiter = GlobalRateLimiter.get_instance() - start_time = time.time() + start_time = time.monotonic() # Manually block for 1.5s block_time = 1.5 @@ -71,7 +71,7 @@ class TestProviderRateLimiter: # They should both wait for the block time results = await asyncio.gather(call_limiter(), call_limiter()) - total_time = time.time() - start_time + total_time = time.monotonic() - start_time # Both should report having waited reactively assert all(results) is True @@ -121,10 +121,10 @@ class TestProviderRateLimiter: GlobalRateLimiter.reset_instance() limiter = GlobalRateLimiter.get_instance(rate_limit=10000, rate_window=60) - start = time.time() + start = time.monotonic() for _ in range(20): await limiter.wait_if_blocked() - duration = time.time() - start + duration = time.monotonic() - start # 20 requests with 10000 limit should be near-instant assert duration < 1.0, f"High rate limit caused throttling: {duration:.2f}s" @@ -252,6 +252,32 @@ class TestProviderRateLimiter: assert result == "ok" assert call_count == 2 + @pytest.mark.asyncio + async def test_execute_with_retry_succeeds_on_httpx_429(self): + """HTTP 429 as httpx.HTTPStatusError then success returns result.""" + import httpx + from httpx import Request, Response + + limiter = GlobalRateLimiter.get_instance(rate_limit=100, rate_window=60) + + call_count = 0 + + async def fail_then_ok(): + nonlocal call_count + call_count += 1 + if call_count == 1: + r = Response(429, request=Request("POST", "http://x"), text="slow") + raise httpx.HTTPStatusError( + "Too Many Requests", request=r.request, response=r + ) + return "ok" + + result = await limiter.execute_with_retry( + fail_then_ok, max_retries=2, base_delay=0.01, max_delay=0.1, jitter=0 + ) + assert result == "ok" + assert call_count == 2 + @pytest.mark.asyncio async def test_max_concurrency_zero_raises(self): """max_concurrency <= 0 raises ValueError.""" diff --git a/tests/providers/test_provider_transport_logging.py b/tests/providers/test_provider_transport_logging.py new file mode 100644 index 0000000..5076930 --- /dev/null +++ b/tests/providers/test_provider_transport_logging.py @@ -0,0 +1,263 @@ +"""Tests for metadata-only provider transport logging by default.""" + +import logging +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from config.constants import NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES +from config.nim import NimSettings +from providers.anthropic_messages import AnthropicMessagesTransport +from providers.base import ProviderConfig +from providers.nvidia_nim import NvidiaNimProvider +from tests.provider_request_mocks import make_openai_compat_stream_request +from tests.providers.test_anthropic_messages import ( + FakeResponse, + MockRequest, + NativeProvider, +) + + +@pytest.fixture +def provider_config(): + return ProviderConfig( + api_key="test-key", + base_url="https://custom.test/v1/", + proxy="socks5://127.0.0.1:9999", + rate_limit=10, + rate_window=60, + http_read_timeout=600.0, + http_write_timeout=15.0, + http_connect_timeout=5.0, + ) + + +@pytest.fixture(autouse=True) +def mock_rate_limiter(): + @asynccontextmanager + async def _slot(): + yield + + with patch("providers.anthropic_messages.GlobalRateLimiter") as mock: + instance = mock.get_scoped_instance.return_value + + async def _passthrough(fn, *args, **kwargs): + return await fn(*args, **kwargs) + + instance.execute_with_retry = AsyncMock(side_effect=_passthrough) + instance.concurrency_slot.side_effect = _slot + yield instance + + +@pytest.mark.asyncio +async def test_native_non_200_logs_exclude_body_text_by_default( + caplog, provider_config +): + provider = NativeProvider(provider_config) + req = MockRequest() + response = FakeResponse(status_code=500, text="SECRET_UPSTREAM_BODY") + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_UPSTREAM_BODY" not in messages + assert "HTTP 500" in messages + assert "body_preview_bytes=" not in messages + + +@pytest.mark.asyncio +async def test_native_non_200_logs_body_when_verbose(caplog, provider_config): + provider_config.log_api_error_tracebacks = True + provider = NativeProvider(provider_config) + req = MockRequest() + response = FakeResponse(status_code=500, text="SECRET_UPSTREAM_BODY") + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_UPSTREAM_BODY" in messages + assert "truncated=False" in messages + + +@pytest.mark.asyncio +async def test_native_non_200_verbose_logs_only_capped_error_body( + caplog, provider_config +): + provider_config.log_api_error_tracebacks = True + provider = NativeProvider(provider_config) + req = MockRequest() + tail = "SECRET_TAIL_NOT_LOGGED" + huge = f"{'A' * (NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES + 50)}{tail}" + response = FakeResponse(status_code=500, text=huge) + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_TAIL_NOT_LOGGED" not in messages + assert "truncated=True" in messages + assert f"body_preview_bytes={NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES}" in messages + + +@pytest.mark.asyncio +async def test_native_non_200_default_does_not_read_oversized_body( + caplog, provider_config +): + provider = NativeProvider(provider_config) + req = MockRequest() + huge = f"{'Z' * 500_000}LEAK_MARKER" + response = FakeResponse(status_code=500, text=huge) + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "LEAK_MARKER" not in messages + assert "ZZZ" not in messages + assert "HTTP 500" in messages + + +@pytest.mark.asyncio +async def test_native_stream_failure_logs_exclude_exception_str_by_default( + caplog, provider_config +): + provider = NativeProvider(provider_config) + req = MockRequest() + response = FakeResponse( + lines=[ + "event: ping", + 'data: {"type":"ping"}', + "", + ] + ) + + async def boom(_self, _response): + raise RuntimeError("SECRET_DETAIL") + if False: + yield "" + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + patch.object(AnthropicMessagesTransport, "_iter_sse_events", boom), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_DETAIL" not in messages + assert "exc_type=RuntimeError" in messages + assert "http_status=None" in messages + + +@pytest.mark.asyncio +async def test_openai_compat_stream_failure_default_logs_exclude_exception_str(caplog): + config = ProviderConfig( + api_key="k", + base_url="http://localhost:1/v1", + log_api_error_tracebacks=False, + ) + provider = NvidiaNimProvider(config, nim_settings=NimSettings()) + req = make_openai_compat_stream_request() + + @asynccontextmanager + async def _noop_slot(): + yield + + with ( + patch.object( + provider, + "_create_stream", + new_callable=AsyncMock, + side_effect=RuntimeError("SECRET_OPENAI_COMPAT"), + ), + patch.object( + provider._global_rate_limiter, + "concurrency_slot", + _noop_slot, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_OPENAI_COMPAT" not in messages + assert "exc_type=RuntimeError" in messages + + +@pytest.mark.asyncio +async def test_openai_compat_stream_failure_respects_verbose_flag(caplog): + config = ProviderConfig( + api_key="k", + base_url="http://localhost:1/v1", + log_api_error_tracebacks=True, + ) + provider = NvidiaNimProvider(config, nim_settings=NimSettings()) + req = make_openai_compat_stream_request() + + @asynccontextmanager + async def _noop_slot(): + yield + + with ( + patch.object( + provider, + "_create_stream", + new_callable=AsyncMock, + side_effect=RuntimeError("SECRET_OPENAI_COMPAT"), + ), + patch.object( + provider._global_rate_limiter, + "concurrency_slot", + _noop_slot, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_OPENAI_COMPAT" in messages diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index b00d780..ac8cbba 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -39,7 +39,7 @@ def _make_settings(**overrides): mock.provider_max_concurrency = 5 mock.http_read_timeout = 300.0 mock.http_write_timeout = 10.0 - mock.http_connect_timeout = 2.0 + mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True mock.nim = NimSettings() for key, value in overrides.items(): diff --git a/tests/providers/test_sse_builder.py b/tests/providers/test_sse_builder.py index 07df2f3..01426e5 100644 --- a/tests/providers/test_sse_builder.py +++ b/tests/providers/test_sse_builder.py @@ -1,19 +1,20 @@ """Tests for core.anthropic.sse.""" -import json from unittest.mock import patch import pytest from core.anthropic import ContentBlockManager, SSEBuilder, map_stop_reason +from core.anthropic.sse import ToolCallState +from core.anthropic.stream_contracts import parse_sse_text def _parse_sse(sse_str: str) -> dict: """Parse an SSE event string into its data payload.""" - for line in sse_str.strip().split("\n"): - if line.startswith("data: "): - return json.loads(line[len("data: ") :]) - raise ValueError(f"No data line found in SSE: {sse_str}") + events = parse_sse_text(sse_str) + if len(events) != 1: + raise ValueError(f"expected 1 SSE event, got {len(events)} in {sse_str!r}") + return events[0].data class TestMapStopReason: @@ -61,6 +62,24 @@ class TestContentBlockManager: assert mgr.text_started is False assert mgr.tool_states == {} + def test_flush_task_arg_buffers_logs_digest_not_secret(self, caplog): + """Invalid Task JSON warnings must not echo argument prefixes (secrets).""" + mgr = ContentBlockManager() + mgr.tool_states[0] = ToolCallState( + block_index=0, tool_id="call_x", name="Task", started=True + ) + mgr.tool_states[ + 0 + ].task_arg_buffer = ( + '{"api_key": "sk-live-super-secret-do-not-log"}not_valid_json' + ) + with caplog.at_level("WARNING"): + out = mgr.flush_task_arg_buffers() + assert out == [(0, "{}")] + text = " | ".join(r.message for r in caplog.records) + assert "sk-live-super-secret" not in text + assert "buffer_sha256_prefix=" in text + class TestSSEBuilderMessageLifecycle: """Tests for message_start, message_delta, message_stop.""" diff --git a/tests/providers/test_streaming_errors.py b/tests/providers/test_streaming_errors.py index 0b5f614..085212b 100644 --- a/tests/providers/test_streaming_errors.py +++ b/tests/providers/test_streaming_errors.py @@ -1,14 +1,20 @@ """Tests for streaming error handling in providers/nvidia_nim/client.py.""" import json +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from config.nim import NimSettings +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, +) from providers.base import ProviderConfig from providers.nvidia_nim import NvidiaNimProvider +from tests.provider_request_mocks import make_openai_compat_stream_request class AsyncStreamMock: @@ -53,22 +59,7 @@ def _make_provider_with_thinking_enabled(enabled: bool): def _make_request(model="test-model", stream=True): """Create a mock request with all fields build_request_body needs.""" - req = MagicMock() - req.model = model - req.stream = stream - req.messages = [] - req.system = None - req.tools = None - req.tool_choice = None - req.metadata = None - req.max_tokens = 4096 - req.temperature = None - req.top_p = None - req.top_k = None - req.stop_sequences = None - req.extra_body = None - req.thinking = None - return req + return make_openai_compat_stream_request(model=model, stream=stream) def _make_chunk( @@ -95,6 +86,29 @@ async def _collect_stream(provider, request): return [e async for e in provider.stream_response(request)] +def _assert_no_content_deltas_after_error_text( + events: list[str], error_substr: str +) -> None: + """After the error text delta, only block close + message tail events may follow.""" + parsed = parse_sse_text("".join(events)) + first_error_idx = None + for i, ev in enumerate(parsed): + if ev.event != "content_block_delta": + continue + delta = ev.data.get("delta", {}) + if delta.get("type") == "text_delta" and error_substr in str( + delta.get("text", "") + ): + first_error_idx = i + break + assert first_error_idx is not None, (error_substr, "".join(events)) + for ev in parsed[first_error_idx + 1 :]: + assert ev.event in ("content_block_stop", "message_delta", "message_stop"), ( + ev.event, + ev.data, + ) + + class TestStreamingExceptionHandling: """Tests for error paths during stream_response.""" @@ -128,6 +142,7 @@ class TestStreamingExceptionHandling: assert "message_start" in event_text assert "API failed" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "API failed") @pytest.mark.asyncio async def test_read_timeout_with_empty_message_emits_fallback(self): @@ -161,6 +176,7 @@ class TestStreamingExceptionHandling: assert "timed out after" in event_text assert "request_id=req_timeout123" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "timed out after") @pytest.mark.asyncio async def test_error_after_partial_content(self): @@ -191,6 +207,7 @@ class TestStreamingExceptionHandling: assert "Hello" in event_text assert "Connection lost" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "Connection lost") @pytest.mark.asyncio async def test_empty_response_gets_space(self): @@ -353,6 +370,10 @@ class TestStreamingExceptionHandling: in event_text ) assert "request_id=REQ405" in event_text + _assert_no_content_deltas_after_error_text( + events, + "Upstream provider NIM rejected the request method or endpoint (HTTP 405).", + ) @pytest.mark.asyncio async def test_stream_rate_limited_retries_via_execute_with_retry(self): @@ -406,6 +427,59 @@ class TestProcessToolCall: assert "search" in event_text assert "call_123" in event_text + def test_tool_call_id_arrives_before_name_still_emits_id_and_name(self): + """Split-stream tool: id (no name) then name then args; id preserved on start.""" + provider = _make_provider() + from core.anthropic import SSEBuilder + + sse = SSEBuilder("msg_test", "test-model") + t1 = { + "index": 0, + "id": "call_split", + "function": {"name": None, "arguments": ""}, + } + t2 = { + "index": 0, + "id": "call_split", + "function": {"name": "Grep", "arguments": ""}, + } + t3 = { + "index": 0, + "id": "call_split", + "function": {"name": None, "arguments": "{}"}, + } + b1 = "".join(provider._process_tool_call(t1, sse)) + b2 = "".join(provider._process_tool_call(t2, sse)) + b3 = "".join(provider._process_tool_call(t3, sse)) + combined = b1 + b2 + b3 + assert "call_split" in combined + assert "Grep" in combined + assert b1 == "" + + def test_tool_call_arguments_buffered_until_name(self): + """Argument deltas before tool name are emitted after the block starts.""" + provider = _make_provider() + from core.anthropic import SSEBuilder + + sse = SSEBuilder("msg_test", "test-model") + t1 = { + "index": 0, + "id": "call_buf", + "function": {"name": None, "arguments": '{"x":'}, + } + t2 = { + "index": 0, + "id": "call_buf", + "function": {"name": "Read", "arguments": "1}"}, + } + b1 = "".join(provider._process_tool_call(t1, sse)) + b2 = "".join(provider._process_tool_call(t2, sse)) + assert b1 == "" + combined = b2 + assert "Read" in combined + assert "call_buf" in combined + assert '{"x":' in combined or "partial_json" in combined + def test_tool_call_without_id_generates_uuid(self): """Tool call without id generates a uuid-based id.""" provider = _make_provider() @@ -617,6 +691,7 @@ class TestStreamChunkEdgeCases: assert "Partial" in event_text assert "Connection reset" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "Connection reset") def test_stream_malformed_tool_args_chunked(self): """Chunked tool args that never form valid JSON are flushed with {}.""" @@ -642,3 +717,36 @@ class TestStreamChunkEdgeCases: event_text = "".join(events1 + events2 + flushed) assert "tool_use" in event_text assert "{}" in event_text + + +@pytest.mark.asyncio +async def test_openai_compat_stream_ends_with_contract_when_tool_name_never_arrives() -> ( + None +): + """Nameless / incomplete tool-call buffer must not break Anthropic stream contract.""" + provider = _make_provider() + request = _make_request() + tc0 = SimpleNamespace( + index=0, + id="call_inc", + function=SimpleNamespace(name=None, arguments="{}"), + ) + stream_mock = AsyncStreamMock([_make_chunk(tool_calls=[tc0])]) + with ( + patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=stream_mock, + ), + patch.object( + provider._global_rate_limiter, + "wait_if_blocked", + new_callable=AsyncMock, + return_value=False, + ), + ): + events = await _collect_stream(provider, request) + text = "".join(events) + assert_anthropic_stream_contract(parse_sse_text(text)) + assert "text_delta" in text diff --git a/tests/stream_contract.py b/tests/stream_contract.py new file mode 100644 index 0000000..9fdc658 --- /dev/null +++ b/tests/stream_contract.py @@ -0,0 +1,18 @@ +"""Shared assertions for canonical provider streaming error envelopes.""" + +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, + text_content, +) + + +def assert_canonical_stream_error_envelope( + events: list[str], *, user_message_substr: str +) -> None: + """Native transports emit message_start → text error → message_stop.""" + blob = "".join(events) + assert "event: error\ndata:" not in blob + parsed = parse_sse_text(blob) + assert_anthropic_stream_contract(parsed, allow_error=False) + assert user_message_substr in text_content(parsed)