diff --git a/.env.example b/.env.example index 42d51a3..e9e1990 100644 --- a/.env.example +++ b/.env.example @@ -47,8 +47,8 @@ OPENROUTER_PROXY="" LMSTUDIO_PROXY="" LLAMACPP_PROXY="" -PROVIDER_RATE_LIMIT=40 -PROVIDER_RATE_WINDOW=60 +PROVIDER_RATE_LIMIT=1 +PROVIDER_RATE_WINDOW=3 PROVIDER_MAX_CONCURRENCY=5 @@ -102,3 +102,25 @@ ENABLE_NETWORK_PROBE_MOCK=true ENABLE_TITLE_GENERATION_SKIP=true ENABLE_SUGGESTION_MODE_SKIP=true ENABLE_FILEPATH_EXTRACTION_MOCK=true + + +# Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default) +ENABLE_WEB_SERVER_TOOLS=true +WEB_FETCH_ALLOWED_SCHEMES=http,https +WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false + + +# Verbose diagnostics (avoid logging raw prompts / SSE bodies in production) +DEBUG_PLATFORM_EDITS=false +DEBUG_SUBAGENT_STACK=false +# When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging). +LOG_RAW_API_PAYLOADS=false +LOG_RAW_SSE_EVENTS=false +# When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data). +LOG_API_ERROR_TRACEBACKS=false +# When true, log message/transcription text previews in messaging adapters (may leak user content). +LOG_RAW_MESSAGING_CONTENT=false +# When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text. +LOG_RAW_CLI_DIAGNOSTICS=false +# When true, log full exception and CLI error message strings in messaging (may leak user content). +LOG_MESSAGING_ERROR_DETAILS=false diff --git a/AGENTS.md b/AGENTS.md index f4335d2..265b3a8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,13 +32,14 @@ - **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code. - **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue. - **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required. +- **Maximum Test Coverage**: There should be maximum test coverage for everything, preferably live smoke test coverage to catch bugs early ## COGNITIVE WORKFLOW 1. **ANALYZE**: Read relevant files. Do not guess. 2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency. 3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits. -4. **VERIFY**: Run ci checks. Confirm the fix via logs or output. +4. **VERIFY**: Run ci checks and relevant smoke tests. Confirm the fix via logs or output. 5. **SPECIFICITY**: Do exactly as much as asked; nothing more, nothing less. 6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly. diff --git a/PLAN.md b/PLAN.md index 9f4607b..f5011f6 100644 --- a/PLAN.md +++ b/PLAN.md @@ -56,13 +56,20 @@ with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging **Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`, and `providers.registry` from the providers package (not per-adapter modules). `core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`. -`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract -assertions for default CI live under `core/anthropic/stream_contracts.py`; -`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers -(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and -unit tests; production HTTP handlers must use `resolve_provider` with +`messaging/` does not import `api`, `cli`, or `smoke`, and may import `providers` +only via `providers.nvidia_nim.voice` (NVIDIA/Riva offline ASR). Stream contract +helpers live in `core/anthropic/stream_contracts.py`; live smoke imports that +module directly (no dedicated smoke SSE shim). NVIDIA NIM chat tuning uses the +canonical `config.nim.NimSettings` model on `Settings`; `providers.registry` +passes `settings.nim` into `NvidiaNimProvider` without a duplicate schema. +Default upstream base URLs use a single constant per endpoint in +`providers/defaults.py` (e.g. `NVIDIA_NIM_DEFAULT_BASE`). Process-cached provider +helpers (`api.dependencies.get_provider` / `get_provider_for_type`) exist for +scripts and unit tests; production HTTP handlers must use `resolve_provider` with `request.app` so the app-scoped `ProviderRegistry` is used. The `api` package -`__all__` exposes HTTP models and `create_app` only (not those helpers). +`__all__` exposes HTTP models and `create_app` only (not `app`, not those helpers). +`api.app:create_app` is the ASGI factory (e.g. `uvicorn api.app:create_app --factory`); +`server.py` still exposes `server:app` as a module-level instance for convenience. ## Target Boundaries diff --git a/api/__init__.py b/api/__init__.py index 03094ba..5f6a6e4 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,6 +1,6 @@ """API layer for Claude Code Proxy.""" -from .app import app, create_app +from .app import create_app from .models import ( MessagesRequest, MessagesResponse, @@ -13,6 +13,5 @@ __all__ = [ "MessagesResponse", "TokenCountRequest", "TokenCountResponse", - "app", "create_app", ] diff --git a/api/app.py b/api/app.py index b47deb3..5b70133 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,6 @@ """FastAPI application factory and configuration.""" -import os +import traceback from contextlib import asynccontextmanager from typing import Any @@ -16,13 +16,7 @@ from providers.exceptions import ProviderError from .routes import router from .runtime import AppRuntime - -# Opt-in to future behavior for python-telegram-bot -os.environ["PTB_TIMEDELTA"] = "1" - -# Configure logging first (before any module logs) -_settings = get_settings() -configure_logging(_settings.log_file) +from .validation_log import summarize_request_validation_body @asynccontextmanager @@ -38,6 +32,11 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: """Create and configure the FastAPI application.""" + settings = get_settings() + configure_logging( + settings.log_file, verbose_third_party=settings.log_raw_api_payloads + ) + app = FastAPI( title="Claude Code Proxy", version="2.0.0", @@ -57,33 +56,7 @@ def create_app() -> FastAPI: except Exception as e: body = {"_json_error": type(e).__name__} - messages = body.get("messages") if isinstance(body, dict) else None - message_summary: list[dict[str, Any]] = [] - if isinstance(messages, list): - for msg in messages: - if not isinstance(msg, dict): - message_summary.append({"message_kind": type(msg).__name__}) - continue - content = msg.get("content") - item: dict[str, Any] = { - "role": msg.get("role"), - "content_kind": type(content).__name__, - } - if isinstance(content, list): - item["block_types"] = [ - block.get("type", "dict") - if isinstance(block, dict) - else type(block).__name__ - for block in content[:12] - ] - item["block_keys"] = [ - sorted(str(key) for key in block)[:12] - for block in content[:5] - if isinstance(block, dict) - ] - elif isinstance(content, str): - item["content_length"] = len(content) - message_summary.append(item) + message_summary, tool_names = summarize_request_validation_body(body) logger.debug( "Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}", @@ -92,20 +65,27 @@ def create_app() -> FastAPI: [list(error.get("loc", ())) for error in exc.errors()], [str(error.get("type", "")) for error in exc.errors()], message_summary, - [ - str(tool.get("name", "")) - for tool in body.get("tools", []) - if isinstance(body, dict) - and isinstance(body.get("tools"), list) - and isinstance(tool, dict) - ], + tool_names, ) return await request_validation_exception_handler(request, exc) @app.exception_handler(ProviderError) async def provider_error_handler(request: Request, exc: ProviderError): """Handle provider-specific errors and return Anthropic format.""" - logger.error(f"Provider Error: {exc.error_type} - {exc.message}") + err_settings = get_settings() + if err_settings.log_api_error_tracebacks: + logger.error( + "Provider Error: error_type={} status_code={} message={}", + exc.error_type, + exc.status_code, + exc.message, + ) + else: + logger.error( + "Provider Error: error_type={} status_code={}", + exc.error_type, + exc.status_code, + ) return JSONResponse( status_code=exc.status_code, content=exc.to_anthropic_format(), @@ -114,10 +94,17 @@ def create_app() -> FastAPI: @app.exception_handler(Exception) async def general_error_handler(request: Request, exc: Exception): """Handle general errors and return Anthropic format.""" - logger.error(f"General Error: {exc!s}") - import traceback - - logger.error(traceback.format_exc()) + settings = get_settings() + if settings.log_api_error_tracebacks: + logger.error("General Error: {}", exc) + logger.error(traceback.format_exc()) + else: + logger.error( + "General Error: path={} method={} exc_type={}", + request.url.path, + request.method, + type(exc).__name__, + ) return JSONResponse( status_code=500, content={ @@ -130,7 +117,3 @@ def create_app() -> FastAPI: ) return app - - -# Default app instance for uvicorn -app = create_app() diff --git a/api/command_utils.py b/api/command_utils.py index ea5251b..76963cd 100644 --- a/api/command_utils.py +++ b/api/command_utils.py @@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str: return "\n" - except Exception: + except ValueError: return "\n" diff --git a/api/dependencies.py b/api/dependencies.py index 3527eb8..c3e5743 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -8,7 +8,11 @@ from config.settings import Settings from config.settings import get_settings as _get_settings from core.anthropic import get_user_facing_error_message from providers.base import BaseProvider -from providers.exceptions import AuthenticationError, UnknownProviderTypeError +from providers.exceptions import ( + AuthenticationError, + ServiceUnavailableError, + UnknownProviderTypeError, +) from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry # Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider` @@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {} def get_settings() -> Settings: - """Get application settings via dependency injection.""" + """Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias).""" return _get_settings() @@ -31,10 +35,9 @@ def resolve_provider( """Resolve a provider using the app-scoped registry when ``app`` is set. When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry` - is always used. If the registry is missing (e.g. a test app without - :class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry` - is installed on ``app.state`` so the process cache is never mixed with - per-request app identity. + must exist (installed by :class:`~api.runtime.AppRuntime` during startup). + Callers that construct a bare ``FastAPI`` without lifespan must set + ``app.state.provider_registry`` explicitly. When ``app`` is ``None`` (no HTTP context), uses the process-level :data:`_providers` cache only. @@ -42,8 +45,10 @@ def resolve_provider( if app is not None: reg = getattr(app.state, "provider_registry", None) if reg is None: - reg = ProviderRegistry() - app.state.provider_registry = reg + raise ServiceUnavailableError( + "Provider registry is not configured. Ensure AppRuntime startup ran " + "or assign app.state.provider_registry for test apps." + ) return _resolve_with_registry(reg, provider_type, settings) return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings) @@ -55,9 +60,10 @@ def _resolve_with_registry( try: provider = registry.get(provider_type, settings) except AuthenticationError as e: - raise HTTPException( - status_code=503, detail=get_user_facing_error_message(e) - ) from e + # Provider :class:`~providers.exceptions.AuthenticationError` messages are + # curated configuration hints (env var names, docs links), not upstream noise. + detail = str(e).strip() or get_user_facing_error_message(e) + raise HTTPException(status_code=503, detail=detail) from e except UnknownProviderTypeError: logger.error( "Unknown provider_type: '{}'. Supported: {}", @@ -73,8 +79,9 @@ def _resolve_with_registry( def get_provider_for_type(provider_type: str) -> BaseProvider: """Get or create a provider in the process-level cache (no ``app``/Request). - For server requests, use :func:`resolve_provider` with the active - :attr:`request.app` so the app-scoped provider registry is used. + HTTP route handlers should call :func:`resolve_provider` with the active + :attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this + process-wide cache. """ return resolve_provider(provider_type, app=None, settings=get_settings()) diff --git a/api/detection.py b/api/detection.py index a299151..7977c1b 100644 --- a/api/detection.py +++ b/api/detection.py @@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st try: cmd_start = content.rfind("Command:") + len("Command:") return True, content[cmd_start:].strip() - except Exception: - pass + except TypeError: + return False, "" return False, "" @@ -121,19 +121,16 @@ def is_filepath_extraction_request( if not user_has_filepaths and not system_has_extract: return False, "", "" - try: - cmd_start = content.find("Command:") + len("Command:") - output_marker = content.find("Output:", cmd_start) - if output_marker == -1: - return False, "", "" - - command = content[cmd_start:output_marker].strip() - output = content[output_marker + len("Output:") :].strip() - - for marker in ["<", "\n\n"]: - if marker in output: - output = output.split(marker)[0].strip() - - return True, command, output - except Exception: + cmd_start = content.find("Command:") + len("Command:") + output_marker = content.find("Output:", cmd_start) + if output_marker == -1: return False, "", "" + + command = content[cmd_start:output_marker].strip() + output = content[output_marker + len("Output:") :].strip() + + for marker in ["<", "\n\n"]: + if marker in output: + output = output.split(marker)[0].strip() + + return True, command, output diff --git a/api/models/anthropic.py b/api/models/anthropic.py index 12b6533..5bda060 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -3,7 +3,7 @@ from enum import StrEnum from typing import Any, Literal -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field # ============================================================================= @@ -15,41 +15,68 @@ class Role(StrEnum): system = "system" -class ContentBlockText(BaseModel): +class _AnthropicBlockBase(BaseModel): + """Pass through provider fields (e.g. ``cache_control``) for native transports.""" + + model_config = ConfigDict(extra="allow") + + +class ContentBlockText(_AnthropicBlockBase): type: Literal["text"] text: str -class ContentBlockImage(BaseModel): +class ContentBlockImage(_AnthropicBlockBase): type: Literal["image"] source: dict[str, Any] -class ContentBlockToolUse(BaseModel): +class ContentBlockToolUse(_AnthropicBlockBase): type: Literal["tool_use"] id: str name: str input: dict[str, Any] -class ContentBlockToolResult(BaseModel): +class ContentBlockToolResult(_AnthropicBlockBase): type: Literal["tool_result"] tool_use_id: str content: str | list[Any] | dict[str, Any] -class ContentBlockThinking(BaseModel): +class ContentBlockThinking(_AnthropicBlockBase): type: Literal["thinking"] thinking: str signature: str | None = None -class ContentBlockRedactedThinking(BaseModel): +class ContentBlockRedactedThinking(_AnthropicBlockBase): type: Literal["redacted_thinking"] data: str -class SystemContent(BaseModel): +class ContentBlockServerToolUse(_AnthropicBlockBase): + """Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``).""" + + type: Literal["server_tool_use"] + id: str + name: str + input: dict[str, Any] + + +class ContentBlockWebSearchToolResult(_AnthropicBlockBase): + type: Literal["web_search_tool_result"] + tool_use_id: str + content: Any + + +class ContentBlockWebFetchToolResult(_AnthropicBlockBase): + type: Literal["web_fetch_tool_result"] + tool_use_id: str + content: Any + + +class SystemContent(_AnthropicBlockBase): type: Literal["text"] text: str @@ -68,12 +95,15 @@ class Message(BaseModel): | ContentBlockToolResult | ContentBlockThinking | ContentBlockRedactedThinking + | ContentBlockServerToolUse + | ContentBlockWebSearchToolResult + | ContentBlockWebFetchToolResult ] ) reasoning_content: str | None = None -class Tool(BaseModel): +class Tool(_AnthropicBlockBase): name: str # Anthropic server tools (e.g. web_search beta tools) include a ``type`` and # may omit ``input_schema`` because the provider owns the schema. @@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel): # Request Models # ============================================================================= class MessagesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + model: str + # Internal routing / debug: accepted on parse but not serialized to providers. + original_model: str | None = Field(default=None, exclude=True) + resolved_provider_model: str | None = Field(default=None, exclude=True) max_tokens: int | None = None messages: list[Message] system: str | list[SystemContent] | None = None @@ -105,13 +140,24 @@ class MessagesRequest(BaseModel): tools: list[Tool] | None = None tool_choice: dict[str, Any] | None = None thinking: ThinkingConfig | None = None + # Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion. + context_management: dict[str, Any] | None = None + output_config: dict[str, Any] | None = None + mcp_servers: list[dict[str, Any]] | None = None extra_body: dict[str, Any] | None = None class TokenCountRequest(BaseModel): + model_config = ConfigDict(extra="allow") + model: str + original_model: str | None = Field(default=None, exclude=True) + resolved_provider_model: str | None = Field(default=None, exclude=True) messages: list[Message] system: str | list[SystemContent] | None = None tools: list[Tool] | None = None thinking: ThinkingConfig | None = None tool_choice: dict[str, Any] | None = None + context_management: dict[str, Any] | None = None + output_config: dict[str, Any] | None = None + mcp_servers: list[dict[str, Any]] | None = None diff --git a/api/runtime.py b/api/runtime.py index 56cfe07..56df000 100644 --- a/api/runtime.py +++ b/api/runtime.py @@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0 async def best_effort( - name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S + name: str, + awaitable: Any, + timeout_s: float = _SHUTDOWN_TIMEOUT_S, + *, + log_verbose_errors: bool = False, ) -> None: """Run a shutdown step with timeout; never raise to callers.""" try: await asyncio.wait_for(awaitable, timeout=timeout_s) except TimeoutError: - logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)") + logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s) except Exception as e: - logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}") + if log_verbose_errors: + logger.warning( + "Shutdown step failed: {}: {}: {}", + name, + type(e).__name__, + e, + ) + else: + logger.warning( + "Shutdown step failed: {}: exc_type={}", + name, + type(e).__name__, + ) def warn_if_process_auth_token(settings: Settings) -> None: @@ -73,20 +89,37 @@ class AppRuntime: self._publish_state() async def shutdown(self) -> None: + verbose = self.settings.log_api_error_tracebacks if self.message_handler is not None: try: self.message_handler.session_store.flush_pending_save() except Exception as e: - logger.warning(f"Session store flush on shutdown: {e}") + if verbose: + logger.warning("Session store flush on shutdown: {}", e) + else: + logger.warning( + "Session store flush on shutdown: exc_type={}", + type(e).__name__, + ) logger.info("Shutdown requested, cleaning up...") if self.messaging_platform: - await best_effort("messaging_platform.stop", self.messaging_platform.stop()) + await best_effort( + "messaging_platform.stop", + self.messaging_platform.stop(), + log_verbose_errors=verbose, + ) if self.cli_manager: - await best_effort("cli_manager.stop_all", self.cli_manager.stop_all()) + await best_effort( + "cli_manager.stop_all", + self.cli_manager.stop_all(), + log_verbose_errors=verbose, + ) if self._provider_registry is not None: await best_effort( - "provider_registry.cleanup", self._provider_registry.cleanup() + "provider_registry.cleanup", + self._provider_registry.cleanup(), + log_verbose_errors=verbose, ) await self._shutdown_limiter() logger.info("Server shut down cleanly") @@ -110,6 +143,10 @@ class AppRuntime: whisper_device=self.settings.whisper_device, hf_token=self.settings.hf_token, nvidia_nim_api_key=self.settings.nvidia_nim_api_key, + messaging_rate_limit=self.settings.messaging_rate_limit, + messaging_rate_window=self.settings.messaging_rate_window, + log_raw_messaging_content=self.settings.log_raw_messaging_content, + log_api_error_tracebacks=self.settings.log_api_error_tracebacks, ), ) @@ -117,12 +154,24 @@ class AppRuntime: await self._start_message_handler() except ImportError as e: - logger.warning(f"Messaging module import error: {e}") + if self.settings.log_api_error_tracebacks: + logger.warning("Messaging module import error: {}", e) + else: + logger.warning( + "Messaging module import error: exc_type={}", + type(e).__name__, + ) except Exception as e: - logger.error(f"Failed to start messaging platform: {e}") - import traceback + if self.settings.log_api_error_tracebacks: + logger.error("Failed to start messaging platform: {}", e) + import traceback - logger.error(traceback.format_exc()) + logger.error(traceback.format_exc()) + else: + logger.error( + "Failed to start messaging platform: exc_type={}", + type(e).__name__, + ) async def _start_message_handler(self) -> None: from cli.manager import CLISessionManager @@ -151,10 +200,13 @@ class AppRuntime: allowed_dirs=allowed_dirs, plans_directory=plans_directory, claude_bin=self.settings.claude_cli_bin, + log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics, + log_messaging_error_details=self.settings.log_messaging_error_details, ) session_store = SessionStore( - storage_path=os.path.join(data_path, "sessions.json") + storage_path=os.path.join(data_path, "sessions.json"), + message_log_cap=self.settings.max_message_log_entries_per_chat, ) platform = self.messaging_platform assert platform is not None @@ -162,6 +214,11 @@ class AppRuntime: platform=platform, cli_manager=self.cli_manager, session_store=session_store, + debug_platform_edits=self.settings.debug_platform_edits, + debug_subagent_stack=self.settings.debug_subagent_stack, + log_raw_messaging_content=self.settings.log_raw_messaging_content, + log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics, + log_messaging_error_details=self.settings.log_messaging_error_details, ) self._restore_tree_state(session_store) @@ -201,18 +258,26 @@ class AppRuntime: self.app.state.cli_manager = self.cli_manager async def _shutdown_limiter(self) -> None: + verbose = self.settings.log_api_error_tracebacks try: from messaging.limiter import MessagingRateLimiter except Exception as e: - logger.debug( - "Rate limiter shutdown skipped (import failed): {}: {}", - type(e).__name__, - e, - ) + if verbose: + logger.debug( + "Rate limiter shutdown skipped (import failed): {}: {}", + type(e).__name__, + e, + ) + else: + logger.debug( + "Rate limiter shutdown skipped (import failed): exc_type={}", + type(e).__name__, + ) return await best_effort( "MessagingRateLimiter.shutdown_instance", MessagingRateLimiter.shutdown_instance(), timeout_s=2.0, + log_verbose_errors=verbose, ) diff --git a/api/services.py b/api/services.py index d0bbcaa..219287c 100644 --- a/api/services.py +++ b/api/services.py @@ -4,7 +4,7 @@ from __future__ import annotations import traceback import uuid -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from typing import Any from fastapi import HTTPException @@ -13,6 +13,7 @@ from loguru import logger from config.settings import Settings from core.anthropic import get_token_count, get_user_facing_error_message +from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS from providers.base import BaseProvider from providers.exceptions import InvalidRequestError, ProviderError @@ -20,15 +21,67 @@ from .model_router import ModelRouter from .models.anthropic import MessagesRequest, TokenCountRequest from .models.responses import TokenCountResponse from .optimization_handlers import try_optimizations -from .web_server_tools import ( +from .web_tools.egress import WebFetchEgressPolicy +from .web_tools.request import ( is_web_server_tool_request, - stream_web_server_tool_response, + openai_chat_upstream_server_tool_error, ) +from .web_tools.streaming import stream_web_server_tool_response TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int] ProviderGetter = Callable[[str], BaseProvider] +# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). +_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"}) + + +def anthropic_sse_streaming_response( + body: AsyncIterator[str], +) -> StreamingResponse: + """Return a :class:`StreamingResponse` for Anthropic-style SSE streams.""" + return StreamingResponse( + body, + media_type="text/event-stream", + headers=ANTHROPIC_SSE_RESPONSE_HEADERS, + ) + + +def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int: + """HTTP status for uncaught non-provider failures (stable client contract).""" + return 500 + + +def _log_unexpected_service_exception( + settings: Settings, + exc: BaseException, + *, + context: str, + request_id: str | None = None, +) -> None: + """Log service-layer failures without echoing exception text unless opted in.""" + if settings.log_api_error_tracebacks: + if request_id is not None: + logger.error("{} request_id={}: {}", context, request_id, exc) + else: + logger.error("{}: {}", context, exc) + logger.error(traceback.format_exc()) + return + if request_id is not None: + logger.error( + "{} request_id={} exc_type={}", + context, + request_id, + type(exc).__name__, + ) + else: + logger.error("{} exc_type={}", context, type(exc).__name__) + + +def _require_non_empty_messages(messages: list[Any]) -> None: + if not messages: + raise InvalidRequestError("messages cannot be empty") + class ClaudeProxyService: """Coordinate request optimization, model routing, token count, and providers.""" @@ -48,25 +101,35 @@ class ClaudeProxyService: def create_message(self, request_data: MessagesRequest) -> object: """Create a message response or streaming response.""" try: - if not request_data.messages: - raise InvalidRequestError("messages cannot be empty") + _require_non_empty_messages(request_data.messages) routed = self._model_router.resolve_messages_request(request_data) - if is_web_server_tool_request(routed.request): + if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS: + tool_err = openai_chat_upstream_server_tool_error( + routed.request, + web_tools_enabled=self._settings.enable_web_server_tools, + ) + if tool_err is not None: + raise InvalidRequestError(tool_err) + + if self._settings.enable_web_server_tools and is_web_server_tool_request( + routed.request + ): input_tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools ) logger.info("Optimization: Handling Anthropic web server tool") - return StreamingResponse( + egress = WebFetchEgressPolicy( + allow_private_network_targets=self._settings.web_fetch_allow_private_networks, + allowed_schemes=self._settings.web_fetch_allowed_scheme_set(), + ) + return anthropic_sse_streaming_response( stream_web_server_tool_response( - routed.request, input_tokens=input_tokens + routed.request, + input_tokens=input_tokens, + web_fetch_egress=egress, + verbose_client_errors=self._settings.log_api_error_tracebacks, ), - media_type="text/event-stream", - headers={ - "X-Accel-Buffering": "no", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, ) optimized = try_optimizations(routed.request, self._settings) @@ -75,6 +138,10 @@ class ClaudeProxyService: logger.debug("No optimization matched, routing to provider") provider = self._provider_getter(routed.resolved.provider_id) + provider.preflight_stream( + routed.request, + thinking_enabled=routed.resolved.thinking_enabled, + ) request_id = f"req_{uuid.uuid4().hex[:12]}" logger.info( @@ -83,34 +150,31 @@ class ClaudeProxyService: routed.request.model, len(routed.request.messages), ) - logger.debug( - "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump() - ) + if self._settings.log_raw_api_payloads: + logger.debug( + "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump() + ) input_tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools ) - return StreamingResponse( + return anthropic_sse_streaming_response( provider.stream_response( routed.request, input_tokens=input_tokens, request_id=request_id, thinking_enabled=routed.resolved.thinking_enabled, ), - media_type="text/event-stream", - headers={ - "X-Accel-Buffering": "no", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, ) except ProviderError: raise except Exception as e: - logger.error(f"Error: {e!s}\n{traceback.format_exc()}") + _log_unexpected_service_exception( + self._settings, e, context="CREATE_MESSAGE_ERROR" + ) raise HTTPException( - status_code=getattr(e, "status_code", 500), + status_code=_http_status_for_unexpected_service_exception(e), detail=get_user_facing_error_message(e), ) from e @@ -119,6 +183,7 @@ class ClaudeProxyService: request_id = f"req_{uuid.uuid4().hex[:12]}" with logger.contextualize(request_id=request_id): try: + _require_non_empty_messages(request_data.messages) routed = self._model_router.resolve_token_count_request(request_data) tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools @@ -131,13 +196,16 @@ class ClaudeProxyService: tokens, ) return TokenCountResponse(input_tokens=tokens) + except ProviderError: + raise except Exception as e: - logger.error( - "COUNT_TOKENS_ERROR: request_id={} error={}\n{}", - request_id, - get_user_facing_error_message(e), - traceback.format_exc(), + _log_unexpected_service_exception( + self._settings, + e, + context="COUNT_TOKENS_ERROR", + request_id=request_id, ) raise HTTPException( - status_code=500, detail=get_user_facing_error_message(e) + status_code=_http_status_for_unexpected_service_exception(e), + detail=get_user_facing_error_message(e), ) from e diff --git a/api/validation_log.py b/api/validation_log.py new file mode 100644 index 0000000..9ccdff0 --- /dev/null +++ b/api/validation_log.py @@ -0,0 +1,48 @@ +"""Safe metadata summaries for HTTP 422 validation logging (no raw text content).""" + +from __future__ import annotations + +from typing import Any + + +def summarize_request_validation_body( + body: Any, +) -> tuple[list[dict[str, Any]], list[str]]: + """Return message shape summary and tool name list for debug logs.""" + messages = body.get("messages") if isinstance(body, dict) else None + message_summary: list[dict[str, Any]] = [] + if isinstance(messages, list): + for msg in messages: + if not isinstance(msg, dict): + message_summary.append({"message_kind": type(msg).__name__}) + continue + content = msg.get("content") + item: dict[str, Any] = { + "role": msg.get("role"), + "content_kind": type(content).__name__, + } + if isinstance(content, list): + item["block_types"] = [ + block.get("type", "dict") + if isinstance(block, dict) + else type(block).__name__ + for block in content[:12] + ] + item["block_keys"] = [ + sorted(str(key) for key in block)[:12] + for block in content[:5] + if isinstance(block, dict) + ] + elif isinstance(content, str): + item["content_length"] = len(content) + message_summary.append(item) + + tool_names: list[str] = [] + if isinstance(body, dict) and isinstance(body.get("tools"), list): + tool_names = [ + str(tool.get("name", "")) + for tool in body["tools"] + if isinstance(tool, dict) + ] + + return message_summary, tool_names diff --git a/api/web_server_tools.py b/api/web_server_tools.py index 5daced4..cedaf95 100644 --- a/api/web_server_tools.py +++ b/api/web_server_tools.py @@ -1,331 +1,22 @@ -"""Local handlers for Anthropic web server tools. - -OpenAI-compatible upstreams can emit regular function calls, but Anthropic's -web tools are server-side: the API response itself must include the tool result. -""" +"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch).""" from __future__ import annotations -import html -import json -import re -import uuid -from collections.abc import AsyncIterator -from datetime import UTC, datetime -from html.parser import HTMLParser -from typing import Any -from urllib.parse import parse_qs, unquote, urlparse - import httpx -from .models.anthropic import MessagesRequest +from api.web_tools.egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + enforce_web_fetch_egress, +) +from api.web_tools.request import is_web_server_tool_request +from api.web_tools.streaming import stream_web_server_tool_response -_REQUEST_TIMEOUT_S = 20.0 -_MAX_SEARCH_RESULTS = 10 -_MAX_FETCH_CHARS = 24_000 - - -class _SearchResultParser(HTMLParser): - def __init__(self) -> None: - super().__init__() - self.results: list[dict[str, str]] = [] - self._href: str | None = None - self._title_parts: list[str] = [] - - def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: - if tag != "a": - return - href = dict(attrs).get("href") - if not href or "uddg=" not in href: - return - parsed = urlparse(href) - query = parse_qs(parsed.query) - uddg = query.get("uddg", [""])[0] - if not uddg: - return - self._href = unquote(uddg) - self._title_parts = [] - - def handle_data(self, data: str) -> None: - if self._href is not None: - self._title_parts.append(data) - - def handle_endtag(self, tag: str) -> None: - if tag != "a" or self._href is None: - return - title = " ".join("".join(self._title_parts).split()) - if title and not any(result["url"] == self._href for result in self.results): - self.results.append({"title": html.unescape(title), "url": self._href}) - self._href = None - self._title_parts = [] - - -class _HTMLTextParser(HTMLParser): - def __init__(self) -> None: - super().__init__() - self.title = "" - self.text_parts: list[str] = [] - self._in_title = False - self._skip_depth = 0 - - def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: - if tag in {"script", "style", "noscript"}: - self._skip_depth += 1 - elif tag == "title": - self._in_title = True - - def handle_endtag(self, tag: str) -> None: - if tag in {"script", "style", "noscript"} and self._skip_depth: - self._skip_depth -= 1 - elif tag == "title": - self._in_title = False - - def handle_data(self, data: str) -> None: - text = " ".join(data.split()) - if not text: - return - if self._in_title: - self.title = f"{self.title} {text}".strip() - elif not self._skip_depth: - self.text_parts.append(text) - - -def _format_event(event_type: str, data: dict[str, Any]) -> str: - return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - - -def _content_text(content: Any) -> str: - if isinstance(content, str): - return content - if isinstance(content, list): - parts = [] - for item in content: - if isinstance(item, dict): - parts.append(str(item.get("text", ""))) - else: - parts.append(str(getattr(item, "text", ""))) - return "\n".join(part for part in parts if part) - return str(content) - - -def _request_text(request: MessagesRequest) -> str: - return "\n".join(_content_text(message.content) for message in request.messages) - - -def _web_tool_name(request: MessagesRequest) -> str | None: - for tool in request.tools or []: - name = tool.name - tool_type = tool.type or "" - if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"): - return name - return None - - -def is_web_server_tool_request(request: MessagesRequest) -> bool: - return _web_tool_name(request) in {"web_search", "web_fetch"} - - -def _extract_query(text: str) -> str: - match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL) - if match: - return match.group(1).strip().strip("\"'") - return text.strip() - - -def _extract_url(text: str) -> str: - match = re.search(r"https?://\S+", text) - return match.group(0).rstrip(").,]") if match else text.strip() - - -async def _run_web_search(query: str) -> list[dict[str, str]]: - async with httpx.AsyncClient( - timeout=_REQUEST_TIMEOUT_S, - follow_redirects=True, - headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"}, - ) as client: - response = await client.get( - "https://lite.duckduckgo.com/lite/", - params={"q": query}, - ) - response.raise_for_status() - - parser = _SearchResultParser() - parser.feed(response.text) - return parser.results[:_MAX_SEARCH_RESULTS] - - -async def _run_web_fetch(url: str) -> dict[str, str]: - async with httpx.AsyncClient( - timeout=_REQUEST_TIMEOUT_S, - follow_redirects=True, - headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"}, - ) as client: - response = await client.get(url) - response.raise_for_status() - - content_type = response.headers.get("content-type", "text/plain") - title = url - data = response.text - if "html" in content_type.lower(): - parser = _HTMLTextParser() - parser.feed(response.text) - title = parser.title or url - data = "\n".join(parser.text_parts) - return { - "url": str(response.url), - "title": title, - "media_type": "text/plain", - "data": data[:_MAX_FETCH_CHARS], - } - - -def _search_summary(query: str, results: list[dict[str, str]]) -> str: - if not results: - return f"No web search results found for: {query}" - lines = [f"Search results for: {query}"] - for index, result in enumerate(results, start=1): - lines.append(f"{index}. {result['title']}\n{result['url']}") - return "\n\n".join(lines) - - -async def stream_web_server_tool_response( - request: MessagesRequest, input_tokens: int -) -> AsyncIterator[str]: - tool_name = _web_tool_name(request) - if tool_name is None: - return - - text = _request_text(request) - message_id = f"msg_{uuid.uuid4()}" - tool_id = f"srvtoolu_{uuid.uuid4().hex}" - output_tokens = 1 - usage_key = ( - "web_search_requests" if tool_name == "web_search" else "web_fetch_requests" - ) - tool_input = ( - {"query": _extract_query(text)} - if tool_name == "web_search" - else {"url": _extract_url(text)} - ) - - yield _format_event( - "message_start", - { - "type": "message_start", - "message": { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [], - "model": request.model, - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": input_tokens, "output_tokens": 1}, - }, - }, - ) - yield _format_event( - "content_block_start", - { - "type": "content_block_start", - "index": 0, - "content_block": { - "type": "server_tool_use", - "id": tool_id, - "name": tool_name, - "input": tool_input, - }, - }, - ) - yield _format_event( - "content_block_stop", {"type": "content_block_stop", "index": 0} - ) - - try: - if tool_name == "web_search": - query = str(tool_input["query"]) - results = await _run_web_search(query) - result_content: Any = [ - { - "type": "web_search_result", - "title": result["title"], - "url": result["url"], - } - for result in results - ] - summary = _search_summary(query, results) - result_block_type = "web_search_tool_result" - else: - fetched = await _run_web_fetch(str(tool_input["url"])) - result_content = { - "type": "web_fetch_result", - "url": fetched["url"], - "content": { - "type": "document", - "source": { - "type": "text", - "media_type": fetched["media_type"], - "data": fetched["data"], - }, - "title": fetched["title"], - "citations": {"enabled": True}, - }, - "retrieved_at": datetime.now(UTC).isoformat(), - } - summary = fetched["data"][:_MAX_FETCH_CHARS] - result_block_type = "web_fetch_tool_result" - except Exception as error: - result_block_type = ( - "web_search_tool_result" - if tool_name == "web_search" - else "web_fetch_tool_result" - ) - error_type = ( - "web_search_tool_result_error" - if tool_name == "web_search" - else "web_fetch_tool_error" - ) - result_content = {"type": error_type, "error_code": "unavailable"} - summary = f"{tool_name} failed: {type(error).__name__}" - - output_tokens = max(1, len(summary) // 4) - - yield _format_event( - "content_block_start", - { - "type": "content_block_start", - "index": 1, - "content_block": { - "type": result_block_type, - "tool_use_id": tool_id, - "content": result_content, - }, - }, - ) - yield _format_event( - "content_block_stop", {"type": "content_block_stop", "index": 1} - ) - yield _format_event( - "content_block_start", - { - "type": "content_block_start", - "index": 2, - "content_block": {"type": "text", "text": summary}, - }, - ) - yield _format_event( - "content_block_stop", {"type": "content_block_stop", "index": 2} - ) - yield _format_event( - "message_delta", - { - "type": "message_delta", - "delta": {"stop_reason": "end_turn", "stop_sequence": None}, - "usage": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "server_tool_use": {usage_key: 1}, - }, - }, - ) - yield _format_event("message_stop", {"type": "message_stop"}) +__all__ = [ + "WebFetchEgressPolicy", + "WebFetchEgressViolation", + "enforce_web_fetch_egress", + "httpx", + "is_web_server_tool_request", + "stream_web_server_tool_response", +] diff --git a/api/web_tools/__init__.py b/api/web_tools/__init__.py new file mode 100644 index 0000000..e0fd14c --- /dev/null +++ b/api/web_tools/__init__.py @@ -0,0 +1,17 @@ +"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming).""" + +from .egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + enforce_web_fetch_egress, +) +from .request import is_web_server_tool_request +from .streaming import stream_web_server_tool_response + +__all__ = [ + "WebFetchEgressPolicy", + "WebFetchEgressViolation", + "enforce_web_fetch_egress", + "is_web_server_tool_request", + "stream_web_server_tool_response", +] diff --git a/api/web_tools/constants.py b/api/web_tools/constants.py new file mode 100644 index 0000000..e7b2c01 --- /dev/null +++ b/api/web_tools/constants.py @@ -0,0 +1,15 @@ +"""Limits and defaults for outbound web server tool HTTP.""" + +_REQUEST_TIMEOUT_S = 20.0 +_MAX_SEARCH_RESULTS = 10 +_MAX_FETCH_CHARS = 24_000 +# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound). +_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024 +# Drain at most this many bytes from redirect responses before following Location. +_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536 +_MAX_WEB_FETCH_REDIRECTS = 10 +_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + +_WEB_TOOL_HTTP_HEADERS = { + "User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0", +} diff --git a/api/web_tools/egress.py b/api/web_tools/egress.py new file mode 100644 index 0000000..30b29d5 --- /dev/null +++ b/api/web_tools/egress.py @@ -0,0 +1,99 @@ +"""Egress policy for user-controlled web_fetch URLs (SSRF guard).""" + +from __future__ import annotations + +import ipaddress +import socket +from dataclasses import dataclass +from urllib.parse import urlparse + + +@dataclass(frozen=True, slots=True) +class WebFetchEgressPolicy: + """Egress rules for user-influenced web_fetch URLs.""" + + allow_private_network_targets: bool + allowed_schemes: frozenset[str] + + +class WebFetchEgressViolation(ValueError): + """Raised when a web_fetch URL is rejected by egress policy (SSRF guard).""" + + +def _port_for_url(parsed) -> int: + if parsed.port is not None: + return parsed.port + return 443 if (parsed.scheme or "").lower() == "https" else 80 + + +def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]: + try: + return socket.getaddrinfo( + host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP + ) + except OSError as exc: + raise WebFetchEgressViolation( + f"Could not resolve host {host!r}: {exc}" + ) from exc + + +def get_validated_stream_addrinfos_for_egress( + url: str, policy: WebFetchEgressPolicy +) -> list[tuple]: + """Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning. + + Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS + server cannot rebind to a disallowed address between resolution and the TCP + connect (used by :func:`api.web_tools.outbound._run_web_fetch`). + """ + parsed = urlparse(url) + scheme = (parsed.scheme or "").lower() + if scheme not in policy.allowed_schemes: + raise WebFetchEgressViolation( + f"URL scheme {scheme!r} is not allowed for web_fetch" + ) + + host = parsed.hostname + if host is None or host == "": + raise WebFetchEgressViolation("web_fetch URL must include a host") + + port = _port_for_url(parsed) + + if policy.allow_private_network_targets: + return _stream_getaddrinfo_or_raise(host, port) + + host_lower = host.lower() + if host_lower == "localhost" or host_lower.endswith(".localhost"): + raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch") + if host_lower.endswith(".local"): + raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch") + + try: + parsed_ip = ipaddress.ip_address(host) + except ValueError: + parsed_ip = None + + if parsed_ip is not None: + if not parsed_ip.is_global: + raise WebFetchEgressViolation( + f"Non-public IP host {host!r} is not allowed for web_fetch" + ) + return _stream_getaddrinfo_or_raise(host, port) + + infos = _stream_getaddrinfo_or_raise(host, port) + for *_, sockaddr in infos: + addr = sockaddr[0] + try: + resolved = ipaddress.ip_address(addr) + except ValueError: + continue + if not resolved.is_global: + raise WebFetchEgressViolation( + f"Host {host!r} resolves to a non-public address ({resolved})" + ) + return infos + + +def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None: + """Validate ``url`` (scheme, host, and resolved addresses) for web_fetch.""" + get_validated_stream_addrinfos_for_egress(url, policy) diff --git a/api/web_tools/outbound.py b/api/web_tools/outbound.py new file mode 100644 index 0000000..4f93f9b --- /dev/null +++ b/api/web_tools/outbound.py @@ -0,0 +1,278 @@ +"""Outbound HTTP for web_search / web_fetch (client, body caps, logging).""" + +from __future__ import annotations + +import asyncio +import socket +from collections.abc import AsyncIterator +from urllib.parse import urljoin, urlparse + +import aiohttp +import httpx +from aiohttp import ClientSession, ClientTimeout, TCPConnector +from aiohttp.abc import AbstractResolver, ResolveResult +from loguru import logger + +from . import constants +from .constants import ( + _MAX_FETCH_CHARS, + _MAX_SEARCH_RESULTS, + _REDIRECT_RESPONSE_BODY_CAP_BYTES, + _REQUEST_TIMEOUT_S, + _WEB_FETCH_REDIRECT_STATUSES, + _WEB_TOOL_HTTP_HEADERS, +) +from .egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + get_validated_stream_addrinfos_for_egress, +) +from .parsers import HTMLTextParser, SearchResultParser + + +def _safe_public_host_for_logs(url: str) -> str: + host = urlparse(url).hostname or "" + return host[:253] + + +def _log_web_tool_failure( + tool_name: str, + error: BaseException, + *, + fetch_url: str | None = None, +) -> None: + exc_type = type(error).__name__ + if isinstance(error, WebFetchEgressViolation): + host = _safe_public_host_for_logs(fetch_url) if fetch_url else "" + logger.warning( + "web_tool_egress_rejected tool={} exc_type={} host={!r}", + tool_name, + exc_type, + host, + ) + return + if tool_name == "web_fetch" and fetch_url: + logger.warning( + "web_tool_failure tool={} exc_type={} host={!r}", + tool_name, + exc_type, + _safe_public_host_for_logs(fetch_url), + ) + else: + logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type) + + +def _web_tool_client_error_summary( + tool_name: str, + error: BaseException, + *, + verbose: bool, +) -> str: + if verbose: + return f"{tool_name} failed: {type(error).__name__}" + return "Web tool request failed." + + +async def _iter_response_body_under_cap( + response: httpx.Response, max_bytes: int +) -> AsyncIterator[bytes]: + if max_bytes <= 0: + return + received = 0 + async for chunk in response.aiter_bytes(chunk_size=65_536): + if received >= max_bytes: + break + remaining = max_bytes - received + if len(chunk) <= remaining: + received += len(chunk) + yield chunk + if received >= max_bytes: + break + else: + yield chunk[:remaining] + break + + +async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None: + async for _ in _iter_response_body_under_cap(response, max_bytes): + pass + + +async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes: + return b"".join( + [piece async for piece in _iter_response_body_under_cap(response, max_bytes)] + ) + + +_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV +_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV + + +def getaddrinfo_rows_to_resolve_results( + host: str, addrinfos: list[tuple] +) -> list[ResolveResult]: + """Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic).""" + out: list[ResolveResult] = [] + for family, _type, proto, _canon, sockaddr in addrinfos: + if family == socket.AF_INET6: + if len(sockaddr) < 3: + continue + if sockaddr[3]: + resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS) + else: + resolved_host, port = sockaddr[:2] + else: + assert family == socket.AF_INET, family + resolved_host, port = sockaddr[0], sockaddr[1] + resolved_host = str(resolved_host) + port = int(port) + out.append( + ResolveResult( + hostname=host, + host=resolved_host, + port=int(port), + family=family, + proto=proto, + flags=_NUMERIC_RESOLVE_FLAGS, + ) + ) + return out + + +class _PinnedEgressStaticResolver(AbstractResolver): + """Return only pre-validated :class:`ResolveResult` for the outbound request.""" + + def __init__(self, results: list[ResolveResult]) -> None: + self._results = results + + async def resolve( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> list[ResolveResult]: + return self._results + + async def close(self) -> None: # pragma: no cover - aiohttp contract + return + + +async def _read_aiohttp_body_capped( + response: aiohttp.ClientResponse, max_bytes: int +) -> bytes: + received = 0 + parts: list[bytes] = [] + async for chunk in response.content.iter_chunked(65_536): + if received >= max_bytes: + break + remaining = max_bytes - received + if len(chunk) <= remaining: + received += len(chunk) + parts.append(chunk) + else: + parts.append(chunk[:remaining]) + break + return b"".join(parts) + + +async def _drain_aiohttp_body_capped( + response: aiohttp.ClientResponse, max_bytes: int +) -> None: + if max_bytes <= 0: + return + received = 0 + async for chunk in response.content.iter_chunked(65_536): + received += len(chunk) + if received >= max_bytes: + break + + +async def _run_web_search(query: str) -> list[dict[str, str]]: + async with ( + httpx.AsyncClient( + timeout=_REQUEST_TIMEOUT_S, + follow_redirects=True, + headers=_WEB_TOOL_HTTP_HEADERS, + ) as client, + client.stream( + "GET", + "https://lite.duckduckgo.com/lite/", + params={"q": query}, + ) as response, + ): + response.raise_for_status() + body_bytes = await _read_response_body_capped( + response, constants._MAX_WEB_FETCH_RESPONSE_BYTES + ) + text = body_bytes.decode("utf-8", errors="replace") + parser = SearchResultParser() + parser.feed(text) + return parser.results[:_MAX_SEARCH_RESULTS] + + +async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]: + """Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses.""" + current_url = url + redirect_hops = 0 + timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S) + + while True: + addr_infos = await asyncio.to_thread( + get_validated_stream_addrinfos_for_egress, current_url, egress + ) + host = urlparse(current_url).hostname or "" + results = getaddrinfo_rows_to_resolve_results(host, addr_infos) + resolver = _PinnedEgressStaticResolver(results) + connector = TCPConnector( + resolver=resolver, + force_close=True, + ) + try: + async with ( + ClientSession( + timeout=timeout, + headers=_WEB_TOOL_HTTP_HEADERS, + connector=connector, + ) as session, + session.get(current_url, allow_redirects=False) as response, + ): + if response.status in _WEB_FETCH_REDIRECT_STATUSES: + await _drain_aiohttp_body_capped( + response, _REDIRECT_RESPONSE_BODY_CAP_BYTES + ) + if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS: + raise WebFetchEgressViolation( + "web_fetch exceeded maximum redirects " + f"({constants._MAX_WEB_FETCH_REDIRECTS})" + ) + location = response.headers.get("location") + if not location or not location.strip(): + raise WebFetchEgressViolation( + "web_fetch redirect response missing Location header" + ) + current_url = urljoin(str(response.url), location.strip()) + redirect_hops += 1 + continue + response.raise_for_status() + content_type = response.headers.get("content-type", "text/plain") + final_url = str(response.url) + encoding = response.get_encoding() or "utf-8" + body_bytes = await _read_aiohttp_body_capped( + response, constants._MAX_WEB_FETCH_RESPONSE_BYTES + ) + finally: + await connector.close() + + break + + text = body_bytes.decode(encoding, errors="replace") + title = final_url + data = text + if "html" in content_type.lower(): + parser = HTMLTextParser() + parser.feed(text) + title = parser.title or final_url + data = "\n".join(parser.text_parts) + return { + "url": final_url, + "title": title, + "media_type": "text/plain", + "data": data[:_MAX_FETCH_CHARS], + } diff --git a/api/web_tools/parsers.py b/api/web_tools/parsers.py new file mode 100644 index 0000000..198b412 --- /dev/null +++ b/api/web_tools/parsers.py @@ -0,0 +1,104 @@ +"""HTML parsing for web_search / web_fetch.""" + +from __future__ import annotations + +import html +import re +from html.parser import HTMLParser +from typing import Any +from urllib.parse import parse_qs, unquote, urlparse + + +class SearchResultParser(HTMLParser): + """DuckDuckGo lite HTML: extract result links and titles.""" + + def __init__(self) -> None: + super().__init__() + self.results: list[dict[str, str]] = [] + self._href: str | None = None + self._title_parts: list[str] = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag != "a": + return + href = dict(attrs).get("href") + if not href or "uddg=" not in href: + return + parsed = urlparse(href) + query = parse_qs(parsed.query) + uddg = query.get("uddg", [""])[0] + if not uddg: + return + self._href = unquote(uddg) + self._title_parts = [] + + def handle_data(self, data: str) -> None: + if self._href is not None: + self._title_parts.append(data) + + def handle_endtag(self, tag: str) -> None: + if tag != "a" or self._href is None: + return + title = " ".join("".join(self._title_parts).split()) + if title and not any(result["url"] == self._href for result in self.results): + self.results.append({"title": html.unescape(title), "url": self._href}) + self._href = None + self._title_parts = [] + + +class HTMLTextParser(HTMLParser): + """Strip scripts/styles and collect visible text + title for fetch previews.""" + + def __init__(self) -> None: + super().__init__() + self.title = "" + self.text_parts: list[str] = [] + self._in_title = False + self._skip_depth = 0 + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag in {"script", "style", "noscript"}: + self._skip_depth += 1 + elif tag == "title": + self._in_title = True + + def handle_endtag(self, tag: str) -> None: + if tag in {"script", "style", "noscript"} and self._skip_depth: + self._skip_depth -= 1 + elif tag == "title": + self._in_title = False + + def handle_data(self, data: str) -> None: + text = " ".join(data.split()) + if not text: + return + if self._in_title: + self.title = f"{self.title} {text}".strip() + elif not self._skip_depth: + self.text_parts.append(text) + + +def content_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text", ""))) + else: + parts.append(str(getattr(item, "text", ""))) + return "\n".join(part for part in parts if part) + return str(content) + + +def extract_query(text: str) -> str: + match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip().strip("\"'") + return text.strip() + + +def extract_url(text: str) -> str: + match = re.search(r"https?://\S+", text) + return match.group(0).rstrip(").,]") if match else text.strip() diff --git a/api/web_tools/request.py b/api/web_tools/request.py new file mode 100644 index 0000000..b794685 --- /dev/null +++ b/api/web_tools/request.py @@ -0,0 +1,87 @@ +"""Detect forced Anthropic web server tool requests.""" + +from __future__ import annotations + +from api.models.anthropic import MessagesRequest, Tool + + +def request_text(request: MessagesRequest) -> str: + """Join all user/assistant message content into one string for tool input parsing.""" + from .parsers import content_text + + return "\n".join(content_text(message.content) for message in request.messages) + + +def forced_tool_turn_text(request: MessagesRequest) -> str: + """Text for parsing forced server-tool inputs: latest user turn only (avoids stale history).""" + if not request.messages: + return "" + + from .parsers import content_text + + for message in reversed(request.messages): + if message.role == "user": + return content_text(message.content) + return "" + + +def forced_server_tool_name(request: MessagesRequest) -> str | None: + """Return web_search or web_fetch only when tool_choice forces that server tool.""" + tc = request.tool_choice + if not isinstance(tc, dict): + return None + if tc.get("type") != "tool": + return None + name = tc.get("name") + if name in {"web_search", "web_fetch"}: + return str(name) + return None + + +def has_tool_named(request: MessagesRequest, name: str) -> bool: + return any(tool.name == name for tool in request.tools or []) + + +def is_web_server_tool_request(request: MessagesRequest) -> bool: + """True when the client forces a web server tool via tool_choice (not merely listed).""" + forced = forced_server_tool_name(request) + if forced is None: + return False + return has_tool_named(request, forced) + + +def is_anthropic_server_tool_definition(tool: Tool) -> bool: + """Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family).""" + name = (tool.name or "").strip() + if name in ("web_search", "web_fetch"): + return True + typ = tool.type + if isinstance(typ, str): + return typ.startswith("web_search") or typ.startswith("web_fetch") + return False + + +def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool: + """True when tools include web_search / web_fetch-style entries (listed, forced or not).""" + return any(is_anthropic_server_tool_definition(t) for t in (request.tools or [])) + + +def openai_chat_upstream_server_tool_error( + request: MessagesRequest, *, web_tools_enabled: bool +) -> str | None: + """Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics.""" + forced = forced_server_tool_name(request) + if forced and not web_tools_enabled: + return ( + f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are " + "disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic " + "Messages transport (e.g. open_router, ollama, lmstudio)." + ) + if not forced and has_listed_anthropic_server_tools(request): + return ( + "OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools " + "(web_search / web_fetch) without the local web server tool handler. Use a native " + "Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with " + "tool_choice, or remove these tools from the request." + ) + return None diff --git a/api/web_tools/streaming.py b/api/web_tools/streaming.py new file mode 100644 index 0000000..eb9b414 --- /dev/null +++ b/api/web_tools/streaming.py @@ -0,0 +1,206 @@ +"""SSE streaming for local web_search / web_fetch server tool results.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncIterator +from datetime import UTC, datetime +from typing import Any + +from api.models.anthropic import MessagesRequest +from core.anthropic.server_tool_sse import ( + SERVER_TOOL_USE, + WEB_FETCH_TOOL_ERROR, + WEB_FETCH_TOOL_RESULT, + WEB_SEARCH_TOOL_RESULT, + WEB_SEARCH_TOOL_RESULT_ERROR, +) +from core.anthropic.sse import format_sse_event + +from . import outbound +from .constants import _MAX_FETCH_CHARS +from .egress import WebFetchEgressPolicy +from .parsers import extract_query, extract_url +from .request import ( + forced_server_tool_name, + forced_tool_turn_text, + has_tool_named, +) + + +def _search_summary(query: str, results: list[dict[str, str]]) -> str: + if not results: + return f"No web search results found for: {query}" + lines = [f"Search results for: {query}"] + for index, result in enumerate(results, start=1): + lines.append(f"{index}. {result['title']}\n{result['url']}") + return "\n\n".join(lines) + + +async def stream_web_server_tool_response( + request: MessagesRequest, + input_tokens: int, + *, + web_fetch_egress: WebFetchEgressPolicy, + verbose_client_errors: bool = False, +) -> AsyncIterator[str]: + """Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback). + + When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full + hosted Anthropic citation or encrypted-content pipeline. + """ + tool_name = forced_server_tool_name(request) + if tool_name is None or not has_tool_named(request, tool_name): + return + + text = forced_tool_turn_text(request) + message_id = f"msg_{uuid.uuid4()}" + tool_id = f"srvtoolu_{uuid.uuid4().hex}" + usage_key = ( + "web_search_requests" if tool_name == "web_search" else "web_fetch_requests" + ) + tool_input = ( + {"query": extract_query(text)} + if tool_name == "web_search" + else {"url": extract_url(text)} + ) + _result_block_for_tool = { + "web_search": WEB_SEARCH_TOOL_RESULT, + "web_fetch": WEB_FETCH_TOOL_RESULT, + } + _error_payload_type_for_tool = { + "web_search": WEB_SEARCH_TOOL_RESULT_ERROR, + "web_fetch": WEB_FETCH_TOOL_ERROR, + } + + yield format_sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [], + "model": request.model, + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": 1}, + }, + }, + ) + yield format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": SERVER_TOOL_USE, + "id": tool_id, + "name": tool_name, + "input": tool_input, + }, + }, + ) + yield format_sse_event( + "content_block_stop", {"type": "content_block_stop", "index": 0} + ) + + try: + if tool_name == "web_search": + query = str(tool_input["query"]) + results = await outbound._run_web_search(query) + result_content: Any = [ + { + "type": "web_search_result", + "title": result["title"], + "url": result["url"], + } + for result in results + ] + summary = _search_summary(query, results) + result_block_type = WEB_SEARCH_TOOL_RESULT + else: + fetched = await outbound._run_web_fetch( + str(tool_input["url"]), web_fetch_egress + ) + result_content = { + "type": "web_fetch_result", + "url": fetched["url"], + "content": { + "type": "document", + "source": { + "type": "text", + "media_type": fetched["media_type"], + "data": fetched["data"], + }, + "title": fetched["title"], + "citations": {"enabled": True}, + }, + "retrieved_at": datetime.now(UTC).isoformat(), + } + summary = fetched["data"][:_MAX_FETCH_CHARS] + result_block_type = WEB_FETCH_TOOL_RESULT + except Exception as error: + fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None + outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url) + result_block_type = _result_block_for_tool[tool_name] + result_content = { + "type": _error_payload_type_for_tool[tool_name], + "error_code": "unavailable", + } + summary = outbound._web_tool_client_error_summary( + tool_name, error, verbose=verbose_client_errors + ) + + output_tokens = max(1, len(summary) // 4) + + yield format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": result_block_type, + "tool_use_id": tool_id, + "content": result_content, + }, + }, + ) + yield format_sse_event( + "content_block_stop", {"type": "content_block_stop", "index": 1} + ) + # Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`, + # not eager `text` on `content_block_start`). + yield format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 2, + "content_block": {"type": "text", "text": ""}, + }, + ) + yield format_sse_event( + "content_block_delta", + { + "type": "content_block_delta", + "index": 2, + "delta": {"type": "text_delta", "text": summary}, + }, + ) + yield format_sse_event( + "content_block_stop", {"type": "content_block_stop", "index": 2} + ) + yield format_sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": None}, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "server_tool_use": {usage_key: 1}, + }, + }, + ) + yield format_sse_event("message_stop", {"type": "message_stop"}) diff --git a/cli/entrypoints.py b/cli/entrypoints.py index ba61a9e..c161de9 100644 --- a/cli/entrypoints.py +++ b/cli/entrypoints.py @@ -30,7 +30,8 @@ def serve() -> None: settings = get_settings() try: uvicorn.run( - "api.app:app", + "api.app:create_app", + factory=True, host=settings.host, port=settings.port, log_level="debug", diff --git a/cli/manager.py b/cli/manager.py index b267633..cb27419 100644 --- a/cli/manager.py +++ b/cli/manager.py @@ -29,6 +29,9 @@ class CLISessionManager: allowed_dirs: list[str] | None = None, plans_directory: str | None = None, claude_bin: str = "claude", + *, + log_raw_cli_diagnostics: bool = False, + log_messaging_error_details: bool = False, ): """ Initialize the session manager. @@ -44,6 +47,8 @@ class CLISessionManager: self.allowed_dirs = allowed_dirs or [] self.plans_directory = plans_directory self.claude_bin = claude_bin + self._log_raw_cli_diagnostics = log_raw_cli_diagnostics + self._log_messaging_error_details = log_messaging_error_details self._sessions: dict[str, CLISession] = {} self._pending_sessions: dict[str, CLISession] = {} @@ -79,6 +84,7 @@ class CLISessionManager: allowed_dirs=self.allowed_dirs, plans_directory=self.plans_directory, claude_bin=self.claude_bin, + log_raw_cli_diagnostics=self._log_raw_cli_diagnostics, ) self._pending_sessions[temp_id] = new_session logger.info(f"Created new session: {temp_id}") @@ -130,7 +136,17 @@ class CLISessionManager: try: await session.stop() except Exception as e: - logger.error(f"Error stopping session: {e}") + if self._log_messaging_error_details: + logger.error( + "Error stopping session: {}: {}", + type(e).__name__, + e, + ) + else: + logger.error( + "Error stopping session: exc_type={}", + type(e).__name__, + ) self._sessions.clear() self._pending_sessions.clear() diff --git a/cli/session.py b/cli/session.py index db18f6c..d007975 100644 --- a/cli/session.py +++ b/cli/session.py @@ -11,6 +11,9 @@ from loguru import logger from .process_registry import register_pid, unregister_pid +# Cap stderr capture so a runaway child cannot exhaust memory; pipe is still drained. +_MAX_STDERR_CAPTURE_BYTES = 256 * 1024 + @dataclass(frozen=True, slots=True) class ClaudeCliConfig: @@ -33,6 +36,8 @@ class CLISession: allowed_dirs: list[str] | None = None, plans_directory: str | None = None, claude_bin: str = "claude", + *, + log_raw_cli_diagnostics: bool = False, ): self.config = ClaudeCliConfig( workspace_path=os.path.normpath(os.path.abspath(workspace_path)), @@ -46,11 +51,40 @@ class CLISession: self.allowed_dirs = self.config.allowed_dirs self.plans_directory = self.config.plans_directory self.claude_bin = self.config.claude_bin + self._log_raw_cli_diagnostics = log_raw_cli_diagnostics self.process: asyncio.subprocess.Process | None = None self.current_session_id: str | None = None self._is_busy = False self._cli_lock = asyncio.Lock() + @staticmethod + async def _drain_stderr_bounded( + process: asyncio.subprocess.Process, + *, + max_bytes: int = _MAX_STDERR_CAPTURE_BYTES, + ) -> bytes: + """Read stderr concurrently with stdout to avoid subprocess pipe deadlocks. + + Retains at most ``max_bytes`` for logging; any excess is discarded, but + the pipe is read until EOF so a noisy child cannot fill the buffer and + block forever. + """ + if not process.stderr: + return b"" + parts: list[bytes] = [] + received = 0 + while True: + chunk = await process.stderr.read(65_536) + if not chunk: + break + if received < max_bytes: + take = min(len(chunk), max_bytes - received) + if take: + parts.append(chunk[:take]) + received += take + # If already at cap, keep reading and discarding until EOF. + return b"".join(parts) + @property def is_busy(self) -> bool: """Check if a task is currently running.""" @@ -140,6 +174,11 @@ class CLISession: session_id_extracted = False buffer = bytearray() + stderr_task: asyncio.Task[bytes] | None = None + if self.process.stderr: + stderr_task = asyncio.create_task( + self._drain_stderr_bounded(self.process) + ) try: while True: @@ -179,23 +218,27 @@ class CLISession: except asyncio.CancelledError: # Cancelling the handler task should not leave a Claude CLI # subprocess running in the background. - try: - await asyncio.shield(self.stop()) - finally: - raise + await asyncio.shield(self.stop()) + raise + finally: + stderr_bytes = b"" + if stderr_task is not None: + stderr_bytes = await stderr_task stderr_text = None - if self.process.stderr: - stderr_output = await self.process.stderr.read() - if stderr_output: - stderr_text = stderr_output.decode( - "utf-8", errors="replace" - ).strip() - logger.error(f"Claude CLI Stderr: {stderr_text}") - # Yield stderr as error event so it shows in UI - if stderr_text: - logger.info("CLI_SESSION: Yielding error event from stderr") - yield {"type": "error", "error": {"message": stderr_text}} + if stderr_bytes: + stderr_text = stderr_bytes.decode("utf-8", errors="replace").strip() + if stderr_text: + if self._log_raw_cli_diagnostics: + logger.error("Claude CLI stderr: {}", stderr_text) + else: + logger.error( + "Claude CLI stderr: bytes={} text_chars={}", + len(stderr_bytes), + len(stderr_text), + ) + logger.info("CLI_SESSION: Yielding error event from stderr") + yield {"type": "error", "error": {"message": stderr_text}} return_code = await self.process.wait() logger.info( @@ -230,7 +273,10 @@ class CLISession: yield event except json.JSONDecodeError: - logger.debug(f"Non-JSON output: {line_str}") + if self._log_raw_cli_diagnostics: + logger.debug("Non-JSON output: {}", line_str) + else: + logger.debug("Non-JSON CLI line: char_len={}", len(line_str)) yield {"type": "raw", "content": line_str} def _extract_session_id(self, event: Any) -> str | None: @@ -273,6 +319,16 @@ class CLISession: unregister_pid(self.process.pid) return True except Exception as e: - logger.error(f"Error stopping process: {e}") + if self._log_raw_cli_diagnostics: + logger.error( + "Error stopping process: {}: {}", + type(e).__name__, + e, + ) + else: + logger.error( + "Error stopping process: exc_type={}", + type(e).__name__, + ) return False return False diff --git a/config/constants.py b/config/constants.py new file mode 100644 index 0000000..987027a --- /dev/null +++ b/config/constants.py @@ -0,0 +1,10 @@ +"""Shared defaults used by config models and provider adapters.""" + +# HTTP client connect timeout (seconds). Keep aligned with README.md and .env.example. +HTTP_CONNECT_TIMEOUT_DEFAULT = 10.0 + +# Anthropic Messages API default when the client omits max_tokens. +ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS = 81920 + +# Max bytes read from a non-200 native messages response when verbose error logging is on. +NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES = 4096 diff --git a/config/logging_config.py b/config/logging_config.py index e1a31e4..7c5d8fd 100644 --- a/config/logging_config.py +++ b/config/logging_config.py @@ -8,6 +8,7 @@ included at top level for easy grep/filter. import json import logging +import re from pathlib import Path from loguru import logger @@ -17,6 +18,22 @@ _configured = False # Context keys we promote to top-level JSON for traceability _CONTEXT_KEYS = ("request_id", "node_id", "chat_id") +_TELEGRAM_BOT_RE = re.compile( + r"(https?://api\.telegram\.org/)bot([0-9]+:[A-Za-z0-9_-]+)(/?)", + re.IGNORECASE, +) +# Authorization: Bearer (HTTP client / proxy debug lines) +_AUTH_BEARER_RE = re.compile( + r"(\bAuthorization\s*:\s*Bearer\s+)([^\s'\"]+)", + re.IGNORECASE, +) + + +def _redact_sensitive_substrings(message: str) -> str: + """Remove obvious API tokens and secrets before JSON log line emission.""" + text = _TELEGRAM_BOT_RE.sub(r"\1bot\3", message) + return _AUTH_BEARER_RE.sub(r"\1", text) + def _serialize_with_context(record) -> str: """Format record as JSON with context vars at top level. @@ -26,7 +43,7 @@ def _serialize_with_context(record) -> str: out = { "time": str(record["time"]), "level": record["level"].name, - "message": record["message"], + "message": _redact_sensitive_substrings(str(record["message"])), "module": record["name"], "function": record["function"], "line": record["line"], @@ -57,11 +74,16 @@ class InterceptHandler(logging.Handler): ) -def configure_logging(log_file: str, *, force: bool = False) -> None: +def configure_logging( + log_file: str, *, force: bool = False, verbose_third_party: bool = False +) -> None: """Configure loguru with JSON output to log_file and intercept stdlib logging. Idempotent: skips if already configured (e.g. hot reload). Use force=True to reconfigure (e.g. in tests with a different log path). + + When ``verbose_third_party`` is false, noisy HTTP and Telegram loggers are capped + at WARNING unless explicitly configured otherwise. """ global _configured if _configured and not force: @@ -88,3 +110,16 @@ def configure_logging(log_file: str, *, force: bool = False) -> None: intercept = InterceptHandler() logging.root.handlers = [intercept] logging.root.setLevel(logging.DEBUG) + + third_party = ( + "httpx", + "httpcore", + "httpcore.http11", + "httpcore.connection", + "telegram", + "telegram.ext", + ) + for name in third_party: + logging.getLogger(name).setLevel( + logging.WARNING if not verbose_third_party else logging.NOTSET + ) diff --git a/config/nim.py b/config/nim.py index b3a9c3b..5bf7761 100644 --- a/config/nim.py +++ b/config/nim.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS + class NimSettings(BaseModel): """Fixed NVIDIA NIM settings (not configurable via env).""" @@ -14,7 +16,9 @@ class NimSettings(BaseModel): ) top_k: int = -1 max_tokens: int = Field( - 81920, ge=1, description="Maximum number of tokens in output." + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + ge=1, + description="Maximum number of tokens in output.", ) presence_penalty: float = Field(0.0, ge=-2.0, le=2.0) frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0) @@ -68,7 +72,7 @@ class NimSettings(BaseModel): return field_defaults.get(key, 1.0) try: val = float(v) - except Exception as err: + except (TypeError, ValueError) as err: raise ValueError( f"{info.field_name} must be a float. Got {type(v).__name__}." ) from err @@ -78,15 +82,15 @@ class NimSettings(BaseModel): @classmethod def validate_int_fields(cls, v, info: ValidationInfo): field_defaults = { - "max_tokens": 81920, + "max_tokens": ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, "min_tokens": 0, } if v is None or v == "": key = info.field_name or "max_tokens" - return field_defaults.get(key, 81920) + return field_defaults.get(key, ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS) try: val = int(v) - except Exception as err: + except (TypeError, ValueError) as err: raise ValueError( f"{info.field_name} must be an int. Got {type(v).__name__}." ) from err @@ -99,7 +103,7 @@ class NimSettings(BaseModel): return None try: return int(v) - except Exception as err: + except (TypeError, ValueError) as err: raise ValueError( f"{info.field_name} must be an int or empty/None." ) from err diff --git a/config/provider_catalog.py b/config/provider_catalog.py new file mode 100644 index 0000000..b0b731e --- /dev/null +++ b/config/provider_catalog.py @@ -0,0 +1,108 @@ +"""Neutral provider catalog: IDs, credentials, defaults, proxy and capability metadata. + +Adapter factories live in :mod:`providers.registry`; this module stays free of +provider implementation imports (see contract tests). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +TransportType = Literal["openai_chat", "anthropic_messages"] + +# Default upstream base URLs (also re-exported via :mod:`providers.defaults`) +NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1" +DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com" +OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1" +LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1" +LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" +OLLAMA_DEFAULT_BASE = "http://localhost:11434" + + +@dataclass(frozen=True, slots=True) +class ProviderDescriptor: + """Metadata for building :class:`~providers.base.ProviderConfig` and factory wiring.""" + + provider_id: str + transport_type: TransportType + capabilities: tuple[str, ...] + credential_env: str | None = None + credential_url: str | None = None + credential_attr: str | None = None + static_credential: str | None = None + default_base_url: str | None = None + base_url_attr: str | None = None + proxy_attr: str | None = None + + +PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { + "nvidia_nim": ProviderDescriptor( + provider_id="nvidia_nim", + transport_type="openai_chat", + credential_env="NVIDIA_NIM_API_KEY", + credential_url="https://build.nvidia.com/settings/api-keys", + credential_attr="nvidia_nim_api_key", + default_base_url=NVIDIA_NIM_DEFAULT_BASE, + proxy_attr="nvidia_nim_proxy", + capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), + ), + "open_router": ProviderDescriptor( + provider_id="open_router", + transport_type="anthropic_messages", + credential_env="OPENROUTER_API_KEY", + credential_url="https://openrouter.ai/keys", + credential_attr="open_router_api_key", + default_base_url=OPENROUTER_DEFAULT_BASE, + proxy_attr="open_router_proxy", + capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"), + ), + "deepseek": ProviderDescriptor( + provider_id="deepseek", + transport_type="openai_chat", + credential_env="DEEPSEEK_API_KEY", + credential_url="https://platform.deepseek.com/api_keys", + credential_attr="deepseek_api_key", + default_base_url=DEEPSEEK_DEFAULT_BASE, + capabilities=("chat", "streaming", "thinking"), + ), + "lmstudio": ProviderDescriptor( + provider_id="lmstudio", + transport_type="anthropic_messages", + static_credential="lm-studio", + default_base_url=LMSTUDIO_DEFAULT_BASE, + base_url_attr="lm_studio_base_url", + proxy_attr="lmstudio_proxy", + capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), + ), + "llamacpp": ProviderDescriptor( + provider_id="llamacpp", + transport_type="anthropic_messages", + static_credential="llamacpp", + default_base_url=LLAMACPP_DEFAULT_BASE, + base_url_attr="llamacpp_base_url", + proxy_attr="llamacpp_proxy", + capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), + ), + "ollama": ProviderDescriptor( + provider_id="ollama", + transport_type="anthropic_messages", + static_credential="ollama", + default_base_url=OLLAMA_DEFAULT_BASE, + base_url_attr="ollama_base_url", + capabilities=( + "chat", + "streaming", + "tools", + "thinking", + "native_anthropic", + "local", + ), + ), +} + +# Order matches docs / historical error text; must match PROVIDER_CATALOG keys. +SUPPORTED_PROVIDER_IDS: tuple[str, ...] = tuple(PROVIDER_CATALOG.keys()) + +if len(set(SUPPORTED_PROVIDER_IDS)) != len(SUPPORTED_PROVIDER_IDS): + raise AssertionError("Duplicate provider ids in PROVIDER_CATALOG key order") diff --git a/config/provider_ids.py b/config/provider_ids.py index e3e50bf..fd08ab7 100644 --- a/config/provider_ids.py +++ b/config/provider_ids.py @@ -1,18 +1,7 @@ -"""Canonical list of model provider type prefixes (provider_id values). - -`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this -module holds the id set for config validation and must stay in sync -(registries assert in `providers.registry`). -""" +"""Canonical provider id tuple (re-exported from the provider catalog).""" from __future__ import annotations -# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys. -SUPPORTED_PROVIDER_IDS: tuple[str, ...] = ( - "nvidia_nim", - "open_router", - "deepseek", - "lmstudio", - "llamacpp", - "ollama", -) +from .provider_catalog import SUPPORTED_PROVIDER_IDS + +__all__ = ("SUPPORTED_PROVIDER_IDS",) diff --git a/config/settings.py b/config/settings.py index f7129fb..c4afb1b 100644 --- a/config/settings.py +++ b/config/settings.py @@ -10,6 +10,7 @@ from dotenv import dotenv_values from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict +from .constants import HTTP_CONNECT_TIMEOUT_DEFAULT from .nim import NimSettings from .provider_ids import SUPPORTED_PROVIDER_IDS @@ -105,6 +106,12 @@ class Settings(BaseSettings): messaging_platform: str = Field( default="discord", validation_alias="MESSAGING_PLATFORM" ) + messaging_rate_limit: int = Field( + default=1, validation_alias="MESSAGING_RATE_LIMIT" + ) + messaging_rate_window: float = Field( + default=1.0, validation_alias="MESSAGING_RATE_WINDOW" + ) # ==================== NVIDIA NIM Config ==================== nvidia_nim_api_key: str = "" @@ -173,7 +180,8 @@ class Settings(BaseSettings): default=10.0, validation_alias="HTTP_WRITE_TIMEOUT" ) http_connect_timeout: float = Field( - default=2.0, validation_alias="HTTP_CONNECT_TIMEOUT" + default=HTTP_CONNECT_TIMEOUT_DEFAULT, + validation_alias="HTTP_CONNECT_TIMEOUT", ) # ==================== Fast Prefix Detection ==================== @@ -185,6 +193,51 @@ class Settings(BaseSettings): enable_suggestion_mode_skip: bool = True enable_filepath_extraction_mock: bool = True + # ==================== Local web server tools (web_search / web_fetch) ==================== + # Off by default: these tools perform outbound HTTP from the proxy (SSRF risk). + enable_web_server_tools: bool = Field( + default=False, validation_alias="ENABLE_WEB_SERVER_TOOLS" + ) + # Comma-separated URL schemes allowed for web_fetch (default: http,https). + web_fetch_allowed_schemes: str = Field( + default="http,https", validation_alias="WEB_FETCH_ALLOWED_SCHEMES" + ) + # When true, skip private/loopback/link-local IP blocking for web_fetch (lab only). + web_fetch_allow_private_networks: bool = Field( + default=False, validation_alias="WEB_FETCH_ALLOW_PRIVATE_NETWORKS" + ) + + # ==================== Debug / diagnostic logging (avoid sensitive content) ==================== + # When false (default), API and SSE helpers log only metadata (counts, lengths, ids). + log_raw_api_payloads: bool = Field( + default=False, validation_alias="LOG_RAW_API_PAYLOADS" + ) + log_raw_sse_events: bool = Field( + default=False, validation_alias="LOG_RAW_SSE_EVENTS" + ) + # When false (default), unhandled exceptions log only type + route metadata (no message/traceback). + log_api_error_tracebacks: bool = Field( + default=False, validation_alias="LOG_API_ERROR_TRACEBACKS" + ) + # When false (default), messaging logs omit text/transcription previews (metadata only). + log_raw_messaging_content: bool = Field( + default=False, validation_alias="LOG_RAW_MESSAGING_CONTENT" + ) + # When true, log full Claude CLI stderr, non-JSON lines, and parser error text. + log_raw_cli_diagnostics: bool = Field( + default=False, validation_alias="LOG_RAW_CLI_DIAGNOSTICS" + ) + # When true, log exception text / CLI error strings in messaging (may leak user content). + log_messaging_error_details: bool = Field( + default=False, validation_alias="LOG_MESSAGING_ERROR_DETAILS" + ) + debug_platform_edits: bool = Field( + default=False, validation_alias="DEBUG_PLATFORM_EDITS" + ) + debug_subagent_stack: bool = Field( + default=False, validation_alias="DEBUG_SUBAGENT_STACK" + ) + # ==================== NIM Settings ==================== nim: NimSettings = Field(default_factory=NimSettings) @@ -215,6 +268,9 @@ class Settings(BaseSettings): claude_workspace: str = "./agent_workspace" allowed_dir: str = "" claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN") + max_message_log_entries_per_chat: int | None = Field( + default=None, validation_alias="MAX_MESSAGE_LOG_ENTRIES_PER_CHAT" + ) # ==================== Server ==================== host: str = "0.0.0.0" @@ -254,6 +310,13 @@ class Settings(BaseSettings): return None return v + @field_validator("max_message_log_entries_per_chat", mode="before") + @classmethod + def parse_optional_log_cap(cls, v: Any) -> Any: + if v == "" or v is None: + return None + return v + @field_validator("whisper_device") @classmethod def validate_whisper_device(cls, v: str) -> str: @@ -272,6 +335,33 @@ class Settings(BaseSettings): ) return v + @field_validator("messaging_rate_limit") + @classmethod + def validate_messaging_rate_limit(cls, v: int) -> int: + if v <= 0: + raise ValueError("messaging_rate_limit must be > 0") + return v + + @field_validator("messaging_rate_window") + @classmethod + def validate_messaging_rate_window(cls, v: float) -> float: + if v <= 0: + raise ValueError("messaging_rate_window must be > 0") + return float(v) + + @field_validator("web_fetch_allowed_schemes") + @classmethod + def validate_web_fetch_allowed_schemes(cls, v: str) -> str: + schemes = [part.strip().lower() for part in v.split(",") if part.strip()] + if not schemes: + raise ValueError("web_fetch_allowed_schemes must list at least one scheme") + for scheme in schemes: + if not scheme.isascii() or not scheme.isalpha(): + raise ValueError( + f"Invalid URL scheme in web_fetch_allowed_schemes: {scheme!r}" + ) + return ",".join(schemes) + @field_validator("ollama_base_url") @classmethod def validate_ollama_base_url(cls, v: str) -> str: @@ -329,12 +419,12 @@ class Settings(BaseSettings): @property def provider_type(self) -> str: """Extract provider type from the default model string.""" - return self.model.split("/", 1)[0] + return Settings.parse_provider_type(self.model) @property def model_name(self) -> str: """Extract the actual model name from the default model string.""" - return self.model.split("/", 1)[1] + return Settings.parse_model_name(self.model) def resolve_model(self, claude_model_name: str) -> str: """Resolve a Claude model name to the configured provider/model string. @@ -362,6 +452,14 @@ class Settings(BaseSettings): return self.enable_sonnet_thinking return self.enable_model_thinking + def web_fetch_allowed_scheme_set(self) -> frozenset[str]: + """Return normalized schemes allowed for web_fetch.""" + return frozenset( + part.strip().lower() + for part in self.web_fetch_allowed_schemes.split(",") + if part.strip() + ) + @staticmethod def parse_provider_type(model_string: str) -> str: """Extract provider type from any 'provider/model' string.""" diff --git a/core/anthropic/__init__.py b/core/anthropic/__init__.py index 9af5987..eaea8a1 100644 --- a/core/anthropic/__init__.py +++ b/core/anthropic/__init__.py @@ -1,9 +1,19 @@ """Anthropic protocol helpers shared across API, providers, and integrations.""" from .content import extract_text_from_content, get_block_attr, get_block_type -from .conversion import AnthropicToOpenAIConverter, build_base_request_body -from .errors import append_request_id, get_user_facing_error_message -from .sse import ContentBlockManager, SSEBuilder, map_stop_reason +from .conversion import ( + AnthropicToOpenAIConverter, + OpenAIConversionError, + build_base_request_body, +) +from .errors import ( + append_request_id, + format_user_error_preview, + get_user_facing_error_message, +) +from .native_messages_request import sanitize_native_messages_thinking_policy +from .provider_stream_error import iter_provider_stream_error_sse_events +from .sse import ContentBlockManager, SSEBuilder, format_sse_event, map_stop_reason from .thinking import ContentChunk, ContentType, ThinkTagParser from .tokens import get_token_count from .tools import HeuristicToolParser @@ -15,15 +25,20 @@ __all__ = [ "ContentChunk", "ContentType", "HeuristicToolParser", + "OpenAIConversionError", "SSEBuilder", "ThinkTagParser", "append_request_id", "build_base_request_body", "extract_text_from_content", + "format_sse_event", + "format_user_error_preview", "get_block_attr", "get_block_type", "get_token_count", "get_user_facing_error_message", + "iter_provider_stream_error_sse_events", "map_stop_reason", + "sanitize_native_messages_thinking_policy", "set_if_not_none", ] diff --git a/core/anthropic/conversion.py b/core/anthropic/conversion.py index 5979ba1..c484ceb 100644 --- a/core/anthropic/conversion.py +++ b/core/anthropic/conversion.py @@ -3,10 +3,34 @@ import json from typing import Any +from pydantic import BaseModel + from .content import get_block_attr, get_block_type from .utils import set_if_not_none +class OpenAIConversionError(Exception): + """Raised when Anthropic content cannot be converted to OpenAI chat without data loss.""" + + +def _openai_reject_native_only_top_level_fields(request_data: Any) -> None: + """OpenAI chat providers may only convert known top-level request fields. + + First-class model fields (e.g. ``context_management``) are not forwarded to + the OpenAI API but are allowed so clients do not hit spurious 400s. + Unknown extra keys (``__pydantic_extra__``) are still rejected. + """ + if not isinstance(request_data, BaseModel): + return + extra = getattr(request_data, "__pydantic_extra__", None) + if not extra: + return + raise OpenAIConversionError( + "OpenAI chat conversion does not support these top-level request fields: " + f"{sorted(str(k) for k in extra)}. Use a native Anthropic transport provider." + ) + + def _tool_name(tool: Any) -> str: return str(getattr(tool, "name", "") or "") @@ -18,6 +42,27 @@ def _tool_input_schema(tool: Any) -> dict[str, Any]: return {"type": "object", "properties": {}} +def _serialize_tool_result_content(tool_content: Any) -> str: + """Serialize tool_result content for OpenAI ``role: tool`` messages (stable JSON for structured values).""" + if tool_content is None: + return "" + if isinstance(tool_content, str): + return tool_content + if isinstance(tool_content, dict): + return json.dumps(tool_content, ensure_ascii=False) + if isinstance(tool_content, list): + parts: list[str] = [] + for item in tool_content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif isinstance(item, dict): + parts.append(json.dumps(item, ensure_ascii=False)) + else: + parts.append(str(item)) + return "\n".join(parts) + return str(tool_content) + + class AnthropicToOpenAIConverter: """Convert Anthropic message format to OpenAI-compatible format.""" @@ -64,20 +109,38 @@ class AnthropicToOpenAIConverter: content_parts: list[str] = [] thinking_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] + seen_tool_use = False for block in content: block_type = get_block_type(block) if block_type == "text": + if seen_tool_use: + raise OpenAIConversionError( + "OpenAI chat conversion does not support assistant text after " + "tool_use in the same message; split the transcript or use a " + "native Anthropic provider." + ) content_parts.append(get_block_attr(block, "text", "")) elif block_type == "thinking": if not include_thinking: continue + if seen_tool_use: + raise OpenAIConversionError( + "OpenAI chat conversion does not support assistant thinking after " + "tool_use in the same message; split the transcript or use a " + "native Anthropic provider." + ) thinking = get_block_attr(block, "thinking", "") content_parts.append(f"\n{thinking}\n") if include_reasoning_content: thinking_parts.append(thinking) + elif block_type == "redacted_thinking": + # Opaque provider continuation data; do not materialize as model-visible text + # or reasoning_content for OpenAI chat upstreams. + continue elif block_type == "tool_use": + seen_tool_use = True tool_input = get_block_attr(block, "input", {}) tool_calls.append( { @@ -91,6 +154,19 @@ class AnthropicToOpenAIConverter: }, } ) + elif block_type == "image": + raise OpenAIConversionError( + "Assistant image blocks are not supported for OpenAI chat conversion." + ) + elif block_type in ( + "server_tool_use", + "web_search_tool_result", + "web_fetch_tool_result", + ): + raise OpenAIConversionError( + "OpenAI chat conversion does not support Anthropic server tool blocks " + f"({block_type!r} in an assistant message). Use a native Anthropic transport provider." + ) content_str = "\n\n".join(content_parts) if not content_str and not tool_calls: @@ -122,21 +198,21 @@ class AnthropicToOpenAIConverter: if block_type == "text": text_parts.append(get_block_attr(block, "text", "")) + elif block_type == "image": + raise OpenAIConversionError( + "User message image blocks are not supported for OpenAI chat " + "conversion; use a vision-capable native Anthropic provider or " + "extend the converter." + ) elif block_type == "tool_result": flush_text() tool_content = get_block_attr(block, "content", "") - if isinstance(tool_content, list): - tool_content = "\n".join( - item.get("text", str(item)) - if isinstance(item, dict) - else str(item) - for item in tool_content - ) + serialized = _serialize_tool_result_content(tool_content) result.append( { "role": "tool", "tool_call_id": get_block_attr(block, "tool_use_id"), - "content": str(tool_content) if tool_content else "", + "content": serialized if serialized else "", } ) @@ -199,6 +275,7 @@ def build_base_request_body( include_reasoning_content: bool = False, ) -> dict[str, Any]: """Build the common parts of an OpenAI-format request body.""" + _openai_reject_native_only_top_level_fields(request_data) messages = AnthropicToOpenAIConverter.convert_messages( request_data.messages, include_thinking=include_thinking, diff --git a/core/anthropic/emitted_sse_tracker.py b/core/anthropic/emitted_sse_tracker.py new file mode 100644 index 0000000..079dd51 --- /dev/null +++ b/core/anthropic/emitted_sse_tracker.py @@ -0,0 +1,97 @@ +"""Track content-block state for native Anthropic SSE strings we emit to clients.""" + +from __future__ import annotations + +import uuid +from collections.abc import Iterator +from contextlib import suppress +from typing import Any + +from core.anthropic.sse import SSEBuilder, format_sse_event +from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines + + +class EmittedNativeSseTracker: + """Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index.""" + + def __init__(self) -> None: + self._buf = "" + self._open_stack: list[int] = [] + self._max_index = -1 + self.message_id: str | None = None + self.model: str = "" + + def feed(self, chunk: str) -> None: + """Record SSE frames completed by ``chunk`` (handles splitting across reads).""" + self._buf += chunk + while True: + sep = self._buf.find("\n\n") + if sep < 0: + break + frame = self._buf[:sep] + self._buf = self._buf[sep + 2 :] + if not frame.strip(): + continue + for event in parse_sse_lines(frame.splitlines()): + self._observe(event) + + def _observe(self, event: SSEEvent) -> None: + if event.event == "message_start": + message = event.data.get("message") + if isinstance(message, dict): + mid = message.get("id") + if isinstance(mid, str) and mid: + self.message_id = mid + model = message.get("model") + if isinstance(model, str) and model: + self.model = model + return + + if event.event == "content_block_start": + idx = event_index(event) + self._max_index = max(self._max_index, idx) + self._open_stack.append(idx) + return + + if event.event == "content_block_stop": + idx = event_index(event) + if self._open_stack and self._open_stack[-1] == idx: + self._open_stack.pop() + else: + with suppress(ValueError): + self._open_stack.remove(idx) + + def next_content_index(self) -> int: + """Next unused content block index based on emitted starts.""" + return self._max_index + 1 + + def iter_close_unclosed_blocks(self) -> Iterator[str]: + """Yield ``content_block_stop`` events for blocks that were started but not stopped.""" + while self._open_stack: + idx = self._open_stack.pop() + yield format_sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": idx}, + ) + + def iter_midstream_error_tail( + self, + error_message: str, + *, + request: Any, + input_tokens: int, + log_raw_sse_events: bool, + ) -> Iterator[str]: + """Close dangling blocks, emit a text error block at a fresh index, then message tail.""" + mid = self.message_id or f"msg_{uuid.uuid4()}" + model = self.model or (getattr(request, "model", "") or "") + sse = SSEBuilder( + mid, + model, + input_tokens, + log_raw_events=log_raw_sse_events, + ) + sse.blocks.next_index = self.next_content_index() + yield from sse.emit_error(error_message) + yield sse.message_delta("end_turn", 1) + yield sse.message_stop() diff --git a/core/anthropic/errors.py b/core/anthropic/errors.py index 38e3d12..afe093d 100644 --- a/core/anthropic/errors.py +++ b/core/anthropic/errors.py @@ -9,11 +9,12 @@ def get_user_facing_error_message( *, read_timeout_s: float | None = None, ) -> str: - """Return a readable, non-empty error message for users.""" - message = str(e).strip() - if message: - return message + """Return a readable, non-empty error message for users. + Known transport and OpenAI SDK exception types are mapped to stable wording + before falling back to ``str(e)``, so empty or noisy SDK messages do not skip + the mapped path. + """ if isinstance(e, httpx.ReadTimeout): if read_timeout_s is not None: return f"Provider request timed out after {read_timeout_s:g}s." @@ -25,13 +26,20 @@ def get_user_facing_error_message( return f"Provider request timed out after {read_timeout_s:g}s." return "Request timed out." + if isinstance(e, openai.RateLimitError): + return "Provider rate limit reached. Please retry shortly." + if isinstance(e, openai.AuthenticationError): + return "Provider authentication failed. Check API key." + if isinstance(e, openai.BadRequestError): + return "Invalid request sent to provider." + name = type(e).__name__ status_code = getattr(e, "status_code", None) - if isinstance(e, openai.RateLimitError) or name == "RateLimitError": + if name == "RateLimitError": return "Provider rate limit reached. Please retry shortly." - if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError": + if name == "AuthenticationError": return "Provider authentication failed. Check API key." - if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError": + if name == "InvalidRequestError": return "Invalid request sent to provider." if name == "OverloadedError": return "Provider is currently overloaded. Please retry." @@ -42,9 +50,18 @@ def get_user_facing_error_message( if name.endswith("ProviderError") or name == "ProviderError": return "Provider request failed." + message = str(e).strip() + if message: + return message + return "Provider request failed unexpectedly." +def format_user_error_preview(exc: Exception, *, max_len: int = 200) -> str: + """Truncate a user-facing error string for short chat replies.""" + return get_user_facing_error_message(exc)[:max_len] + + def append_request_id(message: str, request_id: str | None) -> str: """Append request_id suffix when available.""" base = message.strip() or "Provider request failed unexpectedly." diff --git a/core/anthropic/native_messages_request.py b/core/anthropic/native_messages_request.py new file mode 100644 index 0000000..ed24c6c --- /dev/null +++ b/core/anthropic/native_messages_request.py @@ -0,0 +1,260 @@ +"""Native Anthropic Messages request body construction (JSON-ready dicts). + +Provider adapters supply policy via parameters (defaults, OpenRouter post-steps). +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from pydantic import BaseModel + +_REQUEST_FIELDS = ( + "model", + "messages", + "system", + "max_tokens", + "stop_sequences", + "stream", + "temperature", + "top_p", + "top_k", + "metadata", + "tools", + "tool_choice", + "thinking", + "context_management", + "output_config", + "mcp_servers", + "extra_body", +) + +# Keys that would override routed canonical request fields if merged from ``extra_body``. +_OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS = frozenset( + { + "model", + "messages", + "system", + "tools", + "tool_choice", + "stream", + "max_tokens", + "temperature", + "top_p", + "top_k", + "metadata", + "stop_sequences", + "context_management", + "output_config", + "mcp_servers", + } +) + + +class OpenRouterExtraBodyError(ValueError): + """``extra_body`` contained reserved keys that would override canonical fields.""" + + +def validate_openrouter_extra_body(extra: Any) -> None: + """Reject ``extra_body`` keys that must not override routed request fields.""" + if not isinstance(extra, dict) or not extra: + return + bad = _OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS & extra.keys() + if bad: + raise OpenRouterExtraBodyError( + f"extra_body must not override canonical request fields: {sorted(bad)}" + ) + + +_INTERNAL_FIELDS = { + "thinking", + "extra_body", +} + + +def _serialize_value(value: Any) -> Any: + """Convert Pydantic models and lightweight objects into JSON-ready values.""" + if isinstance(value, BaseModel): + return value.model_dump(exclude_none=True) + if isinstance(value, dict): + return { + key: _serialize_value(item) + for key, item in value.items() + if item is not None + } + if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray): + return [_serialize_value(item) for item in value] + if value is None or isinstance(value, str | int | float | bool): + return value + if hasattr(value, "__dict__"): + return { + key: _serialize_value(item) + for key, item in vars(value).items() + if not key.startswith("_") and item is not None + } + return value + + +def _dump_request_fields(request_data: Any) -> dict[str, Any]: + """Extract the public request fields (OpenRouter-style explicit field list).""" + if isinstance(request_data, BaseModel): + return request_data.model_dump(exclude_none=True) + + dumped: dict[str, Any] = {} + for field in _REQUEST_FIELDS: + value = getattr(request_data, field, None) + if value is not None: + dumped[field] = _serialize_value(value) + return dumped + + +def sanitize_native_messages_thinking_policy( + messages: Any, *, thinking_enabled: bool +) -> Any: + """Filter assistant message thinking blocks for upstream native Anthropic JSON. + + When ``thinking_enabled`` is false, remove ``thinking`` and ``redacted_thinking`` + history so disabled policy is not undermined by prior turns. + + When true, keep ``redacted_thinking`` and signed ``thinking``; remove only + unsigned plain ``thinking`` blocks (not replayable). + """ + if not isinstance(messages, list): + return messages + + sanitized_messages: list[Any] = [] + for message in messages: + if not isinstance(message, dict): + sanitized_messages.append(message) + continue + + if message.get("role") != "assistant": + sanitized_messages.append(message) + continue + + content = message.get("content") + if not isinstance(content, list): + sanitized_messages.append(message) + continue + + if not thinking_enabled: + sanitized_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") in ("thinking", "redacted_thinking") + ) + ] + else: + sanitized_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") == "thinking" + and not isinstance(block.get("signature"), str) + ) + ] + + sanitized_message = dict(message) + sanitized_message["content"] = sanitized_content or "" + sanitized_messages.append(sanitized_message) + + return sanitized_messages + + +def _normalize_system_prompt_for_openrouter(system: Any) -> Any: + """Flatten Claude SDK system blocks for OpenRouter's native endpoint.""" + if not isinstance(system, list): + return system + + text_parts: list[str] = [] + for block in system: + if not isinstance(block, dict): + continue + if block.get("type") == "text" and isinstance(block.get("text"), str): + text_parts.append(block["text"]) + return "\n\n".join(text_parts).strip() if text_parts else system + + +def _apply_openrouter_reasoning_policy(body: dict[str, Any], thinking_cfg: Any) -> None: + """Map Anthropic thinking controls onto OpenRouter reasoning controls.""" + reasoning = body.setdefault("reasoning", {"enabled": True}) + if not isinstance(reasoning, dict): + return + reasoning.setdefault("enabled", True) + if not isinstance(thinking_cfg, dict): + return + budget_tokens = thinking_cfg.get("budget_tokens") + if isinstance(budget_tokens, int): + reasoning.setdefault("max_tokens", budget_tokens) + + +def build_base_native_anthropic_request_body( + request: Any, + *, + default_max_tokens: int, + thinking_enabled: bool, +) -> dict[str, Any]: + """Serialize a Pydantic messages request to a generic native Anthropic body.""" + body = request.model_dump(exclude_none=True) + + body.pop("extra_body", None) + + if "thinking" in body: + thinking_cfg = body.pop("thinking") + if thinking_enabled and isinstance(thinking_cfg, dict): + thinking_payload: dict[str, Any] = {"type": "enabled"} + budget_tokens = thinking_cfg.get("budget_tokens") + if isinstance(budget_tokens, int): + thinking_payload["budget_tokens"] = budget_tokens + body["thinking"] = thinking_payload + + if "max_tokens" not in body: + body["max_tokens"] = default_max_tokens + + if "messages" in body: + body["messages"] = sanitize_native_messages_thinking_policy( + body["messages"], + thinking_enabled=thinking_enabled, + ) + + return body + + +def build_openrouter_native_request_body( + request_data: Any, + *, + thinking_enabled: bool, + default_max_tokens: int, +) -> dict[str, Any]: + """Build an Anthropic-format request body for OpenRouter (policy hooks built-in).""" + dumped_request = _dump_request_fields(request_data) + request_extra = dumped_request.pop("extra_body", None) + thinking_cfg = dumped_request.get("thinking") + body: dict[str, Any] = { + key: value + for key, value in dumped_request.items() + if key not in _INTERNAL_FIELDS + } + + if isinstance(request_extra, dict): + validate_openrouter_extra_body(request_extra) + body.update(request_extra) + + body["messages"] = sanitize_native_messages_thinking_policy( + body.get("messages"), + thinking_enabled=thinking_enabled, + ) + if "system" in body: + body["system"] = _normalize_system_prompt_for_openrouter(body["system"]) + body["stream"] = True + if body.get("max_tokens") is None: + body["max_tokens"] = default_max_tokens + + if thinking_enabled: + _apply_openrouter_reasoning_policy(body, thinking_cfg) + + return body diff --git a/core/anthropic/native_sse_block_policy.py b/core/anthropic/native_sse_block_policy.py new file mode 100644 index 0000000..7a3fd81 --- /dev/null +++ b/core/anthropic/native_sse_block_policy.py @@ -0,0 +1,313 @@ +"""Shared native Anthropic SSE thinking policy, block remapping, and overlap repair. + +Used by :class:`OpenRouterProvider` and line-mode +:class:`providers.anthropic_messages.AnthropicMessagesTransport` providers. +""" + +from __future__ import annotations + +import copy +import json +from dataclasses import dataclass, field +from typing import Any + +__all__ = [ + "NativeSseBlockPolicyState", + "format_native_sse_event", + "is_terminal_openrouter_done_event", + "parse_native_sse_event", + "transform_native_sse_block_event", +] + + +@dataclass +class _UpstreamBlockState: + """Per-upstream content block: segment index and liveness in the model stream.""" + + block_type: str + down_index: int + open: bool + last_start_block: dict[str, Any] | None = None + + +@dataclass +class NativeSseBlockPolicyState: + """Track per-upstream content blocks and remapped Anthropic ``index`` field.""" + + next_index: int = 0 + by_upstream: dict[int, _UpstreamBlockState] = field(default_factory=dict) + dropped_indexes: set[int] = field(default_factory=set) + pending_suppressed_stops: set[int] = field(default_factory=set) + message_stopped: bool = False + + +def format_native_sse_event(event_name: str | None, data_text: str) -> str: + """Format an SSE event from its event name and data payload.""" + lines: list[str] = [] + if event_name: + lines.append(f"event: {event_name}") + lines.extend(f"data: {line}" for line in data_text.splitlines()) + return "\n".join(lines) + "\n\n" + + +def parse_native_sse_event(event: str) -> tuple[str | None, str]: + """Extract the event name and raw data payload from an SSE event.""" + event_name = None + data_lines: list[str] = [] + for line in event.strip().splitlines(): + if line.startswith("event:"): + event_name = line[6:].strip() + elif line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + return event_name, "\n".join(data_lines) + + +def is_terminal_openrouter_done_event(event_name: str | None, data_text: str) -> bool: + """Return whether an event is OpenAI-style terminal noise (``[DONE]``).""" + return (event_name is None or event_name in {"data", "done"}) and ( + data_text.strip().upper() == "[DONE]" + ) + + +def _delta_type_to_block_kind(delta_type: Any) -> str | None: + """Map a content_block_delta type to a content block kind (text/thinking/tool_use).""" + if not isinstance(delta_type, str): + return None + if delta_type in {"thinking_delta", "signature_delta"}: + return "thinking" + if delta_type == "text_delta": + return "text" + if delta_type == "input_json_delta": + return "tool_use" + return None + + +def _synthetic_start_content_block( + block_kind: str, + *, + upstream_index: int, + stored_tool_block: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build a `content_block` for a `content_block_start` with empty streaming fields.""" + if block_kind == "tool_use": + if ( + isinstance(stored_tool_block, dict) + and stored_tool_block.get("type") == "tool_use" + ): + tool_id = stored_tool_block.get("id") + name = stored_tool_block.get("name") + inp = stored_tool_block.get("input") + return { + "type": "tool_use", + "id": tool_id + if isinstance(tool_id, str) and tool_id + else f"toolu_or_{upstream_index}", + "name": name if isinstance(name, str) else "", + "input": inp if isinstance(inp, dict) else {}, + } + return { + "type": "tool_use", + "id": f"toolu_or_{upstream_index}", + "name": "", + "input": {}, + } + if block_kind == "thinking": + return {"type": "thinking", "thinking": ""} + if block_kind == "text": + return {"type": "text", "text": ""} + return {"type": "text", "text": ""} + + +def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool: + if not isinstance(block_type, str): + return False + if block_type.startswith("redacted_thinking"): + return not thinking_enabled + return not thinking_enabled and "thinking" in block_type + + +def _synthetic_close_other_open_blocks( + state: NativeSseBlockPolicyState, current_upstream: int +) -> str: + """Close every open block except `current_upstream` and track duplicate upstream stops.""" + out: list[str] = [] + for upstream, seg in list(state.by_upstream.items()): + if upstream == current_upstream or not seg.open: + continue + out.append( + format_native_sse_event( + "content_block_stop", + json.dumps( + { + "type": "content_block_stop", + "index": seg.down_index, + } + ), + ) + ) + seg.open = False + state.pending_suppressed_stops.add(upstream) + return "".join(out) + + +def _allocate_new_segment( + state: NativeSseBlockPolicyState, + upstream_index: int, + block_type: str, + *, + last_start_block: dict[str, Any] | None = None, +) -> int: + """Assign a new downstream `index` for a segment and record upstream state.""" + new_idx = state.next_index + state.next_index += 1 + state.by_upstream[upstream_index] = _UpstreamBlockState( + block_type=block_type, + down_index=new_idx, + open=True, + last_start_block=last_start_block, + ) + return new_idx + + +def transform_native_sse_block_event( + event: str, + state: NativeSseBlockPolicyState, + *, + thinking_enabled: bool, +) -> str | None: + """Normalize native Anthropic SSE events and enforce local thinking policy.""" + event_name, data_text = parse_native_sse_event(event) + if not event_name or not data_text: + return event + + try: + payload = json.loads(data_text) + except json.JSONDecodeError: + return event + + if event_name == "content_block_start": + block = payload.get("content_block") + if not isinstance(block, dict): + return event + block_type = block.get("type") + upstream_index = payload.get("index") + if not isinstance(upstream_index, int): + return event + if _should_drop_block_type(block_type, thinking_enabled=thinking_enabled): + state.dropped_indexes.add(upstream_index) + return None + + if not isinstance(block_type, str): + return event + prefix = _synthetic_close_other_open_blocks(state, upstream_index) + stored = copy.deepcopy(block) + new_idx = _allocate_new_segment( + state, + upstream_index, + block_type=block_type, + last_start_block=stored, + ) + payload["index"] = new_idx + return prefix + format_native_sse_event(event_name, json.dumps(payload)) + + if event_name == "content_block_delta": + delta = payload.get("delta") + if not isinstance(delta, dict): + return event + delta_type = delta.get("type") + upstream_index = payload.get("index") + if not isinstance(upstream_index, int): + return event + if upstream_index in state.dropped_indexes: + return None + if _should_drop_block_type(delta_type, thinking_enabled=thinking_enabled): + return None + + block_kind = _delta_type_to_block_kind(delta_type) + if block_kind is None: + return event + + seg = state.by_upstream.get(upstream_index) + if seg and seg.open: + payload["index"] = seg.down_index + return format_native_sse_event(event_name, json.dumps(payload)) + + if seg is not None and not seg.open: + # More deltas for an upstream block after a synthetic (or other) close: + # reopen with a new downstream `index` and emit a synthetic `content_block_start` first. + state.pending_suppressed_stops.discard(upstream_index) + carry = seg.last_start_block + new_idx = _allocate_new_segment( + state, + upstream_index, + block_type=block_kind, + last_start_block=carry, + ) + stored_tool = ( + carry + if isinstance(carry, dict) and carry.get("type") == "tool_use" + else None + ) + start_payload = { + "type": "content_block_start", + "index": new_idx, + "content_block": _synthetic_start_content_block( + block_kind, + upstream_index=upstream_index, + stored_tool_block=stored_tool, + ), + } + prefix = format_native_sse_event( + "content_block_start", json.dumps(start_payload) + ) + payload["index"] = new_idx + return prefix + format_native_sse_event(event_name, json.dumps(payload)) + + # Delta with no prior `content_block_start` in this stream + if block_kind in ("text", "tool_use"): + synthetic_block = _synthetic_start_content_block( + block_kind, + upstream_index=upstream_index, + ) + new_idx = _allocate_new_segment( + state, + upstream_index, + block_type=block_kind, + last_start_block=copy.deepcopy(synthetic_block), + ) + start_payload = { + "type": "content_block_start", + "index": new_idx, + "content_block": synthetic_block, + } + prefix = format_native_sse_event( + "content_block_start", json.dumps(start_payload) + ) + payload["index"] = new_idx + return prefix + format_native_sse_event(event_name, json.dumps(payload)) + # thinking: pass through raw (unusual upstream shape) + return event + + if event_name == "content_block_stop": + upstream_index = payload.get("index") + if not isinstance(upstream_index, int): + return event + if upstream_index in state.dropped_indexes: + return None + if upstream_index in state.pending_suppressed_stops: + state.pending_suppressed_stops.discard(upstream_index) + return None + + seg = state.by_upstream.get(upstream_index) + if seg is not None and seg.open: + payload["index"] = seg.down_index + seg.open = False + return format_native_sse_event(event_name, json.dumps(payload)) + if seg is not None: + # Spurious or duplicate `content_block_stop` for a closed block. + return None + if not thinking_enabled: + return None + return event + + return event diff --git a/core/anthropic/provider_stream_error.py b/core/anthropic/provider_stream_error.py new file mode 100644 index 0000000..5534c50 --- /dev/null +++ b/core/anthropic/provider_stream_error.py @@ -0,0 +1,34 @@ +"""Canonical Anthropic-style SSE sequence for provider-side streaming errors.""" + +from __future__ import annotations + +import uuid +from collections.abc import Iterator +from typing import Any + +from core.anthropic.sse import SSEBuilder + + +def iter_provider_stream_error_sse_events( + *, + request: Any, + input_tokens: int, + error_message: str, + sent_any_event: bool, + log_raw_sse_events: bool, + message_id: str | None = None, +) -> Iterator[str]: + """Yield message_start (if needed), a text block with the error, then message_delta/stop.""" + mid = message_id or f"msg_{uuid.uuid4()}" + model = getattr(request, "model", "") or "" + sse = SSEBuilder( + mid, + model, + input_tokens, + log_raw_events=log_raw_sse_events, + ) + if not sent_any_event: + yield sse.message_start() + yield from sse.emit_error(error_message) + yield sse.message_delta("end_turn", 1) + yield sse.message_stop() diff --git a/core/anthropic/server_tool_sse.py b/core/anthropic/server_tool_sse.py new file mode 100644 index 0000000..1c8727d --- /dev/null +++ b/core/anthropic/server_tool_sse.py @@ -0,0 +1,14 @@ +"""SSE content_block ``type`` values for Anthropic web server tools (local handlers). + +Shared by :mod:`api.web_tools` and stream contract tests to avoid drift. +""" + +from __future__ import annotations + +from typing import Final + +SERVER_TOOL_USE: Final = "server_tool_use" +WEB_SEARCH_TOOL_RESULT: Final = "web_search_tool_result" +WEB_FETCH_TOOL_RESULT: Final = "web_fetch_tool_result" +WEB_SEARCH_TOOL_RESULT_ERROR: Final = "web_search_tool_result_error" +WEB_FETCH_TOOL_ERROR: Final = "web_fetch_tool_error" diff --git a/core/anthropic/sse.py b/core/anthropic/sse.py index 7998839..3bf5ebe 100644 --- a/core/anthropic/sse.py +++ b/core/anthropic/sse.py @@ -1,5 +1,6 @@ """SSE event builder for Anthropic-format streaming responses.""" +import hashlib import json from collections.abc import Iterator from dataclasses import dataclass, field @@ -14,6 +15,13 @@ except Exception: ENCODER = None +# Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy. +ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = { + "X-Accel-Buffering": "no", + "Cache-Control": "no-cache", + "Connection": "keep-alive", +} + STOP_REASON_MAP = { "stop": "end_turn", "length": "max_tokens", @@ -29,6 +37,11 @@ def map_stop_reason(openai_reason: str | None) -> str: ) +def format_sse_event(event_type: str, data: dict) -> str: + """Format one Anthropic-style SSE event (no logging).""" + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + + @dataclass class ToolCallState: """State for a single streaming tool call.""" @@ -40,6 +53,7 @@ class ToolCallState: started: bool = False task_arg_buffer: str = "" task_args_emitted: bool = False + pre_start_args: str = "" @dataclass @@ -58,7 +72,25 @@ class ContentBlockManager: self.next_index += 1 return idx + def ensure_tool_state(self, index: int) -> ToolCallState: + """Create tool stream state for ``index`` when the first tool delta arrives.""" + if index not in self.tool_states: + self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="") + return self.tool_states[index] + + def set_stream_tool_id(self, index: int, tool_id: str | None) -> None: + """Record OpenAI tool call id before ``content_block_start`` (split-stream providers).""" + if not tool_id: + return + state = self.ensure_tool_state(index) + state.tool_id = str(tool_id) + def register_tool_name(self, index: int, name: str) -> None: + """Record tool name fragments as they arrive from chunked OpenAI streams. + + Names may be split across deltas; later chunks can extend (``ab`` + ``c``) + or repeat prefixes, so we merge conservatively. + """ if index not in self.tool_states: self.tool_states[index] = ToolCallState( block_index=-1, tool_id="", name=name @@ -82,8 +114,7 @@ class ContentBlockManager: except Exception: return None - if args_json.get("run_in_background") is not False: - args_json["run_in_background"] = False + _normalize_task_run_in_background(args_json) state.task_args_emitted = True state.task_arg_buffer = "" @@ -98,16 +129,17 @@ class ContentBlockManager: out = "{}" try: args_json = json.loads(state.task_arg_buffer) - if args_json.get("run_in_background") is not False: - args_json["run_in_background"] = False + _normalize_task_run_in_background(args_json) out = json.dumps(args_json) - except Exception as e: - prefix = state.task_arg_buffer[:120] + except (json.JSONDecodeError, TypeError, ValueError) as e: + digest = hashlib.sha256( + state.task_arg_buffer.encode("utf-8", errors="replace") + ).hexdigest()[:16] logger.warning( - "Task args invalid JSON (id={} len={} prefix={!r}): {}", + "Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}", state.tool_id or "unknown", len(state.task_arg_buffer), - prefix, + digest, e, ) @@ -117,20 +149,41 @@ class ContentBlockManager: return results +def _normalize_task_run_in_background(args_json: dict) -> None: + """Force Claude Code Task subagents to run in foreground (single shared rule).""" + if args_json.get("run_in_background") is not False: + args_json["run_in_background"] = False + + class SSEBuilder: """Builder for Anthropic SSE streaming events.""" - def __init__(self, message_id: str, model: str, input_tokens: int = 0): + def __init__( + self, + message_id: str, + model: str, + input_tokens: int = 0, + *, + log_raw_events: bool = False, + ): self.message_id = message_id self.model = model self.input_tokens = input_tokens + self._log_raw_events = log_raw_events self.blocks = ContentBlockManager() self._accumulated_text_parts: list[str] = [] self._accumulated_reasoning_parts: list[str] = [] def _format_event(self, event_type: str, data: dict) -> str: - event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip()) + event_str = format_sse_event(event_type, data) + if self._log_raw_events: + logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip()) + else: + logger.debug( + "SSE_EVENT: event_type={} serialized_bytes={}", + event_type, + len(event_str.encode("utf-8")), + ) return event_str def message_start(self) -> str: @@ -289,10 +342,7 @@ class SSEBuilder: yield self.stop_text_block() def close_all_blocks(self) -> Iterator[str]: - if self.blocks.thinking_started: - yield self.stop_thinking_block() - if self.blocks.text_started: - yield self.stop_text_block() + yield from self.close_content_blocks() for tool_index, state in list(self.blocks.tool_states.items()): if state.started: yield self.stop_tool_block(tool_index) diff --git a/core/anthropic/stream_contracts.py b/core/anthropic/stream_contracts.py index 8058392..ba2b4d8 100644 --- a/core/anthropic/stream_contracts.py +++ b/core/anthropic/stream_contracts.py @@ -1,7 +1,6 @@ """Neutral SSE parsing and Anthropic stream shape assertions. -Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for -opt-in smoke scenarios. +Used by default CI contract tests and by opt-in live smoke scenarios. """ from __future__ import annotations @@ -11,6 +10,36 @@ from collections.abc import Iterable from dataclasses import dataclass from typing import Any +from .server_tool_sse import ( + SERVER_TOOL_USE, + WEB_FETCH_TOOL_RESULT, + WEB_SEARCH_TOOL_RESULT, +) + +# Content blocks that only use content_block_start/stop (no deltas), including +# Anthropic server tools and eager text emitted in a single start event. +_NO_DELTA_BLOCK_KINDS = frozenset( + { + SERVER_TOOL_USE, + WEB_SEARCH_TOOL_RESULT, + WEB_FETCH_TOOL_RESULT, + "text_eager", + "redacted_thinking", + } +) + +_ALLOWED_BLOCK_START_TYPES = frozenset( + { + "text", + "thinking", + "tool_use", + "redacted_thinking", + SERVER_TOOL_USE, + WEB_SEARCH_TOOL_RESULT, + WEB_FETCH_TOOL_RESULT, + } +) + @dataclass(frozen=True, slots=True) class SSEEvent: @@ -86,35 +115,46 @@ def assert_anthropic_stream_contract( raise AssertionError(f"unexpected SSE error event: {event.data}") if event.event == "content_block_start": - index = _event_index(event) + index = event_index(event) block = event.data.get("content_block", {}) assert isinstance(block, dict), event.data block_type = str(block.get("type", "")) - assert block_type in {"text", "thinking", "tool_use"}, event.data + assert block_type in _ALLOWED_BLOCK_START_TYPES, event.data assert index not in open_blocks, f"block {index} started twice" assert index not in seen_blocks, f"block {index} reused after stop" - open_blocks[index] = block_type + if block_type == "text" and str(block.get("text", "")).strip(): + storage = "text_eager" + else: + storage = block_type + open_blocks[index] = storage seen_blocks.add(index) continue if event.event == "content_block_delta": - index = _event_index(event) + index = event_index(event) assert index in open_blocks, f"delta for unopened block {index}" + kind = open_blocks[index] + assert kind not in _NO_DELTA_BLOCK_KINDS, ( + f"unexpected delta for start/stop-only block {kind} at index {index}" + ) delta = event.data.get("delta", {}) assert isinstance(delta, dict), event.data delta_type = str(delta.get("type", "")) + if kind == "thinking": + assert delta_type in ( + "thinking_delta", + "signature_delta", + ), f"block {index} is {kind}, got {delta_type}" + continue expected = { "text": "text_delta", - "thinking": "thinking_delta", "tool_use": "input_json_delta", - }[open_blocks[index]] - assert delta_type == expected, ( - f"block {index} is {open_blocks[index]}, got {delta_type}" - ) + }[kind] + assert delta_type == expected, f"block {index} is {kind}, got {delta_type}" continue if event.event == "content_block_stop": - index = _event_index(event) + index = event_index(event) assert index in open_blocks, f"stop for unopened block {index}" open_blocks.pop(index) @@ -129,6 +169,12 @@ def event_names(events: list[SSEEvent]) -> list[str]: def text_content(events: list[SSEEvent]) -> str: parts: list[str] = [] for event in events: + if event.event == "content_block_start": + block = event.data.get("content_block", {}) + if isinstance(block, dict) and block.get("type") == "text": + eager = str(block.get("text", "")) + if eager: + parts.append(eager) delta = event.data.get("delta", {}) if isinstance(delta, dict) and delta.get("type") == "text_delta": parts.append(str(delta.get("text", ""))) @@ -152,7 +198,8 @@ def has_tool_use(events: list[SSEEvent]) -> bool: return False -def _event_index(event: SSEEvent) -> int: +def event_index(event: SSEEvent) -> int: + """Return the content block ``index`` field from an SSE payload (strict).""" value = event.data.get("index") assert isinstance(value, int), event.data return value diff --git a/core/anthropic/tokens.py b/core/anthropic/tokens.py index dfa1115..081d28b 100644 --- a/core/anthropic/tokens.py +++ b/core/anthropic/tokens.py @@ -68,6 +68,27 @@ def get_token_count( total_tokens += len(ENCODER.encode(json.dumps(content))) total_tokens += len(ENCODER.encode(str(tool_use_id))) total_tokens += 8 + elif b_type in ( + "server_tool_use", + "web_search_tool_result", + "web_fetch_tool_result", + ): + if hasattr(block, "model_dump"): + blob: object = block.model_dump() + else: + blob = block + try: + total_tokens += len( + ENCODER.encode( + json.dumps(blob, default=str, ensure_ascii=False) + ) + ) + except (TypeError, ValueError, OverflowError) as e: + logger.debug( + "Block encode fallback b_type={} err={}", b_type, e + ) + total_tokens += len(ENCODER.encode(str(blob))) + total_tokens += 12 else: logger.debug( "Unexpected block type %r, falling back to json/str encoding", diff --git a/core/rate_limit.py b/core/rate_limit.py new file mode 100644 index 0000000..8d9bd81 --- /dev/null +++ b/core/rate_limit.py @@ -0,0 +1,60 @@ +"""Shared strict sliding-window rate limiting primitives.""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque + + +class StrictSlidingWindowLimiter: + """Strict sliding window limiter. + + Guarantees: at most ``rate_limit`` acquisitions in any interval of length + ``rate_window`` (seconds). + + Implemented as an async context manager so call sites can do:: + + async with limiter: + ... + """ + + def __init__(self, rate_limit: int, rate_window: float) -> None: + if rate_limit <= 0: + raise ValueError("rate_limit must be > 0") + if rate_window <= 0: + raise ValueError("rate_window must be > 0") + + self._rate_limit = int(rate_limit) + self._rate_window = float(rate_window) + self._times: deque[float] = deque() + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + while True: + wait_time = 0.0 + async with self._lock: + now = time.monotonic() + cutoff = now - self._rate_window + + while self._times and self._times[0] <= cutoff: + self._times.popleft() + + if len(self._times) < self._rate_limit: + self._times.append(now) + return + + oldest = self._times[0] + wait_time = max(0.0, (oldest + self._rate_window) - now) + + if wait_time > 0: + await asyncio.sleep(wait_time) + else: + await asyncio.sleep(0) + + async def __aenter__(self) -> StrictSlidingWindowLimiter: + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False diff --git a/messaging/cli_event_constants.py b/messaging/cli_event_constants.py new file mode 100644 index 0000000..3d531ec --- /dev/null +++ b/messaging/cli_event_constants.py @@ -0,0 +1,67 @@ +"""CLI event types and status-line mapping for transcript / UI updates.""" + +from collections.abc import Callable +from typing import Any + +# Status message prefixes used to filter our own messages (ignore echo) +STATUS_MESSAGE_PREFIXES = ( + "⏳", + "💭", + "🔧", + "✅", + "❌", + "🚀", + "🤖", + "📋", + "📊", + "🔄", +) + +# Event types that update the transcript (frozenset for O(1) membership) +TRANSCRIPT_EVENT_TYPES = frozenset( + { + "thinking_start", + "thinking_delta", + "thinking_chunk", + "thinking_stop", + "text_start", + "text_delta", + "text_chunk", + "text_stop", + "tool_use_start", + "tool_use_delta", + "tool_use_stop", + "tool_use", + "tool_result", + "block_stop", + "error", + } +) + +# Event type -> (emoji, label) for status updates (O(1) lookup) +_EVENT_STATUS_MAP: dict[str, tuple[str, str]] = { + "thinking_start": ("🧠", "Claude is thinking..."), + "thinking_delta": ("🧠", "Claude is thinking..."), + "thinking_chunk": ("🧠", "Claude is thinking..."), + "text_start": ("🧠", "Claude is working..."), + "text_delta": ("🧠", "Claude is working..."), + "text_chunk": ("🧠", "Claude is working..."), + "tool_result": ("⏳", "Executing tools..."), +} + + +def get_status_for_event( + ptype: str, + parsed: dict[str, Any], + format_status_fn: Callable[..., str], +) -> str | None: + """Return status string for event type, or None if no status update needed.""" + entry = _EVENT_STATUS_MAP.get(ptype) + if entry is not None: + emoji, label = entry + return format_status_fn(emoji, label) + if ptype in ("tool_use_start", "tool_use_delta", "tool_use"): + if parsed.get("name") == "Task": + return format_status_fn("🤖", "Subagent working...") + return format_status_fn("⏳", "Executing tools...") + return None diff --git a/messaging/event_parser.py b/messaging/event_parser.py index 2499248..87eb82a 100644 --- a/messaging/event_parser.py +++ b/messaging/event_parser.py @@ -9,12 +9,14 @@ from typing import Any from loguru import logger -def parse_cli_event(event: Any) -> list[dict]: +def parse_cli_event(event: Any, *, log_raw_cli: bool = False) -> list[dict]: """ Parse a CLI event and return a structured result. Args: event: Raw event dictionary from CLI + log_raw_cli: When True, log full error text from the CLI. Default is + metadata-only (lengths / exit codes) to avoid leaking user content. Returns: List of parsed event dicts. Empty list if not recognized. @@ -140,7 +142,11 @@ def parse_cli_event(event: Any) -> list[dict]: if etype == "error": err = event.get("error") msg = err.get("message") if isinstance(err, dict) else str(err) - logger.info(f"CLI_PARSER: Parsed error event: {msg}") + if log_raw_cli: + logger.info("CLI_PARSER: Parsed error event: {}", msg) + else: + mlen = len(msg) if isinstance(msg, str) else 0 + logger.info("CLI_PARSER: Parsed error event: message_chars={}", mlen) return [{"type": "error", "message": msg}] elif etype == "exit": code = event.get("code", 0) @@ -151,7 +157,19 @@ def parse_cli_event(event: Any) -> list[dict]: else: # Non-zero exit is an error error_msg = stderr if stderr else f"Process exited with code {code}" - logger.warning(f"CLI_PARSER: Error exit (code={code}): {error_msg}") + if log_raw_cli: + logger.warning( + "CLI_PARSER: Error exit (code={}): {}", + code, + error_msg, + ) + else: + em = error_msg if isinstance(error_msg, str) else str(error_msg) + logger.warning( + "CLI_PARSER: Error exit (code={}): message_chars={}", + code, + len(em), + ) return [ {"type": "error", "message": error_msg}, {"type": "complete", "status": "failed"}, diff --git a/messaging/handler.py b/messaging/handler.py index 949cafb..7702e11 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -7,13 +7,12 @@ Uses tree-based queuing for message ordering. """ import asyncio -import os -import time from loguru import logger -from core.anthropic import get_user_facing_error_message +from core.anthropic import format_user_error_preview, get_user_facing_error_message +from .cli_event_constants import STATUS_MESSAGE_PREFIXES from .command_dispatcher import ( dispatch_command, message_kind_for_command, @@ -21,8 +20,10 @@ from .command_dispatcher import ( ) from .event_parser import parse_cli_event from .models import IncomingMessage +from .node_event_pipeline import handle_session_info_event, process_parsed_cli_event from .platforms.base import MessagingPlatform, SessionManagerInterface from .rendering.profiles import build_rendering_profile +from .safe_diagnostics import format_exception_for_log from .session import SessionStore from .transcript import RenderCtx, TranscriptBuffer from .trees.queue_manager import ( @@ -31,54 +32,7 @@ from .trees.queue_manager import ( MessageTree, TreeQueueManager, ) - -# Status message prefixes used to filter our own messages (ignore echo) -STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄") - -# Event types that update the transcript (frozenset for O(1) membership) -TRANSCRIPT_EVENT_TYPES = frozenset( - { - "thinking_start", - "thinking_delta", - "thinking_chunk", - "thinking_stop", - "text_start", - "text_delta", - "text_chunk", - "text_stop", - "tool_use_start", - "tool_use_delta", - "tool_use_stop", - "tool_use", - "tool_result", - "block_stop", - "error", - } -) - -# Event type -> (emoji, label) for status updates (O(1) lookup) -_EVENT_STATUS_MAP = { - "thinking_start": ("🧠", "Claude is thinking..."), - "thinking_delta": ("🧠", "Claude is thinking..."), - "thinking_chunk": ("🧠", "Claude is thinking..."), - "text_start": ("🧠", "Claude is working..."), - "text_delta": ("🧠", "Claude is working..."), - "text_chunk": ("🧠", "Claude is working..."), - "tool_result": ("⏳", "Executing tools..."), -} - - -def _get_status_for_event(ptype: str, parsed: dict, format_status_fn) -> str | None: - """Return status string for event type, or None if no status update needed.""" - entry = _EVENT_STATUS_MAP.get(ptype) - if entry is not None: - emoji, label = entry - return format_status_fn(emoji, label) - if ptype in ("tool_use_start", "tool_use_delta", "tool_use"): - if parsed.get("name") == "Task": - return format_status_fn("🤖", "Subagent working...") - return format_status_fn("⏳", "Executing tools...") - return None +from .ui_updates import ThrottledTranscriptEditor class ClaudeMessageHandler: @@ -97,10 +51,21 @@ class ClaudeMessageHandler: platform: MessagingPlatform, cli_manager: SessionManagerInterface, session_store: SessionStore, + *, + debug_platform_edits: bool = False, + debug_subagent_stack: bool = False, + log_raw_messaging_content: bool = False, + log_raw_cli_diagnostics: bool = False, + log_messaging_error_details: bool = False, ): self.platform = platform self.cli_manager = cli_manager self.session_store = session_store + self._debug_platform_edits = debug_platform_edits + self._debug_subagent_stack = debug_subagent_stack + self._log_raw_messaging_content = log_raw_messaging_content + self._log_raw_cli_diagnostics = log_raw_cli_diagnostics + self._log_messaging_error_details = log_messaging_error_details self._tree_queue = TreeQueueManager( queue_update_callback=self.update_queue_positions, node_started_callback=self.mark_node_processing, @@ -137,16 +102,26 @@ class ClaudeMessageHandler: Determines if this is a new conversation or reply, creates/extends the message tree, and queues for processing. """ - text_preview = (incoming.text or "")[:80] - if len(incoming.text or "") > 80: - text_preview += "..." - logger.info( - "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}", - incoming.chat_id, - incoming.message_id, - incoming.reply_to_message_id, - text_preview, - ) + raw = incoming.text or "" + if self._log_raw_messaging_content: + text_preview = raw[:80] + if len(raw) > 80: + text_preview += "..." + logger.info( + "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}", + incoming.chat_id, + incoming.message_id, + incoming.reply_to_message_id, + text_preview, + ) + else: + logger.info( + "HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_len={}", + incoming.chat_id, + incoming.message_id, + incoming.reply_to_message_id, + len(raw), + ) with logger.contextualize( chat_id=incoming.chat_id, node_id=incoming.message_id @@ -169,7 +144,12 @@ class ClaudeMessageHandler: kind=message_kind_for_command(cmd_base), ) except Exception as e: - logger.debug(f"Failed to record incoming message_id: {e}") + logger.debug( + "Failed to record incoming message_id: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) if await dispatch_command(self, incoming, cmd_base): return @@ -276,7 +256,12 @@ class ClaudeMessageHandler: try: queued_ids = await tree.get_queue_snapshot() except Exception as e: - logger.warning(f"Failed to read queue snapshot: {e}") + logger.warning( + "Failed to read queue snapshot: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) return if not queued_ids: @@ -317,85 +302,11 @@ class ClaudeMessageHandler: self, ) -> tuple[TranscriptBuffer, RenderCtx]: """Create transcript buffer and render context for node processing.""" - transcript = TranscriptBuffer(show_tool_results=False) - return transcript, self.get_render_ctx() - - async def _handle_session_info_event( - self, - event_data: dict, - tree: MessageTree | None, - node_id: str, - captured_session_id: str | None, - temp_session_id: str | None, - ) -> tuple[str | None, str | None]: - """Handle session_info event; return updated (captured_session_id, temp_session_id).""" - if event_data.get("type") != "session_info": - return captured_session_id, temp_session_id - - real_session_id = event_data.get("session_id") - if not real_session_id or not temp_session_id: - return captured_session_id, temp_session_id - - await self.cli_manager.register_real_session_id( - temp_session_id, real_session_id + transcript = TranscriptBuffer( + show_tool_results=False, + debug_subagent_stack=self._debug_subagent_stack, ) - if tree and real_session_id: - await tree.update_state( - node_id, - MessageState.IN_PROGRESS, - session_id=real_session_id, - ) - self.session_store.save_tree(tree.root_id, tree.to_dict()) - - return real_session_id, None - - async def _process_parsed_event( - self, - parsed: dict, - transcript: TranscriptBuffer, - update_ui, - last_status: str | None, - had_transcript_events: bool, - tree: MessageTree | None, - node_id: str, - captured_session_id: str | None, - ) -> tuple[str | None, bool]: - """Process a single parsed CLI event. Returns (last_status, had_transcript_events).""" - ptype = parsed.get("type") or "" - - if ptype in TRANSCRIPT_EVENT_TYPES: - transcript.apply(parsed) - had_transcript_events = True - - status = _get_status_for_event(ptype, parsed, self.format_status) - if status is not None: - await update_ui(status) - last_status = status - elif ptype == "block_stop": - await update_ui(last_status, force=True) - elif ptype == "complete": - if not had_transcript_events: - transcript.apply({"type": "text_chunk", "text": "Done."}) - logger.info("HANDLER: Task complete, updating UI") - await update_ui(self.format_status("✅", "Complete"), force=True) - if tree and captured_session_id: - await tree.update_state( - node_id, - MessageState.COMPLETED, - session_id=captured_session_id, - ) - self.session_store.save_tree(tree.root_id, tree.to_dict()) - elif ptype == "error": - error_msg = parsed.get("message", "Unknown error") - logger.error(f"HANDLER: Error event received: {error_msg}") - logger.info("HANDLER: Updating UI with error status") - await update_ui(self.format_status("❌", "Error"), force=True) - if tree: - await self._propagate_error_to_children( - node_id, error_msg, "Parent task failed" - ) - - return last_status, had_transcript_events + return transcript, self.get_render_ctx() async def _process_node( self, @@ -426,8 +337,6 @@ class ClaudeMessageHandler: transcript, render_ctx = self._create_transcript_and_render_ctx() - last_ui_update = 0.0 - last_displayed_text = None had_transcript_events = False captured_session_id = None temp_session_id = None @@ -439,52 +348,21 @@ class ClaudeMessageHandler: if parent_session_id: logger.info(f"Will fork from parent session: {parent_session_id}") - async def update_ui(status: str | None = None, force: bool = False) -> None: - nonlocal last_ui_update, last_displayed_text, last_status - now = time.time() - if not force and now - last_ui_update < 1.0: - return + editor = ThrottledTranscriptEditor( + platform=self.platform, + parse_mode=self._parse_mode(), + get_limit_chars=self._get_limit_chars, + transcript=transcript, + render_ctx=render_ctx, + node_id=node_id, + chat_id=chat_id, + status_msg_id=status_msg_id, + debug_platform_edits=self._debug_platform_edits, + log_messaging_error_details=self._log_messaging_error_details, + ) - last_ui_update = now - if status is not None: - last_status = status - try: - display = transcript.render( - render_ctx, - limit_chars=self._get_limit_chars(), - status=status, - ) - except Exception as e: - logger.warning(f"Transcript render failed for node {node_id}: {e}") - return - if display and display != last_displayed_text: - logger.debug( - "PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}", - node_id, - chat_id, - status_msg_id, - bool(force), - status, - len(display), - ) - if os.getenv("DEBUG_PLATFORM_EDITS") == "1": - logger.debug("PLATFORM_EDIT_TEXT:\n{}", display) - else: - head = display[:500] - tail = display[-500:] if len(display) > 500 else "" - logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head) - if tail: - logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail) - last_displayed_text = display - try: - await self.platform.queue_edit_message( - chat_id, - status_msg_id, - display, - parse_mode=self._parse_mode(), - ) - except Exception as e: - logger.warning(f"Failed to update platform for node {node_id}: {e}") + async def update_ui(status: str | None = None, force: bool = False) -> None: + await editor.update(status, force=force) try: try: @@ -531,20 +409,28 @@ class ClaudeMessageHandler: ( captured_session_id, temp_session_id, - ) = await self._handle_session_info_event( - event_data, tree, node_id, captured_session_id, temp_session_id + ) = await handle_session_info_event( + event_data, + tree, + node_id, + captured_session_id, + temp_session_id, + cli_manager=self.cli_manager, + session_store=self.session_store, ) if event_data.get("type") == "session_info": continue - parsed_list = parse_cli_event(event_data) + parsed_list = parse_cli_event( + event_data, log_raw_cli=self._log_raw_cli_diagnostics + ) logger.debug(f"HANDLER: Parsed {len(parsed_list)} events from CLI") for parsed in parsed_list: ( last_status, had_transcript_events, - ) = await self._process_parsed_event( + ) = await process_parsed_cli_event( parsed, transcript, update_ui, @@ -553,6 +439,10 @@ class ClaudeMessageHandler: tree, node_id, captured_session_id, + session_store=self.session_store, + format_status=self.format_status, + propagate_error_to_children=self._propagate_error_to_children, + log_messaging_error_details=self._log_messaging_error_details, ) except asyncio.CancelledError: @@ -575,9 +465,12 @@ class ClaudeMessageHandler: ) except Exception as e: logger.error( - f"HANDLER: Task failed with exception: {type(e).__name__}: {e}" + "HANDLER: Task failed with exception: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), ) - error_msg = get_user_facing_error_message(e)[:200] + error_msg = format_user_error_preview(e) transcript.apply({"type": "error", "message": error_msg}) await update_ui(self.format_status("💥", "Task Failed"), force=True) if tree: @@ -595,7 +488,13 @@ class ClaudeMessageHandler: elif temp_session_id: await self.cli_manager.remove_session(temp_session_id) except Exception as e: - logger.debug(f"Failed to remove session for node {node_id}: {e}") + logger.debug( + "Failed to remove session for node {}: {}", + node_id, + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) async def _propagate_error_to_children( self, @@ -691,7 +590,12 @@ class ClaudeMessageHandler: platform, chat_id, str(msg_id), direction="out", kind=kind ) except Exception as e: - logger.debug(f"Failed to record message_id: {e}") + logger.debug( + "Failed to record message_id: {}", + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None: """Update status messages and persist tree state for cancelled nodes.""" diff --git a/messaging/limiter.py b/messaging/limiter.py index 6367e9e..648d4a8 100644 --- a/messaging/limiter.py +++ b/messaging/limiter.py @@ -6,65 +6,16 @@ using a strict sliding window algorithm and a task queue. """ import asyncio -import os -import time from collections import deque from collections.abc import Awaitable, Callable from typing import Any from loguru import logger +from config.settings import get_settings +from core.rate_limit import StrictSlidingWindowLimiter as SlidingWindowLimiter -class SlidingWindowLimiter: - """Strict sliding window limiter. - - Guarantees: at most `rate_limit` acquisitions in any interval of length - `rate_window` (seconds). - - Implemented as an async context manager so call sites can do: - async with limiter: - ... - """ - - def __init__(self, rate_limit: int, rate_window: float) -> None: - if rate_limit <= 0: - raise ValueError("rate_limit must be > 0") - if rate_window <= 0: - raise ValueError("rate_window must be > 0") - - self._rate_limit = int(rate_limit) - self._rate_window = float(rate_window) - self._times: deque[float] = deque() - self._lock = asyncio.Lock() - - async def acquire(self) -> None: - while True: - wait_time = 0.0 - async with self._lock: - now = time.monotonic() - cutoff = now - self._rate_window - - while self._times and self._times[0] <= cutoff: - self._times.popleft() - - if len(self._times) < self._rate_limit: - self._times.append(now) - return - - oldest = self._times[0] - wait_time = max(0.0, (oldest + self._rate_window) - now) - - if wait_time > 0: - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0) - - async def __aenter__(self) -> SlidingWindowLimiter: - await self.acquire() - return self - - async def __aexit__(self, exc_type, exc, tb) -> bool: - return False +from .safe_diagnostics import format_exception_for_log class MessagingRateLimiter: @@ -82,23 +33,29 @@ class MessagingRateLimiter: return super().__new__(cls) @classmethod - async def get_instance(cls) -> MessagingRateLimiter: - """Get the singleton instance of the limiter.""" + async def get_instance( + cls, + *, + rate_limit: int = 1, + rate_window: float = 1.0, + ) -> MessagingRateLimiter: + """Get the singleton instance of the limiter. + + ``rate_limit`` and ``rate_window`` apply only when the singleton is first + created. Call :meth:`shutdown_instance` before changing parameters. + """ async with cls._lock: if cls._instance is None: - cls._instance = cls() + cls._instance = cls(rate_limit=rate_limit, rate_window=rate_window) # Start the background worker (tracked for graceful shutdown). cls._instance._start_worker() return cls._instance - def __init__(self): + def __init__(self, *, rate_limit: int, rate_window: float) -> None: # Prevent double initialization in singleton if hasattr(self, "_initialized"): return - rate_limit = int(os.getenv("MESSAGING_RATE_LIMIT", "1")) - rate_window = float(os.getenv("MESSAGING_RATE_WINDOW", "2.0")) - self.limiter = SlidingWindowLimiter(rate_limit, rate_window) # Custom queue state - using deque for O(1) popleft self._queue_list: deque[str] = deque() # Deque of dedup_keys in order @@ -189,15 +146,27 @@ class MessagingRateLimiter: asyncio.get_event_loop().time() + wait_secs ) else: + d = get_settings().log_messaging_error_details logger.error( - f"Error in limiter worker for key {dedup_key}: {type(e).__name__}: {e}" + "Error in limiter worker for key {}: {}", + dedup_key, + format_exception_for_log(e, log_full_message=d), ) except asyncio.CancelledError: break except Exception as e: - logger.error( - f"MessagingRateLimiter worker critical error: {e}", exc_info=True - ) + d = get_settings().log_messaging_error_details + if d: + logger.error( + "MessagingRateLimiter worker critical error: {}", + e, + exc_info=True, + ) + else: + logger.error( + "MessagingRateLimiter worker critical error: exc_type={}", + type(e).__name__, + ) await asyncio.sleep(1) async def shutdown(self, timeout: float = 2.0) -> None: @@ -223,7 +192,11 @@ class MessagingRateLimiter: except asyncio.CancelledError: pass except Exception as e: - logger.debug(f"MessagingRateLimiter worker shutdown error: {e}") + d = get_settings().log_messaging_error_details + logger.debug( + "MessagingRateLimiter worker shutdown error: {}", + format_exception_for_log(e, log_full_message=d), + ) finally: self._worker_task = None @@ -296,14 +269,29 @@ class MessagingRateLimiter: x in error_msg for x in ["connect", "timeout", "broken"] ): wait = 2**attempt - logger.warning( - f"Limiter fire_and_forget transient error (attempt {attempt + 1}): {e}. Retrying in {wait}s..." - ) + d = get_settings().log_messaging_error_details + if d: + logger.warning( + "Limiter fire_and_forget transient error (attempt {}): {}. Retrying in {}s...", + attempt + 1, + e, + wait, + ) + else: + logger.warning( + "Limiter fire_and_forget transient error (attempt {}): exc_type={}. Retrying in {}s...", + attempt + 1, + type(e).__name__, + wait, + ) await asyncio.sleep(wait) continue + d = get_settings().log_messaging_error_details logger.error( - f"Final error in fire_and_forget for key {dedup_key}: {type(e).__name__}: {e}" + "Final error in fire_and_forget for key {}: {}", + dedup_key, + format_exception_for_log(e, log_full_message=d), ) if not future.done(): future.set_exception(e) diff --git a/messaging/node_event_pipeline.py b/messaging/node_event_pipeline.py new file mode 100644 index 0000000..6f233e2 --- /dev/null +++ b/messaging/node_event_pipeline.py @@ -0,0 +1,103 @@ +"""CLI event handling for a single queued node (transcript + session + errors).""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from loguru import logger + +from .cli_event_constants import TRANSCRIPT_EVENT_TYPES, get_status_for_event +from .platforms.base import SessionManagerInterface +from .safe_diagnostics import text_len_hint +from .session import SessionStore +from .transcript import TranscriptBuffer +from .trees.queue_manager import MessageState, MessageTree + + +async def handle_session_info_event( + event_data: dict[str, Any], + tree: MessageTree | None, + node_id: str, + captured_session_id: str | None, + temp_session_id: str | None, + *, + cli_manager: SessionManagerInterface, + session_store: SessionStore, +) -> tuple[str | None, str | None]: + """Handle session_info event; return updated (captured_session_id, temp_session_id).""" + if event_data.get("type") != "session_info": + return captured_session_id, temp_session_id + + real_session_id = event_data.get("session_id") + if not real_session_id or not temp_session_id: + return captured_session_id, temp_session_id + + await cli_manager.register_real_session_id(temp_session_id, real_session_id) + if tree and real_session_id: + await tree.update_state( + node_id, + MessageState.IN_PROGRESS, + session_id=real_session_id, + ) + session_store.save_tree(tree.root_id, tree.to_dict()) + + return real_session_id, None + + +async def process_parsed_cli_event( + parsed: dict[str, Any], + transcript: TranscriptBuffer, + update_ui: Callable[..., Awaitable[None]], + last_status: str | None, + had_transcript_events: bool, + tree: MessageTree | None, + node_id: str, + captured_session_id: str | None, + *, + session_store: SessionStore, + format_status: Callable[..., str], + propagate_error_to_children: Callable[[str, str, str], Awaitable[None]], + log_messaging_error_details: bool = False, +) -> tuple[str | None, bool]: + """Process a single parsed CLI event. Returns (last_status, had_transcript_events).""" + ptype = parsed.get("type") or "" + + if ptype in TRANSCRIPT_EVENT_TYPES: + transcript.apply(parsed) + had_transcript_events = True + + status = get_status_for_event(ptype, parsed, format_status) + if status is not None: + await update_ui(status) + last_status = status + elif ptype == "block_stop": + await update_ui(last_status, force=True) + elif ptype == "complete": + if not had_transcript_events: + transcript.apply({"type": "text_chunk", "text": "Done."}) + logger.info("HANDLER: Task complete, updating UI") + await update_ui(format_status("✅", "Complete"), force=True) + if tree and captured_session_id: + await tree.update_state( + node_id, + MessageState.COMPLETED, + session_id=captured_session_id, + ) + session_store.save_tree(tree.root_id, tree.to_dict()) + elif ptype == "error": + error_msg = parsed.get("message", "Unknown error") + if log_messaging_error_details: + logger.error("HANDLER: Error event received: {}", error_msg) + else: + em = error_msg if isinstance(error_msg, str) else str(error_msg) + logger.error( + "HANDLER: Error event received: message_chars={}", + text_len_hint(em), + ) + logger.info("HANDLER: Updating UI with error status") + await update_ui(format_status("❌", "Error"), force=True) + if tree: + await propagate_error_to_children(node_id, error_msg, "Parent task failed") + + return last_status, had_transcript_events diff --git a/messaging/platforms/discord.py b/messaging/platforms/discord.py index 201682b..53aa7dc 100644 --- a/messaging/platforms/discord.py +++ b/messaging/platforms/discord.py @@ -6,7 +6,6 @@ Implements MessagingPlatform for Discord using discord.py. import asyncio import contextlib -import os import tempfile from collections.abc import Awaitable, Callable from pathlib import Path @@ -14,7 +13,7 @@ from typing import Any, cast from loguru import logger -from core.anthropic import get_user_facing_error_message +from core.anthropic import format_user_error_preview from ..models import IncomingMessage from ..rendering.discord_markdown import format_status_discord @@ -97,15 +96,18 @@ class DiscordPlatform(MessagingPlatform): whisper_device: str = "cpu", hf_token: str = "", nvidia_nim_api_key: str = "", + messaging_rate_limit: int = 1, + messaging_rate_window: float = 1.0, + log_raw_messaging_content: bool = False, + log_api_error_tracebacks: bool = False, ): if not DISCORD_AVAILABLE: raise ImportError( "discord.py is required. Install with: pip install discord.py" ) - self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN") - raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS") - self.allowed_channel_ids = _parse_allowed_channels(raw_channels) + self.bot_token = bot_token + self.allowed_channel_ids = _parse_allowed_channels(allowed_channel_ids) if not self.bot_token: logger.warning("DISCORD_BOT_TOKEN not set") @@ -130,6 +132,10 @@ class DiscordPlatform(MessagingPlatform): self._voice_note_enabled = voice_note_enabled self._whisper_model = whisper_model self._whisper_device = whisper_device + self._messaging_rate_limit = messaging_rate_limit + self._messaging_rate_window = messaging_rate_window + self._log_raw_messaging_content = log_raw_messaging_content + self._log_api_error_tracebacks = log_api_error_tracebacks async def _handle_client_message(self, message: Any) -> None: """Adapter entry point used by the internal discord client.""" @@ -234,23 +240,40 @@ class DiscordPlatform(MessagingPlatform): status_message_id=status_msg_id, ) - logger.info( - "DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}", - channel_id, - message_id, - (transcribed[:80] + "..." if len(transcribed) > 80 else transcribed), - ) + if self._log_raw_messaging_content: + logger.info( + "DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}", + channel_id, + message_id, + ( + transcribed[:80] + "..." + if len(transcribed) > 80 + else transcribed + ), + ) + else: + logger.info( + "DISCORD_VOICE: chat_id={} message_id={} transcribed_len={}", + channel_id, + message_id, + len(transcribed), + ) await self._message_handler(incoming) return True except ValueError as e: - await message.reply(get_user_facing_error_message(e)[:200]) + await message.reply(format_user_error_preview(e)) return True except ImportError as e: - await message.reply(get_user_facing_error_message(e)[:200]) + await message.reply(format_user_error_preview(e)) return True except Exception as e: - logger.error(f"Voice transcription failed: {e}") + if self._log_api_error_tracebacks: + logger.error("Voice transcription failed: {}", e) + else: + logger.error( + "Voice transcription failed: exc_type={}", type(e).__name__ + ) await message.reply( "Could not transcribe voice note. Please try again or send text." ) @@ -285,16 +308,26 @@ class DiscordPlatform(MessagingPlatform): else None ) - text_preview = (message.content or "")[:80] - if len(message.content or "") > 80: - text_preview += "..." - logger.info( - "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", - channel_id, - message_id, - reply_to, - text_preview, - ) + raw_content = message.content or "" + if self._log_raw_messaging_content: + text_preview = raw_content[:80] + if len(raw_content) > 80: + text_preview += "..." + logger.info( + "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", + channel_id, + message_id, + reply_to, + text_preview, + ) + else: + logger.info( + "DISCORD_MSG: chat_id={} message_id={} reply_to={} text_len={}", + channel_id, + message_id, + reply_to, + len(raw_content), + ) if not self._message_handler: return @@ -313,13 +346,14 @@ class DiscordPlatform(MessagingPlatform): try: await self._message_handler(incoming) except Exception as e: - logger.error(f"Error handling message: {e}") + if self._log_api_error_tracebacks: + logger.error("Error handling message: {}", e) + else: + logger.error("Error handling message: exc_type={}", type(e).__name__) with contextlib.suppress(Exception): await self.send_message( channel_id, - format_status_discord( - "Error:", get_user_facing_error_message(e)[:200] - ), + format_status_discord("Error:", format_user_error_preview(e)), reply_to=message_id, ) @@ -336,7 +370,10 @@ class DiscordPlatform(MessagingPlatform): from ..limiter import MessagingRateLimiter - self._limiter = await MessagingRateLimiter.get_instance() + self._limiter = await MessagingRateLimiter.get_instance( + rate_limit=self._messaging_rate_limit, + rate_window=self._messaging_rate_window, + ) self._start_task = asyncio.create_task( self._client.start(self.bot_token), diff --git a/messaging/platforms/factory.py b/messaging/platforms/factory.py index b40fe87..772a31d 100644 --- a/messaging/platforms/factory.py +++ b/messaging/platforms/factory.py @@ -28,6 +28,10 @@ class MessagingPlatformOptions: whisper_device: str = "cpu" hf_token: str = "" nvidia_nim_api_key: str = "" + messaging_rate_limit: int = 1 + messaging_rate_window: float = 1.0 + log_raw_messaging_content: bool = False + log_api_error_tracebacks: bool = False def create_messaging_platform( @@ -64,6 +68,10 @@ def create_messaging_platform( whisper_device=opts.whisper_device, hf_token=opts.hf_token, nvidia_nim_api_key=opts.nvidia_nim_api_key, + messaging_rate_limit=opts.messaging_rate_limit, + messaging_rate_window=opts.messaging_rate_window, + log_raw_messaging_content=opts.log_raw_messaging_content, + log_api_error_tracebacks=opts.log_api_error_tracebacks, ) if platform_type == "discord": @@ -82,6 +90,10 @@ def create_messaging_platform( whisper_device=opts.whisper_device, hf_token=opts.hf_token, nvidia_nim_api_key=opts.nvidia_nim_api_key, + messaging_rate_limit=opts.messaging_rate_limit, + messaging_rate_window=opts.messaging_rate_window, + log_raw_messaging_content=opts.log_raw_messaging_content, + log_api_error_tracebacks=opts.log_api_error_tracebacks, ) logger.warning( diff --git a/messaging/platforms/telegram.py b/messaging/platforms/telegram.py index d131e44..a8f34df 100644 --- a/messaging/platforms/telegram.py +++ b/messaging/platforms/telegram.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any from loguru import logger -from core.anthropic import get_user_facing_error_message +from core.anthropic import format_user_error_preview if TYPE_CHECKING: from telegram import Update @@ -68,14 +68,18 @@ class TelegramPlatform(MessagingPlatform): whisper_device: str = "cpu", hf_token: str = "", nvidia_nim_api_key: str = "", + messaging_rate_limit: int = 1, + messaging_rate_window: float = 1.0, + log_raw_messaging_content: bool = False, + log_api_error_tracebacks: bool = False, ): if not TELEGRAM_AVAILABLE: raise ImportError( "python-telegram-bot is required. Install with: pip install python-telegram-bot" ) - self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN") - self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID") + self.bot_token = bot_token + self.allowed_user_id = allowed_user_id if not self.bot_token: # We don't raise here to allow instantiation for testing/conditional logic, @@ -97,6 +101,10 @@ class TelegramPlatform(MessagingPlatform): self._voice_note_enabled = voice_note_enabled self._whisper_model = whisper_model self._whisper_device = whisper_device + self._messaging_rate_limit = messaging_rate_limit + self._messaging_rate_window = messaging_rate_window + self._log_raw_messaging_content = log_raw_messaging_content + self._log_api_error_tracebacks = log_api_error_tracebacks async def _register_pending_voice( self, chat_id: str, voice_msg_id: str, status_msg_id: str @@ -172,7 +180,10 @@ class TelegramPlatform(MessagingPlatform): # Initialize rate limiter from ..limiter import MessagingRateLimiter - self._limiter = await MessagingRateLimiter.get_instance() + self._limiter = await MessagingRateLimiter.get_instance( + rate_limit=self._messaging_rate_limit, + rate_window=self._messaging_rate_window, + ) # Send startup notification try: @@ -187,7 +198,13 @@ class TelegramPlatform(MessagingPlatform): startup_text, ) except Exception as e: - logger.warning(f"Could not send startup message: {e}") + if self._log_api_error_tracebacks: + logger.warning("Could not send startup message: {}", e) + else: + logger.warning( + "Could not send startup message: exc_type={}", + type(e).__name__, + ) logger.info("Telegram platform started (Bot API)") @@ -506,16 +523,26 @@ class TelegramPlatform(MessagingPlatform): if getattr(update.message, "message_thread_id", None) is not None else None ) - text_preview = (update.message.text or "")[:80] - if len(update.message.text or "") > 80: - text_preview += "..." - logger.info( - "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", - chat_id, - message_id, - reply_to, - text_preview, - ) + raw_text = update.message.text or "" + if self._log_raw_messaging_content: + text_preview = raw_text[:80] + if len(raw_text) > 80: + text_preview += "..." + logger.info( + "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}", + chat_id, + message_id, + reply_to, + text_preview, + ) + else: + logger.info( + "TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_len={}", + chat_id, + message_id, + reply_to, + len(raw_text), + ) if not self._message_handler: return @@ -534,11 +561,14 @@ class TelegramPlatform(MessagingPlatform): try: await self._message_handler(incoming) except Exception as e: - logger.error(f"Error handling message: {e}") + if self._log_api_error_tracebacks: + logger.error("Error handling message: {}", e) + else: + logger.error("Error handling message: exc_type={}", type(e).__name__) with contextlib.suppress(Exception): await self.send_message( chat_id, - f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}", + f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(format_user_error_preview(e))}", reply_to=incoming.message_id, message_thread_id=thread_id, parse_mode="MarkdownV2", @@ -631,20 +661,37 @@ class TelegramPlatform(MessagingPlatform): status_message_id=status_msg_id, ) - logger.info( - "TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}", - chat_id, - message_id, - (transcribed[:80] + "..." if len(transcribed) > 80 else transcribed), - ) + if self._log_raw_messaging_content: + logger.info( + "TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}", + chat_id, + message_id, + ( + transcribed[:80] + "..." + if len(transcribed) > 80 + else transcribed + ), + ) + else: + logger.info( + "TELEGRAM_VOICE: chat_id={} message_id={} transcribed_len={}", + chat_id, + message_id, + len(transcribed), + ) await self._message_handler(incoming) except ValueError as e: - await update.message.reply_text(get_user_facing_error_message(e)[:200]) + await update.message.reply_text(format_user_error_preview(e)) except ImportError as e: - await update.message.reply_text(get_user_facing_error_message(e)[:200]) + await update.message.reply_text(format_user_error_preview(e)) except Exception as e: - logger.error(f"Voice transcription failed: {e}") + if self._log_api_error_tracebacks: + logger.error("Voice transcription failed: {}", e) + else: + logger.error( + "Voice transcription failed: exc_type={}", type(e).__name__ + ) await update.message.reply_text( "Could not transcribe voice note. Please try again or send text." ) diff --git a/messaging/safe_diagnostics.py b/messaging/safe_diagnostics.py new file mode 100644 index 0000000..add30c5 --- /dev/null +++ b/messaging/safe_diagnostics.py @@ -0,0 +1,17 @@ +"""Helpers for redacting user-derived content from log lines.""" + +from __future__ import annotations + + +def format_exception_for_log(exc: BaseException, *, log_full_message: bool) -> str: + """Return exception type and optionally ``str(exc)`` for operator diagnostics.""" + if log_full_message: + return f"{type(exc).__name__}: {exc}" + return type(exc).__name__ + + +def text_len_hint(text: str | None) -> int: + """Length of text for metadata-only logging (0 when missing).""" + if not text: + return 0 + return len(text) diff --git a/messaging/session.py b/messaging/session.py index a4e4e9d..bbb0cdc 100644 --- a/messaging/session.py +++ b/messaging/session.py @@ -5,8 +5,10 @@ Provides persistent storage for mapping platform messages to Claude CLI session and message trees for conversation continuation. """ +import contextlib import json import os +import tempfile import threading from datetime import UTC, datetime from typing import Any @@ -22,7 +24,12 @@ class SessionStore: Platform-agnostic: works with any messaging platform. """ - def __init__(self, storage_path: str = "sessions.json"): + def __init__( + self, + storage_path: str = "sessions.json", + *, + message_log_cap: int | None = None, + ): self.storage_path = storage_path self._lock = threading.Lock() self._trees: dict[str, dict] = {} # root_id -> tree data @@ -34,11 +41,7 @@ class SessionStore: self._dirty = False self._save_timer: threading.Timer | None = None self._save_debounce_secs = 0.5 - cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip() - try: - self._message_log_cap: int | None = int(cap_raw) if cap_raw else None - except ValueError: - self._message_log_cap = None + self._message_log_cap: int | None = message_log_cap self._load() def _make_chat_key(self, platform: str, chat_id: str) -> str: @@ -104,9 +107,22 @@ class SessionStore: } def _write_data(self, data: dict) -> None: - """Write data dict to disk. Must be called WITHOUT holding self._lock.""" - with open(self.storage_path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) + """Atomically write data dict to disk. Must be called WITHOUT holding self._lock.""" + abs_target = os.path.abspath(self.storage_path) + dir_name = os.path.dirname(abs_target) or "." + fd, tmp_path = tempfile.mkstemp( + dir=dir_name, prefix=".sessions.", suffix=".tmp.json" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, abs_target) + except BaseException: + with contextlib.suppress(OSError): + os.unlink(tmp_path) + raise def _schedule_save(self) -> None: """Schedule a debounced save. Caller must hold self._lock.""" diff --git a/messaging/transcript.py b/messaging/transcript.py index 073f40b..f1a015d 100644 --- a/messaging/transcript.py +++ b/messaging/transcript.py @@ -9,7 +9,6 @@ the transcript grows over time and older content must be truncated. from __future__ import annotations import json -import os from abc import ABC, abstractmethod from collections import deque from collections.abc import Callable, Iterable @@ -210,7 +209,12 @@ class RenderCtx: class TranscriptBuffer: """Maintains an ordered, truncatable transcript of events.""" - def __init__(self, *, show_tool_results: bool = True) -> None: + def __init__( + self, + *, + show_tool_results: bool = True, + debug_subagent_stack: bool = False, + ) -> None: self._segments: list[Segment] = [] self._open_thinking_by_index: dict[int, ThinkingSegment] = {} self._open_text_by_index: dict[int, TextSegment] = {} @@ -227,7 +231,7 @@ class TranscriptBuffer: self._subagent_stack: list[str] = [] # Parallel stack of segments for rendering nested subagents. self._subagent_segments: list[SubagentSegment] = [] - self._debug_subagent_stack = os.getenv("DEBUG_SUBAGENT_STACK") == "1" + self._debug_subagent_stack = debug_subagent_stack def _in_subagent(self) -> bool: return bool(self._subagent_stack) diff --git a/messaging/transcription.py b/messaging/transcription.py index 389a9be..9a5f01f 100644 --- a/messaging/transcription.py +++ b/messaging/transcription.py @@ -5,32 +5,18 @@ Supports: - NVIDIA NIM: NVIDIA NIM Whisper/Parakeet """ -import os from pathlib import Path from typing import Any from loguru import logger -from config.settings import get_settings +from providers.nvidia_nim.voice import ( + transcribe_audio_file as transcribe_nvidia_nim_audio, +) # Max file size in bytes (25 MB) MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024 -# NVIDIA NIM Whisper model mapping: (function_id, language_code) -_NIM_MODEL_MAP: dict[str, tuple[str, str]] = { - "nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"), - "nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"), - "nvidia/parakeet-ctc-0.6b-es": ("None", "es-US"), - "nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"), - "nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"), - "nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"), - "nvidia/parakeet-1.1b-rnnt-multilingual-asr": ( - "71203149-d3b7-4460-8231-1be2543a1fca", - "", - ), - "openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"), -} - # Short model names -> full Hugging Face model IDs (for local Whisper) _MODEL_MAP: dict[str, str] = { "tiny": "openai/whisper-tiny", @@ -51,22 +37,19 @@ def _resolve_model_id(whisper_model: str) -> str: return _MODEL_MAP.get(whisper_model, whisper_model) -def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any: +def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any: """Lazy-load transformers Whisper pipeline. Raises ImportError if not installed.""" global _pipeline_cache if device not in ("cpu", "cuda"): raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}") - resolved_token = ( - hf_token if hf_token is not None else get_settings().hf_token - ) or "" + resolved_token = hf_token or "" cache_key = (model_id, device, resolved_token) if cache_key not in _pipeline_cache: try: import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline - if resolved_token: - os.environ["HF_TOKEN"] = resolved_token + hf_auth_token = resolved_token or None use_cuda = device == "cuda" and torch.cuda.is_available() pipe_device = "cuda:0" if use_cuda else "cpu" @@ -77,9 +60,10 @@ def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> An dtype=model_dtype, low_cpu_mem_usage=True, attn_implementation="sdpa", + token=hf_auth_token, ) model = model.to(pipe_device) - processor = AutoProcessor.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token) pipe = pipeline( "automatic-speech-recognition", @@ -119,7 +103,7 @@ def transcribe_audio( file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported) mime_type: MIME type of the audio (e.g. "audio/ogg") whisper_model: Model ID or short name (local) or NVIDIA NIM model - whisper_device: "cpu" | "cuda" | "nvidia_nim" (defaults to WHISPER_DEVICE env var) + whisper_device: "cpu" | "cuda" | "nvidia_nim" Returns: Transcribed text @@ -140,8 +124,8 @@ def transcribe_audio( ) if whisper_device == "nvidia_nim": - return _transcribe_nim( - file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key + return transcribe_nvidia_nim_audio( + file_path, whisper_model, api_key=nvidia_nim_api_key ) return _transcribe_local( file_path, whisper_model, whisper_device, hf_token=hf_token @@ -169,8 +153,7 @@ def _transcribe_local( ) -> str: """Transcribe using transformers Whisper pipeline.""" model_id = _resolve_model_id(whisper_model) - token: str | None = hf_token if hf_token else None - pipe = _get_pipeline(model_id, whisper_device, hf_token=token) + pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token) audio = _load_audio(file_path) result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"}) text = result.get("text", "") or "" @@ -179,65 +162,3 @@ def _transcribe_local( result_text = text.strip() logger.debug(f"Local transcription: {len(result_text)} chars") return result_text or "(no speech detected)" - - -def _transcribe_nim( - file_path: Path, model: str, *, nvidia_nim_api_key: str = "" -) -> str: - """Transcribe using NVIDIA NIM Whisper API via Riva gRPC client.""" - try: - import riva.client - except ImportError as e: - raise ImportError( - "NVIDIA NIM transcription requires the voice extra. " - "Install with: uv sync --extra voice" - ) from e - - api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key - - # Look up function ID and language code from model mapping - model_config = _NIM_MODEL_MAP.get(model) - if not model_config: - raise ValueError( - f"No NVIDIA NIM config found for model: {model}. " - f"Supported models: {', '.join(_NIM_MODEL_MAP.keys())}" - ) - function_id, language_code = model_config - - # Riva server configuration - server = "grpc.nvcf.nvidia.com:443" - - # Auth with SSL and metadata - auth = riva.client.Auth( - use_ssl=True, - uri=server, - metadata_args=[ - ["function-id", function_id], - ["authorization", f"Bearer {api_key}"], - ], - ) - - asr_service = riva.client.ASRService(auth) - - # Configure recognition - language_code from model config - config = riva.client.RecognitionConfig( - language_code=language_code, - max_alternatives=1, - verbatim_transcripts=True, - ) - - # Read audio file - with open(file_path, "rb") as f: - data = f.read() - - # Perform offline recognition - response = asr_service.offline_recognize(data, config) - - # Extract text from response - use getattr for safe attribute access - transcript = "" - results = getattr(response, "results", None) - if results and results[0].alternatives: - transcript = results[0].alternatives[0].transcript - - logger.debug(f"NIM transcription: {len(transcript)} chars") - return transcript or "(no speech detected)" diff --git a/messaging/trees/processor.py b/messaging/trees/processor.py deleted file mode 100644 index 5fdcca6..0000000 --- a/messaging/trees/processor.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Async queue processor for message trees. - -Handles the async processing lifecycle of tree nodes. -""" - -import asyncio -from collections.abc import Awaitable, Callable - -from loguru import logger - -from core.anthropic import get_user_facing_error_message - -from .data import MessageNode, MessageState, MessageTree - - -class TreeQueueProcessor: - """ - Handles async queue processing for a single tree. - - Separates the async processing logic from the data management. - """ - - def __init__( - self, - queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, - node_started_callback: Callable[[MessageTree, str], Awaitable[None]] - | None = None, - ): - self._queue_update_callback = queue_update_callback - self._node_started_callback = node_started_callback - - def set_queue_update_callback( - self, - queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None, - ) -> None: - """Update the callback used to refresh queue positions.""" - self._queue_update_callback = queue_update_callback - - def set_node_started_callback( - self, - node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None, - ) -> None: - """Update the callback used when a queued node starts processing.""" - self._node_started_callback = node_started_callback - - async def _notify_queue_updated(self, tree: MessageTree) -> None: - """Invoke queue update callback if set.""" - if not self._queue_update_callback: - return - try: - await self._queue_update_callback(tree) - except Exception as e: - logger.warning(f"Queue update callback failed: {e}") - - async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None: - """Invoke node started callback if set.""" - if not self._node_started_callback: - return - try: - await self._node_started_callback(tree, node_id) - except Exception as e: - logger.warning(f"Node started callback failed: {e}") - - async def process_node( - self, - tree: MessageTree, - node: MessageNode, - processor: Callable[[str, MessageNode], Awaitable[None]], - ) -> None: - """Process a single node and then check the queue.""" - # Skip if already in terminal state (e.g. from error propagation) - if node.state == MessageState.ERROR: - logger.info( - f"Skipping node {node.node_id} as it is already in state {node.state}" - ) - # Still need to check for next messages - await self._process_next(tree, processor) - return - - try: - await processor(node.node_id, node) - except asyncio.CancelledError: - logger.info(f"Task for node {node.node_id} was cancelled") - raise - except Exception as e: - logger.error(f"Error processing node {node.node_id}: {e}") - await tree.update_state( - node.node_id, - MessageState.ERROR, - error_message=get_user_facing_error_message(e), - ) - finally: - async with tree.with_lock(): - tree.clear_current_node() - # Check if there are more messages in the queue - await self._process_next(tree, processor) - - async def _process_next( - self, - tree: MessageTree, - processor: Callable[[str, MessageNode], Awaitable[None]], - ) -> None: - """Process the next message in queue, if any.""" - next_node_id = None - node = None - async with tree.with_lock(): - next_node_id = await tree.dequeue() - - if not next_node_id: - tree.set_processing_state(None, False) - logger.debug(f"Tree {tree.root_id} queue empty, marking as free") - return - - tree.set_processing_state(next_node_id, True) - logger.info(f"Processing next queued node {next_node_id}") - - # Process next node (outside lock) - node = tree.get_node(next_node_id) - if node: - tree.set_current_task( - asyncio.create_task(self.process_node(tree, node, processor)) - ) - - # Notify that this node has started processing and refresh queue positions. - if next_node_id: - await self._notify_node_started(tree, next_node_id) - await self._notify_queue_updated(tree) - - async def enqueue_and_start( - self, - tree: MessageTree, - node_id: str, - processor: Callable[[str, MessageNode], Awaitable[None]], - ) -> bool: - """ - Enqueue a node or start processing immediately. - - Args: - tree: The message tree - node_id: Node to process - processor: Async function to process the node - - Returns: - True if queued, False if processing immediately - """ - async with tree.with_lock(): - if tree.is_processing: - tree.put_queue_unlocked(node_id) - queue_size = tree.get_queue_size() - logger.info(f"Queued node {node_id}, position {queue_size}") - return True - else: - tree.set_processing_state(node_id, True) - - # Process outside the lock - node = tree.get_node(node_id) - if node: - tree.set_current_task( - asyncio.create_task(self.process_node(tree, node, processor)) - ) - return False - - def cancel_current(self, tree: MessageTree) -> bool: - """Cancel the currently running task in a tree.""" - return tree.cancel_current_task() diff --git a/messaging/trees/queue_manager.py b/messaging/trees/queue_manager.py index fbf7029..2c0cecd 100644 --- a/messaging/trees/queue_manager.py +++ b/messaging/trees/queue_manager.py @@ -1,30 +1,345 @@ -"""Tree-Based Message Queue Manager - Refactored. - -Coordinates data access, async processing, and error handling. -Uses TreeRepository for data, TreeQueueProcessor for async logic. -""" +"""Tree-based message queue: index, async node processor, and public manager API.""" import asyncio from collections.abc import Awaitable, Callable from loguru import logger +from config.settings import get_settings +from core.anthropic import get_user_facing_error_message + from ..models import IncomingMessage +from ..safe_diagnostics import format_exception_for_log from .data import MessageNode, MessageState, MessageTree -from .processor import TreeQueueProcessor -from .repository import TreeRepository + + +class TreeRepository: + """ + In-memory index of trees and node-to-root mappings. + + Used only by :class:`TreeQueueManager`; kept as a named type for tests. + """ + + def __init__(self) -> None: + self._trees: dict[str, MessageTree] = {} # root_id -> tree + self._node_to_tree: dict[str, str] = {} # node_id -> root_id + + def get_tree(self, root_id: str) -> MessageTree | None: + """Get a tree by its root ID.""" + return self._trees.get(root_id) + + def get_tree_for_node(self, node_id: str) -> MessageTree | None: + """Get the tree containing a given node.""" + root_id = self._node_to_tree.get(node_id) + if not root_id: + return None + return self._trees.get(root_id) + + def get_node(self, node_id: str) -> MessageNode | None: + """Get a node from any tree.""" + tree = self.get_tree_for_node(node_id) + return tree.get_node(node_id) if tree else None + + def add_tree(self, root_id: str, tree: MessageTree) -> None: + """Add a new tree to the repository.""" + self._trees[root_id] = tree + self._node_to_tree[root_id] = root_id + logger.debug("TREE_REPO: add_tree root_id={}", root_id) + + def register_node(self, node_id: str, root_id: str) -> None: + """Register a node ID to a tree.""" + self._node_to_tree[node_id] = root_id + logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id) + + def has_node(self, node_id: str) -> bool: + """Check if a node is registered in any tree.""" + return node_id in self._node_to_tree + + def tree_count(self) -> int: + """Get the number of trees in the repository.""" + return len(self._trees) + + def is_tree_busy(self, root_id: str) -> bool: + """Check if a tree is currently processing.""" + tree = self._trees.get(root_id) + return tree.is_processing if tree else False + + def is_node_tree_busy(self, node_id: str) -> bool: + """Check if the tree containing a node is busy.""" + tree = self.get_tree_for_node(node_id) + return tree.is_processing if tree else False + + def get_queue_size(self, node_id: str) -> int: + """Get queue size for the tree containing a node.""" + tree = self.get_tree_for_node(node_id) + return tree.get_queue_size() if tree else 0 + + def resolve_parent_node_id(self, msg_id: str) -> str | None: + """ + Resolve a message ID to the actual parent node ID. + + Handles the case where msg_id is a status message ID + (which maps to the tree but isn't an actual node). + + Returns: + The node_id to use as parent, or None if not found + """ + tree = self.get_tree_for_node(msg_id) + if not tree: + return None + + if tree.has_node(msg_id): + return msg_id + + node = tree.find_node_by_status_message(msg_id) + if node: + return node.node_id + + return None + + def get_pending_children(self, node_id: str) -> list[MessageNode]: + """ + Get all pending child nodes (recursively) of a given node. + + Used for error propagation - when a node fails, its pending + children should also be marked as failed. + """ + tree = self.get_tree_for_node(node_id) + if not tree: + return [] + + pending: list[MessageNode] = [] + stack = [node_id] + + while stack: + current_id = stack.pop() + node = tree.get_node(current_id) + if not node: + continue + for child_id in node.children_ids: + child = tree.get_node(child_id) + if child and child.state == MessageState.PENDING: + pending.append(child) + stack.append(child_id) + + return pending + + def all_trees(self) -> list[MessageTree]: + """Get all trees in the repository.""" + return list(self._trees.values()) + + def tree_ids(self) -> list[str]: + """Get all tree root IDs.""" + return list(self._trees.keys()) + + def unregister_nodes(self, node_ids: list[str]) -> None: + """Remove node IDs from the node-to-tree mapping.""" + for nid in node_ids: + self._node_to_tree.pop(nid, None) + + def remove_tree(self, root_id: str) -> MessageTree | None: + """ + Remove a tree and all its node mappings from the repository. + + Returns: + The removed tree, or None if not found. + """ + tree = self._trees.pop(root_id, None) + if not tree: + return None + for node in tree.all_nodes(): + self._node_to_tree.pop(node.node_id, None) + logger.debug("TREE_REPO: remove_tree root_id={}", root_id) + return tree + + def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: + """Get all message IDs (incoming + status) for a given platform/chat.""" + msg_ids: set[str] = set() + for tree in self._trees.values(): + for node in tree.all_nodes(): + if str(node.incoming.platform) == str(platform) and str( + node.incoming.chat_id + ) == str(chat_id): + if node.incoming.message_id is not None: + msg_ids.add(str(node.incoming.message_id)) + if node.status_message_id: + msg_ids.add(str(node.status_message_id)) + return msg_ids + + def to_dict(self) -> dict: + """Serialize all trees.""" + return { + "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()}, + "node_to_tree": self._node_to_tree.copy(), + } + + @classmethod + def from_dict(cls, data: dict) -> TreeRepository: + """Deserialize from dictionary.""" + repo = cls() + for root_id, tree_data in data.get("trees", {}).items(): + repo._trees[root_id] = MessageTree.from_dict(tree_data) + repo._node_to_tree = data.get("node_to_tree", {}) + return repo + + +class TreeQueueProcessor: + """ + Per-tree async queue processing (one manager owns one processor instance). + """ + + def __init__( + self, + queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, + node_started_callback: Callable[[MessageTree, str], Awaitable[None]] + | None = None, + ) -> None: + self._queue_update_callback = queue_update_callback + self._node_started_callback = node_started_callback + + def set_queue_update_callback( + self, + queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None, + ) -> None: + """Update the callback used to refresh queue positions.""" + self._queue_update_callback = queue_update_callback + + def set_node_started_callback( + self, + node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None, + ) -> None: + """Update the callback used when a queued node starts processing.""" + self._node_started_callback = node_started_callback + + async def _notify_queue_updated(self, tree: MessageTree) -> None: + """Invoke queue update callback if set.""" + if not self._queue_update_callback: + return + try: + await self._queue_update_callback(tree) + except Exception as e: + d = get_settings().log_messaging_error_details + logger.warning( + "Queue update callback failed: {}", + format_exception_for_log(e, log_full_message=d), + ) + + async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None: + """Invoke node started callback if set.""" + if not self._node_started_callback: + return + try: + await self._node_started_callback(tree, node_id) + except Exception as e: + d = get_settings().log_messaging_error_details + logger.warning( + "Node started callback failed: {}", + format_exception_for_log(e, log_full_message=d), + ) + + async def process_node( + self, + tree: MessageTree, + node: MessageNode, + processor: Callable[[str, MessageNode], Awaitable[None]], + ) -> None: + """Process a single node and then check the queue.""" + if node.state == MessageState.ERROR: + logger.info( + f"Skipping node {node.node_id} as it is already in state {node.state}" + ) + await self._process_next(tree, processor) + return + + try: + await processor(node.node_id, node) + except asyncio.CancelledError: + logger.info(f"Task for node {node.node_id} was cancelled") + raise + except Exception as e: + d = get_settings().log_messaging_error_details + logger.error( + "Error processing node {}: {}", + node.node_id, + format_exception_for_log(e, log_full_message=d), + ) + await tree.update_state( + node.node_id, + MessageState.ERROR, + error_message=get_user_facing_error_message(e), + ) + finally: + async with tree.with_lock(): + tree.clear_current_node() + await self._process_next(tree, processor) + + async def _process_next( + self, + tree: MessageTree, + processor: Callable[[str, MessageNode], Awaitable[None]], + ) -> None: + """Process the next message in queue, if any.""" + next_node_id = None + async with tree.with_lock(): + next_node_id = await tree.dequeue() + + if not next_node_id: + tree.set_processing_state(None, False) + logger.debug(f"Tree {tree.root_id} queue empty, marking as free") + return + + tree.set_processing_state(next_node_id, True) + logger.info(f"Processing next queued node {next_node_id}") + + node = tree.get_node(next_node_id) + if node: + tree.set_current_task( + asyncio.create_task(self.process_node(tree, node, processor)) + ) + + if next_node_id: + await self._notify_node_started(tree, next_node_id) + await self._notify_queue_updated(tree) + + async def enqueue_and_start( + self, + tree: MessageTree, + node_id: str, + processor: Callable[[str, MessageNode], Awaitable[None]], + ) -> bool: + """ + Enqueue a node or start processing immediately. + + Returns: + True if queued, False if processing immediately + """ + async with tree.with_lock(): + if tree.is_processing: + tree.put_queue_unlocked(node_id) + queue_size = tree.get_queue_size() + logger.info(f"Queued node {node_id}, position {queue_size}") + return True + else: + tree.set_processing_state(node_id, True) + + node = tree.get_node(node_id) + if node: + tree.set_current_task( + asyncio.create_task(self.process_node(tree, node, processor)) + ) + return False + + def cancel_current(self, tree: MessageTree) -> bool: + """Cancel the currently running task in a tree.""" + return tree.cancel_current_task() class TreeQueueManager: """ - Manages multiple message trees. Facade that coordinates components. + Manages multiple message trees: index + async processing. Each new conversation creates a new tree. Replies to existing messages add nodes to existing trees. - - Components: - - TreeRepository: Data access layer - - TreeQueueProcessor: Async queue processing """ def __init__( @@ -33,7 +348,7 @@ class TreeQueueManager: node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None = None, _repository: TreeRepository | None = None, - ): + ) -> None: self._repository = _repository or TreeRepository() self._processor = TreeQueueProcessor( queue_update_callback=queue_update_callback, @@ -101,7 +416,6 @@ class TreeQueueManager: if not tree: raise ValueError(f"Parent node {parent_node_id} not found in any tree") - # Add node (tree has its own lock) - outside manager lock to avoid deadlock node = await tree.add_node( node_id=node_id, incoming=incoming, @@ -228,7 +542,6 @@ class TreeQueueManager: cleanup_count = 0 async with tree.with_lock(): - # 1. Cancel running task if tree.cancel_current_task(): current_id = tree.current_node_id if current_id: @@ -240,12 +553,10 @@ class TreeQueueManager: tree.set_node_error_sync(node, "Cancelled by user") cancelled_nodes.append(node) - # 2. Drain queue and mark nodes as cancelled queue_nodes = tree.drain_queue_and_mark_cancelled() cancelled_nodes.extend(queue_nodes) cancelled_ids = {n.node_id for n in cancelled_nodes} - # 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR for node in tree.all_nodes(): if ( node.state in (MessageState.PENDING, MessageState.IN_PROGRESS) @@ -269,10 +580,6 @@ class TreeQueueManager: """ Cancel a single node (queued or in-progress) without affecting other nodes. - - If the node is currently running, cancels the current asyncio task. - - If the node is queued, removes it from the queue. - - Marks the node as ERROR with "Cancelled by user". - Returns: List containing the cancelled node if it was cancellable, else empty list. """ @@ -351,8 +658,6 @@ class TreeQueueManager: async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]: """ Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants). - - Does not call cli_manager.stop_all(). Returns list of cancelled nodes. """ tree = self._repository.get_tree_for_node(branch_root_id) if not tree: @@ -435,3 +740,10 @@ class TreeQueueManager: node_started_callback=node_started_callback, _repository=TreeRepository.from_dict(data), ) + + +__all__ = [ + "TreeQueueManager", + "TreeQueueProcessor", + "TreeRepository", +] diff --git a/messaging/trees/repository.py b/messaging/trees/repository.py deleted file mode 100644 index 2dae1fb..0000000 --- a/messaging/trees/repository.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Repository for message tree data access. - -Provides data access layer for managing trees and node mappings. -""" - -from loguru import logger - -from .data import MessageNode, MessageState, MessageTree - - -class TreeRepository: - """ - Repository for message tree data access. - - Manages the storage and lookup of trees and node-to-tree mappings. - """ - - def __init__(self): - self._trees: dict[str, MessageTree] = {} # root_id -> tree - self._node_to_tree: dict[str, str] = {} # node_id -> root_id - - def get_tree(self, root_id: str) -> MessageTree | None: - """Get a tree by its root ID.""" - return self._trees.get(root_id) - - def get_tree_for_node(self, node_id: str) -> MessageTree | None: - """Get the tree containing a given node.""" - root_id = self._node_to_tree.get(node_id) - if not root_id: - return None - return self._trees.get(root_id) - - def get_node(self, node_id: str) -> MessageNode | None: - """Get a node from any tree.""" - tree = self.get_tree_for_node(node_id) - return tree.get_node(node_id) if tree else None - - def add_tree(self, root_id: str, tree: MessageTree) -> None: - """Add a new tree to the repository.""" - self._trees[root_id] = tree - self._node_to_tree[root_id] = root_id - logger.debug("TREE_REPO: add_tree root_id={}", root_id) - - def register_node(self, node_id: str, root_id: str) -> None: - """Register a node ID to a tree.""" - self._node_to_tree[node_id] = root_id - logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id) - - def has_node(self, node_id: str) -> bool: - """Check if a node is registered in any tree.""" - return node_id in self._node_to_tree - - def tree_count(self) -> int: - """Get the number of trees in the repository.""" - return len(self._trees) - - def is_tree_busy(self, root_id: str) -> bool: - """Check if a tree is currently processing.""" - tree = self._trees.get(root_id) - return tree.is_processing if tree else False - - def is_node_tree_busy(self, node_id: str) -> bool: - """Check if the tree containing a node is busy.""" - tree = self.get_tree_for_node(node_id) - return tree.is_processing if tree else False - - def get_queue_size(self, node_id: str) -> int: - """Get queue size for the tree containing a node.""" - tree = self.get_tree_for_node(node_id) - return tree.get_queue_size() if tree else 0 - - def resolve_parent_node_id(self, msg_id: str) -> str | None: - """ - Resolve a message ID to the actual parent node ID. - - Handles the case where msg_id is a status message ID - (which maps to the tree but isn't an actual node). - - Returns: - The node_id to use as parent, or None if not found - """ - tree = self.get_tree_for_node(msg_id) - if not tree: - return None - - # Check if msg_id is an actual node - if tree.has_node(msg_id): - return msg_id - - # Otherwise, it might be a status message - find the owning node - node = tree.find_node_by_status_message(msg_id) - if node: - return node.node_id - - return None - - def get_pending_children(self, node_id: str) -> list[MessageNode]: - """ - Get all pending child nodes (recursively) of a given node. - - Used for error propagation - when a node fails, its pending - children should also be marked as failed. - """ - tree = self.get_tree_for_node(node_id) - if not tree: - return [] - - pending: list[MessageNode] = [] - stack = [node_id] - - while stack: - current_id = stack.pop() - node = tree.get_node(current_id) - if not node: - continue - for child_id in node.children_ids: - child = tree.get_node(child_id) - if child and child.state == MessageState.PENDING: - pending.append(child) - stack.append(child_id) - - return pending - - def all_trees(self) -> list[MessageTree]: - """Get all trees in the repository.""" - return list(self._trees.values()) - - def tree_ids(self) -> list[str]: - """Get all tree root IDs.""" - return list(self._trees.keys()) - - def unregister_nodes(self, node_ids: list[str]) -> None: - """Remove node IDs from the node-to-tree mapping.""" - for nid in node_ids: - self._node_to_tree.pop(nid, None) - - def remove_tree(self, root_id: str) -> MessageTree | None: - """ - Remove a tree and all its node mappings from the repository. - - Returns: - The removed tree, or None if not found. - """ - tree = self._trees.pop(root_id, None) - if not tree: - return None - for node in tree.all_nodes(): - self._node_to_tree.pop(node.node_id, None) - logger.debug("TREE_REPO: remove_tree root_id={}", root_id) - return tree - - def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: - """Get all message IDs (incoming + status) for a given platform/chat. - - Note: O(total_nodes) scan. Acceptable because this is only called - from /clear (user-initiated, infrequent). - """ - msg_ids: set[str] = set() - for tree in self._trees.values(): - for node in tree.all_nodes(): - if str(node.incoming.platform) == str(platform) and str( - node.incoming.chat_id - ) == str(chat_id): - if node.incoming.message_id is not None: - msg_ids.add(str(node.incoming.message_id)) - if node.status_message_id: - msg_ids.add(str(node.status_message_id)) - return msg_ids - - def to_dict(self) -> dict: - """Serialize all trees.""" - return { - "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()}, - "node_to_tree": self._node_to_tree.copy(), - } - - @classmethod - def from_dict(cls, data: dict) -> TreeRepository: - """Deserialize from dictionary.""" - from .data import MessageTree - - repo = cls() - for root_id, tree_data in data.get("trees", {}).items(): - repo._trees[root_id] = MessageTree.from_dict(tree_data) - repo._node_to_tree = data.get("node_to_tree", {}) - return repo diff --git a/messaging/ui_updates.py b/messaging/ui_updates.py new file mode 100644 index 0000000..929202e --- /dev/null +++ b/messaging/ui_updates.py @@ -0,0 +1,101 @@ +"""Throttled platform UI updates driven by transcript rendering.""" + +from __future__ import annotations + +import time +from collections.abc import Callable + +from loguru import logger + +from .platforms.base import MessagingPlatform +from .safe_diagnostics import format_exception_for_log +from .transcript import RenderCtx, TranscriptBuffer + + +class ThrottledTranscriptEditor: + """Rate-limited status message edits from a growing transcript.""" + + def __init__( + self, + *, + platform: MessagingPlatform, + parse_mode: str | None, + get_limit_chars: Callable[[], int], + transcript: TranscriptBuffer, + render_ctx: RenderCtx, + node_id: str, + chat_id: str, + status_msg_id: str, + debug_platform_edits: bool, + log_messaging_error_details: bool = False, + ) -> None: + self._platform = platform + self._parse_mode = parse_mode + self._get_limit_chars = get_limit_chars + self._transcript = transcript + self._render_ctx = render_ctx + self._node_id = node_id + self._chat_id = chat_id + self._status_msg_id = status_msg_id + self._debug_platform_edits = debug_platform_edits + self._log_messaging_error_details = log_messaging_error_details + self._last_ui_update = 0.0 + self._last_displayed_text: str | None = None + self._last_status: str | None = None + + @property + def last_status(self) -> str | None: + return self._last_status + + async def update(self, status: str | None = None, *, force: bool = False) -> None: + """Render transcript + optional status line and edit the platform message.""" + now = time.time() + if not force and now - self._last_ui_update < 1.0: + return + + self._last_ui_update = now + if status is not None: + self._last_status = status + try: + display = self._transcript.render( + self._render_ctx, + limit_chars=self._get_limit_chars(), + status=status, + ) + except Exception as e: + logger.warning( + "Transcript render failed for node {}: {}", + self._node_id, + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) + return + if display and display != self._last_displayed_text: + logger.debug( + "PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}", + self._node_id, + self._chat_id, + self._status_msg_id, + bool(force), + status, + len(display), + ) + if self._debug_platform_edits: + logger.debug("PLATFORM_EDIT_TEXT:\n{}", display) + self._last_displayed_text = display + try: + await self._platform.queue_edit_message( + self._chat_id, + self._status_msg_id, + display, + parse_mode=self._parse_mode, + ) + except Exception as e: + logger.warning( + "Failed to update platform for node {}: {}", + self._node_id, + format_exception_for_log( + e, log_full_message=self._log_messaging_error_details + ), + ) diff --git a/providers/anthropic_messages.py b/providers/anthropic_messages.py index c35c1d5..b501aac 100644 --- a/providers/anthropic_messages.py +++ b/providers/anthropic_messages.py @@ -2,19 +2,32 @@ from __future__ import annotations -import json from collections.abc import AsyncIterator, Iterator from typing import Any, Literal import httpx from loguru import logger -from core.anthropic import get_user_facing_error_message +from config.constants import ( + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES, +) +from core.anthropic import iter_provider_stream_error_sse_events +from core.anthropic.emitted_sse_tracker import EmittedNativeSseTracker +from core.anthropic.native_messages_request import ( + build_base_native_anthropic_request_body, +) +from core.anthropic.native_sse_block_policy import ( + NativeSseBlockPolicyState, + transform_native_sse_block_event, +) from providers.base import BaseProvider, ProviderConfig -from providers.error_mapping import map_error +from providers.error_mapping import ( + map_error, + user_visible_message_for_mapped_provider_error, +) from providers.rate_limit import GlobalRateLimiter -ANTHROPIC_DEFAULT_MAX_TOKENS = 81920 StreamChunkMode = Literal["line", "event"] @@ -64,25 +77,11 @@ class AnthropicMessagesTransport(BaseProvider): ) -> dict: """Build a native Anthropic request body.""" thinking_enabled = self._is_thinking_enabled(request, thinking_enabled) - body = request.model_dump(exclude_none=True) - - body.pop("extra_body", None) - body.pop("original_model", None) - body.pop("resolved_provider_model", None) - - if "thinking" in body: - thinking_cfg = body.pop("thinking") - if thinking_enabled and isinstance(thinking_cfg, dict): - thinking_payload = {"type": "enabled"} - budget_tokens = thinking_cfg.get("budget_tokens") - if isinstance(budget_tokens, int): - thinking_payload["budget_tokens"] = budget_tokens - body["thinking"] = thinking_payload - - if "max_tokens" not in body: - body["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS - - return body + return build_base_native_anthropic_request_body( + request, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=thinking_enabled, + ) async def _send_stream_request(self, body: dict) -> httpx.Response: """Create a streaming messages response.""" @@ -97,30 +96,68 @@ class AnthropicMessagesTransport(BaseProvider): async def _raise_for_status( self, response: httpx.Response, *, req_tag: str ) -> None: - """Raise for non-200 responses after logging the upstream body.""" + """Raise for non-200 responses after logging safe metadata (or capped body if opted in).""" try: response.raise_for_status() except httpx.HTTPStatusError as error: - response_text = await self._read_error_body(response) - if response_text: + if self._config.log_api_error_tracebacks: + preview, truncated = await self._read_error_body_preview( + response, NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES + ) + if preview: + text = preview.decode("utf-8", errors="replace") + logger.error( + "{}_ERROR:{} HTTP {} body_preview_bytes={} truncated={}: {}", + self._provider_name, + req_tag, + response.status_code, + len(preview), + truncated, + text, + ) + else: + logger.error( + "{}_ERROR:{} HTTP {} (empty error body)", + self._provider_name, + req_tag, + response.status_code, + ) + else: + cl = response.headers.get("content-length", "").strip() + extra = f" content_length_declared={cl}" if cl.isdigit() else "" logger.error( - "{}_ERROR:{} HTTP {}: {}", + "{}_ERROR:{} HTTP {}{}", self._provider_name, req_tag, response.status_code, - response_text, + extra, ) raise error - async def _read_error_body(self, response: httpx.Response) -> str: - """Read a response body for diagnostics.""" - aread = getattr(response, "aread", None) - if aread is None: - return "" - body = await aread() - if isinstance(body, bytes): - return body.decode("utf-8", errors="replace") - return str(body) + async def _read_error_body_preview( + self, response: httpx.Response, max_bytes: int + ) -> tuple[bytes, bool]: + """Read at most ``max_bytes`` from the error body for logging. Returns (preview, truncated).""" + if max_bytes <= 0: + return b"", False + received = 0 + parts: list[bytes] = [] + truncated = False + async for chunk in response.aiter_bytes(chunk_size=65_536): + if received >= max_bytes: + truncated = True + break + remaining = max_bytes - received + take = chunk if len(chunk) <= remaining else chunk[:remaining] + if take: + parts.append(take) + received += len(take) + if len(chunk) > len(take): + truncated = True + break + if received >= max_bytes: + break + return (b"".join(parts), truncated) async def _iter_sse_lines(self, response: httpx.Response) -> AsyncIterator[str]: """Yield raw SSE line chunks preserving local provider behavior.""" @@ -145,6 +182,8 @@ class AnthropicMessagesTransport(BaseProvider): def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any: """Return per-stream provider state for event transformation.""" + if self.stream_chunk_mode == "line": + return NativeSseBlockPolicyState() return None def _transform_stream_event( @@ -155,6 +194,10 @@ class AnthropicMessagesTransport(BaseProvider): thinking_enabled: bool, ) -> str | None: """Transform or drop a grouped SSE event before yielding it downstream.""" + if isinstance(state, NativeSseBlockPolicyState): + return transform_native_sse_block_event( + event, state, thinking_enabled=thinking_enabled + ) return event def _format_error_message(self, base_message: str, request_id: str | None) -> str: @@ -166,15 +209,11 @@ class AnthropicMessagesTransport(BaseProvider): def _get_error_message(self, error: Exception, request_id: str | None) -> str: """Map an exception into a user-facing provider error message.""" mapped_error = map_error(error, rate_limiter=self._global_rate_limiter) - if getattr(mapped_error, "status_code", None) == 405: - base_message = ( - f"Upstream provider {self._provider_name} rejected the request method " - "or endpoint (HTTP 405)." - ) - else: - base_message = get_user_facing_error_message( - mapped_error, read_timeout_s=self._config.http_read_timeout - ) + base_message = user_visible_message_for_mapped_provider_error( + mapped_error, + provider_name=self._provider_name, + read_timeout_s=self._config.http_read_timeout, + ) return self._format_error_message(base_message, request_id) def _emit_error_events( @@ -185,12 +224,14 @@ class AnthropicMessagesTransport(BaseProvider): error_message: str, sent_any_event: bool, ) -> Iterator[str]: - """Emit a native Anthropic error event.""" - error_event = { - "type": "error", - "error": {"type": "api_error", "message": error_message}, - } - yield f"event: error\ndata: {json.dumps(error_event)}\n\n" + """Emit the same Anthropic message lifecycle used by OpenAI-compat providers.""" + yield from iter_provider_stream_error_sse_events( + request=request, + input_tokens=input_tokens, + error_message=error_message, + sent_any_event=sent_any_event, + log_raw_sse_events=self._config.log_raw_sse_events, + ) async def _iter_stream_chunks( self, @@ -200,6 +241,21 @@ class AnthropicMessagesTransport(BaseProvider): thinking_enabled: bool, ) -> AsyncIterator[str]: """Yield stream chunks according to the provider's observable chunk shape.""" + if self.stream_chunk_mode == "line" and isinstance( + state, NativeSseBlockPolicyState + ): + async for event in self._iter_sse_events(response): + output_event = self._transform_stream_event( + event, + state, + thinking_enabled=thinking_enabled, + ) + if output_event is None: + continue + for line in output_event.splitlines(keepends=True): + yield line + return + if self.stream_chunk_mode == "line": async for chunk in self._iter_sse_lines(response): yield chunk @@ -240,15 +296,28 @@ class AnthropicMessagesTransport(BaseProvider): response: httpx.Response | None = None sent_any_event = False state = self._new_stream_state(request, thinking_enabled=thinking_enabled) + emitted_tracker = EmittedNativeSseTracker() async with self._global_rate_limiter.concurrency_slot(): try: - response = await self._global_rate_limiter.execute_with_retry( - self._send_stream_request, body - ) - if response.status_code != 200: - await self._raise_for_status(response, req_tag=req_tag) + async def _validated_stream_send() -> httpx.Response: + """Send request; raise inside retry loop on 429 so rate limiter can backoff.""" + send_response = await self._send_stream_request(body) + if send_response.status_code == 429: + await send_response.aclose() + send_response.raise_for_status() + if send_response.status_code != 200: + try: + await self._raise_for_status(send_response, req_tag=req_tag) + finally: + if not send_response.is_closed: + await send_response.aclose() + return send_response + + response = await self._global_rate_limiter.execute_with_retry( + _validated_stream_send + ) async for chunk in self._iter_stream_chunks( response, @@ -256,12 +325,12 @@ class AnthropicMessagesTransport(BaseProvider): thinking_enabled=thinking_enabled, ): sent_any_event = True + emitted_tracker.feed(chunk) yield chunk except Exception as error: - logger.error( - "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error - ) + if not isinstance(error, httpx.HTTPStatusError): + self._log_stream_transport_error(tag, req_tag, error) error_message = self._get_error_message(error, request_id) if response is not None and not response.is_closed: @@ -273,13 +342,24 @@ class AnthropicMessagesTransport(BaseProvider): type(error).__name__, req_tag, ) - for event in self._emit_error_events( - request=request, - input_tokens=input_tokens, - error_message=error_message, - sent_any_event=sent_any_event, - ): - yield event + if sent_any_event: + for event in emitted_tracker.iter_close_unclosed_blocks(): + yield event + for event in emitted_tracker.iter_midstream_error_tail( + error_message, + request=request, + input_tokens=input_tokens, + log_raw_sse_events=self._config.log_raw_sse_events, + ): + yield event + else: + for event in self._emit_error_events( + request=request, + input_tokens=input_tokens, + error_message=error_message, + sent_any_event=False, + ): + yield event return finally: if response is not None and not response.is_closed: diff --git a/providers/base.py b/providers/base.py index f8ff193..4f1a74b 100644 --- a/providers/base.py +++ b/providers/base.py @@ -6,6 +6,8 @@ from typing import Any from pydantic import BaseModel +from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT + class ProviderConfig(BaseModel): """Configuration for a provider. @@ -21,9 +23,11 @@ class ProviderConfig(BaseModel): max_concurrency: int = 5 http_read_timeout: float = 300.0 http_write_timeout: float = 10.0 - http_connect_timeout: float = 2.0 + http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT enable_thinking: bool = True proxy: str = "" + log_raw_sse_events: bool = False + log_api_error_tracebacks: bool = False class BaseProvider(ABC): @@ -61,6 +65,42 @@ class BaseProvider(ABC): request_enabled = bool(enabled) return config_enabled and request_enabled + def preflight_stream( + self, request: Any, *, thinking_enabled: bool | None = None + ) -> None: + """Eagerly validate/build the upstream request before opening an SSE stream. + + Subclasses with ``_build_request_body`` (OpenAI and native) raise + :class:`providers.exceptions.InvalidRequestError` on conversion failures. + """ + build = getattr(self, "_build_request_body", None) + if build is None: + return + build(request, thinking_enabled=thinking_enabled) + + def _log_stream_transport_error( + self, tag: str, req_tag: str, error: Exception + ) -> None: + """Log streaming transport failures (metadata-only unless verbose is enabled).""" + from loguru import logger + + if self._config.log_api_error_tracebacks: + logger.error( + "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error + ) + return + response = getattr(error, "response", None) + status_code = ( + getattr(response, "status_code", None) if response is not None else None + ) + logger.error( + "{}_ERROR:{} exc_type={} http_status={}", + tag, + req_tag, + type(error).__name__, + status_code, + ) + @abstractmethod async def cleanup(self) -> None: """Release any resources held by this provider.""" @@ -75,5 +115,7 @@ class BaseProvider(ABC): thinking_enabled: bool | None = None, ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" + # Typing: abstract async generators need a yield for AsyncIterator[str] + # inference; this branch is never executed. if False: - yield "" # Required for ty/mypy to accept abstract async generator + yield "" diff --git a/providers/deepseek/__init__.py b/providers/deepseek/__init__.py index 0ecf7d0..5ab4c79 100644 --- a/providers/deepseek/__init__.py +++ b/providers/deepseek/__init__.py @@ -1,5 +1,7 @@ """DeepSeek provider exports.""" -from .client import DEEPSEEK_BASE_URL, DeepSeekProvider +from providers.defaults import DEEPSEEK_DEFAULT_BASE -__all__ = ["DEEPSEEK_BASE_URL", "DeepSeekProvider"] +from .client import DeepSeekProvider + +__all__ = ["DEEPSEEK_DEFAULT_BASE", "DeepSeekProvider"] diff --git a/providers/deepseek/client.py b/providers/deepseek/client.py index 210f223..3996b61 100644 --- a/providers/deepseek/client.py +++ b/providers/deepseek/client.py @@ -3,7 +3,7 @@ from typing import Any from providers.base import ProviderConfig -from providers.defaults import DEEPSEEK_BASE_URL +from providers.defaults import DEEPSEEK_DEFAULT_BASE from providers.openai_compat import OpenAIChatTransport from .request import build_request_body @@ -16,7 +16,7 @@ class DeepSeekProvider(OpenAIChatTransport): super().__init__( config, provider_name="DEEPSEEK", - base_url=config.base_url or DEEPSEEK_BASE_URL, + base_url=config.base_url or DEEPSEEK_DEFAULT_BASE, api_key=config.api_key, ) diff --git a/providers/deepseek/request.py b/providers/deepseek/request.py index 67cabf9..f22a9f8 100644 --- a/providers/deepseek/request.py +++ b/providers/deepseek/request.py @@ -5,6 +5,8 @@ from typing import Any from loguru import logger from core.anthropic import build_base_request_body +from core.anthropic.conversion import OpenAIConversionError +from providers.exceptions import InvalidRequestError def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: @@ -14,10 +16,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - body = build_base_request_body( - request_data, - include_reasoning_content=True, - ) + try: + body = build_base_request_body( + request_data, + include_thinking=thinking_enabled, + include_reasoning_content=thinking_enabled, + ) + except OpenAIConversionError as exc: + raise InvalidRequestError(str(exc)) from exc extra_body: dict[str, Any] = {} request_extra = getattr(request_data, "extra_body", None) diff --git a/providers/defaults.py b/providers/defaults.py index 5f7fcfe..80109dc 100644 --- a/providers/defaults.py +++ b/providers/defaults.py @@ -1,21 +1,19 @@ -"""Default upstream base URLs and shared provider constants. +"""Re-exports default upstream base URLs from the config provider catalog.""" -Adapters and :mod:`providers.registry` import from here to avoid duplicating -literals and to keep ``providers.registry`` free of per-adapter eager imports. -""" +from config.provider_catalog import ( + DEEPSEEK_DEFAULT_BASE, + LLAMACPP_DEFAULT_BASE, + LMSTUDIO_DEFAULT_BASE, + NVIDIA_NIM_DEFAULT_BASE, + OLLAMA_DEFAULT_BASE, + OPENROUTER_DEFAULT_BASE, +) -# OpenAI-compatible chat (NIM, DeepSeek) and local/native provider endpoints -NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1" -DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com" -OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1" -LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1" -LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" -OLLAMA_DEFAULT_BASE = "http://localhost:11434" - -# Backward-compatible names used by existing adapter modules -NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE -DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE -OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE -LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE -LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE -OLLAMA_DEFAULT_BASE_URL = OLLAMA_DEFAULT_BASE +__all__ = ( + "DEEPSEEK_DEFAULT_BASE", + "LLAMACPP_DEFAULT_BASE", + "LMSTUDIO_DEFAULT_BASE", + "NVIDIA_NIM_DEFAULT_BASE", + "OLLAMA_DEFAULT_BASE", + "OPENROUTER_DEFAULT_BASE", +) diff --git a/providers/error_mapping.py b/providers/error_mapping.py index 3826a84..44c33a5 100644 --- a/providers/error_mapping.py +++ b/providers/error_mapping.py @@ -14,10 +14,30 @@ from providers.exceptions import ( from providers.rate_limit import GlobalRateLimiter +def user_visible_message_for_mapped_provider_error( + mapped: Exception, + *, + provider_name: str, + read_timeout_s: float | None, +) -> str: + """Return the user-visible string after :func:`map_error` (405 + mapped types).""" + if getattr(mapped, "status_code", None) == 405: + return ( + f"Upstream provider {provider_name} rejected the request method " + "or endpoint (HTTP 405)." + ) + return get_user_facing_error_message(mapped, read_timeout_s=read_timeout_s) + + def map_error( e: Exception, *, rate_limiter: GlobalRateLimiter | None = None ) -> Exception: - """Map OpenAI or HTTPX exception to specific ProviderError.""" + """Map OpenAI or HTTPX exception to specific ProviderError. + + Streaming transports should pass their scoped limiter (``self._global_rate_limiter``) + so reactive 429 handling applies to the correct provider. Tests may omit + ``rate_limiter`` to use the process-wide singleton. + """ message = get_user_facing_error_message(e) limiter = rate_limiter or GlobalRateLimiter.get_instance() diff --git a/providers/exceptions.py b/providers/exceptions.py index 31c6781..6901b9c 100644 --- a/providers/exceptions.py +++ b/providers/exceptions.py @@ -90,8 +90,20 @@ class APIError(ProviderError): ) -class UnknownProviderTypeError(ValueError): +class UnknownProviderTypeError(InvalidRequestError): """Raised when ``provider_id`` is not registered in the provider map.""" def __init__(self, message: str) -> None: super().__init__(message) + + +class ServiceUnavailableError(ProviderError): + """Raised when the server is not ready (e.g. app lifespan did not wire state).""" + + def __init__(self, message: str, raw_error: Any = None): + super().__init__( + message, + status_code=503, + error_type="api_error", + raw_error=raw_error, + ) diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py index de0c268..891022d 100644 --- a/providers/llamacpp/client.py +++ b/providers/llamacpp/client.py @@ -2,7 +2,7 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig -from providers.defaults import LLAMACPP_DEFAULT_BASE_URL +from providers.defaults import LLAMACPP_DEFAULT_BASE class LlamaCppProvider(AnthropicMessagesTransport): @@ -12,5 +12,5 @@ class LlamaCppProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="LLAMACPP", - default_base_url=LLAMACPP_DEFAULT_BASE_URL, + default_base_url=LLAMACPP_DEFAULT_BASE, ) diff --git a/providers/lmstudio/client.py b/providers/lmstudio/client.py index 2961993..e32b835 100644 --- a/providers/lmstudio/client.py +++ b/providers/lmstudio/client.py @@ -2,7 +2,7 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig -from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL +from providers.defaults import LMSTUDIO_DEFAULT_BASE class LMStudioProvider(AnthropicMessagesTransport): @@ -12,5 +12,5 @@ class LMStudioProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="LMSTUDIO", - default_base_url=LMSTUDIO_DEFAULT_BASE_URL, + default_base_url=LMSTUDIO_DEFAULT_BASE, ) diff --git a/providers/nvidia_nim/__init__.py b/providers/nvidia_nim/__init__.py index b0a410f..253acd1 100644 --- a/providers/nvidia_nim/__init__.py +++ b/providers/nvidia_nim/__init__.py @@ -1,5 +1,7 @@ """NVIDIA NIM provider package.""" -from .client import NVIDIA_NIM_BASE_URL, NvidiaNimProvider +from providers.defaults import NVIDIA_NIM_DEFAULT_BASE -__all__ = ["NVIDIA_NIM_BASE_URL", "NvidiaNimProvider"] +from .client import NvidiaNimProvider + +__all__ = ["NVIDIA_NIM_DEFAULT_BASE", "NvidiaNimProvider"] diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index 5cc2f31..2661203 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -8,7 +8,7 @@ from loguru import logger from config.nim import NimSettings from providers.base import ProviderConfig -from providers.defaults import NVIDIA_NIM_BASE_URL +from providers.defaults import NVIDIA_NIM_DEFAULT_BASE from providers.openai_compat import OpenAIChatTransport from .request import ( @@ -25,7 +25,7 @@ class NvidiaNimProvider(OpenAIChatTransport): super().__init__( config, provider_name="NIM", - base_url=config.base_url or NVIDIA_NIM_BASE_URL, + base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE, api_key=config.api_key, ) self._nim_settings = nim_settings diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index c0a83f3..a4e96de 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -1,5 +1,6 @@ """Request builder for NVIDIA NIM provider.""" +from collections.abc import Callable from copy import deepcopy from typing import Any @@ -7,6 +8,42 @@ from loguru import logger from config.nim import NimSettings from core.anthropic import build_base_request_body, set_if_not_none +from core.anthropic.conversion import OpenAIConversionError +from providers.exceptions import InvalidRequestError + + +def _clone_strip_extra_body( + body: dict[str, Any], + strip: Callable[[dict[str, Any]], bool], +) -> dict[str, Any] | None: + """Deep-clone ``body`` and remove fields via ``strip`` on ``extra_body`` only. + + Returns ``None`` when there is no ``extra_body`` dict or ``strip`` reports no change. + """ + cloned_body = deepcopy(body) + extra_body = cloned_body.get("extra_body") + if not isinstance(extra_body, dict): + return None + if not strip(extra_body): + return None + if not extra_body: + cloned_body.pop("extra_body", None) + return cloned_body + + +def _strip_reasoning_budget_fields(extra_body: dict[str, Any]) -> bool: + removed = extra_body.pop("reasoning_budget", None) is not None + chat_template_kwargs = extra_body.get("chat_template_kwargs") + if ( + isinstance(chat_template_kwargs, dict) + and chat_template_kwargs.pop("reasoning_budget", None) is not None + ): + removed = True + return removed + + +def _strip_chat_template_field(extra_body: dict[str, Any]) -> bool: + return extra_body.pop("chat_template", None) is not None def _set_extra( @@ -23,43 +60,12 @@ def _set_extra( def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None: """Clone a request body and strip only reasoning_budget fields.""" - cloned_body = deepcopy(body) - extra_body = cloned_body.get("extra_body") - if not isinstance(extra_body, dict): - return None - - removed = extra_body.pop("reasoning_budget", None) is not None - - chat_template_kwargs = extra_body.get("chat_template_kwargs") - if ( - isinstance(chat_template_kwargs, dict) - and chat_template_kwargs.pop("reasoning_budget", None) is not None - ): - removed = True - - if not extra_body: - cloned_body.pop("extra_body", None) - - if not removed: - return None - - return cloned_body + return _clone_strip_extra_body(body, _strip_reasoning_budget_fields) def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None: """Clone a request body and strip only chat_template.""" - cloned_body = deepcopy(body) - extra_body = cloned_body.get("extra_body") - if not isinstance(extra_body, dict): - return None - - if extra_body.pop("chat_template", None) is None: - return None - - if not extra_body: - cloned_body.pop("extra_body", None) - - return cloned_body + return _clone_strip_extra_body(body, _strip_chat_template_field) def build_request_body( @@ -71,10 +77,13 @@ def build_request_body( getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - body = build_base_request_body( - request_data, - include_thinking=thinking_enabled, - ) + try: + body = build_base_request_body( + request_data, + include_thinking=thinking_enabled, + ) + except OpenAIConversionError as exc: + raise InvalidRequestError(str(exc)) from exc # NIM-specific max_tokens: cap against nim.max_tokens max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None) diff --git a/providers/nvidia_nim/voice.py b/providers/nvidia_nim/voice.py new file mode 100644 index 0000000..f26255a --- /dev/null +++ b/providers/nvidia_nim/voice.py @@ -0,0 +1,95 @@ +"""NVIDIA NIM / Riva offline ASR for voice notes (provider-owned transport).""" + +from __future__ import annotations + +from pathlib import Path + +from loguru import logger + +# NVIDIA NIM Whisper model mapping: (function_id, language_code) +_NIM_ASR_MODEL_MAP: dict[str, tuple[str, str]] = { + "nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"), + "nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"), + # function-id from NVIDIA NIM API docs (parakeet-ctc-0.6b-es). + "nvidia/parakeet-ctc-0.6b-es": ("a9eeee8f-b509-4712-b19d-194361fa5f31", "es-US"), + "nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"), + "nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"), + "nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"), + "nvidia/parakeet-1.1b-rnnt-multilingual-asr": ( + "71203149-d3b7-4460-8231-1be2543a1fca", + "", + ), + "openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"), +} + +_RIVA_SERVER = "grpc.nvcf.nvidia.com:443" + + +def transcribe_audio_file( + file_path: Path, + model: str, + *, + api_key: str, +) -> str: + """Transcribe audio using NVIDIA NIM / Riva gRPC (offline recognition). + + Args: + file_path: Path to encoded audio bytes readable by Riva. + model: Hugging Face-style NIM model id (see ``_NIM_ASR_MODEL_MAP``). + api_key: NVIDIA API key (Bearer token); must be non-empty. + + Returns: + Transcript text, or ``(no speech detected)`` when empty. + """ + key = (api_key or "").strip() + if not key: + raise ValueError( + "NVIDIA NIM transcription requires a non-empty nvidia_nim_api_key " + "(configure NVIDIA_NIM_API_KEY or pass api_key explicitly)." + ) + + try: + import riva.client + except ImportError as e: + raise ImportError( + "NVIDIA NIM transcription requires the voice extra. " + "Install with: uv sync --extra voice" + ) from e + + model_config = _NIM_ASR_MODEL_MAP.get(model) + if not model_config: + raise ValueError( + f"No NVIDIA NIM config found for model: {model}. " + f"Supported models: {', '.join(_NIM_ASR_MODEL_MAP.keys())}" + ) + function_id, language_code = model_config + + auth = riva.client.Auth( + use_ssl=True, + uri=_RIVA_SERVER, + metadata_args=[ + ["function-id", function_id], + ["authorization", f"Bearer {key}"], + ], + ) + + asr_service = riva.client.ASRService(auth) + + config = riva.client.RecognitionConfig( + language_code=language_code, + max_alternatives=1, + verbatim_transcripts=True, + ) + + with open(file_path, "rb") as f: + data = f.read() + + response = asr_service.offline_recognize(data, config) + + transcript = "" + results = getattr(response, "results", None) + if results and results[0].alternatives: + transcript = results[0].alternatives[0].transcript + + logger.debug(f"NIM transcription: {len(transcript)} chars") + return transcript or "(no speech detected)" diff --git a/providers/ollama/__init__.py b/providers/ollama/__init__.py index 365677c..8934ab9 100644 --- a/providers/ollama/__init__.py +++ b/providers/ollama/__init__.py @@ -1,5 +1,7 @@ """Ollama provider package.""" -from .client import OLLAMA_BASE_URL, OllamaProvider +from providers.defaults import OLLAMA_DEFAULT_BASE -__all__ = ["OLLAMA_BASE_URL", "OllamaProvider"] +from .client import OllamaProvider + +__all__ = ["OLLAMA_DEFAULT_BASE", "OllamaProvider"] diff --git a/providers/ollama/client.py b/providers/ollama/client.py index ebff031..6696bd8 100644 --- a/providers/ollama/client.py +++ b/providers/ollama/client.py @@ -6,8 +6,6 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig from providers.defaults import OLLAMA_DEFAULT_BASE -OLLAMA_BASE_URL = OLLAMA_DEFAULT_BASE - class OllamaProvider(AnthropicMessagesTransport): """Ollama provider using native Anthropic Messages API.""" @@ -16,7 +14,7 @@ class OllamaProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="OLLAMA", - default_base_url=OLLAMA_BASE_URL, + default_base_url=OLLAMA_DEFAULT_BASE, ) self._api_key = config.api_key or "ollama" diff --git a/providers/open_router/__init__.py b/providers/open_router/__init__.py index 7a0cdad..df12496 100644 --- a/providers/open_router/__init__.py +++ b/providers/open_router/__init__.py @@ -1,5 +1,7 @@ """OpenRouter provider - Anthropic-compatible native transport.""" -from .client import OPENROUTER_BASE_URL, OpenRouterProvider +from providers.defaults import OPENROUTER_DEFAULT_BASE -__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"] +from .client import OpenRouterProvider + +__all__ = ["OPENROUTER_DEFAULT_BASE", "OpenRouterProvider"] diff --git a/providers/open_router/client.py b/providers/open_router/client.py index d5c2081..032df60 100644 --- a/providers/open_router/client.py +++ b/providers/open_router/client.py @@ -2,34 +2,25 @@ from __future__ import annotations -import json -import uuid from collections.abc import Iterator -from dataclasses import dataclass, field from typing import Any -from core.anthropic import SSEBuilder, append_request_id +from core.anthropic import append_request_id, iter_provider_stream_error_sse_events +from core.anthropic.native_sse_block_policy import ( + NativeSseBlockPolicyState, + is_terminal_openrouter_done_event, + parse_native_sse_event, + transform_native_sse_block_event, +) from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode from providers.base import ProviderConfig -from providers.defaults import OPENROUTER_BASE_URL +from providers.defaults import OPENROUTER_DEFAULT_BASE from .request import build_request_body _ANTHROPIC_VERSION = "2023-06-01" -@dataclass -class _SSEFilterState: - """Track Anthropic content block index remapping while filtering thinking.""" - - next_index: int = 0 - index_map: dict[int, int] = field(default_factory=dict) - dropped_indexes: set[int] = field(default_factory=set) - open_block_types: dict[int, str] = field(default_factory=dict) - closed_indexes: set[int] = field(default_factory=set) - message_stopped: bool = False - - class OpenRouterProvider(AnthropicMessagesTransport): """OpenRouter provider using the native Anthropic-compatible messages API.""" @@ -39,7 +30,7 @@ class OpenRouterProvider(AnthropicMessagesTransport): super().__init__( config, provider_name="OPENROUTER", - default_base_url=OPENROUTER_BASE_URL, + default_base_url=OPENROUTER_DEFAULT_BASE, ) def _build_request_body( @@ -60,163 +51,9 @@ class OpenRouterProvider(AnthropicMessagesTransport): "anthropic-version": _ANTHROPIC_VERSION, } - @staticmethod - def _format_sse_event(event_name: str | None, data_text: str) -> str: - """Format an SSE event from its event name and data payload.""" - lines: list[str] = [] - if event_name: - lines.append(f"event: {event_name}") - lines.extend(f"data: {line}" for line in data_text.splitlines()) - return "\n".join(lines) + "\n\n" - - @staticmethod - def _parse_sse_event(event: str) -> tuple[str | None, str]: - """Extract the event name and raw data payload from an SSE event.""" - event_name = None - data_lines: list[str] = [] - for line in event.strip().splitlines(): - if line.startswith("event:"): - event_name = line[6:].strip() - elif line.startswith("data:"): - data_lines.append(line[5:].lstrip()) - return event_name, "\n".join(data_lines) - - @staticmethod - def _is_terminal_done_event(event_name: str | None, data_text: str) -> bool: - """Return whether an event is OpenAI-style terminal noise.""" - return (event_name is None or event_name in {"data", "done"}) and ( - data_text.strip().upper() == "[DONE]" - ) - - @staticmethod - def _remap_index( - payload: dict[str, Any], state: _SSEFilterState, *, create: bool - ) -> int | None: - """Return the downstream index for a content block event.""" - upstream_index = payload.get("index") - if not isinstance(upstream_index, int): - return None - if upstream_index in state.dropped_indexes: - return None - mapped_index = state.index_map.get(upstream_index) - if mapped_index is None and create: - mapped_index = state.next_index - state.index_map[upstream_index] = mapped_index - state.next_index += 1 - return mapped_index - - def _close_open_blocks_before( - self, state: _SSEFilterState, upstream_index: int - ) -> str: - """Close overlapping upstream blocks before starting a new block.""" - events: list[str] = [] - for open_upstream_index in list(state.open_block_types): - if open_upstream_index == upstream_index: - continue - mapped_index = state.index_map.get(open_upstream_index) - if mapped_index is None: - continue - payload = {"type": "content_block_stop", "index": mapped_index} - events.append( - self._format_sse_event("content_block_stop", json.dumps(payload)) - ) - state.closed_indexes.add(open_upstream_index) - state.open_block_types.pop(open_upstream_index, None) - return "".join(events) - - @staticmethod - def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool: - if not isinstance(block_type, str): - return False - if block_type.startswith("redacted_thinking"): - return True - return not thinking_enabled and "thinking" in block_type - - def _transform_sse_payload( - self, - event: str, - state: _SSEFilterState, - *, - thinking_enabled: bool, - ) -> str | None: - """Normalize OpenRouter SSE events and enforce local thinking policy.""" - event_name, data_text = self._parse_sse_event(event) - if not event_name or not data_text: - return event - - try: - payload = json.loads(data_text) - except json.JSONDecodeError: - return event - - if event_name == "content_block_start": - block = payload.get("content_block") - if not isinstance(block, dict): - return event - block_type = block.get("type") - upstream_index = payload.get("index") - if self._should_drop_block_type( - block_type, thinking_enabled=thinking_enabled - ): - if isinstance(upstream_index, int): - state.dropped_indexes.add(upstream_index) - return None - - mapped_index = self._remap_index(payload, state, create=True) - if mapped_index is not None: - payload["index"] = mapped_index - if isinstance(upstream_index, int) and isinstance(block_type, str): - prefix = self._close_open_blocks_before(state, upstream_index) - state.open_block_types[upstream_index] = block_type - return prefix + self._format_sse_event( - event_name, json.dumps(payload) - ) - return self._format_sse_event(event_name, json.dumps(payload)) - return None if not thinking_enabled else event - - if event_name == "content_block_delta": - delta = payload.get("delta") - if not isinstance(delta, dict): - return event - delta_type = delta.get("type") - if self._should_drop_block_type( - delta_type, thinking_enabled=thinking_enabled - ): - return None - - mapped_index = self._remap_index(payload, state, create=False) - if mapped_index is not None: - payload["index"] = mapped_index - return self._format_sse_event(event_name, json.dumps(payload)) - if payload.get("index") in state.dropped_indexes: - return None - if not thinking_enabled: - return None - - if event_name == "content_block_stop": - upstream_index = payload.get("index") - if ( - isinstance(upstream_index, int) - and upstream_index in state.closed_indexes - ): - state.closed_indexes.discard(upstream_index) - return None - mapped_index = self._remap_index(payload, state, create=False) - if mapped_index is not None: - payload["index"] = mapped_index - if isinstance(upstream_index, int): - state.open_block_types.pop(upstream_index, None) - return self._format_sse_event(event_name, json.dumps(payload)) - if payload.get("index") in state.dropped_indexes: - return None - if not thinking_enabled: - return None - - return event - def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any: """Create per-stream state for thinking block filtering.""" - return _SSEFilterState() + return NativeSseBlockPolicyState() def _transform_stream_event( self, @@ -226,23 +63,17 @@ class OpenRouterProvider(AnthropicMessagesTransport): thinking_enabled: bool, ) -> str | None: """Drop provider-specific terminal noise and hidden thinking events.""" - if isinstance(state, _SSEFilterState): - event_name, data_text = self._parse_sse_event(event) - if state.message_stopped or self._is_terminal_done_event( + if isinstance(state, NativeSseBlockPolicyState): + event_name, data_text = parse_native_sse_event(event) + if state.message_stopped or is_terminal_openrouter_done_event( event_name, data_text ): return None if event_name == "message_stop": state.message_stopped = True - if thinking_enabled: - if isinstance(state, _SSEFilterState): - return self._transform_sse_payload( - event, state, thinking_enabled=thinking_enabled - ) - return event - if isinstance(state, _SSEFilterState): - return self._transform_sse_payload( + if isinstance(state, NativeSseBlockPolicyState): + return transform_native_sse_block_event( event, state, thinking_enabled=thinking_enabled ) return event @@ -260,9 +91,10 @@ class OpenRouterProvider(AnthropicMessagesTransport): sent_any_event: bool, ) -> Iterator[str]: """Emit the Anthropic SSE error shape expected by Claude clients.""" - sse = SSEBuilder(f"msg_{uuid.uuid4()}", request.model, input_tokens) - if not sent_any_event: - yield sse.message_start() - yield from sse.emit_error(error_message) - yield sse.message_delta("end_turn", 1) - yield sse.message_stop() + yield from iter_provider_stream_error_sse_events( + request=request, + input_tokens=input_tokens, + error_message=error_message, + sent_any_event=sent_any_event, + log_raw_sse_events=self._config.log_raw_sse_events, + ) diff --git a/providers/open_router/request.py b/providers/open_router/request.py index 89759ab..abd93df 100644 --- a/providers/open_router/request.py +++ b/providers/open_router/request.py @@ -2,136 +2,18 @@ from __future__ import annotations -from collections.abc import Sequence from typing import Any from loguru import logger -from pydantic import BaseModel -OPENROUTER_DEFAULT_MAX_TOKENS = 81920 - -_REQUEST_FIELDS = ( - "model", - "messages", - "system", - "max_tokens", - "stop_sequences", - "stream", - "temperature", - "top_p", - "top_k", - "metadata", - "tools", - "tool_choice", - "thinking", - "extra_body", - "original_model", - "resolved_provider_model", +from config.constants import ( + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS as OPENROUTER_DEFAULT_MAX_TOKENS, ) - -_INTERNAL_FIELDS = { - "thinking", - "extra_body", - "original_model", - "resolved_provider_model", -} -_THINKING_HISTORY_BLOCK_TYPES = {"thinking", "redacted_thinking"} - - -def _serialize_value(value: Any) -> Any: - """Convert Pydantic models and lightweight objects into JSON-ready values.""" - if isinstance(value, BaseModel): - return value.model_dump(exclude_none=True) - if isinstance(value, dict): - return { - key: _serialize_value(item) - for key, item in value.items() - if item is not None - } - if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray): - return [_serialize_value(item) for item in value] - if value is None or isinstance(value, str | int | float | bool): - return value - if hasattr(value, "__dict__"): - return { - key: _serialize_value(item) - for key, item in vars(value).items() - if not key.startswith("_") and item is not None - } - return value - - -def _dump_request_fields(request_data: Any) -> dict[str, Any]: - """Extract the public request fields forwarded to OpenRouter.""" - if isinstance(request_data, BaseModel): - return request_data.model_dump(exclude_none=True) - - dumped: dict[str, Any] = {} - for field in _REQUEST_FIELDS: - value = getattr(request_data, field, None) - if value is not None: - dumped[field] = _serialize_value(value) - return dumped - - -def _strip_unsigned_thinking_history(messages: Any) -> Any: - """Remove assistant thinking history blocks that OpenRouter cannot replay.""" - if not isinstance(messages, list): - return messages - - sanitized_messages: list[Any] = [] - for message in messages: - if not isinstance(message, dict): - sanitized_messages.append(message) - continue - - content = message.get("content") - if not isinstance(content, list): - sanitized_messages.append(message) - continue - - sanitized_content = [ - block - for block in content - if not ( - isinstance(block, dict) - and block.get("type") in _THINKING_HISTORY_BLOCK_TYPES - and not isinstance(block.get("signature"), str) - ) - ] - - sanitized_message = dict(message) - sanitized_message["content"] = sanitized_content or "" - sanitized_messages.append(sanitized_message) - - return sanitized_messages - - -def _normalize_system_prompt(system: Any) -> Any: - """Flatten Claude SDK system blocks for OpenRouter's native endpoint.""" - if not isinstance(system, list): - return system - - text_parts: list[str] = [] - for block in system: - if not isinstance(block, dict): - continue - if block.get("type") == "text" and isinstance(block.get("text"), str): - text_parts.append(block["text"]) - return "\n\n".join(text_parts).strip() if text_parts else system - - -def _apply_reasoning(body: dict[str, Any], thinking_cfg: Any) -> None: - """Map Anthropic thinking controls onto OpenRouter reasoning controls.""" - reasoning = body.setdefault("reasoning", {"enabled": True}) - if not isinstance(reasoning, dict): - return - reasoning.setdefault("enabled", True) - if not isinstance(thinking_cfg, dict): - return - budget_tokens = thinking_cfg.get("budget_tokens") - if isinstance(budget_tokens, int): - reasoning.setdefault("max_tokens", budget_tokens) +from core.anthropic.native_messages_request import ( + OpenRouterExtraBodyError, + build_openrouter_native_request_body, +) +from providers.exceptions import InvalidRequestError def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: @@ -142,27 +24,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: len(getattr(request_data, "messages", [])), ) - dumped_request = _dump_request_fields(request_data) - request_extra = dumped_request.pop("extra_body", None) - thinking_cfg = dumped_request.get("thinking") - body = { - key: value - for key, value in dumped_request.items() - if key not in _INTERNAL_FIELDS - } - - if isinstance(request_extra, dict): - body.update(request_extra) - - body["messages"] = _strip_unsigned_thinking_history(body.get("messages")) - if "system" in body: - body["system"] = _normalize_system_prompt(body["system"]) - body["stream"] = True - if body.get("max_tokens") is None: - body["max_tokens"] = OPENROUTER_DEFAULT_MAX_TOKENS - - if thinking_enabled: - _apply_reasoning(body, thinking_cfg) + try: + body = build_openrouter_native_request_body( + request_data, + thinking_enabled=thinking_enabled, + default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS, + ) + except OpenRouterExtraBodyError as exc: + raise InvalidRequestError(str(exc)) from exc logger.debug( "OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}", diff --git a/providers/openai_compat.py b/providers/openai_compat.py index d729896..31c1d7e 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -21,14 +21,40 @@ from core.anthropic import ( SSEBuilder, ThinkTagParser, append_request_id, - get_user_facing_error_message, map_stop_reason, ) from providers.base import BaseProvider, ProviderConfig -from providers.error_mapping import map_error +from providers.error_mapping import ( + map_error, + user_visible_message_for_mapped_provider_error, +) from providers.rate_limit import GlobalRateLimiter +def _iter_heuristic_tool_use_sse( + sse: SSEBuilder, tool_use: dict[str, Any] +) -> Iterator[str]: + """Emit SSE for one heuristic tool_use block (closes open text/thinking first).""" + if tool_use.get("name") == "Task" and isinstance(tool_use.get("input"), dict): + task_input = tool_use["input"] + if task_input.get("run_in_background") is not False: + task_input["run_in_background"] = False + yield from sse.close_content_blocks() + block_idx = sse.blocks.allocate_index() + yield sse.content_block_start( + block_idx, + "tool_use", + id=tool_use["id"], + name=tool_use["name"], + ) + yield sse.content_block_delta( + block_idx, + "input_json_delta", + json.dumps(tool_use["input"]), + ) + yield sse.content_block_stop(block_idx) + + class OpenAIChatTransport(BaseProvider): """Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …).""" @@ -113,6 +139,22 @@ class OpenAIChatTransport(BaseProvider): ) return stream, retry_body + def _emit_tool_arg_delta( + self, sse: SSEBuilder, tc_index: int, args: str + ) -> Iterator[str]: + """Emit one argument fragment for a started tool block (Task buffer or raw JSON).""" + if not args: + return + state = sse.blocks.tool_states.get(tc_index) + if state is None: + return + if state.name == "Task": + parsed = sse.blocks.buffer_task_args(tc_index, args) + if parsed is not None: + yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) + return + yield sse.emit_tool_delta(tc_index, args) + def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]: """Process a single tool call delta and yield SSE events.""" tc_index = tc.get("index", 0) @@ -121,34 +163,42 @@ class OpenAIChatTransport(BaseProvider): fn_delta = tc.get("function", {}) incoming_name = fn_delta.get("name") - arguments = fn_delta.get("arguments", "") + arguments = fn_delta.get("arguments", "") or "" + + if tc.get("id") is not None: + sse.blocks.set_stream_tool_id(tc_index, tc.get("id")) + if incoming_name is not None: sse.blocks.register_tool_name(tc_index, incoming_name) state = sse.blocks.tool_states.get(tc_index) + resolved_id = (state.tool_id if state and state.tool_id else None) or tc.get( + "id" + ) + resolved_name = (state.name if state else "") or "" + + if not state or not state.started: + name_ok = bool((resolved_name or "").strip()) + if name_ok: + tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}" + display_name = (resolved_name or "").strip() or "tool_call" + yield sse.start_tool_block(tc_index, tool_id, display_name) + state = sse.blocks.tool_states[tc_index] + if state.pre_start_args: + pre = state.pre_start_args + state.pre_start_args = "" + yield from self._emit_tool_arg_delta(sse, tc_index, pre) + + state = sse.blocks.tool_states.get(tc_index) + if not arguments: + return if state is None or not state.started: - name = state.name if state else "" - if name or tc.get("id"): - tool_id = tc.get("id") or f"tool_{uuid.uuid4()}" - yield sse.start_tool_block(tc_index, tool_id, name) - - args = arguments - if args: - state = sse.blocks.tool_states.get(tc_index) - if state is None or not state.started: - tool_id = tc.get("id") or f"tool_{uuid.uuid4()}" - name = (state.name if state else None) or "tool_call" - yield sse.start_tool_block(tc_index, tool_id, name) - state = sse.blocks.tool_states.get(tc_index) - - current_name = state.name if state else "" - if current_name == "Task": - parsed = sse.blocks.buffer_task_args(tc_index, args) - if parsed is not None: - yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) + state = sse.blocks.ensure_tool_state(tc_index) + if not (resolved_name or "").strip(): + state.pre_start_args += arguments return - yield sse.emit_tool_delta(tc_index, args) + yield from self._emit_tool_arg_delta(sse, tc_index, arguments) def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]: """Emit buffered Task args as a single JSON delta (best-effort).""" @@ -181,7 +231,12 @@ class OpenAIChatTransport(BaseProvider): """Shared streaming implementation.""" tag = self._provider_name message_id = f"msg_{uuid.uuid4()}" - sse = SSEBuilder(message_id, request.model, input_tokens) + sse = SSEBuilder( + message_id, + request.model, + input_tokens, + log_raw_events=self._config.log_raw_sse_events, + ) body = self._build_request_body(request, thinking_enabled=thinking_enabled) thinking_enabled = self._is_thinking_enabled(request, thinking_enabled) @@ -201,8 +256,6 @@ class OpenAIChatTransport(BaseProvider): heuristic_parser = HeuristicToolParser() finish_reason = None usage_info = None - error_occurred = False - error_message = "" async with self._global_rate_limiter.concurrency_slot(): try: @@ -258,26 +311,10 @@ class OpenAIChatTransport(BaseProvider): yield sse.emit_text_delta(filtered_text) for tool_use in detected_tools: - for event in sse.close_content_blocks(): - yield event - - block_idx = sse.blocks.allocate_index() - if tool_use.get("name") == "Task" and isinstance( - tool_use.get("input"), dict + for event in _iter_heuristic_tool_use_sse( + sse, tool_use ): - tool_use["input"]["run_in_background"] = False - yield sse.content_block_start( - block_idx, - "tool_use", - id=tool_use["id"], - name=tool_use["name"], - ) - yield sse.content_block_delta( - block_idx, - "input_json_delta", - json.dumps(tool_use["input"]), - ) - yield sse.content_block_stop(block_idx) + yield event # Handle native tool calls if delta.tool_calls: @@ -298,18 +335,13 @@ class OpenAIChatTransport(BaseProvider): except asyncio.CancelledError, GeneratorExit: raise except Exception as e: - logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e) + self._log_stream_transport_error(tag, req_tag, e) mapped_e = map_error(e, rate_limiter=self._global_rate_limiter) - error_occurred = True - if getattr(mapped_e, "status_code", None) == 405: - base_message = ( - f"Upstream provider {tag} rejected the request method " - "or endpoint (HTTP 405)." - ) - else: - base_message = get_user_facing_error_message( - mapped_e, read_timeout_s=self._config.http_read_timeout - ) + base_message = user_visible_message_for_mapped_provider_error( + mapped_e, + provider_name=tag, + read_timeout_s=self._config.http_read_timeout, + ) error_message = append_request_id(base_message, request_id) logger.info( "{}_STREAM: Emitting SSE error event for {}{}", @@ -317,10 +349,13 @@ class OpenAIChatTransport(BaseProvider): type(e).__name__, req_tag, ) - for event in sse.close_content_blocks(): + for event in sse.close_all_blocks(): yield event for event in sse.emit_error(error_message): yield event + yield sse.message_delta("end_turn", 1) + yield sse.message_stop() + return # Flush remaining content remaining = think_parser.flush() @@ -338,32 +373,16 @@ class OpenAIChatTransport(BaseProvider): yield sse.emit_text_delta(remaining.content) for tool_use in heuristic_parser.flush(): - for event in sse.close_content_blocks(): + for event in _iter_heuristic_tool_use_sse(sse, tool_use): yield event - block_idx = sse.blocks.allocate_index() - yield sse.content_block_start( - block_idx, - "tool_use", - id=tool_use["id"], - name=tool_use["name"], - ) - if tool_use.get("name") == "Task" and isinstance( - tool_use.get("input"), dict - ): - tool_use["input"]["run_in_background"] = False - yield sse.content_block_delta( - block_idx, - "input_json_delta", - json.dumps(tool_use["input"]), - ) - yield sse.content_block_stop(block_idx) - - if ( - not error_occurred - and sse.blocks.text_index == -1 - and not sse.blocks.tool_states - ): + has_started_tool = any(s.started for s in sse.blocks.tool_states.values()) + has_content_blocks = ( + sse.blocks.text_index != -1 + or sse.blocks.thinking_index != -1 + or has_started_tool + ) + if not has_content_blocks: for event in sse.ensure_text_block(): yield event yield sse.emit_text_delta(" ") diff --git a/providers/rate_limit.py b/providers/rate_limit.py index 0e427e9..59e4291 100644 --- a/providers/rate_limit.py +++ b/providers/rate_limit.py @@ -3,14 +3,16 @@ import asyncio import random import time -from collections import deque from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from typing import Any, ClassVar, TypeVar +import httpx import openai from loguru import logger +from core.rate_limit import StrictSlidingWindowLimiter + T = TypeVar("T") @@ -51,10 +53,10 @@ class GlobalRateLimiter: self._rate_limit = rate_limit self._rate_window = float(rate_window) self._max_concurrency = max_concurrency - # Monotonic timestamps of the last granted slots. - self._request_times: deque[float] = deque() + self._proactive_limiter = StrictSlidingWindowLimiter( + self._rate_limit, self._rate_window + ) self._blocked_until: float = 0 - self._lock = asyncio.Lock() self._concurrency_sem = asyncio.Semaphore(max_concurrency) self._initialized = True @@ -107,12 +109,11 @@ class GlobalRateLimiter: logger.info( "Rebuilding provider rate limiter for updated scope '{}'", scope ) - if scope not in cls._scoped_instances or existing: - cls._scoped_instances[scope] = cls( - rate_limit=desired_rate_limit, - rate_window=desired_rate_window, - max_concurrency=max_concurrency, - ) + cls._scoped_instances[scope] = cls( + rate_limit=desired_rate_limit, + rate_window=desired_rate_window, + max_concurrency=max_concurrency, + ) return cls._scoped_instances[scope] @classmethod @@ -150,27 +151,7 @@ class GlobalRateLimiter: Guarantees: at most `self._rate_limit` acquisitions in any interval of length `self._rate_window` (seconds). """ - while True: - wait_time = 0.0 - async with self._lock: - now = time.monotonic() - cutoff = now - self._rate_window - - while self._request_times and self._request_times[0] <= cutoff: - self._request_times.popleft() - - if len(self._request_times) < self._rate_limit: - self._request_times.append(now) - return - - oldest = self._request_times[0] - wait_time = max(0.0, (oldest + self._rate_window) - now) - - # Sleep outside the lock so other tasks can continue to queue. - if wait_time > 0: - await asyncio.sleep(wait_time) - else: - await asyncio.sleep(0) + await self._proactive_limiter.acquire() def set_blocked(self, seconds: float = 60) -> None: """ @@ -263,6 +244,24 @@ class GlobalRateLimiter: ) self.set_blocked(delay) await asyncio.sleep(delay) + except httpx.HTTPStatusError as e: + if e.response.status_code != 429: + raise + last_exc = e + if attempt >= max_retries: + logger.warning( + f"HTTP 429 retry exhausted after {max_retries} retries" + ) + break + + delay = min(base_delay * (2**attempt), max_delay) + delay += random.uniform(0, jitter) + logger.warning( + f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. " + f"Retrying in {delay:.1f}s..." + ) + self.set_blocked(delay) + await asyncio.sleep(delay) assert last_exc is not None raise last_exc diff --git a/providers/registry.py b/providers/registry.py index 75a6d8c..20ef4a0 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -3,108 +3,20 @@ from __future__ import annotations from collections.abc import Callable, MutableMapping -from dataclasses import dataclass -from typing import Literal -from config.provider_ids import SUPPORTED_PROVIDER_IDS +from config.provider_catalog import ( + PROVIDER_CATALOG, + SUPPORTED_PROVIDER_IDS, + ProviderDescriptor, +) from config.settings import Settings from providers.base import BaseProvider, ProviderConfig -from providers.defaults import ( - DEEPSEEK_DEFAULT_BASE, - LLAMACPP_DEFAULT_BASE, - LMSTUDIO_DEFAULT_BASE, - NVIDIA_NIM_DEFAULT_BASE, - OLLAMA_DEFAULT_BASE, - OPENROUTER_DEFAULT_BASE, -) from providers.exceptions import AuthenticationError, UnknownProviderTypeError -TransportType = Literal["openai_chat", "anthropic_messages"] ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider] - -@dataclass(frozen=True, slots=True) -class ProviderDescriptor: - """Metadata for building :class:`ProviderConfig` and factory wiring.""" - - provider_id: str - transport_type: TransportType - capabilities: tuple[str, ...] - credential_env: str | None = None - credential_url: str | None = None - # If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key). - credential_attr: str | None = None - # If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp). - static_credential: str | None = None - default_base_url: str | None = None - base_url_attr: str | None = None - proxy_attr: str | None = None - - -PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { - "nvidia_nim": ProviderDescriptor( - provider_id="nvidia_nim", - transport_type="openai_chat", - credential_env="NVIDIA_NIM_API_KEY", - credential_url="https://build.nvidia.com/settings/api-keys", - credential_attr="nvidia_nim_api_key", - default_base_url=NVIDIA_NIM_DEFAULT_BASE, - proxy_attr="nvidia_nim_proxy", - capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), - ), - "open_router": ProviderDescriptor( - provider_id="open_router", - transport_type="anthropic_messages", - credential_env="OPENROUTER_API_KEY", - credential_url="https://openrouter.ai/keys", - credential_attr="open_router_api_key", - default_base_url=OPENROUTER_DEFAULT_BASE, - proxy_attr="open_router_proxy", - capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"), - ), - "deepseek": ProviderDescriptor( - provider_id="deepseek", - transport_type="openai_chat", - credential_env="DEEPSEEK_API_KEY", - credential_url="https://platform.deepseek.com/api_keys", - credential_attr="deepseek_api_key", - default_base_url=DEEPSEEK_DEFAULT_BASE, - capabilities=("chat", "streaming", "thinking"), - ), - "lmstudio": ProviderDescriptor( - provider_id="lmstudio", - transport_type="anthropic_messages", - static_credential="lm-studio", - default_base_url=LMSTUDIO_DEFAULT_BASE, - base_url_attr="lm_studio_base_url", - proxy_attr="lmstudio_proxy", - capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), - ), - "llamacpp": ProviderDescriptor( - provider_id="llamacpp", - transport_type="anthropic_messages", - static_credential="llamacpp", - default_base_url=LLAMACPP_DEFAULT_BASE, - base_url_attr="llamacpp_base_url", - proxy_attr="llamacpp_proxy", - capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), - ), - "ollama": ProviderDescriptor( - provider_id="ollama", - transport_type="anthropic_messages", - static_credential="ollama", - default_base_url=OLLAMA_DEFAULT_BASE, - base_url_attr="ollama_base_url", - capabilities=( - "chat", - "streaming", - "tools", - "thinking", - "native_anthropic", - "local", - ), - ), -} +# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``). +PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider: @@ -119,25 +31,25 @@ def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProv return OpenRouterProvider(config) -def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.deepseek import DeepSeekProvider return DeepSeekProvider(config) -def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.lmstudio import LMStudioProvider return LMStudioProvider(config) -def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.llamacpp import LlamaCppProvider return LlamaCppProvider(config) -def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider: +def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider: from providers.ollama import OllamaProvider return OllamaProvider(config) @@ -208,6 +120,8 @@ def build_provider_config( http_connect_timeout=settings.http_connect_timeout, enable_thinking=settings.enable_model_thinking, proxy=proxy, + log_raw_sse_events=settings.log_raw_sse_events, + log_api_error_tracebacks=settings.log_api_error_tracebacks, ) diff --git a/server.py b/server.py index b025be0..958e730 100644 --- a/server.py +++ b/server.py @@ -1,11 +1,13 @@ """ Claude Code Proxy - Entry Point -Minimal entry point that imports the app from the api module. +Minimal entry point that builds the ASGI app via :func:`api.app.create_app`. Run with: uv run uvicorn server:app --host 0.0.0.0 --port 8082 --timeout-graceful-shutdown 5 """ -from api.app import app, create_app +from api.app import create_app + +app = create_app() __all__ = ["app", "create_app"] diff --git a/smoke/capabilities.py b/smoke/capabilities.py index 5f29fff..eafae44 100644 --- a/smoke/capabilities.py +++ b/smoke/capabilities.py @@ -98,7 +98,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = ( "providers.registry.ProviderRegistry", "provider id and Settings", "configured BaseProvider instance", - "503 for missing credentials; ValueError for unknown provider", + "503 for missing credentials; invalid_request_error for unknown provider", ("tests/api/test_dependencies.py", "tests/providers/test_registry.py"), ("test_configured_provider_models_stream_successfully",), ), diff --git a/smoke/lib/e2e.py b/smoke/lib/e2e.py index e669b51..8b20e80 100644 --- a/smoke/lib/e2e.py +++ b/smoke/lib/e2e.py @@ -19,6 +19,14 @@ import httpx import pytest from config.provider_ids import SUPPORTED_PROVIDER_IDS +from core.anthropic.stream_contracts import ( + SSEEvent, + assert_anthropic_stream_contract, + event_index, + has_tool_use, + parse_sse_lines, + text_content, +) from messaging.handler import ClaudeMessageHandler from messaging.models import IncomingMessage from messaging.platforms.base import MessagingPlatform @@ -26,13 +34,6 @@ from messaging.session import SessionStore from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers from smoke.lib.server import RunningServer, start_server from smoke.lib.skips import fail_missing_env -from smoke.lib.sse import ( - SSEEvent, - assert_anthropic_stream_contract, - has_tool_use, - parse_sse_lines, - text_content, -) @dataclass(slots=True) @@ -590,14 +591,14 @@ def assistant_content_from_events(events: list[SSEEvent]) -> list[dict[str, Any] block_order: list[int] = [] for event in events: if event.event == "content_block_start": - index = _event_index(event) + index = event_index(event) block = event.data.get("content_block", {}) if isinstance(block, dict): blocks[index] = dict(block) block_order.append(index) continue if event.event == "content_block_delta": - index = _event_index(event) + index = event_index(event) block = blocks.get(index) delta = event.data.get("delta", {}) if not isinstance(block, dict) or not isinstance(delta, dict): @@ -663,12 +664,6 @@ def default_cli_events(session_id: str) -> list[dict[str, Any]]: ] -def _event_index(event: SSEEvent) -> int: - value = event.data.get("index") - assert isinstance(value, int), event.data - return value - - def assert_product_stream(events: list[SSEEvent]) -> None: assert_anthropic_stream_contract(events) assert text_content(events).strip() or has_tool_use(events), ( diff --git a/smoke/lib/http.py b/smoke/lib/http.py index f019473..71e2cb0 100644 --- a/smoke/lib/http.py +++ b/smoke/lib/http.py @@ -6,9 +6,10 @@ from typing import Any import httpx +from core.anthropic.stream_contracts import SSEEvent, parse_sse_lines + from .config import SmokeConfig, auth_headers, redacted from .server import RunningServer -from .sse import SSEEvent, parse_sse_lines def message_payload( diff --git a/smoke/lib/skips.py b/smoke/lib/skips.py index ffadb85..1cb1e29 100644 --- a/smoke/lib/skips.py +++ b/smoke/lib/skips.py @@ -5,7 +5,7 @@ from __future__ import annotations import httpx import pytest -from smoke.lib.sse import SSEEvent +from core.anthropic.stream_contracts import SSEEvent UPSTREAM_UNAVAILABLE_MARKERS = ( "connection refused", diff --git a/smoke/lib/sse.py b/smoke/lib/sse.py deleted file mode 100644 index a294ce5..0000000 --- a/smoke/lib/sse.py +++ /dev/null @@ -1,29 +0,0 @@ -"""SSE parsing and Anthropic stream assertions for smoke tests. - -Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this -module re-exports it for smoke and historical import paths. -""" - -from __future__ import annotations - -from core.anthropic.stream_contracts import ( - SSEEvent, - assert_anthropic_stream_contract, - event_names, - has_tool_use, - parse_sse_lines, - parse_sse_text, - text_content, - thinking_content, -) - -__all__ = [ - "SSEEvent", - "assert_anthropic_stream_contract", - "event_names", - "has_tool_use", - "parse_sse_lines", - "parse_sse_text", - "text_content", - "thinking_content", -] diff --git a/smoke/prereq/test_provider_prereq_live.py b/smoke/prereq/test_provider_prereq_live.py index 8731278..0d85356 100644 --- a/smoke/prereq/test_provider_prereq_live.py +++ b/smoke/prereq/test_provider_prereq_live.py @@ -5,6 +5,10 @@ import time import httpx import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + text_content, +) from smoke.lib.config import SmokeConfig, auth_headers from smoke.lib.http import collect_message_stream, message_payload from smoke.lib.server import start_server @@ -12,7 +16,6 @@ from smoke.lib.skips import ( skip_if_upstream_unavailable_events, skip_if_upstream_unavailable_exception, ) -from smoke.lib.sse import assert_anthropic_stream_contract, text_content pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")] diff --git a/smoke/prereq/test_tools_prereq_live.py b/smoke/prereq/test_tools_prereq_live.py index 824197b..716678c 100644 --- a/smoke/prereq/test_tools_prereq_live.py +++ b/smoke/prereq/test_tools_prereq_live.py @@ -2,11 +2,14 @@ from __future__ import annotations import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + has_tool_use, +) from smoke.lib.config import SmokeConfig from smoke.lib.http import collect_message_stream, message_payload from smoke.lib.server import start_server from smoke.lib.skips import skip_if_upstream_unavailable_events -from smoke.lib.sse import assert_anthropic_stream_contract, has_tool_use pytestmark = [pytest.mark.live, pytest.mark.smoke_target("tools")] diff --git a/smoke/prereq/test_voice_prereq_live.py b/smoke/prereq/test_voice_prereq_live.py index 21e4dcb..50ce9e5 100644 --- a/smoke/prereq/test_voice_prereq_live.py +++ b/smoke/prereq/test_voice_prereq_live.py @@ -24,12 +24,13 @@ def test_voice_transcription_backend_when_explicitly_enabled( wav_path = tmp_path / "smoke-tone.wav" _write_tone_wav(wav_path) try: - text = transcribe_audio( - wav_path, - "audio/wav", - whisper_model=smoke_config.settings.whisper_model, - whisper_device=smoke_config.settings.whisper_device, - ) + t_kw: dict[str, str] = { + "whisper_model": smoke_config.settings.whisper_model, + "whisper_device": smoke_config.settings.whisper_device, + } + if smoke_config.settings.whisper_device == "nvidia_nim": + t_kw["nvidia_nim_api_key"] = smoke_config.settings.nvidia_nim_api_key + text = transcribe_audio(wav_path, "audio/wav", **t_kw) except ImportError as exc: pytest.skip(str(exc)) assert isinstance(text, str) diff --git a/smoke/product/test_provider_product_live.py b/smoke/product/test_provider_product_live.py index ac61839..be5d45e 100644 --- a/smoke/product/test_provider_product_live.py +++ b/smoke/product/test_provider_product_live.py @@ -3,6 +3,11 @@ from __future__ import annotations import httpx import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_lines, + text_content, +) from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers from smoke.lib.e2e import ( ConversationDriver, @@ -16,11 +21,6 @@ from smoke.lib.skips import ( skip_if_upstream_unavailable_events, skip_if_upstream_unavailable_exception, ) -from smoke.lib.sse import ( - assert_anthropic_stream_contract, - parse_sse_lines, - text_content, -) pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")] diff --git a/smoke/product/test_voice_product_live.py b/smoke/product/test_voice_product_live.py index d1c27f5..0327168 100644 --- a/smoke/product/test_voice_product_live.py +++ b/smoke/product/test_voice_product_live.py @@ -55,6 +55,7 @@ def test_voice_nim_backend_e2e(smoke_config: SmokeConfig, tmp_path: Path) -> Non "audio/wav", whisper_model=smoke_config.settings.whisper_model, whisper_device="nvidia_nim", + nvidia_nim_api_key=smoke_config.settings.nvidia_nim_api_key, ) assert isinstance(text, str) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..10436b4 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite package.""" diff --git a/tests/api/test_anthropic_request_passthrough.py b/tests/api/test_anthropic_request_passthrough.py new file mode 100644 index 0000000..15ee3db --- /dev/null +++ b/tests/api/test_anthropic_request_passthrough.py @@ -0,0 +1,184 @@ +"""Pydantic passthrough of Anthropic protocol fields (e.g. ``cache_control``).""" + +from __future__ import annotations + +from api.models.anthropic import ( + ContentBlockServerToolUse, + ContentBlockText, + ContentBlockWebSearchToolResult, + Message, + MessagesRequest, + SystemContent, + Tool, +) +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS +from core.anthropic.native_messages_request import ( + build_base_native_anthropic_request_body, +) + + +def test_cache_control_preserved_on_parsed_user_text_system_and_tool() -> None: + raw = { + "model": "m", + "max_tokens": 20, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hi", + "cache_control": {"type": "ephemeral"}, + } + ], + } + ], + "system": [ + { + "type": "text", + "text": "be brief", + "cache_control": {"type": "ephemeral"}, + } + ], + "tools": [ + { + "name": "alpha", + "input_schema": {"type": "object"}, + "cache_control": {"type": "ephemeral"}, + } + ], + } + req = MessagesRequest.model_validate(raw) + block = req.messages[0].content[0] + assert isinstance(block, ContentBlockText) + assert block.model_dump()["cache_control"] == {"type": "ephemeral"} + s0 = req.system[0] if isinstance(req.system, list) else None + assert isinstance(s0, SystemContent) + assert s0.model_dump()["cache_control"] == {"type": "ephemeral"} + t0 = req.tools[0] if req.tools else None + assert isinstance(t0, Tool) + assert t0.model_dump()["cache_control"] == {"type": "ephemeral"} + + +def test_build_base_native_body_includes_cache_control() -> None: + req = MessagesRequest( + model="m", + max_tokens=20, + messages=[ + Message( + role="user", + content=[ + ContentBlockText.model_validate( + { + "type": "text", + "text": "x", + "cache_control": {"type": "ephemeral"}, + } + ) + ], + ) + ], + system=[ + SystemContent.model_validate( + { + "type": "text", + "text": "s", + "cache_control": {"type": "ephemeral"}, + } + ) + ], + tools=[ + Tool.model_validate( + { + "name": "n", + "input_schema": {"type": "object"}, + "cache_control": {"type": "ephemeral"}, + } + ) + ], + ) + body = build_base_native_anthropic_request_body( + req, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=False, + ) + assert body["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"} + assert body["system"][0]["cache_control"] == {"type": "ephemeral"} + assert body["tools"][0]["cache_control"] == {"type": "ephemeral"} + + +def test_pydantic_discriminator_still_distinguishes_blocks() -> None: + m = Message.model_validate( + { + "role": "user", + "content": [{"type": "text", "text": "a", "z": 1}], + } + ) + b = m.content[0] + assert isinstance(b, ContentBlockText) + assert b.model_dump()["z"] == 1 + + +def test_server_tool_assistant_blocks_round_trip_in_native_body() -> None: + """Local server-tool responses must parse as valid history for a follow-up request.""" + raw = { + "model": "m", + "max_tokens": 20, + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "server_tool_use", + "id": "srvtoolu_1", + "name": "web_search", + "input": {"query": "q"}, + }, + { + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_1", + "content": [ + { + "type": "web_search_result", + "title": "T", + "url": "https://example.com", + } + ], + }, + ], + } + ], + "mcp_servers": [{"type": "url", "url": "https://example.com/mcp"}], + } + req = MessagesRequest.model_validate(raw) + assert len(req.messages) == 1 + blocks = req.messages[0].content + assert isinstance(blocks, list) + assert isinstance(blocks[0], ContentBlockServerToolUse) + assert isinstance(blocks[1], ContentBlockWebSearchToolResult) + body = build_base_native_anthropic_request_body( + req, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=False, + ) + assert body["mcp_servers"][0]["type"] == "url" + assert body["messages"][0]["content"][0]["type"] == "server_tool_use" + assert body["messages"][0]["content"][1]["type"] == "web_search_tool_result" + + +def test_native_body_preserves_context_and_output_config() -> None: + raw = { + "model": "m", + "max_tokens": 20, + "messages": [{"role": "user", "content": "x"}], + "context_management": {"edits": [{"type": "clear"}]}, + "output_config": {"some": "hint"}, + } + req = MessagesRequest.model_validate(raw) + body = build_base_native_anthropic_request_body( + req, + default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + thinking_enabled=False, + ) + assert body["context_management"] == raw["context_management"] + assert body["output_config"] == raw["output_config"] diff --git a/tests/api/test_api.py b/tests/api/test_api.py index e4912e7..b387048 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -3,9 +3,11 @@ from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient -from api.app import app +from api.app import create_app from providers.nvidia_nim import NvidiaNimProvider +app = create_app() + # Mock provider mock_provider = MagicMock(spec=NvidiaNimProvider) @@ -171,7 +173,7 @@ def test_generic_exception_returns_500(client: TestClient): def test_generic_exception_with_status_code(client: TestClient): - """Generic exception with status_code attribute uses that status (getattr fallback).""" + """Unexpected errors always map to HTTP 500 (ignore ad-hoc status_code attrs).""" class ExceptionWithStatus(RuntimeError): def __init__(self, msg: str, status_code: int = 500): @@ -191,7 +193,7 @@ def test_generic_exception_with_status_code(client: TestClient): "stream": True, }, ) - assert response.status_code == 502 + assert response.status_code == 500 mock_provider.stream_response = _mock_stream_response diff --git a/tests/api/test_app_lifespan_and_errors.py b/tests/api/test_app_lifespan_and_errors.py index 126a869..adc81c5 100644 --- a/tests/api/test_app_lifespan_and_errors.py +++ b/tests/api/test_app_lifespan_and_errors.py @@ -17,6 +17,15 @@ _RUNTIME_EXTRAS = { "nvidia_nim_api_key": "", "claude_cli_bin": "claude", "uses_process_anthropic_auth_token": lambda: False, + "messaging_rate_limit": 1, + "messaging_rate_window": 1.0, + "max_message_log_entries_per_chat": None, + "debug_platform_edits": False, + "debug_subagent_stack": False, + "log_api_error_tracebacks": False, + "log_raw_messaging_content": False, + "log_raw_cli_diagnostics": False, + "log_messaging_error_details": False, } @@ -81,9 +90,50 @@ def test_create_app_provider_error_handler_returns_anthropic_format(): with TestClient(app) as client: resp = client.get("/raise_provider") assert resp.status_code == 401 - body = resp.json() - assert body["type"] == "error" - assert body["error"]["type"] == "authentication_error" + body = resp.json() + assert body["type"] == "error" + assert body["error"]["type"] == "authentication_error" + + +def test_create_app_provider_error_default_logs_exclude_provider_message(): + """Provider errors must not log exc.message by default.""" + from api.app import create_app + from providers.exceptions import AuthenticationError + + app = create_app() + secret = "provider-upstream-secret-detail" + + @app.get("/raise_provider_secret") + async def _raise(): + raise AuthenticationError(secret) + + api_app_mod = importlib.import_module("api.app") + settings = _app_settings( + messaging_platform="telegram", + telegram_bot_token=None, + allowed_telegram_user_id=None, + discord_bot_token=None, + allowed_discord_channels=None, + allowed_dir="", + claude_workspace="./agent_workspace", + host="127.0.0.1", + port=8082, + log_file="server.log", + log_api_error_tracebacks=False, + ) + with ( + patch.object(api_app_mod, "get_settings", return_value=settings), + patch.object(ProviderRegistry, "cleanup", new=AsyncMock()), + patch.object(api_app_mod.logger, "error") as log_err, + ): + with TestClient(app) as client: + resp = client.get("/raise_provider_secret") + assert resp.status_code == 401 + + blob = " ".join(str(a) for c in log_err.call_args_list for a in c.args) + blob += repr([c.kwargs for c in log_err.call_args_list]) + assert secret not in blob + assert "authentication_error" in blob def test_create_app_general_exception_handler_returns_500(): @@ -120,6 +170,50 @@ def test_create_app_general_exception_handler_returns_500(): assert body["error"]["type"] == "api_error" +def test_create_app_general_exception_default_logs_exclude_exception_message(): + """Unhandled errors must not log exception text by default (may echo user content).""" + from api.app import create_app + + app = create_app() + + secret = "user-provided-secret-token-xyzzy" + + @app.get("/raise_secret") + async def _raise_secret(): + raise ValueError(secret) + + api_app_mod = importlib.import_module("api.app") + settings = _app_settings( + messaging_platform="telegram", + telegram_bot_token=None, + allowed_telegram_user_id=None, + discord_bot_token=None, + allowed_discord_channels=None, + allowed_dir="", + claude_workspace="./agent_workspace", + host="127.0.0.1", + port=8082, + log_file="server.log", + log_api_error_tracebacks=False, + ) + with ( + patch.object(api_app_mod, "get_settings", return_value=settings), + patch.object(ProviderRegistry, "cleanup", new=AsyncMock()), + patch.object(api_app_mod.logger, "error") as log_err, + ): + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/raise_secret") + assert resp.status_code == 500 + + flattened: list[str] = [] + for call in log_err.call_args_list: + flattened.extend(str(arg) for arg in call.args) + flattened.append(repr(call.kwargs)) + blob = " ".join(flattened) + assert secret not in blob + assert "ValueError" in blob + + @pytest.mark.parametrize( "messaging_enabled", [True, False], ids=["with_platform", "no_platform"] ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e661873..99230a3 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -2,10 +2,12 @@ from unittest.mock import patch from fastapi.testclient import TestClient -from api.app import app +from api.app import create_app from api.dependencies import get_settings from config.settings import Settings +app = create_app() + def test_anthropic_auth_token_required_and_accepts_x_api_key(): client = TestClient(app) diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py index 2849667..5be4c26 100644 --- a/tests/api/test_dependencies.py +++ b/tests/api/test_dependencies.py @@ -16,7 +16,7 @@ from api.dependencies import ( ) from config.nim import NimSettings from providers.deepseek import DeepSeekProvider -from providers.exceptions import UnknownProviderTypeError +from providers.exceptions import ServiceUnavailableError, UnknownProviderTypeError from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NvidiaNimProvider from providers.ollama import OllamaProvider @@ -40,7 +40,7 @@ def _make_mock_settings(**overrides): mock.nim = NimSettings() mock.http_read_timeout = 300.0 mock.http_write_timeout = 10.0 - mock.http_connect_timeout = 2.0 + mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True for key, value in overrides.items(): setattr(mock, key, value) @@ -434,18 +434,17 @@ def test_resolve_provider_per_app_uses_separate_registries() -> None: assert p1 is not p2 -def test_resolve_provider_lazily_installs_registry() -> None: - """If app has no provider_registry, one is created on app.state.""" +def test_resolve_provider_missing_registry_raises_service_unavailable() -> None: + """HTTP apps must install app.state.provider_registry (e.g. via AppRuntime).""" with patch("api.dependencies.get_settings") as mock_settings: mock_settings.return_value = _make_mock_settings() settings = _make_mock_settings() app = SimpleNamespace(state=State()) assert getattr(app.state, "provider_registry", None) is None - resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) - reg = app.state.provider_registry - assert reg is not None - p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) - assert p2 is reg.get("nvidia_nim", settings) # same registry instance + with pytest.raises( + ServiceUnavailableError, match="Provider registry is not configured" + ): + resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None: diff --git a/tests/api/test_routes_optimizations.py b/tests/api/test_routes_optimizations.py index e0c058c..eb8d2dd 100644 --- a/tests/api/test_routes_optimizations.py +++ b/tests/api/test_routes_optimizations.py @@ -3,10 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient -from api.app import app +from api.app import create_app from api.dependencies import get_settings from config.settings import Settings +app = create_app() + @pytest.fixture def client(): @@ -103,6 +105,17 @@ def test_create_message_empty_messages_returns_400(client): assert "cannot be empty" in data.get("error", {}).get("message", "") +def test_count_tokens_empty_messages_returns_400(client): + """POST /v1/messages/count_tokens with messages: [] matches messages validation.""" + payload = {"model": "claude-3-sonnet", "messages": []} + response = client.post("/v1/messages/count_tokens", json=payload) + assert response.status_code == 400 + data = response.json() + assert data.get("type") == "error" + assert data.get("error", {}).get("type") == "invalid_request_error" + assert "cannot be empty" in data.get("error", {}).get("message", "") + + def test_count_tokens_endpoint(client): payload = { "model": "claude-3-sonnet", diff --git a/tests/api/test_runtime_safe_logging.py b/tests/api/test_runtime_safe_logging.py new file mode 100644 index 0000000..5c86c15 --- /dev/null +++ b/tests/api/test_runtime_safe_logging.py @@ -0,0 +1,70 @@ +"""Tests for safe default logging in :mod:`api.runtime`.""" + +import importlib +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from tests.api.test_app_lifespan_and_errors import _app_settings + + +@pytest.mark.asyncio +async def test_messaging_start_failure_default_logs_exclude_traceback(caplog): + api_runtime_mod = importlib.import_module("api.runtime") + settings = _app_settings( + messaging_platform="telegram", + telegram_bot_token="t", + allowed_telegram_user_id="1", + discord_bot_token=None, + allowed_discord_channels=None, + allowed_dir="", + claude_workspace="./agent_workspace", + host="127.0.0.1", + port=8082, + log_file="server.log", + log_api_error_tracebacks=False, + ) + runtime = api_runtime_mod.AppRuntime(app=MagicMock(), settings=settings) + + with ( + patch( + "messaging.platforms.factory.create_messaging_platform", + side_effect=RuntimeError("SECRET_RUNTIME_DETAIL"), + ), + caplog.at_level(logging.ERROR), + ): + await runtime._start_messaging_if_configured() + + blob = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_RUNTIME_DETAIL" not in blob + assert "exc_type=RuntimeError" in blob + + +@pytest.mark.asyncio +async def test_best_effort_default_logs_exclude_exception_text(caplog): + api_runtime_mod = importlib.import_module("api.runtime") + + async def boom(): + raise ValueError("SECRET_SHUTDOWN") + + with caplog.at_level(logging.WARNING): + await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=False) + + blob = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_SHUTDOWN" not in blob + assert "exc_type=ValueError" in blob + + +@pytest.mark.asyncio +async def test_best_effort_verbose_includes_exception_text(caplog): + api_runtime_mod = importlib.import_module("api.runtime") + + async def boom(): + raise ValueError("VISIBLE_SHUTDOWN") + + with caplog.at_level(logging.WARNING): + await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=True) + + blob = " | ".join(r.getMessage() for r in caplog.records) + assert "VISIBLE_SHUTDOWN" in blob diff --git a/tests/api/test_safe_logging.py b/tests/api/test_safe_logging.py new file mode 100644 index 0000000..1096516 --- /dev/null +++ b/tests/api/test_safe_logging.py @@ -0,0 +1,208 @@ +"""Tests that API and SSE logging avoid raw sensitive payloads by default.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException + +from api import services as services_mod +from api.models.anthropic import Message, MessagesRequest +from api.services import ClaudeProxyService +from config.settings import Settings +from core.anthropic.sse import SSEBuilder + + +def test_create_message_skips_full_payload_debug_log_by_default(): + settings = Settings() + assert settings.log_raw_api_payloads is False + mock_provider = MagicMock() + + async def fake_stream(*_a, **_kw): + yield "event: ping\ndata: {}\n\n" + + mock_provider.stream_response = fake_stream + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="secret-user-text")], + ) + + with patch.object(services_mod.logger, "debug") as mock_debug: + service.create_message(request) + + full_payload_calls = [ + c + for c in mock_debug.call_args_list + if c.args and str(c.args[0]) == "FULL_PAYLOAD [{}]: {}" + ] + assert not full_payload_calls + + +def test_create_message_logs_full_payload_when_opt_in(): + settings = Settings() + settings.log_raw_api_payloads = True + mock_provider = MagicMock() + + async def fake_stream(*_a, **_kw): + yield "event: ping\ndata: {}\n\n" + + mock_provider.stream_response = fake_stream + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="visible")], + ) + + with patch.object(services_mod.logger, "debug") as mock_debug: + service.create_message(request) + + keys = [c.args[0] for c in mock_debug.call_args_list if c.args] + assert any(k == "FULL_PAYLOAD [{}]: {}" for k in keys) + + +def test_sse_builder_default_debug_has_no_serialized_json_content(): + with patch("core.anthropic.sse.logger.debug") as mock_debug: + sse = SSEBuilder("msg_x", "m", 1, log_raw_events=False) + sse.message_start() + + assert mock_debug.call_count == 1 + message = str(mock_debug.call_args) + assert "serialized_bytes=" in message + assert "role" not in message + assert "assistant" not in message + + +def test_sse_builder_raw_logging_includes_event_body_when_enabled(): + with patch("core.anthropic.sse.logger.debug") as mock_debug: + sse = SSEBuilder("msg_x", "m", 1, log_raw_events=True) + sse.message_start() + + assert mock_debug.call_count == 1 + message = str(mock_debug.call_args) + assert "message_start" in message + assert "role" in message + + +def _flatten_log_calls(mock_log) -> str: + parts: list[str] = [] + for call in mock_log.call_args_list: + parts.extend(str(arg) for arg in call.args) + parts.append(repr(call.kwargs)) + return " ".join(parts) + + +def test_create_message_unexpected_error_default_logs_exclude_exception_text(): + settings = Settings() + assert settings.log_api_error_tracebacks is False + secret = "upstream-secret-token-abc" + + mock_provider = MagicMock() + + def stream_boom(*_a, **_kw): + raise RuntimeError(secret) + + mock_provider.stream_response = stream_boom + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="hi")], + ) + + with ( + patch.object(services_mod.logger, "error") as log_err, + pytest.raises(HTTPException), + ): + service.create_message(request) + + blob = _flatten_log_calls(log_err) + assert secret not in blob + assert "RuntimeError" in blob + + +def test_create_message_unexpected_error_always_returns_500(): + """Non-provider failures must not leak arbitrary status_code attributes.""" + + class WeirdError(Exception): + status_code = 418 + + settings = Settings() + mock_provider = MagicMock() + + def stream_boom(*_a, **_kw): + raise WeirdError("no") + + mock_provider.stream_response = stream_boom + service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider) + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=10, + messages=[Message(role="user", content="hi")], + ) + + with pytest.raises(HTTPException) as excinfo: + service.create_message(request) + + assert excinfo.value.status_code == 500 + + +def test_parse_cli_event_error_logs_metadata_by_default(): + """CLI parser must not log raw error text unless LOG_RAW_CLI_DIAGNOSTICS is on.""" + from messaging.event_parser import parse_cli_event + + secret = "user-secret-parser-leak-xyz" + with patch("messaging.event_parser.logger.info") as log_info: + parse_cli_event( + {"type": "error", "error": {"message": secret}}, log_raw_cli=False + ) + flat = " ".join(str(c) for c in log_info.call_args_list) + assert secret not in flat + assert "message_chars" in flat + + +def test_parse_cli_event_error_logs_text_when_log_raw_cli_enabled(): + from messaging.event_parser import parse_cli_event + + secret = "visible-cli-parser-msg" + with patch("messaging.event_parser.logger.info") as log_info: + parse_cli_event( + {"type": "error", "error": {"message": secret}}, log_raw_cli=True + ) + flat = " ".join(str(c) for c in log_info.call_args_list) + assert secret in flat + + +def test_count_tokens_unexpected_error_default_logs_exclude_exception_text(): + settings = Settings() + assert settings.log_api_error_tracebacks is False + secret = "count-tokens-leak-xyz" + + def boom(*_a, **_kw): + raise ValueError(secret) + + service = ClaudeProxyService( + settings, + provider_getter=lambda _: MagicMock(), + token_counter=boom, + ) + from api.models.anthropic import TokenCountRequest + + req = TokenCountRequest( + model="claude-3-haiku-20240307", + messages=[Message(role="user", content="x")], + ) + + with ( + patch.object(services_mod.logger, "error") as log_err, + pytest.raises(HTTPException), + ): + service.count_tokens(req) + + blob = _flatten_log_calls(log_err) + assert secret not in blob + assert "ValueError" in blob diff --git a/tests/api/test_validation_log.py b/tests/api/test_validation_log.py new file mode 100644 index 0000000..9408745 --- /dev/null +++ b/tests/api/test_validation_log.py @@ -0,0 +1,33 @@ +"""Tests for validation log summaries (metadata only).""" + +from api.validation_log import summarize_request_validation_body + + +def test_summarize_lists_block_metadata_without_echoing_string_content(): + body = { + "messages": [ + { + "role": "user", + "content": "secret user phrase", + } + ], + "tools": [{"name": "web_search", "type": "web_search_20250305"}], + } + summary, tool_names = summarize_request_validation_body(body) + assert summary == [ + { + "role": "user", + "content_kind": "str", + "content_length": 18, + } + ] + assert tool_names == ["web_search"] + blob = repr(summary) + repr(tool_names) + assert "secret" not in blob + + +def test_summarize_handles_non_dict_messages_and_missing_tools(): + body = {"messages": ["not_a_dict"]} + summary, tool_names = summarize_request_validation_body(body) + assert summary == [{"message_kind": "str"}] + assert tool_names == [] diff --git a/tests/api/test_web_server_tools.py b/tests/api/test_web_server_tools.py index c00d538..c46ddf2 100644 --- a/tests/api/test_web_server_tools.py +++ b/tests/api/test_web_server_tools.py @@ -1,20 +1,65 @@ -import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest +import api.web_tools.constants as web_tool_constants +from api.model_router import ModelRouter, ResolvedModel, RoutedMessagesRequest from api.models.anthropic import Message, MessagesRequest, Tool -from api.web_server_tools import ( - is_web_server_tool_request, - stream_web_server_tool_response, +from api.services import ClaudeProxyService +from api.web_tools import egress as web_egress +from api.web_tools.egress import ( + WebFetchEgressPolicy, + WebFetchEgressViolation, + enforce_web_fetch_egress, +) +from api.web_tools.outbound import ( + _drain_response_body_capped, + _read_response_body_capped, + _run_web_fetch, +) +from api.web_tools.request import is_web_server_tool_request +from api.web_tools.streaming import stream_web_server_tool_response +from config.settings import Settings +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, + text_content, +) +from messaging.event_parser import parse_cli_event +from providers.exceptions import InvalidRequestError + +_STRICT_EGRESS = WebFetchEgressPolicy( + allow_private_network_targets=False, + allowed_schemes=frozenset({"http", "https"}), ) -def _event_data(event: str) -> dict: - data_line = next(line for line in event.splitlines() if line.startswith("data: ")) - return json.loads(data_line.removeprefix("data: ")) +class FixedProviderModelRouter(ModelRouter): + """Test double: pin ``provider_id`` for OpenAI vs native routing assertions.""" + + def __init__(self, settings: Settings, provider_id: str) -> None: + super().__init__(settings) + self._fixed_provider_id = provider_id + + def resolve_messages_request( + self, request: MessagesRequest + ) -> RoutedMessagesRequest: + resolved = ResolvedModel( + original_model=request.model, + provider_id=self._fixed_provider_id, + provider_model=request.model, + provider_model_ref=f"{self._fixed_provider_id}/{request.model}", + thinking_enabled=False, + ) + routed = request.model_copy(deep=True) + routed.model = resolved.provider_model + return RoutedMessagesRequest(request=routed, resolved=resolved) -def test_detects_web_search_server_tool_request(): +def test_web_server_tool_not_detected_when_tool_only_listed(): + """Listing web_search without forcing it must not skip the upstream provider.""" request = MessagesRequest( model="claude-haiku-4-5-20251001", max_tokens=100, @@ -22,16 +67,243 @@ def test_detects_web_search_server_tool_request(): tools=[Tool(name="web_search", type="web_search_20250305")], ) + assert not is_web_server_tool_request(request) + + +def test_web_server_tool_detected_when_tool_choice_forces_it(): + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="search")], + tools=[Tool(name="web_search", type="web_search_20250305")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + assert is_web_server_tool_request(request) +def test_web_server_tool_not_detected_when_forced_name_missing_from_tools(): + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="hi")], + tools=[Tool(name="other", type="function")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + + assert not is_web_server_tool_request(request) + + +def test_service_rejects_forced_server_tool_on_openai_when_disabled(): + """OpenAI Chat upstreams cannot run forced server tools without the local handler.""" + settings = Settings() + assert settings.enable_web_server_tools is False + service = ClaudeProxyService( + settings, + provider_getter=lambda _: MagicMock(), + model_router=FixedProviderModelRouter(settings, "nvidia_nim"), + ) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content="Perform a web search for the query: DeepSeek V4 model release 2026", + ) + ], + tools=[Tool(name="web_search", type="web_search_20250305")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + with pytest.raises(InvalidRequestError, match="ENABLE_WEB_SERVER_TOOLS"): + service.create_message(request) + + +@pytest.mark.parametrize( + "url", + [ + "http://127.0.0.1/", + "http://192.168.1.1/", + "http://10.0.0.1/", + "http://[::1]/", + "http://localhost/foo", + "http://mybox.local/", + "file:///etc/passwd", + "http://169.254.169.254/latest/meta-data/", + ], +) +def test_enforce_web_fetch_egress_blocks_internal_or_disallowed(url: str): + with pytest.raises(WebFetchEgressViolation): + enforce_web_fetch_egress(url, _STRICT_EGRESS) + + +def test_enforce_web_fetch_egress_allows_global_literal_ip(): + enforce_web_fetch_egress("http://8.8.8.8/", _STRICT_EGRESS) + + +def test_enforce_web_fetch_egress_skips_private_checks_when_opted_in(): + enforce_web_fetch_egress( + "http://127.0.0.1/", + WebFetchEgressPolicy( + allow_private_network_targets=True, + allowed_schemes=frozenset({"http", "https"}), + ), + ) + + +def _cm(mock_client: MagicMock) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +def _stream_cm(response: httpx.Response) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=response) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +def _aiohttp_response( + status: int, + *, + url: str = "http://8.8.8.8/", + location: str | None = None, + body: bytes = b"hello world", +) -> MagicMock: + r = MagicMock() + r.status = status + r.url = url + hdrs: dict[str, str] = {} + if location is not None: + hdrs["location"] = location + r.headers = hdrs + r.get_encoding = MagicMock(return_value="utf-8") + r.raise_for_status = MagicMock() + r.request_info = MagicMock() + r.history = () + + async def iter_chunked(_n: int) -> Any: + yield body + + r.content.iter_chunked = MagicMock(side_effect=iter_chunked) + return r + + +def _aiohttp_client_session_patch( + *responses: MagicMock, +) -> tuple[MagicMock, MagicMock]: + """Build ``ClientSession`` mock that serves ``responses`` to successive ``get`` calls.""" + queue = list(responses) + n = 0 + + def get_side(*_a: Any, **_k: Any) -> Any: + nonlocal n + resp = queue[n] if n < len(queue) else queue[-1] + n += 1 + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=resp) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + session = MagicMock() + session.get = MagicMock(side_effect=get_side) + + client_cm = MagicMock() + client_cm.__aenter__ = AsyncMock(return_value=session) + client_cm.__aexit__ = AsyncMock(return_value=None) + return client_cm, session + + +def test_enforce_web_fetch_egress_documents_connect_time_pinning(): + assert enforce_web_fetch_egress.__doc__ and "resolved addresses" in ( + enforce_web_fetch_egress.__doc__ or "" + ) + assert ( + web_egress.get_validated_stream_addrinfos_for_egress.__doc__ + and "pinning" + in (web_egress.get_validated_stream_addrinfos_for_egress.__doc__ or "") + ) + assert "DNS-pinned" in (_run_web_fetch.__doc__ or "") + + +@pytest.mark.asyncio +async def test_run_web_fetch_follows_redirect_when_each_hop_is_allowed(): + res_redirect = _aiohttp_response( + 302, url="http://8.8.8.8/start", location="/final", body=b"" + ) + res_ok = _aiohttp_response(200, url="http://8.8.8.8/final", body=b"hello world") + client_cm, session = _aiohttp_client_session_patch(res_redirect, res_ok) + with patch("api.web_tools.outbound.ClientSession", return_value=client_cm): + out = await _run_web_fetch("http://8.8.8.8/start", _STRICT_EGRESS) + + assert out["data"] == "hello world" + assert session.get.call_count == 2 + + +@pytest.mark.asyncio +async def test_run_web_fetch_truncates_large_body_to_byte_cap(monkeypatch): + huge = b"x" * 5000 + res_ok = _aiohttp_response(200, url="http://8.8.8.8/big", body=huge) + client_cm, _ = _aiohttp_client_session_patch(res_ok) + monkeypatch.setattr(web_tool_constants, "_MAX_WEB_FETCH_RESPONSE_BYTES", 100) + with patch("api.web_tools.outbound.ClientSession", return_value=client_cm): + out = await _run_web_fetch("http://8.8.8.8/big", _STRICT_EGRESS) + + assert len(out["data"]) <= 100 + assert out["data"] == "x" * 100 + + +@pytest.mark.asyncio +async def test_run_web_fetch_redirect_to_blocked_host_raises(): + res_redirect = _aiohttp_response( + 302, + url="http://8.8.8.8/start", + location="http://127.0.0.1/secret", + body=b"", + ) + client_cm, session = _aiohttp_client_session_patch(res_redirect) + with ( + patch("api.web_tools.outbound.ClientSession", return_value=client_cm), + pytest.raises(WebFetchEgressViolation), + ): + await _run_web_fetch("http://8.8.8.8/start", _STRICT_EGRESS) + + session.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_web_fetch_redirect_without_location_raises(): + res_bad = _aiohttp_response(302, url="http://8.8.8.8/here", body=b"") + client_cm, _ = _aiohttp_client_session_patch(res_bad) + with ( + patch("api.web_tools.outbound.ClientSession", return_value=client_cm), + pytest.raises(WebFetchEgressViolation, match="missing Location"), + ): + await _run_web_fetch("http://8.8.8.8/here", _STRICT_EGRESS) + + +@pytest.mark.asyncio +async def test_run_web_fetch_excess_redirects_raises(): + res1 = _aiohttp_response(302, url="http://8.8.8.8/a", location="/b", body=b"") + res2 = _aiohttp_response(302, url="http://8.8.8.8/b", location="/c", body=b"") + client_cm, _ = _aiohttp_client_session_patch(res1, res2) + with ( + patch("api.web_tools.constants._MAX_WEB_FETCH_REDIRECTS", 1), + patch("api.web_tools.outbound.ClientSession", return_value=client_cm), + pytest.raises(WebFetchEgressViolation, match="exceeded maximum redirects"), + ): + await _run_web_fetch("http://8.8.8.8/a", _STRICT_EGRESS) + + @pytest.mark.asyncio async def test_streams_web_search_server_tool_result(monkeypatch): async def fake_search(query: str) -> list[dict[str, str]]: assert query == "DeepSeek V4 model release 2026" return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}] - monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search) + monkeypatch.setattr("api.web_tools.outbound._run_web_search", fake_search) request = MessagesRequest( model="claude-haiku-4-5-20251001", max_tokens=100, @@ -47,22 +319,92 @@ async def test_streams_web_search_server_tool_result(monkeypatch): tool_choice={"type": "tool", "name": "web_search"}, ) - events = [ - event - async for event in stream_web_server_tool_response(request, input_tokens=42) + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, input_tokens=42, web_fetch_egress=_STRICT_EGRESS + ) + ] + ) + events = parse_sse_text(raw) + assert_anthropic_stream_contract(events) + starts = [e for e in events if e.event == "content_block_start"] + assert starts[0].data["content_block"]["type"] == "server_tool_use" + assert starts[0].data["content_block"]["name"] == "web_search" + tool_use_id = starts[0].data["content_block"]["id"] + assert starts[1].data["content_block"]["type"] == "web_search_tool_result" + assert starts[1].data["content_block"]["tool_use_id"] == tool_use_id + assert starts[1].data["content_block"]["content"][0]["url"] == ( + "https://example.com/v4" + ) + text_deltas = [ + e + for e in events + if e.event == "content_block_delta" + and e.data.get("delta", {}).get("type") == "text_delta" ] - payloads = [_event_data(event) for event in events] + assert text_deltas, "summary must be streamed as text_delta" + assert "example.com" in text_content(events) + cli_text: list[str] = [] + for ev in events: + cli_text.extend( + str(p.get("text", "")) + for p in parse_cli_event(ev.data) + if p.get("type") == "text_delta" + ) + assert "example.com" in "".join(cli_text) + deltas = [e for e in events if e.event == "message_delta"] + assert deltas[-1].data["usage"]["server_tool_use"] == {"web_search_requests": 1} - assert payloads[1]["content_block"]["type"] == "server_tool_use" - assert payloads[1]["content_block"]["name"] == "web_search" - assert payloads[3]["content_block"]["type"] == "web_search_tool_result" - assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4" - assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1} + +@pytest.mark.asyncio +async def test_forced_web_fetch_ignores_stale_url_from_prior_user_turns(monkeypatch): + """Only the latest user message supplies the URL (not earlier transcript text).""" + target = "https://new-only.example.com/page" + + async def fake_fetch(url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: + assert url == target + return { + "url": url, + "title": "T", + "media_type": "text/plain", + "data": "x", + } + + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", fake_fetch) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content="Earlier turn https://stale.com/old-article ignore this", + ), + Message(role="assistant", content="ok"), + Message( + role="user", + content=f"Please fetch {target} for the summary", + ), + ], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, input_tokens=1, web_fetch_egress=_STRICT_EGRESS + ) + ] + ) + assert target in raw @pytest.mark.asyncio async def test_streams_web_fetch_server_tool_result(monkeypatch): - async def fake_fetch(url: str) -> dict[str, str]: + async def fake_fetch(url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: assert url == "https://example.com/article" return { "url": url, @@ -71,7 +413,7 @@ async def test_streams_web_fetch_server_tool_result(monkeypatch): "data": "Article body", } - monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch) + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", fake_fetch) request = MessagesRequest( model="claude-haiku-4-5-20251001", max_tokens=100, @@ -82,15 +424,198 @@ async def test_streams_web_fetch_server_tool_result(monkeypatch): tool_choice={"type": "tool", "name": "web_fetch"}, ) - events = [ - event - async for event in stream_web_server_tool_response(request, input_tokens=42) - ] - payloads = [_event_data(event) for event in events] - - assert payloads[1]["content_block"]["type"] == "server_tool_use" - assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result" - assert payloads[3]["content_block"]["content"]["content"]["title"] == ( + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, input_tokens=42, web_fetch_egress=_STRICT_EGRESS + ) + ] + ) + events = parse_sse_text(raw) + assert_anthropic_stream_contract(events) + starts = [e for e in events if e.event == "content_block_start"] + assert starts[0].data["content_block"]["type"] == "server_tool_use" + tool_use_id = starts[0].data["content_block"]["id"] + assert starts[1].data["content_block"]["type"] == "web_fetch_tool_result" + assert starts[1].data["content_block"]["tool_use_id"] == tool_use_id + assert starts[1].data["content_block"]["content"]["content"]["title"] == ( "Example Article" ) - assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1} + assert any( + e.event == "content_block_delta" + and e.data.get("delta", {}).get("type") == "text_delta" + for e in events + ) + assert "Article body" in text_content(events) + cli_text: list[str] = [] + for ev in events: + cli_text.extend( + str(p.get("text", "")) + for p in parse_cli_event(ev.data) + if p.get("type") == "text_delta" + ) + assert "Article body" in "".join(cli_text) + deltas = [e for e in events if e.event == "message_delta"] + assert deltas[-1].data["usage"]["server_tool_use"] == {"web_fetch_requests": 1} + + +@pytest.mark.asyncio +async def test_streams_web_fetch_error_summary_generic_by_default(monkeypatch): + secret = "sensitive-upstream-token" + + async def boom(_url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: + raise ValueError(secret) + + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", boom) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content="Fetch https://example.com/sensitive-path?x=1 please", + ) + ], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + with patch("api.web_tools.outbound.logger.warning") as log_warn: + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, + input_tokens=1, + web_fetch_egress=_STRICT_EGRESS, + verbose_client_errors=False, + ) + ] + ) + + assert secret not in raw + assert "ValueError" not in raw + assert "Web tool request failed." in raw + err_events = parse_sse_text(raw) + assert_anthropic_stream_contract(err_events) + assert any( + e.event == "content_block_delta" + and e.data.get("delta", {}).get("type") == "text_delta" + for e in err_events + ) + cli_err_text: list[str] = [] + for ev in err_events: + cli_err_text.extend( + str(p.get("text", "")) + for p in parse_cli_event(ev.data) + if p.get("type") == "text_delta" + ) + assert "Web tool request failed." in "".join(cli_err_text) + log_blob = " ".join(str(a) for c in log_warn.call_args_list for a in c.args) + assert secret not in log_blob + assert "example.com" in log_blob + assert "/sensitive-path" not in log_blob + + +@pytest.mark.asyncio +async def test_streams_web_fetch_error_summary_verbose_includes_exception_class( + monkeypatch, +): + async def boom(_url: str, _egress: WebFetchEgressPolicy) -> dict[str, str]: + raise OSError(5, "oops") + + monkeypatch.setattr("api.web_tools.outbound._run_web_fetch", boom) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="Fetch https://example.com/x")], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + raw = "".join( + [ + event + async for event in stream_web_server_tool_response( + request, + input_tokens=1, + web_fetch_egress=_STRICT_EGRESS, + verbose_client_errors=True, + ) + ] + ) + assert "OSError" in raw + + +@pytest.mark.asyncio +async def test_read_response_body_capped_truncates_single_oversized_chunk(): + cap = 500 + + async def aiter_bytes(chunk_size=None): + yield b"z" * (cap * 20) + + response = MagicMock() + response.aiter_bytes = aiter_bytes + + out = await _read_response_body_capped(response, cap) + assert len(out) == cap + assert out == b"z" * cap + + +@pytest.mark.asyncio +async def test_drain_response_body_capped_stops_after_first_chunk_when_oversized(): + cap = 300 + chunk_calls = {"n": 0} + + async def aiter_bytes(chunk_size=None): + chunk_calls["n"] += 1 + yield b"y" * (cap * 10) + + response = MagicMock() + response.aiter_bytes = aiter_bytes + + await _drain_response_body_capped(response, cap) + assert chunk_calls["n"] == 1 + + +def test_service_rejects_listed_server_tools_on_openai_chat() -> None: + settings = Settings() + service = ClaudeProxyService( + settings, + provider_getter=lambda _: MagicMock(), + model_router=FixedProviderModelRouter(settings, "deepseek"), + ) + request = MessagesRequest( + model="m", + max_tokens=20, + messages=[Message(role="user", content="q")], + tools=[Tool(name="web_search", type="web_search_20250305")], + ) + with pytest.raises(InvalidRequestError, match="OpenAI Chat upstreams"): + service.create_message(request) + + +def test_listed_server_tools_routed_on_open_router() -> None: + """Native Anthropic transport may receive listed server tool definitions.""" + settings = Settings() + + async def fake_stream(*_a, **_k): + yield 'event: message_start\ndata: {"type":"message_start"}\n\n' + yield 'event: message_stop\ndata: {"type":"message_stop"}\n\n' + + mock_provider = MagicMock() + mock_provider.stream_response = fake_stream + service = ClaudeProxyService( + settings, + provider_getter=lambda _: mock_provider, + model_router=FixedProviderModelRouter(settings, "open_router"), + ) + request = MessagesRequest( + model="m", + max_tokens=20, + messages=[Message(role="user", content="q")], + tools=[Tool(name="web_search", type="web_search_20250305")], + ) + service.create_message(request) + mock_provider.preflight_stream.assert_called() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index a22f9dd..590f51c 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -3,6 +3,7 @@ import asyncio import json import os +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -300,7 +301,7 @@ class TestCLISession: mock_process = AsyncMock() mock_process.stdout.read.side_effect = [b""] # No stdout - mock_process.stderr.read.return_value = b"Fatal error" + mock_process.stderr.read.side_effect = [b"Fatal error", b""] mock_process.wait.return_value = 1 with patch( @@ -319,6 +320,62 @@ class TestCLISession: assert events[1]["code"] == 1 assert events[1]["stderr"] == "Fatal error" + @pytest.mark.asyncio + async def test_start_task_stderr_while_stdout_streams(self): + """Stderr is drained concurrently so stdout streaming is not blocked.""" + from cli.session import CLISession + + session = CLISession("/tmp", "http://localhost:8082/v1") + + mock_process = AsyncMock() + mock_process.stdout.read.side_effect = [ + b'{"type": "message", "content": "Hi"}\n', + b"", + ] + mock_process.stderr.read.side_effect = [b"warning on stderr\n", b""] + mock_process.wait.return_value = 0 + + with patch( + "asyncio.create_subprocess_exec", new_callable=AsyncMock + ) as mock_exec: + mock_exec.return_value = mock_process + + events = [e async for e in session.start_task("Hello")] + + assert mock_process.stderr.read.await_count >= 2 + err_events = [e for e in events if e.get("type") == "error"] + assert len(err_events) == 1 + assert "warning on stderr" in err_events[0]["error"]["message"] + assert events[-1]["type"] == "exit" + assert events[-1]["code"] == 0 + + @pytest.mark.asyncio + async def test_drain_stderr_bounded_retains_cap_but_drains_to_eof(self): + """Oversized stderr is fully drained so the pipe cannot deadlock; capture is bounded.""" + from cli.session import _MAX_STDERR_CAPTURE_BYTES, CLISession + + total_len = _MAX_STDERR_CAPTURE_BYTES + 100_000 + remaining: dict[str, int] = {"n": total_len} + + class _FakeStderr: + async def read(self, n: int = 65536) -> bytes: + left = remaining["n"] + if left <= 0: + return b"" + take = min(n, left) + remaining["n"] = left - take + return b"y" * take + + class _FakeProcess: + stderr = _FakeStderr() + + out = await CLISession._drain_stderr_bounded( + cast(asyncio.subprocess.Process, _FakeProcess()) + ) + assert len(out) == _MAX_STDERR_CAPTURE_BYTES + assert out == b"y" * _MAX_STDERR_CAPTURE_BYTES + assert remaining["n"] == 0 + @pytest.mark.asyncio async def test_stop_session(self): """Test stopping the session process.""" diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 583dc75..e62f583 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -3,6 +3,10 @@ import pytest from pydantic import ValidationError +from config.constants import ( + ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS, + HTTP_CONNECT_TIMEOUT_DEFAULT, +) from config.nim import NimSettings @@ -22,6 +26,7 @@ class TestSettings: monkeypatch.delenv("MODEL", raising=False) monkeypatch.delenv("HTTP_READ_TIMEOUT", raising=False) + monkeypatch.delenv("HTTP_CONNECT_TIMEOUT", raising=False) monkeypatch.setitem(Settings.model_config, "env_file", ()) settings = Settings() assert settings.model == "nvidia_nim/z-ai/glm4.7" @@ -31,6 +36,12 @@ class TestSettings: assert isinstance(settings.fast_prefix_detection, bool) assert isinstance(settings.enable_model_thinking, bool) assert settings.http_read_timeout == 120.0 + assert settings.http_connect_timeout == HTTP_CONNECT_TIMEOUT_DEFAULT + assert settings.enable_web_server_tools is False + assert settings.log_raw_api_payloads is False + assert settings.log_raw_sse_events is False + assert settings.debug_platform_edits is False + assert settings.debug_subagent_stack is False def test_get_settings_cached(self): """Test get_settings returns cached instance.""" @@ -57,10 +68,10 @@ class TestSettings: assert len(settings.model) > 0 def test_base_url_constant(self): - """Test NVIDIA_NIM_BASE_URL is a constant.""" - from providers.nvidia_nim import NVIDIA_NIM_BASE_URL + """Test NVIDIA_NIM_DEFAULT_BASE is a constant.""" + from providers.nvidia_nim import NVIDIA_NIM_DEFAULT_BASE - assert NVIDIA_NIM_BASE_URL == "https://integrate.api.nvidia.com/v1" + assert NVIDIA_NIM_DEFAULT_BASE == "https://integrate.api.nvidia.com/v1" def test_lm_studio_base_url_from_env(self, monkeypatch): """LM_STUDIO_BASE_URL env var is loaded into settings.""" @@ -127,6 +138,18 @@ class TestSettings: settings = Settings() assert settings.http_connect_timeout == 5.0 + def test_http_connect_timeout_default_matches_shared_constant( + self, monkeypatch + ) -> None: + """Default must match config.constants (and README / .env.example).""" + from config.settings import Settings + + monkeypatch.delenv("HTTP_CONNECT_TIMEOUT", raising=False) + monkeypatch.setitem(Settings.model_config, "env_file", ()) + settings = Settings() + assert settings.http_connect_timeout == HTTP_CONNECT_TIMEOUT_DEFAULT + assert HTTP_CONNECT_TIMEOUT_DEFAULT == 10.0 + def test_enable_model_thinking_from_env(self, monkeypatch): """ENABLE_MODEL_THINKING env var is loaded into settings.""" from config.settings import Settings @@ -319,6 +342,9 @@ class TestNimSettingsInvalidBounds: class TestNimSettingsValidators: """Test custom field validators in NimSettings.""" + def test_default_max_tokens_matches_shared_constant(self): + assert NimSettings().max_tokens == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS + @pytest.mark.parametrize( "seed_val,expected", [("", None), (None, None), ("42", 42), (42, 42)], diff --git a/tests/config/test_logging_config.py b/tests/config/test_logging_config.py index 11a818c..9f54df0 100644 --- a/tests/config/test_logging_config.py +++ b/tests/config/test_logging_config.py @@ -4,6 +4,8 @@ import json import logging from pathlib import Path +from loguru import logger + from config.logging_config import configure_logging @@ -57,3 +59,38 @@ def test_configure_logging_skips_when_already_configured(tmp_path): assert "Still goes to first file" in (tmp_path / "test.log").read_text( encoding="utf-8" ) + + +def test_telegram_bot_token_redacted_in_message_field(tmp_path) -> None: + log_file = str(tmp_path / "redact.log") + configure_logging(log_file, force=True, verbose_third_party=False) + token = "123456:ABCDEF-ghij-klm" + logger.info("Calling {}", f"https://api.telegram.org/bot{token}/getMe") + logger.complete() + text = Path(log_file).read_text(encoding="utf-8") + assert token not in text + assert "bot/" in text or "redacted" in text + + +def test_bearer_substring_redacted_in_log_file(tmp_path) -> None: + log_file = str(tmp_path / "bearer.log") + configure_logging(log_file, force=True, verbose_third_party=False) + secret = "ya29.secret-token-abc" + logger.info("Request headers: Authorization: Bearer {}", secret) + logger.complete() + text = Path(log_file).read_text(encoding="utf-8") + assert secret not in text + assert "Bearer" in text + + +def test_httpx_logger_quieted_when_not_verbose_third_party(tmp_path) -> None: + log_file = str(tmp_path / "quiet.log") + configure_logging(log_file, force=True, verbose_third_party=False) + assert logging.getLogger("httpx").level >= logging.WARNING + assert logging.getLogger("httpcore").level >= logging.WARNING + + +def test_httpx_resets_to_notset_when_verbose_third_party(tmp_path) -> None: + log_file = str(tmp_path / "verbose.log") + configure_logging(log_file, force=True, verbose_third_party=True) + assert logging.getLogger("httpx").level == logging.NOTSET diff --git a/tests/contracts/test_architecture_contracts.py b/tests/contracts/test_architecture_contracts.py index 344122e..25be850 100644 --- a/tests/contracts/test_architecture_contracts.py +++ b/tests/contracts/test_architecture_contracts.py @@ -13,6 +13,25 @@ def test_architecture_plan_exists() -> None: text = plan.read_text(encoding="utf-8") assert "Intended Dependency Direction" in text assert "Smoke Coverage Policy" in text + assert "providers.nvidia_nim.voice" in text + assert "no dedicated smoke SSE shim" in text + + +def test_smoke_lib_has_no_sse_shim_module() -> None: + repo_root = Path(__file__).resolve().parents[2] + assert not (repo_root / "smoke" / "lib" / "sse.py").exists() + + +def test_api_package_exports_match_plan() -> None: + import api + + assert set(api.__all__) == { + "MessagesRequest", + "MessagesResponse", + "TokenCountRequest", + "TokenCountResponse", + "create_app", + } def test_root_env_example_is_the_single_template_source() -> None: diff --git a/tests/contracts/test_import_boundaries.py b/tests/contracts/test_import_boundaries.py index 918d7d1..3bde11d 100644 --- a/tests/contracts/test_import_boundaries.py +++ b/tests/contracts/test_import_boundaries.py @@ -54,6 +54,15 @@ def test_core_does_not_import_product_packages() -> None: assert offenders == [] +def test_provider_catalog_is_single_source_for_supported_ids() -> None: + from config.provider_catalog import PROVIDER_CATALOG, SUPPORTED_PROVIDER_IDS + from providers.registry import PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES + + assert tuple(PROVIDER_CATALOG.keys()) == SUPPORTED_PROVIDER_IDS + assert PROVIDER_DESCRIPTORS is PROVIDER_CATALOG + assert set(SUPPORTED_PROVIDER_IDS) == set(PROVIDER_FACTORIES) + + def test_config_does_not_import_non_config_packages() -> None: """Settings and env handling must not depend on transport or protocol layers.""" repo_root = Path(__file__).resolve().parents[2] @@ -71,14 +80,34 @@ def test_config_does_not_import_non_config_packages() -> None: assert offenders == [] -def test_messaging_does_not_import_api_or_cli_or_providers() -> None: - """Messaging is wired by ``api.runtime``; must not import server or provider adapters.""" +_MESSAGING_ALLOWED_PROVIDER_MODULES = frozenset({"providers.nvidia_nim.voice"}) + + +def test_messaging_does_not_import_disallowed_modules() -> None: + """Messaging is wired by ``api.runtime``; narrow provider imports only for NIM voice ASR.""" repo_root = Path(__file__).resolve().parents[2] - offenders = _imports_matching( - [repo_root / "messaging"], - forbidden_prefixes=("api.", "cli.", "providers.", "smoke."), - ) - assert offenders == [] + offenders: list[str] = [] + for path in (repo_root / "messaging").rglob("*.py"): + for imported in _imports_from(path, repo_root): + if imported is None: + continue + if ( + imported == "api" + or imported.startswith("api.") + or imported == "cli" + or imported.startswith("cli.") + or imported == "smoke" + or imported.startswith("smoke.") + ): + rel = path.relative_to(repo_root) + offenders.append(f"{rel}: {imported}") + elif imported.startswith("providers."): + if imported in _MESSAGING_ALLOWED_PROVIDER_MODULES: + continue + rel = path.relative_to(repo_root) + offenders.append(f"{rel}: {imported}") + + assert sorted(offenders) == [] def test_api_may_only_import_narrow_provider_facade() -> None: diff --git a/tests/contracts/test_smoke_sse_reexport.py b/tests/contracts/test_smoke_sse_reexport.py deleted file mode 100644 index 4190312..0000000 --- a/tests/contracts/test_smoke_sse_reexport.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`.""" - -from __future__ import annotations - -import core.anthropic.stream_contracts as core_sc -import smoke.lib.sse as smoke_sse - - -def test_smoke_lib_sse_reexports_core_stream_contracts() -> None: - for name in smoke_sse.__all__: - assert getattr(smoke_sse, name) is getattr(core_sc, name) diff --git a/tests/contracts/test_stream_contracts.py b/tests/contracts/test_stream_contracts.py index 0125459..6898cc3 100644 --- a/tests/contracts/test_stream_contracts.py +++ b/tests/contracts/test_stream_contracts.py @@ -8,16 +8,14 @@ from __future__ import annotations from collections.abc import Iterable from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser +from core.anthropic.sse import format_sse_event from core.anthropic.stream_contracts import ( assert_anthropic_stream_contract, event_names, - has_tool_use, parse_sse_text, text_content, thinking_content, ) -from messaging.event_parser import parse_cli_event -from messaging.transcript import RenderCtx, TranscriptBuffer def test_interleaved_thinking_text_blocks_are_valid() -> None: @@ -59,44 +57,49 @@ def test_mixed_reasoning_content_and_think_tags_keep_order() -> None: assert text_content(events) == " visible done" -def test_thinking_tool_text_and_transcript_order_contract() -> None: - builder = SSEBuilder("msg_contract", "contract-model") - chunks = [builder.message_start()] - chunks.extend(builder.ensure_thinking_block()) - chunks.append(builder.emit_thinking_delta("inspect first")) - chunks.extend(builder.close_content_blocks()) - tool_block_index = builder.blocks.allocate_index() - chunks.append( - builder.content_block_start( - tool_block_index, "tool_use", id="toolu_1", name="Read" - ) - ) - chunks.append( - builder.content_block_delta( - tool_block_index, "input_json_delta", '{"file":"README.md"}' - ) - ) - chunks.append(builder.content_block_stop(tool_block_index)) - chunks.extend(builder.ensure_text_block()) - chunks.append(builder.emit_text_delta("done")) - chunks.extend(builder.close_all_blocks()) - chunks.append(builder.message_delta("end_turn", 20)) - chunks.append(builder.message_stop()) - +def test_redacted_thinking_block_start_stop_is_valid() -> None: + """Native redacted_thinking uses start/stop only (no deltas).""" + chunks = [ + format_sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": "msg_r", + "type": "message", + "role": "assistant", + "content": [], + "model": "m", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + }, + ), + format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "redacted_thinking", "data": "opaque"}, + }, + ), + format_sse_event( + "content_block_stop", + {"type": "content_block_stop", "index": 0}, + ), + format_sse_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": None}, + "usage": {"input_tokens": 1, "output_tokens": 2}, + }, + ), + format_sse_event("message_stop", {"type": "message_stop"}), + ] events = parse_sse_text("".join(chunks)) assert_anthropic_stream_contract(events) - assert has_tool_use(events) - - transcript = TranscriptBuffer() - for event in events: - for parsed in parse_cli_event(event.data): - transcript.apply(parsed) - rendered = transcript.render(_render_ctx(), limit_chars=3900, status=None) - assert ( - rendered.find("inspect first") - < rendered.find("Tool call:") - < rendered.find("done") - ) def test_enable_thinking_false_suppresses_reasoning_only() -> None: @@ -186,13 +189,3 @@ def _emit_parser_parts( def _parse_builder_events(chunks: Iterable[str]): return parse_sse_text("".join(chunks)) - - -def _render_ctx() -> RenderCtx: - return RenderCtx( - bold=lambda text: f"*{text}*", - code_inline=lambda text: f"`{text}`", - escape_code=lambda text: text, - escape_text=lambda text: text, - render_markdown=lambda text: text, - ) diff --git a/tests/core/anthropic/test_native_sse_block_policy.py b/tests/core/anthropic/test_native_sse_block_policy.py new file mode 100644 index 0000000..f1b6658 --- /dev/null +++ b/tests/core/anthropic/test_native_sse_block_policy.py @@ -0,0 +1,133 @@ +"""Unit tests for shared native Anthropic SSE thinking policy / block remapping.""" + +from __future__ import annotations + +import json + +from core.anthropic.native_sse_block_policy import ( + NativeSseBlockPolicyState, + format_native_sse_event, + transform_native_sse_block_event, +) + + +def test_thinking_start_dropped_when_disabled() -> None: + st = NativeSseBlockPolicyState() + payload = { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "thinking", "thinking": ""}, + } + ev = format_native_sse_event( + "content_block_start", + json.dumps(payload), + ) + assert transform_native_sse_block_event(ev, st, thinking_enabled=False) is None + + +def test_thinking_delta_dropped_when_disabled() -> None: + st = NativeSseBlockPolicyState() + # No prior start in stream (OpenRouter-style: returns None when thinking off) + payload = { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "thinking_delta", "thinking": "secret"}, + } + ev = format_native_sse_event("content_block_delta", json.dumps(payload)) + assert transform_native_sse_block_event(ev, st, thinking_enabled=False) is None + + +def test_text_block_passthrough_when_thinking_disabled() -> None: + st = NativeSseBlockPolicyState() + payload = { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + } + ev = format_native_sse_event("content_block_start", json.dumps(payload)) + out = transform_native_sse_block_event(ev, st, thinking_enabled=False) + assert out is not None + assert '"index": 0' in (out or "") + + +def test_interleaved_thinking_signature_delta_remaps_to_reopened_block_index() -> None: + """After text interrupts thinking, signature_delta must follow the reopened segment index.""" + st = NativeSseBlockPolicyState() + + def run(ev: str) -> str | None: + return transform_native_sse_block_event(ev, st, thinking_enabled=True) + + out1 = run( + format_native_sse_event( + "content_block_start", + json.dumps( + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "thinking", "thinking": ""}, + } + ), + ) + ) + assert out1 is not None and '"index": 0' in out1 + + out2 = run( + format_native_sse_event( + "content_block_start", + json.dumps( + { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "text", "text": ""}, + } + ), + ) + ) + assert out2 is not None + + out3 = run( + format_native_sse_event( + "content_block_delta", + json.dumps( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "thinking_delta", "thinking": "plan"}, + } + ), + ) + ) + assert out3 is not None + assert "content_block_start" in out3 + + out4 = run( + format_native_sse_event( + "content_block_delta", + json.dumps( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "signature_delta", "signature": "sig"}, + } + ), + ) + ) + assert out4 is not None + assert '"index": 2' in out4 + assert "signature_delta" in out4 + + +def test_startless_text_delta_synthesizes_start_when_thinking_disabled() -> None: + """Startless text deltas must not be dropped when thinking is disabled (OpenRouter).""" + st = NativeSseBlockPolicyState() + payload = { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Hello"}, + } + ev = format_native_sse_event("content_block_delta", json.dumps(payload)) + out = transform_native_sse_block_event(ev, st, thinking_enabled=False) + assert out is not None + assert "content_block_start" in (out or "") + assert "Hello" in (out or "") + assert "text_delta" in (out or "") diff --git a/tests/core/test_strict_sliding_window.py b/tests/core/test_strict_sliding_window.py new file mode 100644 index 0000000..1607521 --- /dev/null +++ b/tests/core/test_strict_sliding_window.py @@ -0,0 +1,38 @@ +"""Direct tests for :class:`core.rate_limit.StrictSlidingWindowLimiter`.""" + +import time + +import pytest + +from core.rate_limit import StrictSlidingWindowLimiter + + +@pytest.mark.asyncio +async def test_strict_window_allows_burst_then_blocks(): + lim = StrictSlidingWindowLimiter(rate_limit=2, rate_window=0.2) + await lim.acquire() + await lim.acquire() + start = time.monotonic() + await lim.acquire() + assert time.monotonic() - start >= 0.15 + + +@pytest.mark.asyncio +async def test_strict_window_async_context_manager(): + lim = StrictSlidingWindowLimiter(rate_limit=1, rate_window=0.15) + + async def run(): + async with lim: + pass + + await run() + start = time.monotonic() + await run() + assert time.monotonic() - start >= 0.1 + + +def test_strict_window_rejects_invalid_config(): + with pytest.raises(ValueError): + StrictSlidingWindowLimiter(rate_limit=0, rate_window=1.0) + with pytest.raises(ValueError): + StrictSlidingWindowLimiter(rate_limit=1, rate_window=0.0) diff --git a/tests/messaging/test_handler.py b/tests/messaging/test_handler.py index 7695dd0..013b096 100644 --- a/tests/messaging/test_handler.py +++ b/tests/messaging/test_handler.py @@ -13,6 +13,49 @@ def handler(mock_platform, mock_cli_manager, mock_session_store): return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store) +@pytest.mark.asyncio +async def test_handle_message_default_logs_text_len_not_content( + mock_platform, mock_cli_manager, mock_session_store, incoming_message_factory +): + secret = "user-secret-content-never-log-default" + handler = ClaudeMessageHandler( + mock_platform, + mock_cli_manager, + mock_session_store, + log_raw_messaging_content=False, + ) + incoming = incoming_message_factory(text=secret) + with ( + patch.object(handler, "_handle_message_impl", new_callable=AsyncMock), + patch("messaging.handler.logger.info") as log_info, + ): + await handler.handle_message(incoming) + blob = " ".join(str(c) for c in log_info.call_args_list) + assert secret not in blob + assert "text_len=" in blob + + +@pytest.mark.asyncio +async def test_handle_message_raw_content_logging_includes_preview( + mock_platform, mock_cli_manager, mock_session_store, incoming_message_factory +): + secret = "visible-preview-xyz" + handler = ClaudeMessageHandler( + mock_platform, + mock_cli_manager, + mock_session_store, + log_raw_messaging_content=True, + ) + incoming = incoming_message_factory(text=secret) + with ( + patch.object(handler, "_handle_message_impl", new_callable=AsyncMock), + patch("messaging.handler.logger.info") as log_info, + ): + await handler.handle_message(incoming) + blob = " ".join(str(c) for c in log_info.call_args_list) + assert secret in blob + + def test_get_initial_status_new_conversation(handler): """New conversation always returns launching message.""" result = handler._get_initial_status(None, None) diff --git a/tests/messaging/test_handler_format.py b/tests/messaging/test_handler_format.py index 2cf60ee..cede9fb 100644 --- a/tests/messaging/test_handler_format.py +++ b/tests/messaging/test_handler_format.py @@ -17,7 +17,6 @@ def handler(): platform = MagicMock() cli = MagicMock() store = MagicMock() - # Kept for backwards test structure; transcript rendering is now separate. return (platform, cli, store) diff --git a/tests/messaging/test_handler_markdown_and_status_edges.py b/tests/messaging/test_handler_markdown_and_status_edges.py index 185e8b1..662c6ff 100644 --- a/tests/messaging/test_handler_markdown_and_status_edges.py +++ b/tests/messaging/test_handler_markdown_and_status_edges.py @@ -4,6 +4,7 @@ import pytest from messaging.handler import ClaudeMessageHandler from messaging.models import IncomingMessage +from messaging.node_event_pipeline import process_parsed_cli_event from messaging.rendering.telegram_markdown import render_markdown_to_mdv2 from messaging.trees.data import MessageNode, MessageState @@ -333,7 +334,7 @@ async def test_handle_message_incoming_text_none_safe(): @pytest.mark.asyncio async def test_process_parsed_event_malformed_content_continues(): - """Malformed/unknown parsed event does not crash _process_parsed_event.""" + """Malformed/unknown parsed event does not crash process_parsed_cli_event.""" platform = MagicMock() platform.queue_edit_message = AsyncMock() @@ -344,7 +345,7 @@ async def test_process_parsed_event_malformed_content_continues(): transcript = MagicMock() update_ui = AsyncMock() - last_status, had = await handler._process_parsed_event( + last_status, had = await process_parsed_cli_event( parsed={"type": "unknown_type"}, transcript=transcript, update_ui=update_ui, @@ -353,6 +354,9 @@ async def test_process_parsed_event_malformed_content_continues(): tree=None, node_id="n1", captured_session_id=None, + session_store=session_store, + format_status=handler.format_status, + propagate_error_to_children=AsyncMock(), ) assert last_status is None assert had is False diff --git a/tests/messaging/test_limiter.py b/tests/messaging/test_limiter.py index 565e80f..0882c03 100644 --- a/tests/messaging/test_limiter.py +++ b/tests/messaging/test_limiter.py @@ -1,16 +1,10 @@ import asyncio -import os +import contextlib import time import pytest import pytest_asyncio -# Set environment variables relative to test execution -os.environ["MESSAGING_RATE_LIMIT"] = "1" -os.environ["MESSAGING_RATE_WINDOW"] = "0.5" - -import contextlib - from messaging.limiter import MessagingRateLimiter @@ -19,11 +13,8 @@ class TestMessagingRateLimiter: @pytest_asyncio.fixture(autouse=True) async def reset_limiter(self): - """Reset singleton and environment before each test.""" - # Ensure the singleton worker is stopped between tests to avoid dangling tasks. + """Reset singleton before each test.""" await MessagingRateLimiter.shutdown_instance(timeout=0.1) - os.environ["MESSAGING_RATE_LIMIT"] = "1" - os.environ["MESSAGING_RATE_WINDOW"] = "0.5" yield @@ -32,9 +23,16 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_singleton_pattern(self): """Test that get_instance returns the same object.""" - limiter1 = await MessagingRateLimiter.get_instance() - limiter2 = await MessagingRateLimiter.get_instance() + limiter1 = await MessagingRateLimiter.get_instance( + rate_limit=1, rate_window=0.5 + ) + limiter2 = await MessagingRateLimiter.get_instance( + rate_limit=99, rate_window=99.0 + ) assert limiter1 is limiter2 + # First-construction wins for rate parameters + assert limiter1.limiter._rate_limit == 1 + assert limiter1.limiter._rate_window == 0.5 @pytest.mark.asyncio async def test_compaction(self): @@ -42,13 +40,8 @@ class TestMessagingRateLimiter: Verify multiple rapid requests with same dedup_key are compacted. Logic ported from verify_limiter.py """ - # Set slow rate for testing compaction - os.environ["MESSAGING_RATE_LIMIT"] = "1" - os.environ["MESSAGING_RATE_WINDOW"] = "1.0" - - # Must reset instance to pick up new env vars - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) call_counts = {} @@ -78,10 +71,8 @@ class TestMessagingRateLimiter: Verify that even when compacted, all futures resolve to the result of the LAST execution. Logic ported from verify_limiter_v2.py """ - os.environ["MESSAGING_RATE_LIMIT"] = "1" - os.environ["MESSAGING_RATE_WINDOW"] = "0.5" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=0.5) call_counts = {} msg_id = "test_msg_hang" @@ -116,8 +107,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_flood_wait_handling(self): """Test that FloodWait exceptions pause the worker.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) # Mock exception with .seconds attribute class FloodWait(Exception): @@ -157,8 +148,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_flood_wait_retry_after_parsing(self): """Error message with 'retry after N' parses the wait seconds.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def mock_flood(): raise Exception("Flood wait: retry after 2 seconds") @@ -172,8 +163,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_non_flood_exception_no_pause(self): """Non-flood exception doesn't trigger pause.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def mock_error(): raise ValueError("some regular error") @@ -187,8 +178,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_flood_with_seconds_attribute(self): """Exception with .seconds attribute uses that value for pause.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) class FloodWaitCustom(Exception): def __init__(self): @@ -209,10 +200,8 @@ class TestMessagingRateLimiter: Proactive limiter should enforce a strict sliding window: for any i, t[i+rate_limit] - t[i] >= rate_window (within tolerance). """ - os.environ["MESSAGING_RATE_LIMIT"] = "2" - os.environ["MESSAGING_RATE_WINDOW"] = "0.5" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=2, rate_window=0.5) async def acquire(i: int) -> float: async def _do() -> float: @@ -235,8 +224,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_compaction_last_task_fails_all_futures_get_exception(self): """When compacted task's last func fails, all futures get the exception.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def ok_task(): return "ok" @@ -255,8 +244,8 @@ class TestMessagingRateLimiter: @pytest.mark.asyncio async def test_fire_and_forget_failure_logged(self, caplog): """fire_and_forget with failing task logs error and does not re-raise.""" - MessagingRateLimiter._instance = None - limiter = await MessagingRateLimiter.get_instance() + await MessagingRateLimiter.shutdown_instance(timeout=0.1) + limiter = await MessagingRateLimiter.get_instance(rate_limit=1, rate_window=1.0) async def fail_task(): raise ValueError("fire_and_forget failed") @@ -264,4 +253,6 @@ class TestMessagingRateLimiter: limiter.fire_and_forget(fail_task, dedup_key="fire_fail") await asyncio.sleep(1.5) - assert any("fire_and_forget failed" in str(r) for r in caplog.records) + joined = " ".join(str(r.message) for r in caplog.records) + assert "ValueError" in joined + assert "fire_and_forget failed" not in joined diff --git a/tests/messaging/test_messaging_factory.py b/tests/messaging/test_messaging_factory.py index 5d3c0a3..a2ed0cd 100644 --- a/tests/messaging/test_messaging_factory.py +++ b/tests/messaging/test_messaging_factory.py @@ -41,6 +41,10 @@ class TestCreateMessagingPlatform: whisper_device="cuda", hf_token="", nvidia_nim_api_key="", + messaging_rate_limit=1, + messaging_rate_window=1.0, + log_raw_messaging_content=False, + log_api_error_tracebacks=False, ) def test_telegram_without_token(self): @@ -85,6 +89,10 @@ class TestCreateMessagingPlatform: whisper_device="nvidia_nim", hf_token="", nvidia_nim_api_key="", + messaging_rate_limit=1, + messaging_rate_window=1.0, + log_raw_messaging_content=False, + log_api_error_tracebacks=False, ) def test_discord_without_token(self): diff --git a/tests/messaging/test_session_store_edge_cases.py b/tests/messaging/test_session_store_edge_cases.py index 31683d4..8656c36 100644 --- a/tests/messaging/test_session_store_edge_cases.py +++ b/tests/messaging/test_session_store_edge_cases.py @@ -81,15 +81,40 @@ class TestSessionStoreSaveEdgeCases: """Tests for save failure handling.""" def test_save_io_error_handled(self, tmp_store): - """Write failure in _write_data() raises (callers handle the error).""" + """Write failure during atomic replace is surfaced to callers.""" tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) with ( - patch("builtins.open", side_effect=OSError("disk full")), + patch("messaging.session.os.replace", side_effect=OSError("disk full")), pytest.raises(OSError), ): tmp_store._write_data(tmp_store._snapshot()) +class TestSessionStoreAtomicWrites: + """Atomic persistence: failed replace must not truncate the prior file.""" + + def test_failed_replace_keeps_prior_bytes_and_marks_dirty(self, tmp_path): + path = str(tmp_path / "sessions.json") + store = SessionStore(storage_path=path) + store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) + store.flush_pending_save() + with open(path, encoding="utf-8") as f: + disk_after_first = f.read() + + store.save_tree("r2", {"root_id": "r2", "nodes": {"r2": {}}}) + + with patch( + "messaging.session.os.replace", side_effect=OSError("replace failed") + ): + store.flush_pending_save() + + with open(path, encoding="utf-8") as f: + disk_after_failed = f.read() + assert disk_after_failed == disk_after_first + assert store._dirty is True + assert store.get_tree("r2") is not None + + class TestSessionStoreClearAll: def test_clear_all_wipes_state_and_persists(self, tmp_path): path = str(tmp_path / "sessions.json") diff --git a/tests/messaging/test_stream_transcript_contract.py b/tests/messaging/test_stream_transcript_contract.py new file mode 100644 index 0000000..2ef2b60 --- /dev/null +++ b/tests/messaging/test_stream_transcript_contract.py @@ -0,0 +1,62 @@ +"""Messaging-specific assertions built on neutral Anthropic stream contracts.""" + +from __future__ import annotations + +from core.anthropic import SSEBuilder +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + has_tool_use, + parse_sse_text, +) +from messaging.event_parser import parse_cli_event +from messaging.transcript import RenderCtx, TranscriptBuffer + + +def test_thinking_tool_text_and_transcript_order_contract() -> None: + builder = SSEBuilder("msg_contract", "contract-model") + chunks = [builder.message_start()] + chunks.extend(builder.ensure_thinking_block()) + chunks.append(builder.emit_thinking_delta("inspect first")) + chunks.extend(builder.close_content_blocks()) + tool_block_index = builder.blocks.allocate_index() + chunks.append( + builder.content_block_start( + tool_block_index, "tool_use", id="toolu_1", name="Read" + ) + ) + chunks.append( + builder.content_block_delta( + tool_block_index, "input_json_delta", '{"file":"README.md"}' + ) + ) + chunks.append(builder.content_block_stop(tool_block_index)) + chunks.extend(builder.ensure_text_block()) + chunks.append(builder.emit_text_delta("done")) + chunks.extend(builder.close_all_blocks()) + chunks.append(builder.message_delta("end_turn", 20)) + chunks.append(builder.message_stop()) + + events = parse_sse_text("".join(chunks)) + assert_anthropic_stream_contract(events) + assert has_tool_use(events) + + transcript = TranscriptBuffer() + for event in events: + for parsed in parse_cli_event(event.data): + transcript.apply(parsed) + rendered = transcript.render(_render_ctx(), limit_chars=3900, status=None) + assert ( + rendered.find("inspect first") + < rendered.find("Tool call:") + < rendered.find("done") + ) + + +def _render_ctx() -> RenderCtx: + return RenderCtx( + bold=lambda s: f"*{s}*", + code_inline=lambda s: f"`{s}`", + escape_code=lambda s: s, + escape_text=lambda s: s, + render_markdown=lambda s: s, + ) diff --git a/tests/messaging/test_transcription.py b/tests/messaging/test_transcription.py index 6aaaca4..023f723 100644 --- a/tests/messaging/test_transcription.py +++ b/tests/messaging/test_transcription.py @@ -25,7 +25,7 @@ def test_transcribe_file_too_large_raises(): path = Path(f.name) try: with pytest.raises(ValueError, match="too large"): - transcribe_audio(path, "audio/ogg", whisper_device="auto") + transcribe_audio(path, "audio/ogg", whisper_device="cpu") finally: path.unlink(missing_ok=True) @@ -87,15 +87,9 @@ def test_transcribe_invalid_device_raises(): f.write(b"fake ogg") path = Path(f.name) try: - # Mock settings to return invalid device "auto" - mock_settings = MagicMock() - mock_settings.whisper_device = "auto" - mock_settings.whisper_model = "base" - # Patch _load_audio to avoid ImportError from missing librosa # Device validation happens in _get_pipeline before torch import with ( - patch("messaging.transcription.get_settings", return_value=mock_settings), patch("messaging.transcription._load_audio"), pytest.raises(ValueError, match="whisper_device must be 'cpu' or 'cuda'"), ): @@ -104,6 +98,24 @@ def test_transcribe_invalid_device_raises(): path.unlink(missing_ok=True) +def test_transcribe_nim_requires_api_key(): + """NIM path rejects empty API key without reading global settings.""" + with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f: + f.write(b"fake ogg") + path = Path(f.name) + try: + with pytest.raises(ValueError, match="non-empty"): + transcribe_audio( + path, + "audio/ogg", + whisper_device="nvidia_nim", + whisper_model="openai/whisper-large-v3", + nvidia_nim_api_key="", + ) + finally: + path.unlink(missing_ok=True) + + def test_transcribe_local_import_error_raises(): """Local backend when voice_local extra not installed raises ImportError.""" with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f: @@ -120,6 +132,6 @@ def test_transcribe_local_import_error_raises(): ), pytest.raises(ImportError, match="voice_local extra"), ): - transcribe_audio(path, "audio/ogg", whisper_device="auto") + transcribe_audio(path, "audio/ogg", whisper_device="cpu") finally: path.unlink(missing_ok=True) diff --git a/tests/messaging/test_transcription_nim.py b/tests/messaging/test_transcription_nim.py new file mode 100644 index 0000000..b15f89b --- /dev/null +++ b/tests/messaging/test_transcription_nim.py @@ -0,0 +1,39 @@ +"""Tests for NVIDIA NIM voice transcription wiring.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from messaging.transcription import transcribe_audio + + +def test_transcribe_audio_nvidia_nim_forwards_api_key(tmp_path: Path) -> None: + wav = tmp_path / "stub.wav" + wav.write_bytes(b"\x00" * 128) + with patch("messaging.transcription.transcribe_nvidia_nim_audio") as nim_fn: + nim_fn.return_value = "ok" + out = transcribe_audio( + wav, + "audio/wav", + whisper_model="openai/whisper-large-v3", + whisper_device="nvidia_nim", + nvidia_nim_api_key="test-nim-key", + ) + nim_fn.assert_called_once_with( + wav, "openai/whisper-large-v3", api_key="test-nim-key" + ) + assert out == "ok" + + +def test_nim_asr_model_map_entries_are_real_function_ids() -> None: + from providers.nvidia_nim.voice import _NIM_ASR_MODEL_MAP + + for function_id, language_code in _NIM_ASR_MODEL_MAP.values(): + assert function_id + assert function_id.strip().lower() != "none" + # Hosted NIM function-id is a lowercase UUID string. + parts = function_id.split("-") + assert len(parts) == 5 + assert all(p for p in parts) + assert language_code is not None diff --git a/tests/messaging/test_tree_processor.py b/tests/messaging/test_tree_processor.py index e6a3a03..664748d 100644 --- a/tests/messaging/test_tree_processor.py +++ b/tests/messaging/test_tree_processor.py @@ -6,7 +6,7 @@ import pytest from messaging.models import IncomingMessage from messaging.trees.data import MessageNode, MessageState, MessageTree -from messaging.trees.processor import TreeQueueProcessor +from messaging.trees.queue_manager import TreeQueueProcessor @pytest.fixture diff --git a/tests/messaging/test_tree_repository.py b/tests/messaging/test_tree_repository.py index 2860ae6..e5bf4ef 100644 --- a/tests/messaging/test_tree_repository.py +++ b/tests/messaging/test_tree_repository.py @@ -4,7 +4,7 @@ import pytest from messaging.models import IncomingMessage from messaging.trees.data import MessageNode, MessageState, MessageTree -from messaging.trees.repository import TreeRepository +from messaging.trees.queue_manager import TreeRepository @pytest.fixture diff --git a/tests/provider_request_mocks.py b/tests/provider_request_mocks.py new file mode 100644 index 0000000..ef2a40d --- /dev/null +++ b/tests/provider_request_mocks.py @@ -0,0 +1,25 @@ +"""Shared MagicMock request objects for OpenAI-compatible provider tests.""" + +from unittest.mock import MagicMock + + +def make_openai_compat_stream_request( + *, model: str = "test-model", stream: bool = True +) -> MagicMock: + """Minimal request stub matching :meth:`OpenAIChatTransport._build_request_body` needs.""" + req = MagicMock() + req.model = model + req.stream = stream + req.messages = [] + req.system = None + req.tools = None + req.tool_choice = None + req.metadata = None + req.max_tokens = 4096 + req.temperature = None + req.top_p = None + req.top_k = None + req.stop_sequences = None + req.extra_body = None + req.thinking = None + return req diff --git a/tests/providers/test_anthropic_messages.py b/tests/providers/test_anthropic_messages.py index 2d9498d..1cb57cb 100644 --- a/tests/providers/test_anthropic_messages.py +++ b/tests/providers/test_anthropic_messages.py @@ -6,8 +6,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS +from core.anthropic.sse import format_sse_event +from core.anthropic.stream_contracts import event_index, parse_sse_text from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig +from tests.stream_contract import assert_canonical_stream_error_envelope class NativeProvider(AnthropicMessagesTransport): @@ -32,8 +36,6 @@ class MockRequest: "model": self.model, "messages": [{"role": "user", "content": "Hello"}], "extra_body": {"ignored": True}, - "original_model": "claude", - "resolved_provider_model": "native/test-model", "thinking": {"enabled": thinking_enabled}, } @@ -42,16 +44,30 @@ class MockRequest: class FakeResponse: - def __init__(self, *, status_code=200, lines=None, text=""): + def __init__( + self, + *, + status_code=200, + lines=None, + text="", + raise_after_line_index: int | None = None, + ): self.status_code = status_code self._lines = lines or [] self._text = text + self._raise_after_line_index = raise_after_line_index self.is_closed = False self.request = httpx.Request("POST", "https://example.test/v1/messages") + self.headers = httpx.Headers() async def aiter_lines(self): - for line in self._lines: + for i, line in enumerate(self._lines): yield line + if ( + self._raise_after_line_index is not None + and i >= self._raise_after_line_index + ): + raise RuntimeError("mid-stream failure") async def aread(self): return self._text.encode() @@ -67,6 +83,11 @@ class FakeResponse: async def aclose(self): self.is_closed = True + async def aiter_bytes(self, chunk_size: int = 65_536): + data = self._text.encode("utf-8") + for offset in range(0, len(data), chunk_size): + yield data[offset : offset + chunk_size] + @pytest.fixture def provider_config(): @@ -122,10 +143,8 @@ def test_default_request_body_strips_internal_fields(provider_config): assert body["model"] == "test-model" assert body["thinking"] == {"type": "enabled"} - assert body["max_tokens"] == 81920 + assert body["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS assert "extra_body" not in body - assert "original_model" not in body - assert "resolved_provider_model" not in body def test_default_request_body_preserves_thinking_budget(provider_config): @@ -210,7 +229,66 @@ async def test_stream_maps_non_200_to_error_event_and_closes_response( ] assert response.is_closed - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "REQ_123" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + blob = "".join(events) + assert "REQ_123" in blob + + +@pytest.mark.asyncio +async def test_midstream_error_closes_open_block_and_uses_fresh_content_index( + provider_config, +): + """After upstream message_start + content_block_start, synthetic errors must not reuse index 0.""" + provider = NativeProvider(provider_config) + req = MockRequest() + mid = "msg_midstream_err" + msg_start = format_sse_event( + "message_start", + { + "type": "message_start", + "message": { + "id": mid, + "type": "message", + "role": "assistant", + "content": [], + "model": "test-model", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + }, + ) + block_start = format_sse_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + ) + lines: list[str] = [] + for blob in (msg_start, block_start): + lines.extend(blob.splitlines()) + response = FakeResponse(lines=lines, raise_after_line_index=len(lines) - 1) + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in provider.stream_response(req)] + + assert_canonical_stream_error_envelope( + events, user_message_substr="mid-stream failure" + ) + parsed = parse_sse_text("".join(events)) + starts = [e for e in parsed if e.event == "content_block_start"] + assert event_index(starts[0]) == 0 + assert event_index(starts[-1]) == 1 + assert {event_index(e) for e in parsed if e.event == "content_block_stop"} == {0, 1} diff --git a/tests/providers/test_anthropic_messages_429_retry.py b/tests/providers/test_anthropic_messages_429_retry.py new file mode 100644 index 0000000..e64bbb7 --- /dev/null +++ b/tests/providers/test_anthropic_messages_429_retry.py @@ -0,0 +1,125 @@ +"""Native Anthropic transport: HTTP 429 is retried inside execute_with_retry.""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from providers.base import ProviderConfig +from providers.rate_limit import GlobalRateLimiter +from tests.providers.test_anthropic_messages import ( + FakeResponse, + MockRequest, + NativeProvider, +) +from tests.stream_contract import assert_canonical_stream_error_envelope + + +@pytest.fixture +def provider_config(): + return ProviderConfig( + api_key="test-key", + base_url="https://custom.test/v1/", + rate_limit=100, + rate_window=60, + http_read_timeout=600.0, + http_write_timeout=15.0, + http_connect_timeout=5.0, + ) + + +@pytest.mark.asyncio +async def test_native_stream_retries_on_http_429_then_streams(provider_config): + """First response 429 (closed), second 200 streams; send is called twice.""" + GlobalRateLimiter.reset_instance() + try: + provider = NativeProvider(provider_config) + req = MockRequest() + request_obj = httpx.Request("POST", "https://custom.test/v1/messages") + ok_lines = [ + "event: message_start", + 'data: {"type":"message_start"}', + "", + ] + ok_response = FakeResponse(lines=ok_lines) + too_many = FakeResponse(status_code=429, text="rate limited") + + send_calls = {"n": 0} + + async def send_side_effect(*_a, **_kw): + send_calls["n"] += 1 + if send_calls["n"] == 1: + return too_many + return ok_response + + with ( + patch.object(provider._client, "build_request", return_value=request_obj), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + side_effect=send_side_effect, + ), + patch( + "asyncio.sleep", + new_callable=AsyncMock, + ), + ): + events = [e async for e in provider.stream_response(req)] + + assert send_calls["n"] == 2 + assert too_many.is_closed + assert ok_response.is_closed + assert events == [ + "event: message_start\n", + 'data: {"type":"message_start"}\n', + "\n", + ] + finally: + GlobalRateLimiter.reset_instance() + + +@pytest.mark.asyncio +async def test_non_429_http_error_not_retried(provider_config): + """HTTP 500 from upstream is not retried; single send.""" + GlobalRateLimiter.reset_instance() + try: + + @asynccontextmanager + async def _slot(): + yield + + with patch("providers.anthropic_messages.GlobalRateLimiter") as mock_gl: + instance = mock_gl.get_scoped_instance.return_value + + async def _passthrough(fn, *args, **kwargs): + return await fn(*args, **kwargs) + + instance.execute_with_retry = AsyncMock(side_effect=_passthrough) + instance.concurrency_slot.side_effect = _slot + + provider = NativeProvider(provider_config) + req = MockRequest() + err = FakeResponse(status_code=500, text="Internal Server Error") + + with ( + patch.object( + provider._client, "build_request", return_value=MagicMock() + ), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=err, + ) as mock_send, + ): + events = [e async for e in provider.stream_response(req)] + + mock_send.assert_awaited_once() + assert err.is_closed + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + finally: + GlobalRateLimiter.reset_instance() diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py index 7312fbd..99aeba6 100644 --- a/tests/providers/test_converter.py +++ b/tests/providers/test_converter.py @@ -2,7 +2,12 @@ import json import pytest -from core.anthropic import AnthropicToOpenAIConverter +from api.models.anthropic import MessagesRequest +from core.anthropic import ( + AnthropicToOpenAIConverter, + OpenAIConversionError, + build_base_request_body, +) # --- Mock Classes --- @@ -468,3 +473,102 @@ def test_convert_multiple_tool_results(): assert len(result) == 2 assert result[0]["tool_call_id"] == "t1" assert result[1]["tool_call_id"] == "t2" + + +def test_convert_user_message_tool_result_dict_as_json(): + content = [ + MockBlock( + type="tool_result", + tool_use_id="t_dict", + content={"ok": True, "count": 2}, + ), + ] + messages = [MockMessage("user", content)] + result = AnthropicToOpenAIConverter.convert_messages(messages) + assert result[0]["role"] == "tool" + assert result[0]["content"] == '{"ok": true, "count": 2}' + + +def test_assistant_redacted_thinking_omitted_from_openai_chat(): + """Opaque redacted_thinking is not materialized as content or reasoning_content.""" + content = [ + MockBlock(type="redacted_thinking", data="secret-opaque"), + MockBlock(type="text", text="Visible."), + ] + messages = [MockMessage("assistant", content)] + result = AnthropicToOpenAIConverter.convert_messages( + messages, include_thinking=True, include_reasoning_content=True + ) + assert result[0]["content"] == "Visible." + assert "secret-opaque" not in result[0]["content"] + assert "reasoning_content" not in result[0] + + +def test_convert_user_message_image_raises(): + content = [ + MockBlock(type="image", source={"type": "url", "url": "https://example.com/x"}) + ] + messages = [MockMessage("user", content)] + with pytest.raises(OpenAIConversionError): + AnthropicToOpenAIConverter.convert_messages(messages) + + +def test_convert_assistant_text_after_tool_use_raises(): + content = [ + MockBlock(type="tool_use", id="call_z", name="Read", input={}), + MockBlock(type="text", text="Illegal after tool"), + ] + messages = [MockMessage("assistant", content)] + with pytest.raises(OpenAIConversionError): + AnthropicToOpenAIConverter.convert_messages(messages) + + +def test_openai_build_accepts_declared_native_top_level_hints() -> None: + """OpenAI conversion ignores known non-OpenAI hints (e.g. context_management) without 400.""" + req = MessagesRequest.model_validate( + { + "model": "m", + "messages": [{"role": "user", "content": "h"}], + "context_management": {"edits": []}, + "output_config": {"foo": 1}, + "mcp_servers": [{"type": "url", "url": "https://x.com"}], + } + ) + body = build_base_request_body(req, default_max_tokens=100) + assert "context_management" not in body + assert "output_config" not in body + assert "mcp_servers" not in body + assert body["model"] == "m" + + +def test_openai_build_rejects_unknown_top_level_extras() -> None: + """Truly unknown keys must still be rejected (not dropped silently).""" + req = MessagesRequest.model_validate( + { + "model": "m", + "messages": [{"role": "user", "content": "h"}], + "experimental_client_only_passthrough": True, + } + ) + with pytest.raises(OpenAIConversionError, match="top-level request fields"): + build_base_request_body(req) + + +@pytest.mark.parametrize( + "content", + [ + [MockBlock(type="server_tool_use", id="1", name="web_search", input={})], + [MockBlock(type="web_search_tool_result", tool_use_id="1", content=[])], + [ + MockBlock( + type="web_fetch_tool_result", + tool_use_id="1", + content={"type": "web_fetch_result", "url": "https://a.com/x"}, + ) + ], + ], +) +def test_convert_assistant_server_tool_blocks_raise(content) -> None: + messages = [MockMessage("assistant", content)] + with pytest.raises(OpenAIConversionError, match="server tool"): + AnthropicToOpenAIConverter.convert_messages(messages) diff --git a/tests/providers/test_deepseek.py b/tests/providers/test_deepseek.py index bb00e0a..7a4d997 100644 --- a/tests/providers/test_deepseek.py +++ b/tests/providers/test_deepseek.py @@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from api.models.anthropic import ContentBlockImage, Message, MessagesRequest from providers.base import ProviderConfig -from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider +from providers.deepseek import DEEPSEEK_DEFAULT_BASE, DeepSeekProvider +from providers.exceptions import InvalidRequestError class MockMessage: @@ -41,7 +43,7 @@ class MockRequest: def deepseek_config(): return ProviderConfig( api_key="test_deepseek_key", - base_url=DEEPSEEK_BASE_URL, + base_url=DEEPSEEK_DEFAULT_BASE, rate_limit=10, rate_window=60, enable_thinking=True, @@ -72,7 +74,7 @@ def test_init(deepseek_config): with patch("providers.openai_compat.AsyncOpenAI") as mock_openai: provider = DeepSeekProvider(deepseek_config) assert provider._api_key == "test_deepseek_key" - assert provider._base_url == DEEPSEEK_BASE_URL + assert provider._base_url == DEEPSEEK_DEFAULT_BASE mock_openai.assert_called_once() @@ -91,7 +93,7 @@ def test_build_request_body_global_disable_blocks_request_thinking(): provider = DeepSeekProvider( ProviderConfig( api_key="test_deepseek_key", - base_url=DEEPSEEK_BASE_URL, + base_url=DEEPSEEK_DEFAULT_BASE, rate_limit=10, rate_window=60, enable_thinking=False, @@ -152,6 +154,67 @@ def test_build_request_body_preserves_reasoning_content(deepseek_provider): assert body["messages"][0]["reasoning_content"] == "First think" +def test_build_request_body_disabled_thinking_omits_reasoning_and_thinking_tags(): + """Resolved thinking policy must strip assistant thinking from OpenAI history.""" + provider = DeepSeekProvider( + ProviderConfig( + api_key="test_deepseek_key", + base_url=DEEPSEEK_DEFAULT_BASE, + rate_limit=10, + rate_window=60, + enable_thinking=False, + ) + ) + req = MockRequest( + system=None, + model="deepseek-chat", + messages=[ + MockMessage( + "assistant", + [ + MockBlock(type="thinking", thinking="secret"), + MockBlock(type="text", text="hi"), + ], + ) + ], + ) + + body = provider._build_request_body(req) + + assistant = body["messages"][0] + assert "reasoning_content" not in assistant + assert "secret" not in assistant["content"] + assert assistant["content"] == "hi" + + +def test_build_request_body_disabled_thinking_omits_redacted_blocks(): + """redacted_thinking is not sent as OpenAI text when thinking is disabled.""" + provider = DeepSeekProvider( + ProviderConfig( + api_key="test_deepseek_key", + base_url=DEEPSEEK_DEFAULT_BASE, + rate_limit=10, + rate_window=60, + enable_thinking=False, + ) + ) + req = MockRequest( + system=None, + model="deepseek-chat", + messages=[ + MockMessage( + "assistant", + [ + MockBlock(type="redacted_thinking", data="opaque-xyz"), + MockBlock(type="text", text="hi"), + ], + ) + ], + ) + body = provider._build_request_body(req) + assert "opaque-xyz" not in body["messages"][0]["content"] + + @pytest.mark.asyncio async def test_stream_response_reasoning_content(deepseek_provider): """reasoning_content deltas are emitted as thinking blocks.""" @@ -179,3 +242,37 @@ async def test_stream_response_reasoning_content(deepseek_provider): assert any( '"thinking_delta"' in event and "Thinking..." in event for event in events ) + + +def test_preflight_stream_rejects_unsupported_user_image_for_openai_conversion(): + """Eager preflight: image block fails before a stream would be opened.""" + request = MessagesRequest( + model="deepseek/deepseek-chat", + max_tokens=100, + messages=[ + Message( + role="user", + content=[ + ContentBlockImage( + type="image", + source={ + "type": "base64", + "media_type": "image/png", + "data": "YQ==", + }, + ) + ], + ) + ], + ) + provider = DeepSeekProvider( + ProviderConfig( + api_key="k", + base_url=DEEPSEEK_DEFAULT_BASE, + rate_limit=10, + rate_window=60, + ) + ) + with pytest.raises(InvalidRequestError) as exc: + provider.preflight_stream(request, thinking_enabled=True) + assert "image" in str(exc.value).lower() diff --git a/tests/providers/test_error_mapping.py b/tests/providers/test_error_mapping.py index 162b05d..2a180e7 100644 --- a/tests/providers/test_error_mapping.py +++ b/tests/providers/test_error_mapping.py @@ -1,13 +1,21 @@ """Tests for provider error mapping and core error formatting.""" +from pathlib import Path from unittest.mock import MagicMock, patch import openai import pytest from httpx import ReadTimeout, Request, Response -from core.anthropic import append_request_id, get_user_facing_error_message -from providers.error_mapping import map_error +from core.anthropic import ( + append_request_id, + format_user_error_preview, + get_user_facing_error_message, +) +from providers.error_mapping import ( + map_error, + user_visible_message_for_mapped_provider_error, +) from providers.exceptions import ( APIError, AuthenticationError, @@ -101,27 +109,6 @@ class TestMapError: result = map_error(exc) assert result is exc - @pytest.mark.parametrize( - "exc_cls,expected_cls", - [ - (openai.AuthenticationError, AuthenticationError), - (openai.RateLimitError, RateLimitError), - (openai.BadRequestError, InvalidRequestError), - ], - ids=["auth", "rate_limit", "bad_request"], - ) - def test_mapping_parametrized(self, exc_cls, expected_cls): - """Parametrized check of openai -> provider error mapping.""" - status_map = { - openai.AuthenticationError: 401, - openai.RateLimitError: 429, - openai.BadRequestError: 400, - } - exc = _make_openai_error(exc_cls, status_code=status_map[exc_cls]) - with patch("providers.error_mapping.GlobalRateLimiter"): - result = map_error(exc) - assert isinstance(result, expected_cls) - def test_user_facing_message_read_timeout_empty_string(): """ReadTimeout wrapping TimeoutError should still produce readable text.""" @@ -134,3 +121,35 @@ def test_append_request_id_suffix(): """Request id suffix should be appended deterministically.""" message = append_request_id("Provider request failed.", "req_abc123") assert message == "Provider request failed. (request_id=req_abc123)" + + +def test_user_facing_message_bad_request_prefers_mapped_text_over_sdk_string(): + """BadRequestError should map to stable wording even when str(exc) is non-empty.""" + exc = _make_openai_error( + openai.BadRequestError, message="leaky-upstream-detail", status_code=400 + ) + assert get_user_facing_error_message(exc) == "Invalid request sent to provider." + + +def test_format_user_error_preview_truncates(): + exc = ValueError("x" * 500) + preview = format_user_error_preview(exc, max_len=20) + assert len(preview) == 20 + assert preview == "x" * 20 + + +def test_user_visible_message_for_mapped_provider_error_405(): + mapped = APIError("ignored", status_code=405, raw_error="") + msg = user_visible_message_for_mapped_provider_error( + mapped, provider_name="ACME", read_timeout_s=30.0 + ) + assert "ACME" in msg and "405" in msg + + +def test_streaming_transports_pass_scoped_rate_limiter_to_map_error(): + """Guardrail: streaming adapters must scope reactive 429 handling per provider.""" + root = Path(__file__).resolve().parents[2] + for name in ("anthropic_messages.py", "openai_compat.py"): + text = (root / "providers" / name).read_text(encoding="utf-8") + assert "map_error(" in text, name + assert "rate_limiter=self._global_rate_limiter" in text, name diff --git a/tests/providers/test_llamacpp.py b/tests/providers/test_llamacpp.py index 04e0874..df87069 100644 --- a/tests/providers/test_llamacpp.py +++ b/tests/providers/test_llamacpp.py @@ -5,8 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS from providers.base import ProviderConfig from providers.llamacpp import LlamaCppProvider +from tests.stream_contract import assert_canonical_stream_error_envelope class MockMessage: @@ -221,7 +223,7 @@ async def test_stream_response_adds_max_tokens_if_missing(llamacpp_provider): [e async for e in llamacpp_provider.stream_response(req)] _, kwargs = mock_build.call_args - assert kwargs["json"]["max_tokens"] == 81920 + assert kwargs["json"]["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS @pytest.mark.asyncio @@ -254,10 +256,10 @@ async def test_stream_error_status_code(llamacpp_provider): async for e in llamacpp_provider.stream_response(req, request_id="TEST_ID") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "TEST_ID" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + assert "TEST_ID" in "".join(events) @pytest.mark.asyncio @@ -281,10 +283,11 @@ async def test_stream_network_error(llamacpp_provider): async for e in llamacpp_provider.stream_response(req, request_id="TEST_ID2") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Connection refused" in events[0] - assert "TEST_ID2" in events[0] + blob = "".join(events) + assert_canonical_stream_error_envelope( + events, user_message_substr="Connection refused" + ) + assert "TEST_ID2" in blob @pytest.mark.asyncio @@ -315,8 +318,36 @@ async def test_stream_error_405_mentions_upstream_provider(llamacpp_provider): e async for e in llamacpp_provider.stream_response(req, request_id="REQ405") ] + blob = "".join(events) assert ( "Upstream provider LLAMACPP rejected the request method or endpoint (HTTP 405)." - in events[0] + in blob ) - assert "REQ405" in events[0] + assert "REQ405" in blob + + +def test_build_request_body_disabled_thinking_strips_native_thinking_history( + llamacpp_config, +): + """With thinking disabled, prior assistant thinking/redacted blocks are omitted.""" + config = llamacpp_config.model_copy(update={"enable_thinking": False}) + provider = LlamaCppProvider(config) + messages = [ + MockMessage("user", "Hi"), + MockMessage( + "assistant", + [ + {"type": "thinking", "thinking": "p"}, + {"type": "redacted_thinking", "data": "ZGF0YQ=="}, + ], + ), + ] + req = MockRequest( + system=None, + messages=messages, + ) + body = provider._build_request_body(req, thinking_enabled=False) + asst = body["messages"][1] + assert asst["content"] == "" + assert "thinking" not in str(body) + assert "redacted_thinking" not in str(body) diff --git a/tests/providers/test_lmstudio.py b/tests/providers/test_lmstudio.py index f8c4dee..06aca7a 100644 --- a/tests/providers/test_lmstudio.py +++ b/tests/providers/test_lmstudio.py @@ -5,8 +5,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest +from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS from providers.base import ProviderConfig from providers.lmstudio import LMStudioProvider +from tests.stream_contract import assert_canonical_stream_error_envelope class MockMessage: @@ -252,7 +254,7 @@ async def test_stream_response_adds_max_tokens_if_missing(lmstudio_provider): [e async for e in lmstudio_provider.stream_response(req)] _, kwargs = mock_build.call_args - assert kwargs["json"]["max_tokens"] == 81920 + assert kwargs["json"]["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS @pytest.mark.asyncio @@ -285,10 +287,10 @@ async def test_stream_error_status_code(lmstudio_provider): async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "TEST_ID" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + assert "TEST_ID" in "".join(events) @pytest.mark.asyncio @@ -312,10 +314,11 @@ async def test_stream_network_error(lmstudio_provider): async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID2") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Connection refused" in events[0] - assert "TEST_ID2" in events[0] + blob = "".join(events) + assert_canonical_stream_error_envelope( + events, user_message_substr="Connection refused" + ) + assert "TEST_ID2" in blob @pytest.mark.asyncio @@ -346,8 +349,61 @@ async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider): e async for e in lmstudio_provider.stream_response(req, request_id="REQ405") ] + blob = "".join(events) assert ( "Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)." - in events[0] + in blob ) - assert "REQ405" in events[0] + assert "REQ405" in blob + + +def test_build_request_body_disabled_thinking_strips_native_thinking_history( + lmstudio_config, +): + """Disabled thinking must omit prior assistant thinking/redacted blocks in JSON.""" + provider = LMStudioProvider( + lmstudio_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest( + system=None, + messages=[ + MockMessage("user", "hi"), + MockMessage( + "assistant", + [ + {"type": "thinking", "thinking": "t"}, + {"type": "redacted_thinking", "data": "opaque"}, + ], + ), + ], + ) + body = provider._build_request_body(req, thinking_enabled=False) + assert body["messages"][1]["content"] == "" + assert "redacted_thinking" not in str(body) + + +def test_build_request_body_preserves_signed_thinking_native_history(lmstudio_config): + """When thinking is enabled, signed thinking blocks are kept; unsigned stripped.""" + provider = LMStudioProvider(lmstudio_config) + req = MockRequest( + system=None, + messages=[ + MockMessage("user", "hi"), + MockMessage( + "assistant", + [ + { + "type": "thinking", + "thinking": "signed", + "signature": "sig", + }, + {"type": "redacted_thinking", "data": "opaque"}, + ], + ), + ], + ) + body = provider._build_request_body(req, thinking_enabled=True) + c = body["messages"][1]["content"] + assert isinstance(c, list) + assert any(isinstance(b, dict) and b.get("type") == "thinking" for b in c) + assert any(isinstance(b, dict) and b.get("type") == "redacted_thinking" for b in c) diff --git a/tests/providers/test_nim_request_clone.py b/tests/providers/test_nim_request_clone.py new file mode 100644 index 0000000..30260c1 --- /dev/null +++ b/tests/providers/test_nim_request_clone.py @@ -0,0 +1,40 @@ +"""Tests for NVIDIA NIM request body cloning helpers.""" + +from copy import deepcopy + +from providers.nvidia_nim.request import clone_body_without_reasoning_budget + + +def test_clone_body_without_reasoning_budget_strips_top_level_and_nested(): + body: dict = { + "model": "x", + "extra_body": { + "reasoning_budget": 99, + "chat_template_kwargs": {"reasoning_budget": 42, "thinking": True}, + "top_k": 1, + }, + } + original_extra = deepcopy(body["extra_body"]) + out = clone_body_without_reasoning_budget(body) + + assert out is not None + assert out["extra_body"]["chat_template_kwargs"] == {"thinking": True} + assert "reasoning_budget" not in out["extra_body"] + assert body["extra_body"] == original_extra + + +def test_clone_body_without_reasoning_budget_returns_none_when_unchanged(): + body = {"model": "x", "extra_body": {"top_k": 3}} + assert clone_body_without_reasoning_budget(body) is None + + +def test_clone_body_without_reasoning_budget_returns_none_without_extra_body(): + assert clone_body_without_reasoning_budget({"model": "y"}) is None + + +def test_clone_body_drops_empty_extra_body_after_strip(): + body = {"model": "z", "extra_body": {"reasoning_budget": 7}} + out = clone_body_without_reasoning_budget(body) + assert out is not None + assert "extra_body" not in out + assert "extra_body" in body diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py index dcacbfd..4c190ed 100644 --- a/tests/providers/test_nvidia_nim.py +++ b/tests/providers/test_nvidia_nim.py @@ -5,6 +5,8 @@ import openai import pytest from httpx import Request, Response +from config.nim import NimSettings +from providers.defaults import NVIDIA_NIM_DEFAULT_BASE from providers.nvidia_nim import NvidiaNimProvider @@ -42,7 +44,7 @@ class MockRequest: def _make_bad_request_error(message: str) -> openai.BadRequestError: response = Response( status_code=400, - request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"), + request=Request("POST", f"{NVIDIA_NIM_DEFAULT_BASE}/chat/completions"), ) body = {"error": {"message": message, "type": "BadRequestError", "code": 400}} return openai.BadRequestError(message, response=response, body=body) @@ -67,8 +69,6 @@ def mock_rate_limiter(): async def test_init(provider_config): """Test provider initialization.""" with patch("providers.openai_compat.AsyncOpenAI") as mock_openai: - from config.nim import NimSettings - provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) assert provider._api_key == "test_key" assert provider._base_url == "https://test.api.nvidia.com/v1" @@ -78,7 +78,6 @@ async def test_init(provider_config): @pytest.mark.asyncio async def test_init_uses_configurable_timeouts(): """Test that provider passes configurable read/write/connect timeouts to client.""" - from config.nim import NimSettings from providers.base import ProviderConfig config = ProviderConfig( @@ -100,8 +99,6 @@ async def test_init_uses_configurable_timeouts(): @pytest.mark.asyncio async def test_build_request_body(provider_config): """Test request body construction.""" - from config.nim import NimSettings - provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) req = MockRequest() body = provider._build_request_body(req) @@ -124,8 +121,6 @@ async def test_build_request_body(provider_config): async def test_build_request_body_omits_reasoning_when_globally_disabled( provider_config, ): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config.model_copy(update={"enable_thinking": False}), nim_settings=NimSettings(), @@ -142,8 +137,6 @@ async def test_build_request_body_omits_reasoning_when_globally_disabled( async def test_build_request_body_omits_reasoning_when_request_disables_thinking( provider_config, ): - from config.nim import NimSettings - provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) req = MockRequest() req.thinking.enabled = False @@ -241,8 +234,6 @@ async def test_stream_response_thinking_reasoning_content(nim_provider): @pytest.mark.asyncio async def test_stream_response_suppresses_thinking_when_disabled(provider_config): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config.model_copy(update={"enable_thinking": False}), nim_settings=NimSettings(), @@ -285,8 +276,6 @@ def _make_bad_request_error(message: str) -> openai.BadRequestError: @pytest.mark.asyncio async def test_stream_response_retries_without_chat_template(provider_config): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config, nim_settings=NimSettings(chat_template="custom_template"), @@ -344,8 +333,6 @@ async def test_stream_response_retries_without_chat_template(provider_config): @pytest.mark.asyncio async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config): - from config.nim import NimSettings - provider = NvidiaNimProvider( provider_config, nim_settings=NimSettings(chat_template="custom_template"), @@ -361,7 +348,7 @@ async def test_stream_response_does_not_retry_unrelated_bad_request(provider_con assert mock_create.await_count == 1 event_text = "".join(events) - assert "unrelated bad request" in event_text + assert "Invalid request sent to provider" in event_text assert "event: message_stop" in event_text @@ -457,5 +444,5 @@ async def test_stream_response_bad_request_without_reasoning_budget_does_not_ret events = [e async for e in nim_provider.stream_response(req)] assert mock_create.await_count == 1 - assert any("Unsupported field: top_k" in event for event in events) + assert any("Invalid request sent to provider" in event for event in events) assert any("message_stop" in event for event in events) diff --git a/tests/providers/test_ollama.py b/tests/providers/test_ollama.py index 1edeb3b..ffb90a7 100644 --- a/tests/providers/test_ollama.py +++ b/tests/providers/test_ollama.py @@ -6,7 +6,8 @@ import httpx import pytest from providers.base import ProviderConfig -from providers.ollama import OLLAMA_BASE_URL, OllamaProvider +from providers.ollama import OLLAMA_DEFAULT_BASE, OllamaProvider +from tests.stream_contract import assert_canonical_stream_error_envelope class MockMessage: @@ -92,7 +93,7 @@ def test_init_uses_default_base_url(): config = ProviderConfig(api_key="ollama", base_url=None) with patch("httpx.AsyncClient"): provider = OllamaProvider(config) - assert provider._base_url == OLLAMA_BASE_URL + assert provider._base_url == OLLAMA_DEFAULT_BASE def test_init_uses_configurable_timeouts(): @@ -199,6 +200,30 @@ async def test_build_request_body_omits_thinking_when_disabled(ollama_config): assert body["model"] == "llama3.1:8b" +def test_build_request_body_disabled_thinking_strips_assistant_thinking_blocks( + ollama_config, +): + """Prior assistant thinking/redacted blocks are removed when policy is off.""" + provider = OllamaProvider( + ollama_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest( + system=None, + messages=[ + MockMessage("user", "hi"), + MockMessage( + "assistant", + [ + {"type": "thinking", "thinking": "t"}, + {"type": "redacted_thinking", "data": "opaque"}, + ], + ), + ], + ) + body = provider._build_request_body(req, thinking_enabled=False) + assert body["messages"][1]["content"] == "" + + @pytest.mark.asyncio async def test_stream_error_status_code(ollama_provider): """Non-200 status code is yielded as an SSE API error.""" @@ -228,10 +253,10 @@ async def test_stream_error_status_code(ollama_provider): async for event in ollama_provider.stream_response(req, request_id="REQ") ] - assert len(events) == 1 - assert events[0].startswith("event: error\ndata: {") - assert "Internal Server Error" in events[0] - assert "REQ" in events[0] + assert_canonical_stream_error_envelope( + events, user_message_substr="Provider API request failed" + ) + assert "REQ" in "".join(events) @pytest.mark.asyncio diff --git a/tests/providers/test_open_router.py b/tests/providers/test_open_router.py index d88b2c6..7a6da14 100644 --- a/tests/providers/test_open_router.py +++ b/tests/providers/test_open_router.py @@ -6,6 +6,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, + text_content, + thinking_content, +) from providers.base import ProviderConfig from providers.open_router import OpenRouterProvider from providers.open_router.request import OPENROUTER_DEFAULT_MAX_TOKENS @@ -30,8 +36,6 @@ class MockRequest: self.tool_choice = None self.metadata = None self.extra_body = {} - self.original_model = "claude-3-sonnet" - self.resolved_provider_model = "open_router/stepfun/step-3.5-flash:free" self.thinking = MagicMock() self.thinking.enabled = True for k, v in kwargs.items(): @@ -135,8 +139,26 @@ def test_build_request_body_is_native_anthropic(open_router_provider): assert body["system"] == "System prompt" assert body["reasoning"] == {"enabled": True} assert "extra_body" not in body - assert "original_model" not in body - assert "resolved_provider_model" not in body + + +def test_openrouter_extra_body_rejects_overriding_reserved_fields() -> None: + from providers.exceptions import InvalidRequestError + from providers.open_router.request import build_request_body + + r = MockRequest() + r.extra_body = {"model": "hijack"} + with pytest.raises(InvalidRequestError, match="model"): + build_request_body(r, thinking_enabled=True) + + +def test_openrouter_extra_body_allows_openrouter_only_keys() -> None: + from providers.open_router.request import build_request_body + + r = MockRequest() + r.extra_body = {"transforms": ["no-web"], "plugins": []} + body = build_request_body(r, thinking_enabled=False) + assert body["transforms"] == ["no-web"] + assert body["plugins"] == [] def test_build_request_body_omits_reasoning_when_globally_disabled( @@ -188,7 +210,6 @@ def test_build_request_body_default_max_tokens(open_router_provider): body = open_router_provider._build_request_body(req) assert body["max_tokens"] == OPENROUTER_DEFAULT_MAX_TOKENS - assert body["max_tokens"] == 81920 def test_build_request_body_strips_unsigned_thinking_history(open_router_provider): @@ -209,7 +230,32 @@ def test_build_request_body_strips_unsigned_thinking_history(open_router_provide body = open_router_provider._build_request_body(req) - assert body["messages"][1]["content"] == [{"type": "text", "text": "Hello"}] + assert body["messages"][1]["content"] == [ + {"type": "redacted_thinking", "data": "opaque"}, + {"type": "text", "text": "Hello"}, + ] + + +def test_build_request_body_strips_redacted_when_thinking_disabled( + open_router_config, +): + """Disabled thinking must remove all assistant thinking history including redacted.""" + provider = OpenRouterProvider( + open_router_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest( + messages=[ + MockMessage( + "assistant", + [ + {"type": "redacted_thinking", "data": "opaque"}, + {"type": "text", "text": "Hi"}, + ], + ) + ] + ) + body = provider._build_request_body(req) + assert body["messages"][0]["content"] == [{"type": "text", "text": "Hi"}] def test_build_request_body_preserves_signed_thinking_history(open_router_provider): @@ -336,12 +382,12 @@ async def test_stream_response_suppresses_native_thinking_when_disabled( assert "Answer" in event_text text_start = next(event for event in events if "content_block_start" in event) - payload = json.loads(text_start.split("data: ", 1)[1]) + payload = parse_sse_text(text_start)[0].data assert payload["index"] == 0 @pytest.mark.asyncio -async def test_stream_response_drops_redacted_thinking_when_enabled( +async def test_stream_response_preserves_redacted_thinking_when_enabled( open_router_provider, ): response = FakeResponse( @@ -378,16 +424,216 @@ async def test_stream_response_drops_redacted_thinking_when_enabled( ): events = [e async for e in open_router_provider.stream_response(MockRequest())] + event_text = "".join(events) + assert "redacted_thinking" in event_text + assert "opaque" in event_text + assert "Answer" in event_text + + parsed = parse_sse_text(event_text) + first_start = next( + p + for p in parsed + if p.event == "content_block_start" + and p.data.get("content_block", {}).get("type") == "redacted_thinking" + ) + assert first_start.data["index"] == 0 + + +@pytest.mark.asyncio +async def test_stream_response_drops_redacted_thinking_when_disabled( + open_router_config, +): + provider = OpenRouterProvider( + open_router_config.model_copy(update={"enable_thinking": False}) + ) + response = FakeResponse( + lines=[ + "event: content_block_start", + 'data: {"type":"content_block_start","index":0,"content_block":{"type":"redacted_thinking","data":"opaque"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":0}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"Answer"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":1}', + "", + ] + ) + + with ( + patch.object(provider._client, "build_request"), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in provider.stream_response(MockRequest())] + event_text = "".join(events) assert "redacted_thinking" not in event_text + assert "opaque" not in event_text assert "Answer" in event_text start_event = next(event for event in events if "content_block_start" in event) - payload = json.loads(start_event.split("data: ", 1)[1]) + payload = parse_sse_text(start_event)[0].data assert payload["index"] == 0 assert payload["content_block"]["type"] == "text" +@pytest.mark.asyncio +async def test_stream_response_reopens_interleaved_thinking_after_text( + open_router_provider, +): + """Overthinking+text+more thinking: downstream indices must not reuse closed blocks.""" + response = FakeResponse( + lines=[ + "event: message_start", + 'data: {"type":"message_start","message":{}}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":0,' + '"content_block":{"type":"thinking","thinking":"","signature":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":0,' + '"delta":{"type":"thinking_delta","thinking":"first"}}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":1,' + '"content_block":{"type":"text","text":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":0,' + '"delta":{"type":"thinking_delta","thinking":" second"}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":1,' + '"delta":{"type":"text_delta","text":"Answer"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":1}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":0}', + "", + "event: message_stop", + 'data: {"type":"message_stop"}', + "", + ] + ) + + with ( + patch.object(open_router_provider._client, "build_request"), + patch.object( + open_router_provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in open_router_provider.stream_response(MockRequest())] + + parsed = parse_sse_text("".join(events)) + assert_anthropic_stream_contract(parsed) + assert thinking_content(parsed) == "first second" + assert "Answer" in text_content(parsed) + stop_payloads = [ + p.data + for p in parsed + if p.event == "content_block_stop" + and p.data.get("type") == "content_block_stop" + ] + seen_stop_indices: set[int] = set() + for s in stop_payloads: + idx = s.get("index") + assert isinstance(idx, int) + assert idx not in seen_stop_indices, "stop reused or duplicated index" + seen_stop_indices.add(idx) + # Two distinct thinking block indices: initial + reopened segment + think_starts = [ + p + for p in parsed + if p.event == "content_block_start" + and p.data.get("content_block", {}).get("type") == "thinking" + ] + assert len(think_starts) == 2, ( + "reopened thinking must have its own `content_block_start`" + ) + + +@pytest.mark.asyncio +async def test_stream_response_reopened_tool_use_preserves_tool_identity( + open_router_provider, +): + """After overlapping close, resumed input_json_delta must keep original tool id/name.""" + lines: list[str] = [] + for payload in ( + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "tool_use", + "id": "toolu_real_1", + "name": "Read", + "input": {}, + }, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "input_json_delta", "partial_json": '{"path'}, + }, + { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "text", "text": ""}, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "input_json_delta", "partial_json": '":"/tmp"}'}, + }, + {"type": "content_block_stop", "index": 1}, + {"type": "content_block_stop", "index": 0}, + ): + event_name = payload["type"] + lines.extend((f"event: {event_name}", f"data: {json.dumps(payload)}", "")) + + response = FakeResponse(lines=lines) + + with ( + patch.object(open_router_provider._client, "build_request"), + patch.object( + open_router_provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + ): + events = [e async for e in open_router_provider.stream_response(MockRequest())] + + parsed = parse_sse_text("".join(events)) + tool_starts = [ + p + for p in parsed + if p.event == "content_block_start" + and p.data.get("content_block", {}).get("type") == "tool_use" + ] + assert len(tool_starts) == 2 + for start in tool_starts: + block = start.data["content_block"] + assert block["id"] == "toolu_real_1" + assert block["name"] == "Read" + + @pytest.mark.asyncio async def test_stream_response_closes_overlapping_thinking_before_text( open_router_provider, diff --git a/tests/providers/test_provider_rate_limit.py b/tests/providers/test_provider_rate_limit.py index 7cd00d9..7d59280 100644 --- a/tests/providers/test_provider_rate_limit.py +++ b/tests/providers/test_provider_rate_limit.py @@ -27,11 +27,11 @@ class TestProviderRateLimiter: GlobalRateLimiter.reset_instance() limiter = GlobalRateLimiter.get_instance(rate_limit=1, rate_window=0.25) - start_time = time.time() + start_time = time.monotonic() async def call_limiter(): await limiter.wait_if_blocked() - return time.time() + return time.monotonic() # 5 requests. # R0 -> 0s @@ -41,7 +41,7 @@ class TestProviderRateLimiter: # R4 -> 1.00s results = [await call_limiter() for _ in range(5)] - total_time = time.time() - start_time + total_time = time.monotonic() - start_time assert len(results) == 5 # Should take at least ~1.0s @@ -56,7 +56,7 @@ class TestProviderRateLimiter: GlobalRateLimiter.reset_instance() limiter = GlobalRateLimiter.get_instance() - start_time = time.time() + start_time = time.monotonic() # Manually block for 1.5s block_time = 1.5 @@ -71,7 +71,7 @@ class TestProviderRateLimiter: # They should both wait for the block time results = await asyncio.gather(call_limiter(), call_limiter()) - total_time = time.time() - start_time + total_time = time.monotonic() - start_time # Both should report having waited reactively assert all(results) is True @@ -121,10 +121,10 @@ class TestProviderRateLimiter: GlobalRateLimiter.reset_instance() limiter = GlobalRateLimiter.get_instance(rate_limit=10000, rate_window=60) - start = time.time() + start = time.monotonic() for _ in range(20): await limiter.wait_if_blocked() - duration = time.time() - start + duration = time.monotonic() - start # 20 requests with 10000 limit should be near-instant assert duration < 1.0, f"High rate limit caused throttling: {duration:.2f}s" @@ -252,6 +252,32 @@ class TestProviderRateLimiter: assert result == "ok" assert call_count == 2 + @pytest.mark.asyncio + async def test_execute_with_retry_succeeds_on_httpx_429(self): + """HTTP 429 as httpx.HTTPStatusError then success returns result.""" + import httpx + from httpx import Request, Response + + limiter = GlobalRateLimiter.get_instance(rate_limit=100, rate_window=60) + + call_count = 0 + + async def fail_then_ok(): + nonlocal call_count + call_count += 1 + if call_count == 1: + r = Response(429, request=Request("POST", "http://x"), text="slow") + raise httpx.HTTPStatusError( + "Too Many Requests", request=r.request, response=r + ) + return "ok" + + result = await limiter.execute_with_retry( + fail_then_ok, max_retries=2, base_delay=0.01, max_delay=0.1, jitter=0 + ) + assert result == "ok" + assert call_count == 2 + @pytest.mark.asyncio async def test_max_concurrency_zero_raises(self): """max_concurrency <= 0 raises ValueError.""" diff --git a/tests/providers/test_provider_transport_logging.py b/tests/providers/test_provider_transport_logging.py new file mode 100644 index 0000000..5076930 --- /dev/null +++ b/tests/providers/test_provider_transport_logging.py @@ -0,0 +1,263 @@ +"""Tests for metadata-only provider transport logging by default.""" + +import logging +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from config.constants import NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES +from config.nim import NimSettings +from providers.anthropic_messages import AnthropicMessagesTransport +from providers.base import ProviderConfig +from providers.nvidia_nim import NvidiaNimProvider +from tests.provider_request_mocks import make_openai_compat_stream_request +from tests.providers.test_anthropic_messages import ( + FakeResponse, + MockRequest, + NativeProvider, +) + + +@pytest.fixture +def provider_config(): + return ProviderConfig( + api_key="test-key", + base_url="https://custom.test/v1/", + proxy="socks5://127.0.0.1:9999", + rate_limit=10, + rate_window=60, + http_read_timeout=600.0, + http_write_timeout=15.0, + http_connect_timeout=5.0, + ) + + +@pytest.fixture(autouse=True) +def mock_rate_limiter(): + @asynccontextmanager + async def _slot(): + yield + + with patch("providers.anthropic_messages.GlobalRateLimiter") as mock: + instance = mock.get_scoped_instance.return_value + + async def _passthrough(fn, *args, **kwargs): + return await fn(*args, **kwargs) + + instance.execute_with_retry = AsyncMock(side_effect=_passthrough) + instance.concurrency_slot.side_effect = _slot + yield instance + + +@pytest.mark.asyncio +async def test_native_non_200_logs_exclude_body_text_by_default( + caplog, provider_config +): + provider = NativeProvider(provider_config) + req = MockRequest() + response = FakeResponse(status_code=500, text="SECRET_UPSTREAM_BODY") + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_UPSTREAM_BODY" not in messages + assert "HTTP 500" in messages + assert "body_preview_bytes=" not in messages + + +@pytest.mark.asyncio +async def test_native_non_200_logs_body_when_verbose(caplog, provider_config): + provider_config.log_api_error_tracebacks = True + provider = NativeProvider(provider_config) + req = MockRequest() + response = FakeResponse(status_code=500, text="SECRET_UPSTREAM_BODY") + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_UPSTREAM_BODY" in messages + assert "truncated=False" in messages + + +@pytest.mark.asyncio +async def test_native_non_200_verbose_logs_only_capped_error_body( + caplog, provider_config +): + provider_config.log_api_error_tracebacks = True + provider = NativeProvider(provider_config) + req = MockRequest() + tail = "SECRET_TAIL_NOT_LOGGED" + huge = f"{'A' * (NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES + 50)}{tail}" + response = FakeResponse(status_code=500, text=huge) + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_TAIL_NOT_LOGGED" not in messages + assert "truncated=True" in messages + assert f"body_preview_bytes={NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES}" in messages + + +@pytest.mark.asyncio +async def test_native_non_200_default_does_not_read_oversized_body( + caplog, provider_config +): + provider = NativeProvider(provider_config) + req = MockRequest() + huge = f"{'Z' * 500_000}LEAK_MARKER" + response = FakeResponse(status_code=500, text=huge) + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "LEAK_MARKER" not in messages + assert "ZZZ" not in messages + assert "HTTP 500" in messages + + +@pytest.mark.asyncio +async def test_native_stream_failure_logs_exclude_exception_str_by_default( + caplog, provider_config +): + provider = NativeProvider(provider_config) + req = MockRequest() + response = FakeResponse( + lines=[ + "event: ping", + 'data: {"type":"ping"}', + "", + ] + ) + + async def boom(_self, _response): + raise RuntimeError("SECRET_DETAIL") + if False: + yield "" + + with ( + patch.object(provider._client, "build_request", return_value=MagicMock()), + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=response, + ), + patch.object(AnthropicMessagesTransport, "_iter_sse_events", boom), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_DETAIL" not in messages + assert "exc_type=RuntimeError" in messages + assert "http_status=None" in messages + + +@pytest.mark.asyncio +async def test_openai_compat_stream_failure_default_logs_exclude_exception_str(caplog): + config = ProviderConfig( + api_key="k", + base_url="http://localhost:1/v1", + log_api_error_tracebacks=False, + ) + provider = NvidiaNimProvider(config, nim_settings=NimSettings()) + req = make_openai_compat_stream_request() + + @asynccontextmanager + async def _noop_slot(): + yield + + with ( + patch.object( + provider, + "_create_stream", + new_callable=AsyncMock, + side_effect=RuntimeError("SECRET_OPENAI_COMPAT"), + ), + patch.object( + provider._global_rate_limiter, + "concurrency_slot", + _noop_slot, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_OPENAI_COMPAT" not in messages + assert "exc_type=RuntimeError" in messages + + +@pytest.mark.asyncio +async def test_openai_compat_stream_failure_respects_verbose_flag(caplog): + config = ProviderConfig( + api_key="k", + base_url="http://localhost:1/v1", + log_api_error_tracebacks=True, + ) + provider = NvidiaNimProvider(config, nim_settings=NimSettings()) + req = make_openai_compat_stream_request() + + @asynccontextmanager + async def _noop_slot(): + yield + + with ( + patch.object( + provider, + "_create_stream", + new_callable=AsyncMock, + side_effect=RuntimeError("SECRET_OPENAI_COMPAT"), + ), + patch.object( + provider._global_rate_limiter, + "concurrency_slot", + _noop_slot, + ), + caplog.at_level(logging.ERROR), + ): + _ = [e async for e in provider.stream_response(req)] + + messages = " | ".join(r.getMessage() for r in caplog.records) + assert "SECRET_OPENAI_COMPAT" in messages diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index b00d780..ac8cbba 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -39,7 +39,7 @@ def _make_settings(**overrides): mock.provider_max_concurrency = 5 mock.http_read_timeout = 300.0 mock.http_write_timeout = 10.0 - mock.http_connect_timeout = 2.0 + mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True mock.nim = NimSettings() for key, value in overrides.items(): diff --git a/tests/providers/test_sse_builder.py b/tests/providers/test_sse_builder.py index 07df2f3..01426e5 100644 --- a/tests/providers/test_sse_builder.py +++ b/tests/providers/test_sse_builder.py @@ -1,19 +1,20 @@ """Tests for core.anthropic.sse.""" -import json from unittest.mock import patch import pytest from core.anthropic import ContentBlockManager, SSEBuilder, map_stop_reason +from core.anthropic.sse import ToolCallState +from core.anthropic.stream_contracts import parse_sse_text def _parse_sse(sse_str: str) -> dict: """Parse an SSE event string into its data payload.""" - for line in sse_str.strip().split("\n"): - if line.startswith("data: "): - return json.loads(line[len("data: ") :]) - raise ValueError(f"No data line found in SSE: {sse_str}") + events = parse_sse_text(sse_str) + if len(events) != 1: + raise ValueError(f"expected 1 SSE event, got {len(events)} in {sse_str!r}") + return events[0].data class TestMapStopReason: @@ -61,6 +62,24 @@ class TestContentBlockManager: assert mgr.text_started is False assert mgr.tool_states == {} + def test_flush_task_arg_buffers_logs_digest_not_secret(self, caplog): + """Invalid Task JSON warnings must not echo argument prefixes (secrets).""" + mgr = ContentBlockManager() + mgr.tool_states[0] = ToolCallState( + block_index=0, tool_id="call_x", name="Task", started=True + ) + mgr.tool_states[ + 0 + ].task_arg_buffer = ( + '{"api_key": "sk-live-super-secret-do-not-log"}not_valid_json' + ) + with caplog.at_level("WARNING"): + out = mgr.flush_task_arg_buffers() + assert out == [(0, "{}")] + text = " | ".join(r.message for r in caplog.records) + assert "sk-live-super-secret" not in text + assert "buffer_sha256_prefix=" in text + class TestSSEBuilderMessageLifecycle: """Tests for message_start, message_delta, message_stop.""" diff --git a/tests/providers/test_streaming_errors.py b/tests/providers/test_streaming_errors.py index 0b5f614..085212b 100644 --- a/tests/providers/test_streaming_errors.py +++ b/tests/providers/test_streaming_errors.py @@ -1,14 +1,20 @@ """Tests for streaming error handling in providers/nvidia_nim/client.py.""" import json +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from config.nim import NimSettings +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, +) from providers.base import ProviderConfig from providers.nvidia_nim import NvidiaNimProvider +from tests.provider_request_mocks import make_openai_compat_stream_request class AsyncStreamMock: @@ -53,22 +59,7 @@ def _make_provider_with_thinking_enabled(enabled: bool): def _make_request(model="test-model", stream=True): """Create a mock request with all fields build_request_body needs.""" - req = MagicMock() - req.model = model - req.stream = stream - req.messages = [] - req.system = None - req.tools = None - req.tool_choice = None - req.metadata = None - req.max_tokens = 4096 - req.temperature = None - req.top_p = None - req.top_k = None - req.stop_sequences = None - req.extra_body = None - req.thinking = None - return req + return make_openai_compat_stream_request(model=model, stream=stream) def _make_chunk( @@ -95,6 +86,29 @@ async def _collect_stream(provider, request): return [e async for e in provider.stream_response(request)] +def _assert_no_content_deltas_after_error_text( + events: list[str], error_substr: str +) -> None: + """After the error text delta, only block close + message tail events may follow.""" + parsed = parse_sse_text("".join(events)) + first_error_idx = None + for i, ev in enumerate(parsed): + if ev.event != "content_block_delta": + continue + delta = ev.data.get("delta", {}) + if delta.get("type") == "text_delta" and error_substr in str( + delta.get("text", "") + ): + first_error_idx = i + break + assert first_error_idx is not None, (error_substr, "".join(events)) + for ev in parsed[first_error_idx + 1 :]: + assert ev.event in ("content_block_stop", "message_delta", "message_stop"), ( + ev.event, + ev.data, + ) + + class TestStreamingExceptionHandling: """Tests for error paths during stream_response.""" @@ -128,6 +142,7 @@ class TestStreamingExceptionHandling: assert "message_start" in event_text assert "API failed" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "API failed") @pytest.mark.asyncio async def test_read_timeout_with_empty_message_emits_fallback(self): @@ -161,6 +176,7 @@ class TestStreamingExceptionHandling: assert "timed out after" in event_text assert "request_id=req_timeout123" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "timed out after") @pytest.mark.asyncio async def test_error_after_partial_content(self): @@ -191,6 +207,7 @@ class TestStreamingExceptionHandling: assert "Hello" in event_text assert "Connection lost" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "Connection lost") @pytest.mark.asyncio async def test_empty_response_gets_space(self): @@ -353,6 +370,10 @@ class TestStreamingExceptionHandling: in event_text ) assert "request_id=REQ405" in event_text + _assert_no_content_deltas_after_error_text( + events, + "Upstream provider NIM rejected the request method or endpoint (HTTP 405).", + ) @pytest.mark.asyncio async def test_stream_rate_limited_retries_via_execute_with_retry(self): @@ -406,6 +427,59 @@ class TestProcessToolCall: assert "search" in event_text assert "call_123" in event_text + def test_tool_call_id_arrives_before_name_still_emits_id_and_name(self): + """Split-stream tool: id (no name) then name then args; id preserved on start.""" + provider = _make_provider() + from core.anthropic import SSEBuilder + + sse = SSEBuilder("msg_test", "test-model") + t1 = { + "index": 0, + "id": "call_split", + "function": {"name": None, "arguments": ""}, + } + t2 = { + "index": 0, + "id": "call_split", + "function": {"name": "Grep", "arguments": ""}, + } + t3 = { + "index": 0, + "id": "call_split", + "function": {"name": None, "arguments": "{}"}, + } + b1 = "".join(provider._process_tool_call(t1, sse)) + b2 = "".join(provider._process_tool_call(t2, sse)) + b3 = "".join(provider._process_tool_call(t3, sse)) + combined = b1 + b2 + b3 + assert "call_split" in combined + assert "Grep" in combined + assert b1 == "" + + def test_tool_call_arguments_buffered_until_name(self): + """Argument deltas before tool name are emitted after the block starts.""" + provider = _make_provider() + from core.anthropic import SSEBuilder + + sse = SSEBuilder("msg_test", "test-model") + t1 = { + "index": 0, + "id": "call_buf", + "function": {"name": None, "arguments": '{"x":'}, + } + t2 = { + "index": 0, + "id": "call_buf", + "function": {"name": "Read", "arguments": "1}"}, + } + b1 = "".join(provider._process_tool_call(t1, sse)) + b2 = "".join(provider._process_tool_call(t2, sse)) + assert b1 == "" + combined = b2 + assert "Read" in combined + assert "call_buf" in combined + assert '{"x":' in combined or "partial_json" in combined + def test_tool_call_without_id_generates_uuid(self): """Tool call without id generates a uuid-based id.""" provider = _make_provider() @@ -617,6 +691,7 @@ class TestStreamChunkEdgeCases: assert "Partial" in event_text assert "Connection reset" in event_text assert "message_stop" in event_text + _assert_no_content_deltas_after_error_text(events, "Connection reset") def test_stream_malformed_tool_args_chunked(self): """Chunked tool args that never form valid JSON are flushed with {}.""" @@ -642,3 +717,36 @@ class TestStreamChunkEdgeCases: event_text = "".join(events1 + events2 + flushed) assert "tool_use" in event_text assert "{}" in event_text + + +@pytest.mark.asyncio +async def test_openai_compat_stream_ends_with_contract_when_tool_name_never_arrives() -> ( + None +): + """Nameless / incomplete tool-call buffer must not break Anthropic stream contract.""" + provider = _make_provider() + request = _make_request() + tc0 = SimpleNamespace( + index=0, + id="call_inc", + function=SimpleNamespace(name=None, arguments="{}"), + ) + stream_mock = AsyncStreamMock([_make_chunk(tool_calls=[tc0])]) + with ( + patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=stream_mock, + ), + patch.object( + provider._global_rate_limiter, + "wait_if_blocked", + new_callable=AsyncMock, + return_value=False, + ), + ): + events = await _collect_stream(provider, request) + text = "".join(events) + assert_anthropic_stream_contract(parse_sse_text(text)) + assert "text_delta" in text diff --git a/tests/stream_contract.py b/tests/stream_contract.py new file mode 100644 index 0000000..9fdc658 --- /dev/null +++ b/tests/stream_contract.py @@ -0,0 +1,18 @@ +"""Shared assertions for canonical provider streaming error envelopes.""" + +from core.anthropic.stream_contracts import ( + assert_anthropic_stream_contract, + parse_sse_text, + text_content, +) + + +def assert_canonical_stream_error_envelope( + events: list[str], *, user_message_substr: str +) -> None: + """Native transports emit message_start → text error → message_stop.""" + blob = "".join(events) + assert "event: error\ndata:" not in blob + parsed = parse_sse_text(blob) + assert_anthropic_stream_contract(parsed, allow_error=False) + assert user_message_substr in text_content(parsed)