From b926f60f6476a378e5b055d330b4993d091d9ac1 Mon Sep 17 00:00:00 2001 From: Alishahryar1 Date: Fri, 24 Apr 2026 23:01:14 -0700 Subject: [PATCH] feat: Anthropic web server tools, provider metadata, messaging hardening - Add local web_search/web_fetch SSE handling and optional tool schemas - Extend HeuristicToolParser for JSON-style WebFetch/WebSearch text - Consolidate provider defaults, ids, and exception typing; stream contracts - Messaging: typed options, voice config injection, platform contract cleanup - Tests for web server tools, converters, parsers, contracts; ignore debug-*.log --- .env.example | 1 + .gitignore | 1 + PLAN.md | 27 +- api/__init__.py | 3 - api/app.py | 62 +++- api/dependencies.py | 58 +++- api/models/anthropic.py | 5 +- api/routes.py | 8 +- api/runtime.py | 98 +++--- api/services.py | 20 ++ api/web_server_tools.py | 331 +++++++++++++++++++++ config/provider_ids.py | 17 ++ config/settings.py | 18 +- core/anthropic/conversion.py | 13 +- core/anthropic/stream_contracts.py | 158 ++++++++++ core/anthropic/tools.py | 41 ++- messaging/commands.py | 26 +- messaging/handler.py | 2 +- messaging/platforms/base.py | 28 +- messaging/platforms/discord.py | 23 +- messaging/platforms/factory.py | 46 ++- messaging/platforms/telegram.py | 23 +- messaging/transcription.py | 45 ++- messaging/voice.py | 11 + providers/__init__.py | 23 +- providers/deepseek/client.py | 3 +- providers/defaults.py | 19 ++ providers/exceptions.py | 7 + providers/llamacpp/client.py | 3 +- providers/lmstudio/client.py | 3 +- providers/nvidia_nim/client.py | 3 +- providers/open_router/client.py | 2 +- providers/openai_compat.py | 14 +- providers/registry.py | 103 +++++-- smoke/lib/e2e.py | 3 +- smoke/lib/sse.py | 169 ++--------- tests/api/test_api.py | 39 +-- tests/api/test_app_lifespan_and_errors.py | 65 ++-- tests/api/test_dependencies.py | 63 +++- tests/api/test_models_validators.py | 15 + tests/api/test_web_server_tools.py | 96 ++++++ tests/conftest.py | 9 + tests/contracts/test_import_boundaries.py | 152 +++++++++- tests/contracts/test_smoke_sse_reexport.py | 11 + tests/contracts/test_stream_contracts.py | 11 +- tests/messaging/test_messaging_factory.py | 59 +++- tests/messaging/test_voice_handlers.py | 41 +-- tests/providers/test_converter.py | 19 +- tests/providers/test_parsers.py | 34 +++ tests/providers/test_registry.py | 66 +++- 50 files changed, 1658 insertions(+), 439 deletions(-) create mode 100644 api/web_server_tools.py create mode 100644 config/provider_ids.py create mode 100644 core/anthropic/stream_contracts.py create mode 100644 providers/defaults.py create mode 100644 tests/api/test_web_server_tools.py create mode 100644 tests/contracts/test_smoke_sse_reexport.py diff --git a/.env.example b/.env.example index 58d5125..ec442ac 100644 --- a/.env.example +++ b/.env.example @@ -66,6 +66,7 @@ VOICE_NOTE_ENABLED=false # WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim" # - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local) # - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice) +# (Independent of MODEL=nvidia_nim/...: that selects the *chat* provider; this selects voice STT only.) WHISPER_DEVICE="nvidia_nim" # WHISPER_MODEL: # - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo) diff --git a/.gitignore b/.gitignore index 8ec5364..5400e2f 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ __pycache__ agent_workspace .env server.log +debug-*.log .coverage llama_cache .smoke-results diff --git a/PLAN.md b/PLAN.md index d66f2c7..9f4607b 100644 --- a/PLAN.md +++ b/PLAN.md @@ -33,20 +33,43 @@ flowchart TD core --> providers core --> messaging providers --> api + api --> cli[cli] + api --> messaging cli --> messaging - messaging --> api ``` +Runtime note: `api.runtime` imports `cli` and `messaging` to wire the optional +messaging stack; `messaging` does not import `cli` (session/CLI access is passed +in from `api.runtime`). + The practical rule is simpler than the graph: shared protocol helpers belong in neutral core modules, not under a provider package. Provider adapters may depend on the neutral protocol layer, but API and messaging code should not import provider internals. +The diagram above mixes **Python import direction** (e.g. `config` → `providers`) +with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging`). +`PLAN.md` remains the product map; **encoded** rules (including root imports like +`import api`, relative imports, and `api` → `providers` facade allowlists) live in +`tests/contracts/test_import_boundaries.py`. + +**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`, +and `providers.registry` from the providers package (not per-adapter modules). +`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`. +`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract +assertions for default CI live under `core/anthropic/stream_contracts.py`; +`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers +(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and +unit tests; production HTTP handlers must use `resolve_provider` with +`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package +`__all__` exposes HTTP models and `create_app` only (not those helpers). + ## Target Boundaries - `core/anthropic/`: Anthropic protocol helpers, stream primitives, content extraction, token estimation, user-facing error strings, request conversion, - thinking, and tool helpers shared across API, providers, messaging, and tests. + thinking, tool helpers, and stream contract assertions + (`stream_contracts.py`) shared across API, providers, messaging, and tests. - `api/runtime.py`: application composition, optional messaging startup, session store restoration, and cleanup ownership. - `providers/`: provider descriptors, credential resolution, transport diff --git a/api/__init__.py b/api/__init__.py index 459d019..03094ba 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,7 +1,6 @@ """API layer for Claude Code Proxy.""" from .app import app, create_app -from .dependencies import get_provider, get_provider_for_type from .models import ( MessagesRequest, MessagesResponse, @@ -16,6 +15,4 @@ __all__ = [ "TokenCountResponse", "app", "create_app", - "get_provider", - "get_provider_for_type", ] diff --git a/api/app.py b/api/app.py index fc46f6d..b47deb3 100644 --- a/api/app.py +++ b/api/app.py @@ -2,8 +2,11 @@ import os from contextlib import asynccontextmanager +from typing import Any from fastapi import FastAPI, Request +from fastapi.exception_handlers import request_validation_exception_handler +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from loguru import logger @@ -11,7 +14,6 @@ from config.logging_config import configure_logging from config.settings import get_settings from providers.exceptions import ProviderError -from .dependencies import cleanup_provider from .routes import router from .runtime import AppRuntime @@ -26,9 +28,7 @@ configure_logging(_settings.log_file) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" - runtime = AppRuntime.for_app( - app, settings=get_settings(), provider_cleanup=cleanup_provider - ) + runtime = AppRuntime.for_app(app, settings=get_settings()) await runtime.startup() yield @@ -48,6 +48,60 @@ def create_app() -> FastAPI: app.include_router(router) # Exception handlers + @app.exception_handler(RequestValidationError) + async def validation_error_handler(request: Request, exc: RequestValidationError): + """Log request shape for 422 debugging without content values.""" + body: Any + try: + body = await request.json() + except Exception as e: + body = {"_json_error": type(e).__name__} + + messages = body.get("messages") if isinstance(body, dict) else None + message_summary: list[dict[str, Any]] = [] + if isinstance(messages, list): + for msg in messages: + if not isinstance(msg, dict): + message_summary.append({"message_kind": type(msg).__name__}) + continue + content = msg.get("content") + item: dict[str, Any] = { + "role": msg.get("role"), + "content_kind": type(content).__name__, + } + if isinstance(content, list): + item["block_types"] = [ + block.get("type", "dict") + if isinstance(block, dict) + else type(block).__name__ + for block in content[:12] + ] + item["block_keys"] = [ + sorted(str(key) for key in block)[:12] + for block in content[:5] + if isinstance(block, dict) + ] + elif isinstance(content, str): + item["content_length"] = len(content) + message_summary.append(item) + + logger.debug( + "Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}", + request.url.path, + str(request.url.query), + [list(error.get("loc", ())) for error in exc.errors()], + [str(error.get("type", "")) for error in exc.errors()], + message_summary, + [ + str(tool.get("name", "")) + for tool in body.get("tools", []) + if isinstance(body, dict) + and isinstance(body.get("tools"), list) + and isinstance(tool, dict) + ], + ) + return await request_validation_exception_handler(request, exc) + @app.exception_handler(ProviderError) async def provider_error_handler(request: Request, exc: ProviderError): """Handle provider-specific errors and return Anthropic format.""" diff --git a/api/dependencies.py b/api/dependencies.py index ede7588..3527eb8 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -2,15 +2,18 @@ from fastapi import Depends, HTTPException, Request from loguru import logger +from starlette.applications import Starlette from config.settings import Settings from config.settings import get_settings as _get_settings from core.anthropic import get_user_facing_error_message from providers.base import BaseProvider -from providers.exceptions import AuthenticationError +from providers.exceptions import AuthenticationError, UnknownProviderTypeError from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry -# Provider registry: keyed by provider type string, lazily populated +# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider` +# when there is no ``Request``/``app`` (unit tests, scripts). HTTP handlers must pass +# ``app`` to :func:`resolve_provider` so the app-scoped registry is used. _providers: dict[str, BaseProvider] = {} @@ -19,19 +22,43 @@ def get_settings() -> Settings: return _get_settings() -def get_provider_for_type(provider_type: str) -> BaseProvider: - """Get or create a provider for the given provider type. +def resolve_provider( + provider_type: str, + *, + app: Starlette | None, + settings: Settings, +) -> BaseProvider: + """Resolve a provider using the app-scoped registry when ``app`` is set. - Providers are cached in the registry and reused across requests. + When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry` + is always used. If the registry is missing (e.g. a test app without + :class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry` + is installed on ``app.state`` so the process cache is never mixed with + per-request app identity. + + When ``app`` is ``None`` (no HTTP context), uses the process-level + :data:`_providers` cache only. """ - should_log_init = provider_type not in _providers + if app is not None: + reg = getattr(app.state, "provider_registry", None) + if reg is None: + reg = ProviderRegistry() + app.state.provider_registry = reg + return _resolve_with_registry(reg, provider_type, settings) + return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings) + + +def _resolve_with_registry( + registry: ProviderRegistry, provider_type: str, settings: Settings +) -> BaseProvider: + should_log_init = not registry.is_cached(provider_type) try: - provider = ProviderRegistry(_providers).get(provider_type, get_settings()) + provider = registry.get(provider_type, settings) except AuthenticationError as e: raise HTTPException( status_code=503, detail=get_user_facing_error_message(e) ) from e - except ValueError: + except UnknownProviderTypeError: logger.error( "Unknown provider_type: '{}'. Supported: {}", provider_type, @@ -43,6 +70,15 @@ def get_provider_for_type(provider_type: str) -> BaseProvider: return provider +def get_provider_for_type(provider_type: str) -> BaseProvider: + """Get or create a provider in the process-level cache (no ``app``/Request). + + For server requests, use :func:`resolve_provider` with the active + :attr:`request.app` so the app-scoped provider registry is used. + """ + return resolve_provider(provider_type, app=None, settings=get_settings()) + + def require_api_key( request: Request, settings: Settings = Depends(get_settings) ) -> None: @@ -78,9 +114,11 @@ def require_api_key( def get_provider() -> BaseProvider: - """Get or create the default provider (based on MODEL env var). + """Get or create the default provider (``MODEL`` / ``provider_type``). - Backward-compatible convenience for health/root endpoints and tests. + Process-cache helper for scripts, unit tests, and non-FastAPI callers. HTTP + handlers must use :func:`resolve_provider` with :attr:`request.app` so the + app-scoped :class:`~providers.registry.ProviderRegistry` is used. """ return get_provider_for_type(get_settings().provider_type) diff --git a/api/models/anthropic.py b/api/models/anthropic.py index a8bbef3..12b6533 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -75,8 +75,11 @@ class Message(BaseModel): class Tool(BaseModel): name: str + # Anthropic server tools (e.g. web_search beta tools) include a ``type`` and + # may omit ``input_schema`` because the provider owns the schema. + type: str | None = None description: str | None = None - input_schema: dict[str, Any] + input_schema: dict[str, Any] | None = None class ThinkingConfig(BaseModel): diff --git a/api/routes.py b/api/routes.py index 4dfc982..39e9837 100644 --- a/api/routes.py +++ b/api/routes.py @@ -6,7 +6,8 @@ from loguru import logger from config.settings import Settings from core.anthropic import get_token_count -from .dependencies import get_provider_for_type, get_settings, require_api_key +from . import dependencies +from .dependencies import get_settings, require_api_key from .models.anthropic import MessagesRequest, TokenCountRequest from .models.responses import ModelResponse, ModelsListResponse from .services import ClaudeProxyService @@ -54,12 +55,15 @@ SUPPORTED_CLAUDE_MODELS = [ def get_proxy_service( + request: Request, settings: Settings = Depends(get_settings), ) -> ClaudeProxyService: """Build the request service for route handlers.""" return ClaudeProxyService( settings, - provider_getter=get_provider_for_type, + provider_getter=lambda provider_type: dependencies.resolve_provider( + provider_type, app=request.app, settings=settings + ), token_counter=get_token_count, ) diff --git a/api/runtime.py b/api/runtime.py index bb6d8c8..56cfe07 100644 --- a/api/runtime.py +++ b/api/runtime.py @@ -4,16 +4,20 @@ from __future__ import annotations import asyncio import os -from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any from fastapi import FastAPI from loguru import logger from config.settings import Settings, get_settings +from providers.registry import ProviderRegistry -from .dependencies import cleanup_provider +if TYPE_CHECKING: + from cli.manager import CLISessionManager + from messaging.handler import ClaudeMessageHandler + from messaging.platforms.base import MessagingPlatform + from messaging.session import SessionStore _SHUTDOWN_TIMEOUT_S = 5.0 @@ -32,8 +36,7 @@ async def best_effort( def warn_if_process_auth_token(settings: Settings) -> None: """Warn when server auth was implicitly inherited from the shell.""" - uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None) - if callable(uses_process_token) and uses_process_token(): + if settings.uses_process_anthropic_auth_token(): logger.warning( "ANTHROPIC_AUTH_TOKEN is set in the process environment but not in " "a configured .env file. The proxy will require that token. Add " @@ -48,32 +51,29 @@ class AppRuntime: app: FastAPI settings: Settings - provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider - messaging_platform: Any = None - message_handler: Any = None - cli_manager: Any = None + _provider_registry: ProviderRegistry | None = field(default=None, init=False) + messaging_platform: MessagingPlatform | None = None + message_handler: ClaudeMessageHandler | None = None + cli_manager: CLISessionManager | None = None @classmethod def for_app( cls, app: FastAPI, settings: Settings | None = None, - provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider, ) -> AppRuntime: - return cls( - app=app, - settings=settings or get_settings(), - provider_cleanup=provider_cleanup, - ) + return cls(app=app, settings=settings or get_settings()) async def startup(self) -> None: logger.info("Starting Claude Code Proxy...") + self._provider_registry = ProviderRegistry() + self.app.state.provider_registry = self._provider_registry warn_if_process_auth_token(self.settings) await self._start_messaging_if_configured() self._publish_state() async def shutdown(self) -> None: - if self.message_handler and hasattr(self.message_handler, "session_store"): + if self.message_handler is not None: try: self.message_handler.session_store.flush_pending_save() except Exception as e: @@ -84,20 +84,33 @@ class AppRuntime: await best_effort("messaging_platform.stop", self.messaging_platform.stop()) if self.cli_manager: await best_effort("cli_manager.stop_all", self.cli_manager.stop_all()) - await best_effort("cleanup_provider", self.provider_cleanup()) + if self._provider_registry is not None: + await best_effort( + "provider_registry.cleanup", self._provider_registry.cleanup() + ) await self._shutdown_limiter() logger.info("Server shut down cleanly") async def _start_messaging_if_configured(self) -> None: try: - from messaging.platforms.factory import create_messaging_platform + from messaging.platforms.factory import ( + MessagingPlatformOptions, + create_messaging_platform, + ) self.messaging_platform = create_messaging_platform( - platform_type=self.settings.messaging_platform, - bot_token=self.settings.telegram_bot_token, - allowed_user_id=self.settings.allowed_telegram_user_id, - discord_bot_token=self.settings.discord_bot_token, - allowed_discord_channels=self.settings.allowed_discord_channels, + self.settings.messaging_platform, + MessagingPlatformOptions( + telegram_bot_token=self.settings.telegram_bot_token, + allowed_telegram_user_id=self.settings.allowed_telegram_user_id, + discord_bot_token=self.settings.discord_bot_token, + allowed_discord_channels=self.settings.allowed_discord_channels, + voice_note_enabled=self.settings.voice_note_enabled, + whisper_model=self.settings.whisper_model, + whisper_device=self.settings.whisper_device, + hf_token=self.settings.hf_token, + nvidia_nim_api_key=self.settings.nvidia_nim_api_key, + ), ) if self.messaging_platform: @@ -137,29 +150,31 @@ class AppRuntime: api_url=api_url, allowed_dirs=allowed_dirs, plans_directory=plans_directory, - claude_bin=getattr(self.settings, "claude_cli_bin", "claude"), + claude_bin=self.settings.claude_cli_bin, ) session_store = SessionStore( storage_path=os.path.join(data_path, "sessions.json") ) + platform = self.messaging_platform + assert platform is not None self.message_handler = ClaudeMessageHandler( - platform=self.messaging_platform, + platform=platform, cli_manager=self.cli_manager, session_store=session_store, ) self._restore_tree_state(session_store) - self.messaging_platform.on_message(self.message_handler.handle_message) - await self.messaging_platform.start() - logger.info( - f"{self.messaging_platform.name} platform started with message handler" - ) + platform.on_message(self.message_handler.handle_message) + await platform.start() + logger.info(f"{platform.name} platform started with message handler") - def _restore_tree_state(self, session_store: Any) -> None: + def _restore_tree_state(self, session_store: SessionStore) -> None: saved_trees = session_store.get_all_trees() if not saved_trees: return + if self.message_handler is None: + return logger.info(f"Restoring {len(saved_trees)} conversation trees...") from messaging.trees.queue_manager import TreeQueueManager @@ -188,11 +203,16 @@ class AppRuntime: async def _shutdown_limiter(self) -> None: try: from messaging.limiter import MessagingRateLimiter - - await best_effort( - "MessagingRateLimiter.shutdown_instance", - MessagingRateLimiter.shutdown_instance(), - timeout_s=2.0, + except Exception as e: + logger.debug( + "Rate limiter shutdown skipped (import failed): {}: {}", + type(e).__name__, + e, ) - except Exception: - pass + return + + await best_effort( + "MessagingRateLimiter.shutdown_instance", + MessagingRateLimiter.shutdown_instance(), + timeout_s=2.0, + ) diff --git a/api/services.py b/api/services.py index bc31fc5..037f48d 100644 --- a/api/services.py +++ b/api/services.py @@ -20,6 +20,10 @@ from .model_router import ModelRouter from .models.anthropic import MessagesRequest, TokenCountRequest from .models.responses import TokenCountResponse from .optimization_handlers import try_optimizations +from .web_server_tools import ( + is_web_server_tool_request, + stream_web_server_tool_response, +) TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int] @@ -48,6 +52,22 @@ class ClaudeProxyService: raise InvalidRequestError("messages cannot be empty") routed = self._model_router.resolve_messages_request(request_data) + if is_web_server_tool_request(routed.request): + input_tokens = self._token_counter( + routed.request.messages, routed.request.system, routed.request.tools + ) + logger.info("Optimization: Handling Anthropic web server tool") + return StreamingResponse( + stream_web_server_tool_response( + routed.request, input_tokens=input_tokens + ), + media_type="text/event-stream", + headers={ + "X-Accel-Buffering": "no", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) optimized = try_optimizations(routed.request, self._settings) if optimized is not None: diff --git a/api/web_server_tools.py b/api/web_server_tools.py new file mode 100644 index 0000000..5daced4 --- /dev/null +++ b/api/web_server_tools.py @@ -0,0 +1,331 @@ +"""Local handlers for Anthropic web server tools. + +OpenAI-compatible upstreams can emit regular function calls, but Anthropic's +web tools are server-side: the API response itself must include the tool result. +""" + +from __future__ import annotations + +import html +import json +import re +import uuid +from collections.abc import AsyncIterator +from datetime import UTC, datetime +from html.parser import HTMLParser +from typing import Any +from urllib.parse import parse_qs, unquote, urlparse + +import httpx + +from .models.anthropic import MessagesRequest + +_REQUEST_TIMEOUT_S = 20.0 +_MAX_SEARCH_RESULTS = 10 +_MAX_FETCH_CHARS = 24_000 + + +class _SearchResultParser(HTMLParser): + def __init__(self) -> None: + super().__init__() + self.results: list[dict[str, str]] = [] + self._href: str | None = None + self._title_parts: list[str] = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag != "a": + return + href = dict(attrs).get("href") + if not href or "uddg=" not in href: + return + parsed = urlparse(href) + query = parse_qs(parsed.query) + uddg = query.get("uddg", [""])[0] + if not uddg: + return + self._href = unquote(uddg) + self._title_parts = [] + + def handle_data(self, data: str) -> None: + if self._href is not None: + self._title_parts.append(data) + + def handle_endtag(self, tag: str) -> None: + if tag != "a" or self._href is None: + return + title = " ".join("".join(self._title_parts).split()) + if title and not any(result["url"] == self._href for result in self.results): + self.results.append({"title": html.unescape(title), "url": self._href}) + self._href = None + self._title_parts = [] + + +class _HTMLTextParser(HTMLParser): + def __init__(self) -> None: + super().__init__() + self.title = "" + self.text_parts: list[str] = [] + self._in_title = False + self._skip_depth = 0 + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag in {"script", "style", "noscript"}: + self._skip_depth += 1 + elif tag == "title": + self._in_title = True + + def handle_endtag(self, tag: str) -> None: + if tag in {"script", "style", "noscript"} and self._skip_depth: + self._skip_depth -= 1 + elif tag == "title": + self._in_title = False + + def handle_data(self, data: str) -> None: + text = " ".join(data.split()) + if not text: + return + if self._in_title: + self.title = f"{self.title} {text}".strip() + elif not self._skip_depth: + self.text_parts.append(text) + + +def _format_event(event_type: str, data: dict[str, Any]) -> str: + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + + +def _content_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text", ""))) + else: + parts.append(str(getattr(item, "text", ""))) + return "\n".join(part for part in parts if part) + return str(content) + + +def _request_text(request: MessagesRequest) -> str: + return "\n".join(_content_text(message.content) for message in request.messages) + + +def _web_tool_name(request: MessagesRequest) -> str | None: + for tool in request.tools or []: + name = tool.name + tool_type = tool.type or "" + if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"): + return name + return None + + +def is_web_server_tool_request(request: MessagesRequest) -> bool: + return _web_tool_name(request) in {"web_search", "web_fetch"} + + +def _extract_query(text: str) -> str: + match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip().strip("\"'") + return text.strip() + + +def _extract_url(text: str) -> str: + match = re.search(r"https?://\S+", text) + return match.group(0).rstrip(").,]") if match else text.strip() + + +async def _run_web_search(query: str) -> list[dict[str, str]]: + async with httpx.AsyncClient( + timeout=_REQUEST_TIMEOUT_S, + follow_redirects=True, + headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"}, + ) as client: + response = await client.get( + "https://lite.duckduckgo.com/lite/", + params={"q": query}, + ) + response.raise_for_status() + + parser = _SearchResultParser() + parser.feed(response.text) + return parser.results[:_MAX_SEARCH_RESULTS] + + +async def _run_web_fetch(url: str) -> dict[str, str]: + async with httpx.AsyncClient( + timeout=_REQUEST_TIMEOUT_S, + follow_redirects=True, + headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"}, + ) as client: + response = await client.get(url) + response.raise_for_status() + + content_type = response.headers.get("content-type", "text/plain") + title = url + data = response.text + if "html" in content_type.lower(): + parser = _HTMLTextParser() + parser.feed(response.text) + title = parser.title or url + data = "\n".join(parser.text_parts) + return { + "url": str(response.url), + "title": title, + "media_type": "text/plain", + "data": data[:_MAX_FETCH_CHARS], + } + + +def _search_summary(query: str, results: list[dict[str, str]]) -> str: + if not results: + return f"No web search results found for: {query}" + lines = [f"Search results for: {query}"] + for index, result in enumerate(results, start=1): + lines.append(f"{index}. {result['title']}\n{result['url']}") + return "\n\n".join(lines) + + +async def stream_web_server_tool_response( + request: MessagesRequest, input_tokens: int +) -> AsyncIterator[str]: + tool_name = _web_tool_name(request) + if tool_name is None: + return + + text = _request_text(request) + message_id = f"msg_{uuid.uuid4()}" + tool_id = f"srvtoolu_{uuid.uuid4().hex}" + output_tokens = 1 + usage_key = ( + "web_search_requests" if tool_name == "web_search" else "web_fetch_requests" + ) + tool_input = ( + {"query": _extract_query(text)} + if tool_name == "web_search" + else {"url": _extract_url(text)} + ) + + yield _format_event( + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [], + "model": request.model, + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": 1}, + }, + }, + ) + yield _format_event( + "content_block_start", + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "server_tool_use", + "id": tool_id, + "name": tool_name, + "input": tool_input, + }, + }, + ) + yield _format_event( + "content_block_stop", {"type": "content_block_stop", "index": 0} + ) + + try: + if tool_name == "web_search": + query = str(tool_input["query"]) + results = await _run_web_search(query) + result_content: Any = [ + { + "type": "web_search_result", + "title": result["title"], + "url": result["url"], + } + for result in results + ] + summary = _search_summary(query, results) + result_block_type = "web_search_tool_result" + else: + fetched = await _run_web_fetch(str(tool_input["url"])) + result_content = { + "type": "web_fetch_result", + "url": fetched["url"], + "content": { + "type": "document", + "source": { + "type": "text", + "media_type": fetched["media_type"], + "data": fetched["data"], + }, + "title": fetched["title"], + "citations": {"enabled": True}, + }, + "retrieved_at": datetime.now(UTC).isoformat(), + } + summary = fetched["data"][:_MAX_FETCH_CHARS] + result_block_type = "web_fetch_tool_result" + except Exception as error: + result_block_type = ( + "web_search_tool_result" + if tool_name == "web_search" + else "web_fetch_tool_result" + ) + error_type = ( + "web_search_tool_result_error" + if tool_name == "web_search" + else "web_fetch_tool_error" + ) + result_content = {"type": error_type, "error_code": "unavailable"} + summary = f"{tool_name} failed: {type(error).__name__}" + + output_tokens = max(1, len(summary) // 4) + + yield _format_event( + "content_block_start", + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": result_block_type, + "tool_use_id": tool_id, + "content": result_content, + }, + }, + ) + yield _format_event( + "content_block_stop", {"type": "content_block_stop", "index": 1} + ) + yield _format_event( + "content_block_start", + { + "type": "content_block_start", + "index": 2, + "content_block": {"type": "text", "text": summary}, + }, + ) + yield _format_event( + "content_block_stop", {"type": "content_block_stop", "index": 2} + ) + yield _format_event( + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": None}, + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "server_tool_use": {usage_key: 1}, + }, + }, + ) + yield _format_event("message_stop", {"type": "message_stop"}) diff --git a/config/provider_ids.py b/config/provider_ids.py new file mode 100644 index 0000000..c8ecb5e --- /dev/null +++ b/config/provider_ids.py @@ -0,0 +1,17 @@ +"""Canonical list of model provider type prefixes (provider_id values). + +`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this +module holds the id set for config validation and must stay in sync +(registries assert in `providers.registry`). +""" + +from __future__ import annotations + +# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys. +SUPPORTED_PROVIDER_IDS: tuple[str, ...] = ( + "nvidia_nim", + "open_router", + "deepseek", + "lmstudio", + "llamacpp", +) diff --git a/config/settings.py b/config/settings.py index 137eba3..184f4ae 100644 --- a/config/settings.py +++ b/config/settings.py @@ -11,6 +11,7 @@ from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from .nim import NimSettings +from .provider_ids import SUPPORTED_PROVIDER_IDS def _env_files() -> tuple[Path, ...]: @@ -252,25 +253,16 @@ class Settings(BaseSettings): def validate_model_format(cls, v: str | None) -> str | None: if v is None: return None - valid_providers = ( - "nvidia_nim", - "open_router", - "deepseek", - "lmstudio", - "llamacpp", - ) if "/" not in v: raise ValueError( f"Model must be prefixed with provider type. " - f"Valid providers: {', '.join(valid_providers)}. " + f"Valid providers: {', '.join(SUPPORTED_PROVIDER_IDS)}. " f"Format: provider_type/model/name" ) provider = v.split("/", 1)[0] - if provider not in valid_providers: - raise ValueError( - f"Invalid provider: '{provider}'. " - f"Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'" - ) + if provider not in SUPPORTED_PROVIDER_IDS: + supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDER_IDS) + raise ValueError(f"Invalid provider: '{provider}'. Supported: {supported}") return v @model_validator(mode="after") diff --git a/core/anthropic/conversion.py b/core/anthropic/conversion.py index d05f781..5979ba1 100644 --- a/core/anthropic/conversion.py +++ b/core/anthropic/conversion.py @@ -7,6 +7,17 @@ from .content import get_block_attr, get_block_type from .utils import set_if_not_none +def _tool_name(tool: Any) -> str: + return str(getattr(tool, "name", "") or "") + + +def _tool_input_schema(tool: Any) -> dict[str, Any]: + schema = getattr(tool, "input_schema", None) + if isinstance(schema, dict): + return schema + return {"type": "object", "properties": {}} + + class AnthropicToOpenAIConverter: """Convert Anthropic message format to OpenAI-compatible format.""" @@ -140,7 +151,7 @@ class AnthropicToOpenAIConverter: "function": { "name": tool.name, "description": tool.description or "", - "parameters": tool.input_schema, + "parameters": _tool_input_schema(tool), }, } for tool in tools diff --git a/core/anthropic/stream_contracts.py b/core/anthropic/stream_contracts.py new file mode 100644 index 0000000..8058392 --- /dev/null +++ b/core/anthropic/stream_contracts.py @@ -0,0 +1,158 @@ +"""Neutral SSE parsing and Anthropic stream shape assertions. + +Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for +opt-in smoke scenarios. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class SSEEvent: + event: str + data: dict[str, Any] + raw: str + + +def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]: + events: list[SSEEvent] = [] + current_event = "" + data_parts: list[str] = [] + raw_parts: list[str] = [] + + for line in lines: + stripped = line.rstrip("\r\n") + if stripped == "": + _append_event(events, current_event, data_parts, raw_parts) + current_event = "" + data_parts = [] + raw_parts = [] + continue + raw_parts.append(stripped) + if stripped.startswith("event:"): + current_event = stripped.split(":", 1)[1].strip() + elif stripped.startswith("data:"): + data_parts.append(stripped.split(":", 1)[1].strip()) + + _append_event(events, current_event, data_parts, raw_parts) + return events + + +def parse_sse_text(text: str) -> list[SSEEvent]: + return parse_sse_lines(text.splitlines()) + + +def _append_event( + events: list[SSEEvent], + current_event: str, + data_parts: list[str], + raw_parts: list[str], +) -> None: + if not current_event and not data_parts: + return + data_text = "\n".join(data_parts) + data: dict[str, Any] + try: + parsed = json.loads(data_text) if data_text else {} + data = parsed if isinstance(parsed, dict) else {"value": parsed} + except json.JSONDecodeError: + data = {"raw": data_text} + events.append(SSEEvent(current_event, data, "\n".join(raw_parts))) + + +def assert_anthropic_stream_contract( + events: list[SSEEvent], *, allow_error: bool = False +) -> None: + """Check minimal Anthropic-style SSE invariants: start/stop, block nesting. + + Does *not* assert strict event ordering (e.g. :class:`message_delta` vs + content blocks) beyond presence of a final ``message_stop``; stricter + ordering can be tested in product or transport-specific suites. + """ + assert events, "stream produced no SSE events" + event_names = [event.event for event in events] + assert "message_start" in event_names, event_names + assert event_names[-1] == "message_stop", event_names + + open_blocks: dict[int, str] = {} + seen_blocks: set[int] = set() + for event in events: + if event.event == "error" and not allow_error: + raise AssertionError(f"unexpected SSE error event: {event.data}") + + if event.event == "content_block_start": + index = _event_index(event) + block = event.data.get("content_block", {}) + assert isinstance(block, dict), event.data + block_type = str(block.get("type", "")) + assert block_type in {"text", "thinking", "tool_use"}, event.data + assert index not in open_blocks, f"block {index} started twice" + assert index not in seen_blocks, f"block {index} reused after stop" + open_blocks[index] = block_type + seen_blocks.add(index) + continue + + if event.event == "content_block_delta": + index = _event_index(event) + assert index in open_blocks, f"delta for unopened block {index}" + delta = event.data.get("delta", {}) + assert isinstance(delta, dict), event.data + delta_type = str(delta.get("type", "")) + expected = { + "text": "text_delta", + "thinking": "thinking_delta", + "tool_use": "input_json_delta", + }[open_blocks[index]] + assert delta_type == expected, ( + f"block {index} is {open_blocks[index]}, got {delta_type}" + ) + continue + + if event.event == "content_block_stop": + index = _event_index(event) + assert index in open_blocks, f"stop for unopened block {index}" + open_blocks.pop(index) + + assert not open_blocks, f"unclosed blocks: {open_blocks}" + assert seen_blocks, "stream did not emit any content blocks" + + +def event_names(events: list[SSEEvent]) -> list[str]: + return [event.event for event in events] + + +def text_content(events: list[SSEEvent]) -> str: + parts: list[str] = [] + for event in events: + delta = event.data.get("delta", {}) + if isinstance(delta, dict) and delta.get("type") == "text_delta": + parts.append(str(delta.get("text", ""))) + return "".join(parts) + + +def thinking_content(events: list[SSEEvent]) -> str: + parts: list[str] = [] + for event in events: + delta = event.data.get("delta", {}) + if isinstance(delta, dict) and delta.get("type") == "thinking_delta": + parts.append(str(delta.get("thinking", ""))) + return "".join(parts) + + +def has_tool_use(events: list[SSEEvent]) -> bool: + for event in events: + block = event.data.get("content_block", {}) + if isinstance(block, dict) and block.get("type") == "tool_use": + return True + return False + + +def _event_index(event: SSEEvent) -> int: + value = event.data.get("index") + assert isinstance(value, int), event.data + return value diff --git a/core/anthropic/tools.py b/core/anthropic/tools.py index 0751d9d..c09beb5 100644 --- a/core/anthropic/tools.py +++ b/core/anthropic/tools.py @@ -1,5 +1,6 @@ """Heuristic parser for text-emitted tool calls.""" +import json import re import uuid from enum import Enum @@ -31,6 +32,9 @@ class HeuristicToolParser: _PARAM_PATTERN = re.compile( r"]+)>(.*?)(?:|$)", re.DOTALL ) + _WEB_TOOL_JSON_PATTERN = re.compile( + r"(?is)\b(?:use\s+)?(?PWebFetch|WebSearch)\b.*?(?P\{.*?\})" + ) def __init__(self): self._state = ParserState.TEXT @@ -39,6 +43,41 @@ class HeuristicToolParser: self._current_function_name = None self._current_parameters = {} + def _extract_web_tool_json_calls(self) -> tuple[str, list[dict[str, Any]]]: + detected_tools: list[dict[str, Any]] = [] + + for match in self._WEB_TOOL_JSON_PATTERN.finditer(self._buffer): + try: + tool_input = json.loads(match.group("json")) + except json.JSONDecodeError: + continue + if not isinstance(tool_input, dict): + continue + + tool_name = match.group("tool") + if tool_name == "WebFetch" and "url" not in tool_input: + continue + if tool_name == "WebSearch" and "query" not in tool_input: + continue + + detected_tools.append( + { + "type": "tool_use", + "id": f"toolu_heuristic_{uuid.uuid4().hex[:8]}", + "name": tool_name, + "input": tool_input, + } + ) + logger.debug( + "Heuristic bypass: Detected JSON-style tool call '{}'", + tool_name, + ) + + if not detected_tools: + return self._buffer, [] + + return "", detected_tools + def _strip_control_tokens(self, text: str) -> str: return _CONTROL_TOKEN_RE.sub("", text) @@ -58,7 +97,7 @@ class HeuristicToolParser: """Feed text and return safe text plus detected tool calls.""" self._buffer += text self._buffer = self._strip_control_tokens(self._buffer) - detected_tools = [] + self._buffer, detected_tools = self._extract_web_tool_json_calls() filtered_output_parts: list[str] = [] while True: diff --git a/messaging/commands.py b/messaging/commands.py index 5a16ef5..e3f71d6 100644 --- a/messaging/commands.py +++ b/messaging/commands.py @@ -114,23 +114,15 @@ async def _delete_message_ids( numeric.sort(reverse=True) ordered = [mid for _, mid in numeric] + non_numeric - batch_fn = getattr(handler.platform, "queue_delete_messages", None) - if callable(batch_fn): - try: - CHUNK = 100 - for i in range(0, len(ordered), CHUNK): - chunk = ordered[i : i + CHUNK] - await batch_fn(chat_id, chunk, fire_and_forget=False) - except Exception as e: - logger.debug(f"Batch delete failed: {type(e).__name__}: {e}") - else: - for mid in ordered: - try: - await handler.platform.queue_delete_message( - chat_id, mid, fire_and_forget=False - ) - except Exception as e: - logger.debug(f"Delete failed for msg {mid}: {type(e).__name__}: {e}") + try: + CHUNK = 100 + for i in range(0, len(ordered), CHUNK): + chunk = ordered[i : i + CHUNK] + await handler.platform.queue_delete_messages( + chat_id, chunk, fire_and_forget=False + ) + except Exception as e: + logger.debug(f"Batch delete failed: {type(e).__name__}: {e}") async def _handle_clear_branch( diff --git a/messaging/handler.py b/messaging/handler.py index c1365a4..949cafb 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -467,7 +467,7 @@ class ClaudeMessageHandler: status, len(display), ) - if os.getenv("DEBUG_TELEGRAM_EDITS") == "1": + if os.getenv("DEBUG_PLATFORM_EDITS") == "1": logger.debug("PLATFORM_EDIT_TEXT:\n{}", display) else: head = display[:500] diff --git a/messaging/platforms/base.py b/messaging/platforms/base.py index b45c7f0..60beffa 100644 --- a/messaging/platforms/base.py +++ b/messaging/platforms/base.py @@ -17,15 +17,10 @@ class CLISession(Protocol): def start_task( self, prompt: str, session_id: str | None = None, fork_session: bool = False - ) -> AsyncGenerator[dict, Any]: - """Start a task in the CLI session.""" - ... + ) -> AsyncGenerator[dict, Any]: ... @property - @abstractmethod - def is_busy(self) -> bool: - """Check if session is busy.""" - pass + def is_busy(self) -> bool: ... @runtime_checkable @@ -101,7 +96,8 @@ class MessagingPlatform(ABC): text: Message content reply_to: Optional message ID to reply to parse_mode: Optional formatting mode ("markdown", "html") - message_thread_id: Optional forum topic ID (Telegram) + message_thread_id: Optional thread or topic id for threaded channels + (e.g. forum topics); unused on platforms that do not support it. Returns: The message ID of the sent message @@ -192,6 +188,22 @@ class MessagingPlatform(ABC): """ pass + async def queue_delete_messages( + self, + chat_id: str, + message_ids: list[str], + *, + fire_and_forget: bool = True, + ) -> None: + """Delete many messages; default loops :meth:`queue_delete_message`. + + Adapters with native bulk delete should override. + """ + for mid in message_ids: + await self.queue_delete_message( + chat_id, mid, fire_and_forget=fire_and_forget + ) + @abstractmethod def on_message( self, diff --git a/messaging/platforms/discord.py b/messaging/platforms/discord.py index 70fd5a3..201682b 100644 --- a/messaging/platforms/discord.py +++ b/messaging/platforms/discord.py @@ -91,6 +91,12 @@ class DiscordPlatform(MessagingPlatform): self, bot_token: str | None = None, allowed_channel_ids: str | None = None, + *, + voice_note_enabled: bool = True, + whisper_model: str = "base", + whisper_device: str = "cpu", + hf_token: str = "", + nvidia_nim_api_key: str = "", ): if not DISCORD_AVAILABLE: raise ImportError( @@ -117,7 +123,13 @@ class DiscordPlatform(MessagingPlatform): self._limiter: Any | None = None self._start_task: asyncio.Task | None = None self._pending_voice = PendingVoiceRegistry() - self._voice_transcription = VoiceTranscriptionService() + self._voice_transcription = VoiceTranscriptionService( + hf_token=hf_token, + nvidia_nim_api_key=nvidia_nim_api_key, + ) + self._voice_note_enabled = voice_note_enabled + self._whisper_model = whisper_model + self._whisper_device = whisper_device async def _handle_client_message(self, message: Any) -> None: """Adapter entry point used by the internal discord client.""" @@ -154,10 +166,7 @@ class DiscordPlatform(MessagingPlatform): self, message: Any, attachment: Any, channel_id: str ) -> bool: """Handle voice/audio attachment. Returns True if handled.""" - from config.settings import get_settings - - settings = get_settings() - if not settings.voice_note_enabled: + if not self._voice_note_enabled: await message.reply("Voice notes are disabled.") return True @@ -201,8 +210,8 @@ class DiscordPlatform(MessagingPlatform): transcribed = await self._voice_transcription.transcribe( tmp_path, ct, - whisper_model=settings.whisper_model, - whisper_device=settings.whisper_device, + whisper_model=self._whisper_model, + whisper_device=self._whisper_device, ) if not await self._is_voice_still_pending(channel_id, message_id): diff --git a/messaging/platforms/factory.py b/messaging/platforms/factory.py index b34b110..b40fe87 100644 --- a/messaging/platforms/factory.py +++ b/messaging/platforms/factory.py @@ -6,30 +6,50 @@ To add a new platform (e.g. Discord, Slack): 2. Add a case to create_messaging_platform() below """ +from __future__ import annotations + +from dataclasses import dataclass + from loguru import logger from .base import MessagingPlatform +@dataclass(frozen=True, slots=True) +class MessagingPlatformOptions: + """Typed wiring from :class:`~api.runtime.AppRuntime` into platform adapters.""" + + telegram_bot_token: str | None = None + allowed_telegram_user_id: str | None = None + discord_bot_token: str | None = None + allowed_discord_channels: str | None = None + voice_note_enabled: bool = True + whisper_model: str = "base" + whisper_device: str = "cpu" + hf_token: str = "" + nvidia_nim_api_key: str = "" + + def create_messaging_platform( platform_type: str, - **kwargs, + options: MessagingPlatformOptions | None = None, ) -> MessagingPlatform | None: """Create a messaging platform instance based on type. Args: - platform_type: Platform identifier ("telegram", "discord", etc.) - **kwargs: Platform-specific configuration passed to the constructor. + platform_type: Platform identifier (``telegram``, ``discord``, ``none``). + options: Token, allowlist, and voice / transcription settings. Returns: - Configured MessagingPlatform instance, or None if not configured. + Configured :class:`MessagingPlatform` instance, or None if not configured. """ + opts = options or MessagingPlatformOptions() if platform_type == "none": logger.info("Messaging platform disabled by configuration") return None if platform_type == "telegram": - bot_token = kwargs.get("bot_token") + bot_token = opts.telegram_bot_token if not bot_token: logger.info("No Telegram bot token configured, skipping platform setup") return None @@ -38,11 +58,16 @@ def create_messaging_platform( return TelegramPlatform( bot_token=bot_token, - allowed_user_id=kwargs.get("allowed_user_id"), + allowed_user_id=opts.allowed_telegram_user_id, + voice_note_enabled=opts.voice_note_enabled, + whisper_model=opts.whisper_model, + whisper_device=opts.whisper_device, + hf_token=opts.hf_token, + nvidia_nim_api_key=opts.nvidia_nim_api_key, ) if platform_type == "discord": - bot_token = kwargs.get("discord_bot_token") + bot_token = opts.discord_bot_token if not bot_token: logger.info("No Discord bot token configured, skipping platform setup") return None @@ -51,7 +76,12 @@ def create_messaging_platform( return DiscordPlatform( bot_token=bot_token, - allowed_channel_ids=kwargs.get("allowed_discord_channels"), + allowed_channel_ids=opts.allowed_discord_channels, + voice_note_enabled=opts.voice_note_enabled, + whisper_model=opts.whisper_model, + whisper_device=opts.whisper_device, + hf_token=opts.hf_token, + nvidia_nim_api_key=opts.nvidia_nim_api_key, ) logger.warning( diff --git a/messaging/platforms/telegram.py b/messaging/platforms/telegram.py index f9ffdf2..d131e44 100644 --- a/messaging/platforms/telegram.py +++ b/messaging/platforms/telegram.py @@ -62,6 +62,12 @@ class TelegramPlatform(MessagingPlatform): self, bot_token: str | None = None, allowed_user_id: str | None = None, + *, + voice_note_enabled: bool = True, + whisper_model: str = "base", + whisper_device: str = "cpu", + hf_token: str = "", + nvidia_nim_api_key: str = "", ): if not TELEGRAM_AVAILABLE: raise ImportError( @@ -84,7 +90,13 @@ class TelegramPlatform(MessagingPlatform): self._limiter: Any | None = None # Will be MessagingRateLimiter # Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id) self._pending_voice = PendingVoiceRegistry() - self._voice_transcription = VoiceTranscriptionService() + self._voice_transcription = VoiceTranscriptionService( + hf_token=hf_token, + nvidia_nim_api_key=nvidia_nim_api_key, + ) + self._voice_note_enabled = voice_note_enabled + self._whisper_model = whisper_model + self._whisper_device = whisper_device async def _register_pending_voice( self, chat_id: str, voice_msg_id: str, status_msg_id: str @@ -544,10 +556,7 @@ class TelegramPlatform(MessagingPlatform): ): return - from config.settings import get_settings - - settings = get_settings() - if not settings.voice_note_enabled: + if not self._voice_note_enabled: await update.message.reply_text("Voice notes are disabled.") return @@ -600,8 +609,8 @@ class TelegramPlatform(MessagingPlatform): transcribed = await self._voice_transcription.transcribe( tmp_path, voice.mime_type or "audio/ogg", - whisper_model=settings.whisper_model, - whisper_device=settings.whisper_device, + whisper_model=self._whisper_model, + whisper_device=self._whisper_device, ) if not await self._is_voice_still_pending(chat_id, message_id): diff --git a/messaging/transcription.py b/messaging/transcription.py index 37524b5..389a9be 100644 --- a/messaging/transcription.py +++ b/messaging/transcription.py @@ -42,8 +42,8 @@ _MODEL_MAP: dict[str, str] = { "large-v3-turbo": "openai/whisper-large-v3-turbo", } -# Lazy-loaded pipelines: (model_id, device) -> pipeline -_pipeline_cache: dict[tuple[str, str], Any] = {} +# Lazy-loaded pipelines: (model_id, device, hf_token_fingerprint) -> pipeline +_pipeline_cache: dict[tuple[str, str, str], Any] = {} def _resolve_model_id(whisper_model: str) -> str: @@ -51,20 +51,22 @@ def _resolve_model_id(whisper_model: str) -> str: return _MODEL_MAP.get(whisper_model, whisper_model) -def _get_pipeline(model_id: str, device: str) -> Any: +def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any: """Lazy-load transformers Whisper pipeline. Raises ImportError if not installed.""" global _pipeline_cache if device not in ("cpu", "cuda"): raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}") - cache_key = (model_id, device) + resolved_token = ( + hf_token if hf_token is not None else get_settings().hf_token + ) or "" + cache_key = (model_id, device, resolved_token) if cache_key not in _pipeline_cache: try: import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline - token = get_settings().hf_token - if token: - os.environ["HF_TOKEN"] = token + if resolved_token: + os.environ["HF_TOKEN"] = resolved_token use_cuda = device == "cuda" and torch.cuda.is_available() pipe_device = "cuda:0" if use_cuda else "cpu" @@ -103,6 +105,8 @@ def transcribe_audio( *, whisper_model: str = "base", whisper_device: str = "cpu", + hf_token: str = "", + nvidia_nim_api_key: str = "", ) -> str: """ Transcribe audio file to text. @@ -136,9 +140,12 @@ def transcribe_audio( ) if whisper_device == "nvidia_nim": - return _transcribe_nim(file_path, whisper_model) - else: - return _transcribe_local(file_path, whisper_model, whisper_device) + return _transcribe_nim( + file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key + ) + return _transcribe_local( + file_path, whisper_model, whisper_device, hf_token=hf_token + ) # Whisper expects 16 kHz sample rate @@ -153,10 +160,17 @@ def _load_audio(file_path: Path) -> dict[str, Any]: return {"array": waveform, "sampling_rate": sr} -def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str) -> str: +def _transcribe_local( + file_path: Path, + whisper_model: str, + whisper_device: str, + *, + hf_token: str = "", +) -> str: """Transcribe using transformers Whisper pipeline.""" model_id = _resolve_model_id(whisper_model) - pipe = _get_pipeline(model_id, whisper_device) + token: str | None = hf_token if hf_token else None + pipe = _get_pipeline(model_id, whisper_device, hf_token=token) audio = _load_audio(file_path) result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"}) text = result.get("text", "") or "" @@ -167,7 +181,9 @@ def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str) return result_text or "(no speech detected)" -def _transcribe_nim(file_path: Path, model: str) -> str: +def _transcribe_nim( + file_path: Path, model: str, *, nvidia_nim_api_key: str = "" +) -> str: """Transcribe using NVIDIA NIM Whisper API via Riva gRPC client.""" try: import riva.client @@ -177,8 +193,7 @@ def _transcribe_nim(file_path: Path, model: str) -> str: "Install with: uv sync --extra voice" ) from e - settings = get_settings() - api_key = settings.nvidia_nim_api_key + api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key # Look up function ID and language code from model mapping model_config = _NIM_MODEL_MAP.get(model) diff --git a/messaging/voice.py b/messaging/voice.py index 821176b..0dd4b12 100644 --- a/messaging/voice.py +++ b/messaging/voice.py @@ -46,6 +46,15 @@ class PendingVoiceRegistry: class VoiceTranscriptionService: """Run configured transcription backends off the event loop.""" + def __init__( + self, + *, + hf_token: str = "", + nvidia_nim_api_key: str = "", + ) -> None: + self._hf_token = hf_token + self._nvidia_nim_api_key = nvidia_nim_api_key + async def transcribe( self, file_path: Path, @@ -62,4 +71,6 @@ class VoiceTranscriptionService: mime_type, whisper_model=whisper_model, whisper_device=whisper_device, + hf_token=self._hf_token, + nvidia_nim_api_key=self._nvidia_nim_api_key, ) diff --git a/providers/__init__.py b/providers/__init__.py index 61ceda7..b1e37de 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -1,8 +1,11 @@ -"""Providers package - implement your own provider by extending BaseProvider.""" +"""Providers package - implement your own provider by extending BaseProvider. + +Concrete adapters (e.g. ``NvidiaNimProvider``) live in subpackages; import them +from ``providers.nvidia_nim`` etc. to avoid loading every adapter when the +``providers`` package is imported. +""" -from .anthropic_messages import AnthropicMessagesTransport from .base import BaseProvider, ProviderConfig -from .deepseek import DeepSeekProvider from .exceptions import ( APIError, AuthenticationError, @@ -10,27 +13,17 @@ from .exceptions import ( OverloadedError, ProviderError, RateLimitError, + UnknownProviderTypeError, ) -from .llamacpp import LlamaCppProvider -from .lmstudio import LMStudioProvider -from .nvidia_nim import NvidiaNimProvider -from .open_router import OpenRouterProvider -from .openai_compat import OpenAIChatTransport __all__ = [ "APIError", - "AnthropicMessagesTransport", "AuthenticationError", "BaseProvider", - "DeepSeekProvider", "InvalidRequestError", - "LMStudioProvider", - "LlamaCppProvider", - "NvidiaNimProvider", - "OpenAIChatTransport", - "OpenRouterProvider", "OverloadedError", "ProviderConfig", "ProviderError", "RateLimitError", + "UnknownProviderTypeError", ] diff --git a/providers/deepseek/client.py b/providers/deepseek/client.py index 3fb2d26..dc21bf7 100644 --- a/providers/deepseek/client.py +++ b/providers/deepseek/client.py @@ -3,12 +3,11 @@ from typing import Any from providers.base import ProviderConfig +from providers.defaults import DEEPSEEK_BASE_URL from providers.openai_compat import OpenAIChatTransport from .request import build_request_body -DEEPSEEK_BASE_URL = "https://api.deepseek.com" - class DeepSeekProvider(OpenAIChatTransport): """DeepSeek provider using OpenAI-compatible chat completions.""" diff --git a/providers/defaults.py b/providers/defaults.py new file mode 100644 index 0000000..89aaa8b --- /dev/null +++ b/providers/defaults.py @@ -0,0 +1,19 @@ +"""Default upstream base URLs and shared provider constants. + +Adapters and :mod:`providers.registry` import from here to avoid duplicating +literals and to keep ``providers.registry`` free of per-adapter eager imports. +""" + +# OpenAI-compatible chat (NIM, DeepSeek) and local OpenAI-shaped endpoints +NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1" +DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com" +OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1" +LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1" +LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" + +# Backward-compatible names used by existing adapter modules +NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE +DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE +OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE +LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE +LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE diff --git a/providers/exceptions.py b/providers/exceptions.py index d009986..31c6781 100644 --- a/providers/exceptions.py +++ b/providers/exceptions.py @@ -88,3 +88,10 @@ class APIError(ProviderError): error_type="api_error", raw_error=raw_error, ) + + +class UnknownProviderTypeError(ValueError): + """Raised when ``provider_id`` is not registered in the provider map.""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py index 21eb004..de0c268 100644 --- a/providers/llamacpp/client.py +++ b/providers/llamacpp/client.py @@ -2,8 +2,7 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig - -LLAMACPP_DEFAULT_BASE_URL = "http://localhost:8080/v1" +from providers.defaults import LLAMACPP_DEFAULT_BASE_URL class LlamaCppProvider(AnthropicMessagesTransport): diff --git a/providers/lmstudio/client.py b/providers/lmstudio/client.py index fa1a8a0..2961993 100644 --- a/providers/lmstudio/client.py +++ b/providers/lmstudio/client.py @@ -2,8 +2,7 @@ from providers.anthropic_messages import AnthropicMessagesTransport from providers.base import ProviderConfig - -LMSTUDIO_DEFAULT_BASE_URL = "http://localhost:1234/v1" +from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL class LMStudioProvider(AnthropicMessagesTransport): diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index e31a207..63218ff 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -8,6 +8,7 @@ from loguru import logger from config.nim import NimSettings from providers.base import ProviderConfig +from providers.defaults import NVIDIA_NIM_BASE_URL from providers.openai_compat import OpenAIChatTransport from .request import ( @@ -16,8 +17,6 @@ from .request import ( clone_body_without_reasoning_budget, ) -NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1" - class NvidiaNimProvider(OpenAIChatTransport): """NVIDIA NIM provider using official OpenAI client.""" diff --git a/providers/open_router/client.py b/providers/open_router/client.py index 3067351..f0df371 100644 --- a/providers/open_router/client.py +++ b/providers/open_router/client.py @@ -11,10 +11,10 @@ from typing import Any from core.anthropic import SSEBuilder, append_request_id from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode from providers.base import ProviderConfig +from providers.defaults import OPENROUTER_BASE_URL from .request import build_request_body -OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" _ANTHROPIC_VERSION = "2023-06-01" diff --git a/providers/openai_compat.py b/providers/openai_compat.py index 47655f0..93c3af4 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -1,5 +1,10 @@ -"""Shared base class for OpenAI-compatible providers (NIM, OpenRouter, LM Studio).""" +"""OpenAI-style chat base for :class:`OpenAIChatTransport` (NIM, DeepSeek, etc.). +``AnthropicMessagesTransport``-based providers (OpenRouter, LM Studio, …) live +in separate modules; do not list them as subclasses of this class. +""" + +import asyncio import json import uuid from abc import abstractmethod @@ -25,7 +30,7 @@ from providers.rate_limit import GlobalRateLimiter class OpenAIChatTransport(BaseProvider): - """Base class for providers using OpenAI-compatible chat completions API.""" + """Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …).""" def __init__( self, @@ -114,6 +119,7 @@ class OpenAIChatTransport(BaseProvider): fn_delta = tc.get("function", {}) incoming_name = fn_delta.get("name") + arguments = fn_delta.get("arguments", "") if incoming_name is not None: sse.blocks.register_tool_name(tc_index, incoming_name) @@ -124,7 +130,7 @@ class OpenAIChatTransport(BaseProvider): tool_id = tc.get("id") or f"tool_{uuid.uuid4()}" yield sse.start_tool_block(tc_index, tool_id, name) - args = fn_delta.get("arguments", "") + args = arguments if args: state = sse.blocks.tool_states.get(tc_index) if state is None or not state.started: @@ -285,6 +291,8 @@ class OpenAIChatTransport(BaseProvider): for event in self._process_tool_call(tc_info, sse): yield event + except asyncio.CancelledError, GeneratorExit: + raise except Exception as e: logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e) mapped_e = map_error(e, rate_limiter=self._global_rate_limiter) diff --git a/providers/registry.py b/providers/registry.py index 3ef6159..acbe73a 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -6,17 +6,17 @@ from collections.abc import Callable, MutableMapping from dataclasses import dataclass from typing import Literal +from config.provider_ids import SUPPORTED_PROVIDER_IDS from config.settings import Settings from providers.base import BaseProvider, ProviderConfig -from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider -from providers.exceptions import AuthenticationError -from providers.llamacpp import LlamaCppProvider -from providers.lmstudio import LMStudioProvider -from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider -from providers.open_router import ( - OPENROUTER_BASE_URL, - OpenRouterProvider, +from providers.defaults import ( + DEEPSEEK_DEFAULT_BASE, + LLAMACPP_DEFAULT_BASE, + LMSTUDIO_DEFAULT_BASE, + NVIDIA_NIM_DEFAULT_BASE, + OPENROUTER_DEFAULT_BASE, ) +from providers.exceptions import AuthenticationError, UnknownProviderTypeError TransportType = Literal["openai_chat", "anthropic_messages"] ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider] @@ -24,11 +24,17 @@ ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider] @dataclass(frozen=True, slots=True) class ProviderDescriptor: + """Metadata for building :class:`ProviderConfig` and factory wiring.""" + provider_id: str transport_type: TransportType capabilities: tuple[str, ...] credential_env: str | None = None credential_url: str | None = None + # If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key). + credential_attr: str | None = None + # If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp). + static_credential: str | None = None default_base_url: str | None = None base_url_attr: str | None = None proxy_attr: str | None = None @@ -40,7 +46,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { transport_type="openai_chat", credential_env="NVIDIA_NIM_API_KEY", credential_url="https://build.nvidia.com/settings/api-keys", - default_base_url=NVIDIA_NIM_BASE_URL, + credential_attr="nvidia_nim_api_key", + default_base_url=NVIDIA_NIM_DEFAULT_BASE, proxy_attr="nvidia_nim_proxy", capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), ), @@ -49,7 +56,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { transport_type="anthropic_messages", credential_env="OPENROUTER_API_KEY", credential_url="https://openrouter.ai/keys", - default_base_url=OPENROUTER_BASE_URL, + credential_attr="open_router_api_key", + default_base_url=OPENROUTER_DEFAULT_BASE, proxy_attr="open_router_proxy", capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"), ), @@ -58,13 +66,15 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { transport_type="openai_chat", credential_env="DEEPSEEK_API_KEY", credential_url="https://platform.deepseek.com/api_keys", - default_base_url=DEEPSEEK_BASE_URL, + credential_attr="deepseek_api_key", + default_base_url=DEEPSEEK_DEFAULT_BASE, capabilities=("chat", "streaming", "thinking"), ), "lmstudio": ProviderDescriptor( provider_id="lmstudio", transport_type="anthropic_messages", - default_base_url="http://localhost:1234/v1", + static_credential="lm-studio", + default_base_url=LMSTUDIO_DEFAULT_BASE, base_url_attr="lm_studio_base_url", proxy_attr="lmstudio_proxy", capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), @@ -72,7 +82,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { "llamacpp": ProviderDescriptor( provider_id="llamacpp", transport_type="anthropic_messages", - default_base_url="http://localhost:8080/v1", + static_credential="llamacpp", + default_base_url=LLAMACPP_DEFAULT_BASE, base_url_attr="llamacpp_base_url", proxy_attr="llamacpp_proxy", capabilities=("chat", "streaming", "tools", "native_anthropic", "local"), @@ -81,22 +92,32 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider: + from providers.nvidia_nim import NvidiaNimProvider + return NvidiaNimProvider(config, nim_settings=settings.nim) def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider: + from providers.open_router import OpenRouterProvider + return OpenRouterProvider(config) def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider: + from providers.deepseek import DeepSeekProvider + return DeepSeekProvider(config) def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider: + from providers.lmstudio import LMStudioProvider + return LMStudioProvider(config) def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider: + from providers.llamacpp import LlamaCppProvider + return LlamaCppProvider(config) @@ -108,6 +129,15 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = { "llamacpp": _create_llamacpp, } +if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set( + PROVIDER_FACTORIES +) != set(SUPPORTED_PROVIDER_IDS): + raise AssertionError( + "PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: " + f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} " + f"ids={set(SUPPORTED_PROVIDER_IDS)!r}" + ) + def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str: if attr_name is None: @@ -116,17 +146,11 @@ def _string_attr(settings: Settings, attr_name: str | None, default: str = "") - return value if isinstance(value, str) else default -def _credential_for(provider_id: str, settings: Settings) -> str: - if provider_id == "nvidia_nim": - return settings.nvidia_nim_api_key - if provider_id == "open_router": - return settings.open_router_api_key - if provider_id == "deepseek": - return settings.deepseek_api_key - if provider_id == "lmstudio": - return "lm-studio" - if provider_id == "llamacpp": - return "llamacpp" +def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str: + if descriptor.static_credential is not None: + return descriptor.static_credential + if descriptor.credential_attr: + return _string_attr(settings, descriptor.credential_attr) return "" @@ -144,7 +168,7 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None def build_provider_config( descriptor: ProviderDescriptor, settings: Settings ) -> ProviderConfig: - credential = _credential_for(descriptor.provider_id, settings) + credential = _credential_for(descriptor, settings) _require_credential(descriptor, credential) base_url = _string_attr( settings, descriptor.base_url_attr, descriptor.default_base_url or "" @@ -168,7 +192,7 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider: descriptor = PROVIDER_DESCRIPTORS.get(provider_id) if descriptor is None: supported = "', '".join(PROVIDER_DESCRIPTORS) - raise ValueError( + raise UnknownProviderTypeError( f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'" ) @@ -185,12 +209,33 @@ class ProviderRegistry: def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None): self._providers = providers if providers is not None else {} + def is_cached(self, provider_id: str) -> bool: + """Return whether a provider for this id is already in the cache.""" + return provider_id in self._providers + def get(self, provider_id: str, settings: Settings) -> BaseProvider: if provider_id not in self._providers: self._providers[provider_id] = create_provider(provider_id, settings) return self._providers[provider_id] async def cleanup(self) -> None: - for provider in self._providers.values(): - await provider.cleanup() - self._providers.clear() + """Call ``cleanup`` on every cached provider, then clear the cache. + + Attempts all providers even if one fails. A single failure is re-raised + as-is; multiple failures are wrapped in :exc:`ExceptionGroup`. + """ + items = list(self._providers.items()) + errors: list[Exception] = [] + try: + for _pid, provider in items: + try: + await provider.cleanup() + except Exception as e: + errors.append(e) + finally: + self._providers.clear() + if len(errors) == 1: + raise errors[0] + if len(errors) > 1: + msg = "One or more provider cleanups failed" + raise ExceptionGroup(msg, errors) diff --git a/smoke/lib/e2e.py b/smoke/lib/e2e.py index d5e0592..e669b51 100644 --- a/smoke/lib/e2e.py +++ b/smoke/lib/e2e.py @@ -18,6 +18,7 @@ from typing import Any import httpx import pytest +from config.provider_ids import SUPPORTED_PROVIDER_IDS from messaging.handler import ClaudeMessageHandler from messaging.models import IncomingMessage from messaging.platforms.base import MessagingPlatform @@ -153,7 +154,7 @@ class ConversationDriver: class ProviderMatrixDriver: """Resolve provider models and enforce matrix semantics for product smoke.""" - ALL_PROVIDERS = ("nvidia_nim", "open_router", "deepseek", "lmstudio", "llamacpp") + ALL_PROVIDERS: tuple[str, ...] = SUPPORTED_PROVIDER_IDS def __init__(self, config: SmokeConfig) -> None: self.config = config diff --git a/smoke/lib/sse.py b/smoke/lib/sse.py index 0948dd0..a294ce5 100644 --- a/smoke/lib/sse.py +++ b/smoke/lib/sse.py @@ -1,148 +1,29 @@ -"""SSE parsing and Anthropic stream assertions for smoke tests.""" +"""SSE parsing and Anthropic stream assertions for smoke tests. + +Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this +module re-exports it for smoke and historical import paths. +""" from __future__ import annotations -import json -from collections.abc import Iterable -from dataclasses import dataclass -from typing import Any +from core.anthropic.stream_contracts import ( + SSEEvent, + assert_anthropic_stream_contract, + event_names, + has_tool_use, + parse_sse_lines, + parse_sse_text, + text_content, + thinking_content, +) - -@dataclass(frozen=True, slots=True) -class SSEEvent: - event: str - data: dict[str, Any] - raw: str - - -def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]: - events: list[SSEEvent] = [] - current_event = "" - data_parts: list[str] = [] - raw_parts: list[str] = [] - - for line in lines: - stripped = line.rstrip("\r\n") - if stripped == "": - _append_event(events, current_event, data_parts, raw_parts) - current_event = "" - data_parts = [] - raw_parts = [] - continue - raw_parts.append(stripped) - if stripped.startswith("event:"): - current_event = stripped.split(":", 1)[1].strip() - elif stripped.startswith("data:"): - data_parts.append(stripped.split(":", 1)[1].strip()) - - _append_event(events, current_event, data_parts, raw_parts) - return events - - -def parse_sse_text(text: str) -> list[SSEEvent]: - return parse_sse_lines(text.splitlines()) - - -def _append_event( - events: list[SSEEvent], - current_event: str, - data_parts: list[str], - raw_parts: list[str], -) -> None: - if not current_event and not data_parts: - return - data_text = "\n".join(data_parts) - data: dict[str, Any] - try: - parsed = json.loads(data_text) if data_text else {} - data = parsed if isinstance(parsed, dict) else {"value": parsed} - except json.JSONDecodeError: - data = {"raw": data_text} - events.append(SSEEvent(current_event, data, "\n".join(raw_parts))) - - -def assert_anthropic_stream_contract( - events: list[SSEEvent], *, allow_error: bool = False -) -> None: - assert events, "stream produced no SSE events" - event_names = [event.event for event in events] - assert "message_start" in event_names, event_names - assert event_names[-1] == "message_stop", event_names - - open_blocks: dict[int, str] = {} - seen_blocks: set[int] = set() - for event in events: - if event.event == "error" and not allow_error: - raise AssertionError(f"unexpected SSE error event: {event.data}") - - if event.event == "content_block_start": - index = _event_index(event) - block = event.data.get("content_block", {}) - assert isinstance(block, dict), event.data - block_type = str(block.get("type", "")) - assert block_type in {"text", "thinking", "tool_use"}, event.data - assert index not in open_blocks, f"block {index} started twice" - assert index not in seen_blocks, f"block {index} reused after stop" - open_blocks[index] = block_type - seen_blocks.add(index) - continue - - if event.event == "content_block_delta": - index = _event_index(event) - assert index in open_blocks, f"delta for unopened block {index}" - delta = event.data.get("delta", {}) - assert isinstance(delta, dict), event.data - delta_type = str(delta.get("type", "")) - expected = { - "text": "text_delta", - "thinking": "thinking_delta", - "tool_use": "input_json_delta", - }[open_blocks[index]] - assert delta_type == expected, ( - f"block {index} is {open_blocks[index]}, got {delta_type}" - ) - continue - - if event.event == "content_block_stop": - index = _event_index(event) - assert index in open_blocks, f"stop for unopened block {index}" - open_blocks.pop(index) - - assert not open_blocks, f"unclosed blocks: {open_blocks}" - assert seen_blocks, "stream did not emit any content blocks" - - -def event_names(events: list[SSEEvent]) -> list[str]: - return [event.event for event in events] - - -def text_content(events: list[SSEEvent]) -> str: - parts: list[str] = [] - for event in events: - delta = event.data.get("delta", {}) - if isinstance(delta, dict) and delta.get("type") == "text_delta": - parts.append(str(delta.get("text", ""))) - return "".join(parts) - - -def thinking_content(events: list[SSEEvent]) -> str: - parts: list[str] = [] - for event in events: - delta = event.data.get("delta", {}) - if isinstance(delta, dict) and delta.get("type") == "thinking_delta": - parts.append(str(delta.get("thinking", ""))) - return "".join(parts) - - -def has_tool_use(events: list[SSEEvent]) -> bool: - for event in events: - block = event.data.get("content_block", {}) - if isinstance(block, dict) and block.get("type") == "tool_use": - return True - return False - - -def _event_index(event: SSEEvent) -> int: - value = event.data.get("index") - assert isinstance(value, int), event.data - return value +__all__ = [ + "SSEEvent", + "assert_anthropic_stream_contract", + "event_names", + "has_tool_use", + "parse_sse_lines", + "parse_sse_text", + "text_content", + "thinking_content", +] diff --git a/tests/api/test_api.py b/tests/api/test_api.py index cc53294..f9149c8 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import pytest from fastapi.testclient import TestClient from api.app import app @@ -9,7 +10,7 @@ from providers.nvidia_nim import NvidiaNimProvider mock_provider = MagicMock(spec=NvidiaNimProvider) # Track stream_response calls for test_model_mapping -_stream_response_calls = [] +_stream_response_calls: list = [] async def _mock_stream_response(*args, **kwargs): @@ -21,26 +22,30 @@ async def _mock_stream_response(*args, **kwargs): mock_provider.stream_response = _mock_stream_response -# Patch get_provider_for_type to always return mock_provider -_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider) -_patcher.start() -client = TestClient(app) +@pytest.fixture(scope="module") +def client(): + """HTTP client with provider resolution stubbed; patch only for this file.""" + with ( + patch("api.dependencies.resolve_provider", return_value=mock_provider), + TestClient(app) as test_client, + ): + yield test_client -def test_root(): +def test_root(client: TestClient): response = client.get("/") assert response.status_code == 200 assert response.json()["status"] == "ok" -def test_health(): +def test_health(client: TestClient): response = client.get("/health") assert response.status_code == 200 assert response.json()["status"] == "healthy" -def test_models_list(): +def test_models_list(client: TestClient): response = client.get("/v1/models") assert response.status_code == 200 data = response.json() @@ -51,7 +56,7 @@ def test_models_list(): assert data["last_id"] == ids[-1] -def test_probe_endpoints_return_204_with_allow_headers(): +def test_probe_endpoints_return_204_with_allow_headers(client: TestClient): responses = [ client.head("/"), client.options("/"), @@ -68,7 +73,7 @@ def test_probe_endpoints_return_204_with_allow_headers(): assert "Allow" in response.headers -def test_create_message_stream(): +def test_create_message_stream(client: TestClient): """Create message returns streaming response.""" payload = { "model": "claude-3-sonnet", @@ -83,7 +88,7 @@ def test_create_message_stream(): assert b"message_start" in content or b"event:" in content -def test_model_mapping(): +def test_model_mapping(client: TestClient): # Test Haiku mapping _stream_response_calls.clear() payload_haiku = { @@ -98,7 +103,7 @@ def test_model_mapping(): assert args[0].model != "claude-3-haiku-20240307" -def test_error_fallbacks(): +def test_error_fallbacks(client: TestClient): from providers.exceptions import ( AuthenticationError, OverloadedError, @@ -143,7 +148,7 @@ def test_error_fallbacks(): mock_provider.stream_response = _mock_stream_response -def test_generic_exception_returns_500(): +def test_generic_exception_returns_500(client: TestClient): """Non-ProviderError exceptions are caught and returned as HTTPException(500).""" def _raise_runtime(*args, **kwargs): @@ -163,7 +168,7 @@ def test_generic_exception_returns_500(): mock_provider.stream_response = _mock_stream_response -def test_generic_exception_with_status_code(): +def test_generic_exception_with_status_code(client: TestClient): """Generic exception with status_code attribute uses that status (getattr fallback).""" class ExceptionWithStatus(RuntimeError): @@ -188,7 +193,7 @@ def test_generic_exception_with_status_code(): mock_provider.stream_response = _mock_stream_response -def test_generic_exception_empty_message_returns_non_empty_detail(): +def test_generic_exception_empty_message_returns_non_empty_detail(client: TestClient): """Exceptions with empty __str__ still return a readable HTTP detail.""" class SilentError(RuntimeError): @@ -213,7 +218,7 @@ def test_generic_exception_empty_message_returns_non_empty_detail(): mock_provider.stream_response = _mock_stream_response -def test_count_tokens_endpoint(): +def test_count_tokens_endpoint(client: TestClient): """count_tokens endpoint returns token count.""" response = client.post( "/v1/messages/count_tokens", @@ -223,7 +228,7 @@ def test_count_tokens_endpoint(): assert "input_tokens" in response.json() -def test_stop_endpoint_no_handler_no_cli_503(): +def test_stop_endpoint_no_handler_no_cli_503(client: TestClient): """POST /stop without handler or cli_manager returns 503.""" # Ensure no handler or cli_manager on app state if hasattr(app.state, "message_handler"): diff --git a/tests/api/test_app_lifespan_and_errors.py b/tests/api/test_app_lifespan_and_errors.py index 614de3b..126a869 100644 --- a/tests/api/test_app_lifespan_and_errors.py +++ b/tests/api/test_app_lifespan_and_errors.py @@ -7,6 +7,23 @@ import pytest from fastapi.testclient import TestClient from config.settings import Settings +from providers.registry import ProviderRegistry + +_RUNTIME_EXTRAS = { + "voice_note_enabled": True, + "whisper_model": "base", + "whisper_device": "cpu", + "hf_token": "", + "nvidia_nim_api_key": "", + "claude_cli_bin": "claude", + "uses_process_anthropic_auth_token": lambda: False, +} + + +def _app_settings(**kwargs): + """Minimal settings namespace for AppRuntime (matches typed :class:`Settings` fields used).""" + data = {**_RUNTIME_EXTRAS, **kwargs} + return SimpleNamespace(**data) def test_warn_if_process_auth_token_logs_warning(): @@ -45,7 +62,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format(): raise AuthenticationError("bad key") api_app_mod = importlib.import_module("api.app") - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token=None, allowed_telegram_user_id=None, @@ -59,7 +76,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format(): ) with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()), + patch.object(ProviderRegistry, "cleanup", new=AsyncMock()), ): with TestClient(app) as client: resp = client.get("/raise_provider") @@ -79,7 +96,7 @@ def test_create_app_general_exception_handler_returns_500(): raise RuntimeError("boom") api_app_mod = importlib.import_module("api.app") - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token=None, allowed_telegram_user_id=None, @@ -93,7 +110,7 @@ def test_create_app_general_exception_handler_returns_500(): ) with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()), + patch.object(ProviderRegistry, "cleanup", new=AsyncMock()), ): with TestClient(app, raise_server_exceptions=False) as client: resp = client.get("/raise_general") @@ -111,7 +128,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled): app = create_app() - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token="token" if messaging_enabled else None, allowed_telegram_user_id="123", @@ -147,10 +164,10 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled): api_app_mod = importlib.import_module("api.app") - cleanup_provider = AsyncMock() + registry_cleanup = AsyncMock() with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), + patch.object(ProviderRegistry, "cleanup", new=registry_cleanup), patch( "messaging.platforms.factory.create_messaging_platform", return_value=fake_platform if messaging_enabled else None, @@ -182,7 +199,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled): cli_manager.stop_all.assert_not_awaited() assert getattr(app.state, "messaging_platform", "missing") is None - cleanup_provider.assert_awaited_once() + registry_cleanup.assert_awaited_once() def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path): @@ -190,7 +207,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path): app = create_app() - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token="token", allowed_telegram_user_id="123", @@ -218,10 +235,10 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path): cli_manager.stop_all = AsyncMock() api_app_mod = importlib.import_module("api.app") - cleanup_provider = AsyncMock() + registry_cleanup = AsyncMock() with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), + patch.object(ProviderRegistry, "cleanup", new=registry_cleanup), patch( "messaging.platforms.factory.create_messaging_platform", return_value=fake_platform, @@ -234,7 +251,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path): fake_platform.stop.assert_awaited_once() cli_manager.stop_all.assert_awaited_once() - cleanup_provider.assert_awaited_once() + registry_cleanup.assert_awaited_once() def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog): @@ -243,7 +260,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog): app = create_app() - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token="token", allowed_telegram_user_id="123", @@ -257,10 +274,10 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog): ) api_app_mod = importlib.import_module("api.app") - cleanup_provider = AsyncMock() + registry_cleanup = AsyncMock() with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), + patch.object(ProviderRegistry, "cleanup", new=registry_cleanup), patch( "messaging.platforms.factory.create_messaging_platform", side_effect=ImportError("discord not installed"), @@ -270,7 +287,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog): pass assert getattr(app.state, "messaging_platform", None) is None - cleanup_provider.assert_awaited_once() + registry_cleanup.assert_awaited_once() def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path): @@ -279,7 +296,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path): app = create_app() - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token="token", allowed_telegram_user_id="123", @@ -307,10 +324,10 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path): cli_manager.stop_all = AsyncMock() api_app_mod = importlib.import_module("api.app") - cleanup_provider = AsyncMock() + registry_cleanup = AsyncMock() with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), + patch.object(ProviderRegistry, "cleanup", new=registry_cleanup), patch( "messaging.platforms.factory.create_messaging_platform", return_value=fake_platform, @@ -321,7 +338,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path): ): pass - cleanup_provider.assert_awaited_once() + registry_cleanup.assert_awaited_once() def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path): @@ -330,7 +347,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path): app = create_app() - settings = SimpleNamespace( + settings = _app_settings( messaging_platform="telegram", telegram_bot_token="token", allowed_telegram_user_id="123", @@ -359,10 +376,10 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path): cli_manager.stop_all = AsyncMock() api_app_mod = importlib.import_module("api.app") - cleanup_provider = AsyncMock() + registry_cleanup = AsyncMock() with ( patch.object(api_app_mod, "get_settings", return_value=settings), - patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), + patch.object(ProviderRegistry, "cleanup", new=registry_cleanup), patch( "messaging.platforms.factory.create_messaging_platform", return_value=fake_platform, @@ -374,4 +391,4 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path): pass session_store.flush_pending_save.assert_called_once() - cleanup_provider.assert_awaited_once() + registry_cleanup.assert_awaited_once() diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py index 892373a..ccec712 100644 --- a/tests/api/test_dependencies.py +++ b/tests/api/test_dependencies.py @@ -1,19 +1,26 @@ +from types import SimpleNamespace +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException +from starlette.applications import Starlette +from starlette.datastructures import State from api.dependencies import ( cleanup_provider, get_provider, get_provider_for_type, get_settings, + resolve_provider, ) from config.nim import NimSettings from providers.deepseek import DeepSeekProvider +from providers.exceptions import UnknownProviderTypeError from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NvidiaNimProvider from providers.open_router import OpenRouterProvider +from providers.registry import ProviderRegistry def _make_mock_settings(**overrides): @@ -304,11 +311,11 @@ async def test_get_provider_deepseek_missing_api_key(): @pytest.mark.asyncio async def test_get_provider_unknown_type(): - """Test that unknown provider_type raises ValueError.""" + """Unknown ``provider_type`` raises :exc:`~providers.exceptions.UnknownProviderTypeError`.""" with patch("api.dependencies.get_settings") as mock_settings: mock_settings.return_value = _make_mock_settings(provider_type="unknown") - with pytest.raises(ValueError, match="Unknown provider_type"): + with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"): get_provider() @@ -390,3 +397,55 @@ async def test_cleanup_provider_cleans_all(): nim._client.aclose.assert_called_once() lmstudio._client.aclose.assert_called_once() + + +def test_resolve_provider_per_app_uses_separate_registries() -> None: + """With app set, each app gets its own provider cache (not process _providers).""" + with patch("api.dependencies.get_settings") as mock_settings: + mock_settings.return_value = _make_mock_settings() + settings = _make_mock_settings() + app1 = SimpleNamespace(state=State()) + app2 = SimpleNamespace(state=State()) + app1.state.provider_registry = ProviderRegistry() + app2.state.provider_registry = ProviderRegistry() + p1 = resolve_provider( + "nvidia_nim", app=cast(Starlette, app1), settings=settings + ) + p2 = resolve_provider( + "nvidia_nim", app=cast(Starlette, app2), settings=settings + ) + assert isinstance(p1, NvidiaNimProvider) + assert isinstance(p2, NvidiaNimProvider) + assert p1 is not p2 + + +def test_resolve_provider_lazily_installs_registry() -> None: + """If app has no provider_registry, one is created on app.state.""" + with patch("api.dependencies.get_settings") as mock_settings: + mock_settings.return_value = _make_mock_settings() + settings = _make_mock_settings() + app = SimpleNamespace(state=State()) + assert getattr(app.state, "provider_registry", None) is None + resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) + reg = app.state.provider_registry + assert reg is not None + p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings) + assert p2 is reg.get("nvidia_nim", settings) # same registry instance + + +def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None: + """Only :exc:`~providers.exceptions.UnknownProviderTypeError` logs unknown provider.""" + import api.dependencies as deps + + with ( + patch.object(deps, "get_settings", return_value=_make_mock_settings()), + patch.object( + ProviderRegistry, + "get", + side_effect=ValueError("unrelated config"), + ), + patch.object(deps.logger, "error") as log_err, + pytest.raises(ValueError, match="unrelated config"), + ): + deps.resolve_provider("nvidia_nim", app=None, settings=_make_mock_settings()) + log_err.assert_not_called() diff --git a/tests/api/test_models_validators.py b/tests/api/test_models_validators.py index 4e7740b..20393a4 100644 --- a/tests/api/test_models_validators.py +++ b/tests/api/test_models_validators.py @@ -91,6 +91,21 @@ def test_messages_request_accepts_adaptive_thinking_type(): assert dumped["thinking"]["type"] == "adaptive" +def test_messages_request_accepts_anthropic_server_tool_without_input_schema(): + request = MessagesRequest.model_validate( + { + "model": "claude-opus-4-7", + "max_tokens": 100, + "messages": [{"role": "user", "content": "search"}], + "tools": [{"type": "web_search_20250305", "name": "web_search"}], + } + ) + + dumped = request.model_dump(exclude_none=True) + + assert dumped["tools"] == [{"name": "web_search", "type": "web_search_20250305"}] + + def test_messages_request_accepts_redacted_thinking_blocks(): request = MessagesRequest.model_validate( { diff --git a/tests/api/test_web_server_tools.py b/tests/api/test_web_server_tools.py new file mode 100644 index 0000000..c00d538 --- /dev/null +++ b/tests/api/test_web_server_tools.py @@ -0,0 +1,96 @@ +import json + +import pytest + +from api.models.anthropic import Message, MessagesRequest, Tool +from api.web_server_tools import ( + is_web_server_tool_request, + stream_web_server_tool_response, +) + + +def _event_data(event: str) -> dict: + data_line = next(line for line in event.splitlines() if line.startswith("data: ")) + return json.loads(data_line.removeprefix("data: ")) + + +def test_detects_web_search_server_tool_request(): + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[Message(role="user", content="search")], + tools=[Tool(name="web_search", type="web_search_20250305")], + ) + + assert is_web_server_tool_request(request) + + +@pytest.mark.asyncio +async def test_streams_web_search_server_tool_result(monkeypatch): + async def fake_search(query: str) -> list[dict[str, str]]: + assert query == "DeepSeek V4 model release 2026" + return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}] + + monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message( + role="user", + content=( + "Perform a web search for the query: DeepSeek V4 model release 2026" + ), + ) + ], + tools=[Tool(name="web_search", type="web_search_20250305")], + tool_choice={"type": "tool", "name": "web_search"}, + ) + + events = [ + event + async for event in stream_web_server_tool_response(request, input_tokens=42) + ] + payloads = [_event_data(event) for event in events] + + assert payloads[1]["content_block"]["type"] == "server_tool_use" + assert payloads[1]["content_block"]["name"] == "web_search" + assert payloads[3]["content_block"]["type"] == "web_search_tool_result" + assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4" + assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1} + + +@pytest.mark.asyncio +async def test_streams_web_fetch_server_tool_result(monkeypatch): + async def fake_fetch(url: str) -> dict[str, str]: + assert url == "https://example.com/article" + return { + "url": url, + "title": "Example Article", + "media_type": "text/plain", + "data": "Article body", + } + + monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch) + request = MessagesRequest( + model="claude-haiku-4-5-20251001", + max_tokens=100, + messages=[ + Message(role="user", content="Fetch https://example.com/article please") + ], + tools=[Tool(name="web_fetch", type="web_fetch_20250910")], + tool_choice={"type": "tool", "name": "web_fetch"}, + ) + + events = [ + event + async for event in stream_web_server_tool_response(request, input_tokens=42) + ] + payloads = [_event_data(event) for event in events] + + assert payloads[1]["content_block"]["type"] == "server_tool_use" + assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result" + assert payloads[3]["content_block"]["content"]["content"]["title"] == ( + "Example Article" + ) + assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1} diff --git a/tests/conftest.py b/tests/conftest.py index 1967be2..463b557 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,6 +118,15 @@ def mock_platform(): platform.queue_edit_message = AsyncMock() platform.queue_delete_message = AsyncMock() + async def _queue_delete_messages( + chat_id: str, message_ids: list[str], *, fire_and_forget: bool = True + ) -> None: + qdm = platform.queue_delete_message + for mid in message_ids: + await qdm(chat_id, mid, fire_and_forget=fire_and_forget) + + platform.queue_delete_messages = AsyncMock(side_effect=_queue_delete_messages) + def _fire_and_forget(task): if asyncio.iscoroutine(task): # Create a task to avoid "coroutine was never awaited" warning diff --git a/tests/contracts/test_import_boundaries.py b/tests/contracts/test_import_boundaries.py index 0eed7d6..918d7d1 100644 --- a/tests/contracts/test_import_boundaries.py +++ b/tests/contracts/test_import_boundaries.py @@ -1,8 +1,20 @@ +"""Package import contract tests (static AST; dynamic ``importlib`` loads are not scanned).""" + from __future__ import annotations import ast from pathlib import Path +# `api` may only import this narrow ``providers`` surface (AGENTS/PLAN). +_API_ALLOWED_PROVIDER_MODULES = frozenset( + { + "providers", + "providers.base", + "providers.exceptions", + "providers.registry", + } +) + def test_api_and_messaging_do_not_import_provider_common() -> None: repo_root = Path(__file__).resolve().parents[2] @@ -25,6 +37,66 @@ def test_provider_adapters_do_not_import_runtime_layers() -> None: assert offenders == [] +def test_core_does_not_import_product_packages() -> None: + """Neutral ``core`` must stay independent of API, workers, and providers.""" + repo_root = Path(__file__).resolve().parents[2] + offenders = _imports_matching( + [repo_root / "core"], + forbidden_prefixes=( + "api.", + "messaging.", + "cli.", + "smoke.", + "providers.", + "config.", + ), + ) + assert offenders == [] + + +def test_config_does_not_import_non_config_packages() -> None: + """Settings and env handling must not depend on transport or protocol layers.""" + repo_root = Path(__file__).resolve().parents[2] + offenders = _imports_matching( + [repo_root / "config"], + forbidden_prefixes=( + "api.", + "messaging.", + "cli.", + "smoke.", + "providers.", + "core.", + ), + ) + assert offenders == [] + + +def test_messaging_does_not_import_api_or_cli_or_providers() -> None: + """Messaging is wired by ``api.runtime``; must not import server or provider adapters.""" + repo_root = Path(__file__).resolve().parents[2] + offenders = _imports_matching( + [repo_root / "messaging"], + forbidden_prefixes=("api.", "cli.", "providers.", "smoke."), + ) + assert offenders == [] + + +def test_api_may_only_import_narrow_provider_facade() -> None: + """HTTP layer must not depend on per-adapter provider subpackages.""" + repo_root = Path(__file__).resolve().parents[2] + offenders: list[str] = [] + for path in (repo_root / "api").rglob("*.py"): + for imported in _imports_from(path, repo_root): + if imported is None or not imported.startswith("providers"): + continue + if imported in _API_ALLOWED_PROVIDER_MODULES: + continue + if imported.startswith("providers."): + rel = path.relative_to(repo_root) + offenders.append(f"{rel}: {imported}") + assert sorted(offenders) == [] + + def test_removed_openrouter_rollback_transport_stays_removed() -> None: repo_root = Path(__file__).resolve().parents[2] @@ -35,6 +107,11 @@ def test_removed_openrouter_rollback_transport_stays_removed() -> None: def test_architecture_doc_names_enforced_boundaries() -> None: repo_root = Path(__file__).resolve().parents[2] + contract_test = repo_root / "tests" / "contracts" / "test_import_boundaries.py" + assert contract_test.is_file() + stream_contracts = repo_root / "core" / "anthropic" / "stream_contracts.py" + assert stream_contracts.is_file() + text = (repo_root / "PLAN.md").read_text(encoding="utf-8") assert "core/anthropic/" in text @@ -46,26 +123,89 @@ def _imports_matching( roots: list[Path], *, forbidden_prefixes: tuple[str, ...] ) -> list[str]: offenders: list[str] = [] + repo_root = roots[0].parent for root in roots: for path in root.rglob("*.py"): rel = path.relative_to(root.parent) offenders.extend( f"{rel}: {imported}" - for imported in _imports_from(path) - if imported in forbidden_prefixes - or imported.startswith(forbidden_prefixes) + for imported in _imports_from(path, repo_root) + if imported is not None and _is_forbidden(imported, forbidden_prefixes) ) return sorted(offenders) -def _imports_from(path: Path) -> list[str]: +def _is_forbidden(name: str, forbidden: tuple[str, ...]) -> bool: + """Match root modules (``import api``) and submodules (``import api.x``).""" + for token in forbidden: + if not token: + continue + root = token.rstrip(".") + if name == root or name.startswith(f"{root}."): + return True + return False + + +def _module_fqn_from_path(repo_root: Path, path: Path) -> str: + rel = path.relative_to(repo_root) + if rel.name == "__init__.py": + return ".".join(rel.parent.parts) if rel.parent != Path() else rel.parent.name + return ".".join(rel.with_suffix("").parts) + + +def _importing_package_parts(repo_root: Path, path: Path) -> list[str]: + """Package in which this file's module lives (for relative imports).""" + rel = path.relative_to(repo_root) + if rel.name == "__init__.py": + return list(rel.parent.parts) + fqn = _module_fqn_from_path(repo_root, path) + parts = fqn.split(".") + if len(parts) <= 1: + return [] + return parts[:-1] + + +def _resolve_relative_import( + repo_root: Path, path: Path, node: ast.ImportFrom +) -> str | None: + """Best-effort absolute name for ``from .x`` / ``from ..y`` (level >= 1).""" + if node.level == 0 and node.module: + return node.module + base = _importing_package_parts(repo_root, path) + for _ in range(node.level - 1): + if not base: + return None + base.pop() + if not node.module: + return ".".join(base) if base else None + return ".".join(base + node.module.split(".")) + + +def _imports_from(path: Path, repo_root: Path) -> list[str]: tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) imports: list[str] = [] for node in ast.walk(tree): if isinstance(node, ast.Import): imports.extend(alias.name for alias in node.names) - elif isinstance(node, ast.ImportFrom) and node.module: - imports.append(node.module) + elif isinstance(node, ast.ImportFrom): + if node.level == 0: + if node.module: + imports.append(node.module) + continue + if node.module is not None: + resolved = _resolve_relative_import(repo_root, path, node) + if resolved: + imports.append(resolved) + else: + base = _importing_package_parts(repo_root, path).copy() + for _ in range(node.level - 1): + if base: + base.pop() + for alias in node.names: + if base: + imports.append(".".join([*base, alias.name])) + else: + imports.append(alias.name) return imports diff --git a/tests/contracts/test_smoke_sse_reexport.py b/tests/contracts/test_smoke_sse_reexport.py new file mode 100644 index 0000000..4190312 --- /dev/null +++ b/tests/contracts/test_smoke_sse_reexport.py @@ -0,0 +1,11 @@ +"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`.""" + +from __future__ import annotations + +import core.anthropic.stream_contracts as core_sc +import smoke.lib.sse as smoke_sse + + +def test_smoke_lib_sse_reexports_core_stream_contracts() -> None: + for name in smoke_sse.__all__: + assert getattr(smoke_sse, name) is getattr(core_sc, name) diff --git a/tests/contracts/test_stream_contracts.py b/tests/contracts/test_stream_contracts.py index f6e964f..0125459 100644 --- a/tests/contracts/test_stream_contracts.py +++ b/tests/contracts/test_stream_contracts.py @@ -1,11 +1,14 @@ +"""Stream/SSE contract tests. Strict transcript *ordering* is covered here for +``SSEBuilder`` output; for transport-integrated ordering, add messaging or API +integration tests. +""" + from __future__ import annotations from collections.abc import Iterable from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser -from messaging.event_parser import parse_cli_event -from messaging.transcript import RenderCtx, TranscriptBuffer -from smoke.lib.sse import ( +from core.anthropic.stream_contracts import ( assert_anthropic_stream_contract, event_names, has_tool_use, @@ -13,6 +16,8 @@ from smoke.lib.sse import ( text_content, thinking_content, ) +from messaging.event_parser import parse_cli_event +from messaging.transcript import RenderCtx, TranscriptBuffer def test_interleaved_thinking_text_blocks_are_valid() -> None: diff --git a/tests/messaging/test_messaging_factory.py b/tests/messaging/test_messaging_factory.py index 9078e65..5d3c0a3 100644 --- a/tests/messaging/test_messaging_factory.py +++ b/tests/messaging/test_messaging_factory.py @@ -2,7 +2,10 @@ from unittest.mock import MagicMock, patch -from messaging.platforms.factory import create_messaging_platform +from messaging.platforms.factory import ( + MessagingPlatformOptions, + create_messaging_platform, +) class TestCreateMessagingPlatform: @@ -16,15 +19,29 @@ class TestCreateMessagingPlatform: patch( "messaging.platforms.telegram.TelegramPlatform", return_value=mock_platform, - ), + ) as platform_cls, ): result = create_messaging_platform( "telegram", - bot_token="test_token", - allowed_user_id="12345", + MessagingPlatformOptions( + telegram_bot_token="test_token", + allowed_telegram_user_id="12345", + voice_note_enabled=False, + whisper_model="large-v3", + whisper_device="cuda", + ), ) assert result is mock_platform + platform_cls.assert_called_once_with( + bot_token="test_token", + allowed_user_id="12345", + voice_note_enabled=False, + whisper_model="large-v3", + whisper_device="cuda", + hf_token="", + nvidia_nim_api_key="", + ) def test_telegram_without_token(self): """Return None when no bot_token for Telegram.""" @@ -33,7 +50,9 @@ class TestCreateMessagingPlatform: def test_telegram_empty_token(self): """Return None when bot_token is empty string.""" - result = create_messaging_platform("telegram", bot_token="") + result = create_messaging_platform( + "telegram", MessagingPlatformOptions(telegram_bot_token="") + ) assert result is None def test_discord_with_token(self): @@ -44,15 +63,29 @@ class TestCreateMessagingPlatform: patch( "messaging.platforms.discord.DiscordPlatform", return_value=mock_platform, - ), + ) as platform_cls, ): result = create_messaging_platform( "discord", - discord_bot_token="test_token", - allowed_discord_channels="123,456", + MessagingPlatformOptions( + discord_bot_token="test_token", + allowed_discord_channels="123,456", + voice_note_enabled=False, + whisper_model="small", + whisper_device="nvidia_nim", + ), ) assert result is mock_platform + platform_cls.assert_called_once_with( + bot_token="test_token", + allowed_channel_ids="123,456", + voice_note_enabled=False, + whisper_model="small", + whisper_device="nvidia_nim", + hf_token="", + nvidia_nim_api_key="", + ) def test_discord_without_token(self): """Return None when no discord_bot_token for Discord.""" @@ -62,7 +95,11 @@ class TestCreateMessagingPlatform: def test_discord_empty_token(self): """Return None when discord_bot_token is empty string.""" result = create_messaging_platform( - "discord", discord_bot_token="", allowed_discord_channels="123" + "discord", + MessagingPlatformOptions( + discord_bot_token="", + allowed_discord_channels="123", + ), ) assert result is None @@ -73,5 +110,7 @@ class TestCreateMessagingPlatform: def test_unknown_platform_with_kwargs(self): """Return None for unknown platform even with kwargs.""" - result = create_messaging_platform("slack", bot_token="token") + result = create_messaging_platform( + "slack", MessagingPlatformOptions(telegram_bot_token="token") + ) assert result is None diff --git a/tests/messaging/test_voice_handlers.py b/tests/messaging/test_voice_handlers.py index b2537c8..1228507 100644 --- a/tests/messaging/test_voice_handlers.py +++ b/tests/messaging/test_voice_handlers.py @@ -17,19 +17,21 @@ def telegram_platform(): @pytest.mark.asyncio -async def test_telegram_voice_disabled_sends_reply(telegram_platform): +async def test_telegram_voice_disabled_sends_reply(): """When voice_note_enabled is False, reply with disabled message.""" + with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): + telegram_platform = TelegramPlatform( + bot_token="test_token", + allowed_user_id="12345", + voice_note_enabled=False, + ) mock_update = MagicMock() mock_update.message.voice = MagicMock(file_id="f1", mime_type="audio/ogg") mock_update.effective_user.id = 12345 mock_update.effective_chat.id = 6789 mock_update.message.reply_text = AsyncMock() - with patch( - "config.settings.get_settings", - return_value=MagicMock(voice_note_enabled=False), - ): - await telegram_platform._on_telegram_voice(mock_update, MagicMock()) + await telegram_platform._on_telegram_voice(mock_update, MagicMock()) mock_update.message.reply_text.assert_called_once_with("Voice notes are disabled.") @@ -42,11 +44,7 @@ async def test_telegram_voice_unauthorized_ignored(telegram_platform): mock_update.effective_user.id = 99999 # Not 12345 mock_update.message.reply_text = AsyncMock() - with patch( - "config.settings.get_settings", - return_value=MagicMock(voice_note_enabled=True), - ): - await telegram_platform._on_telegram_voice(mock_update, MagicMock()) + await telegram_platform._on_telegram_voice(mock_update, MagicMock()) mock_update.message.reply_text.assert_not_called() @@ -82,17 +80,8 @@ async def test_telegram_voice_success_invokes_handler(telegram_platform): mock_file.download_to_drive = fake_download - mock_settings = MagicMock( - voice_note_enabled=True, - whisper_model="base", - ) - mock_queue_send = AsyncMock(return_value="999") with ( - patch( - "config.settings.get_settings", - return_value=mock_settings, - ), patch( "messaging.transcription.transcribe_audio", return_value="Hello from voice", @@ -164,7 +153,11 @@ class TestDiscordGetAudioAttachment: @pytest.mark.asyncio async def test_discord_voice_disabled_sends_reply(): """When voice_note_enabled is False, reply with disabled message.""" - platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123") + platform = DiscordPlatform( + bot_token="token", + allowed_channel_ids="123", + voice_note_enabled=False, + ) platform._message_handler = None mock_message = MagicMock() @@ -178,10 +171,6 @@ async def test_discord_voice_disabled_sends_reply(): mock_att.filename = "voice.ogg" mock_message.attachments = [mock_att] - with patch( - "config.settings.get_settings", - return_value=MagicMock(voice_note_enabled=False), - ): - await platform._on_discord_message(mock_message) + await platform._on_discord_message(mock_message) mock_message.reply.assert_called_once_with("Voice notes are disabled.") diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py index ae5ec6c..7312fbd 100644 --- a/tests/providers/test_converter.py +++ b/tests/providers/test_converter.py @@ -24,7 +24,7 @@ class MockBlock: class MockTool: - def __init__(self, name, description, input_schema): + def __init__(self, name, description, input_schema=None): self.name = name self.description = description self.input_schema = input_schema @@ -79,6 +79,23 @@ def test_convert_tools(): assert result[1]["function"]["description"] == "" # Check default empty string +def test_convert_tool_without_input_schema_uses_empty_object_schema(): + tools = [MockTool("web_search", None)] + + result = AnthropicToOpenAIConverter.convert_tools(tools) + + assert result == [ + { + "type": "function", + "function": { + "name": "web_search", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + @pytest.mark.parametrize( "tool_choice,expected", [ diff --git a/tests/providers/test_parsers.py b/tests/providers/test_parsers.py index 0b36dc7..c351a7a 100644 --- a/tests/providers/test_parsers.py +++ b/tests/providers/test_parsers.py @@ -449,6 +449,40 @@ def test_heuristic_tool_parser_flush_no_tool(): assert tools == [] +def test_heuristic_tool_parser_json_style_web_fetch_tool_call(): + parser = HeuristicToolParser() + text = ( + "Use WebFetch on the article.\n\n" + "{\n" + ' "url": "https://example.com/article",\n' + ' "prompt": "Summarize it."\n' + "}\n" + ) + + filtered, tools = parser.feed(text) + tools.extend(parser.flush()) + + assert filtered == "" + assert len(tools) == 1 + assert tools[0]["name"] == "WebFetch" + assert tools[0]["input"] == { + "url": "https://example.com/article", + "prompt": "Summarize it.", + } + + +def test_heuristic_tool_parser_json_style_web_search_tool_call(): + parser = HeuristicToolParser() + + filtered, tools = parser.feed('Use WebSearch {"query": "DeepSeek V4"}') + tools.extend(parser.flush()) + + assert filtered == "" + assert len(tools) == 1 + assert tools[0]["name"] == "WebSearch" + assert tools[0]["input"] == {"query": "DeepSeek V4"} + + def test_heuristic_tool_parser_unicode_function_name(): """Unicode characters in function parameters.""" parser = HeuristicToolParser() diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 72b4714..f30f74b 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -1,9 +1,13 @@ -from unittest.mock import MagicMock, patch +import subprocess +import sys +from unittest.mock import AsyncMock, MagicMock, patch import pytest from config.nim import NimSettings +from config.provider_ids import SUPPORTED_PROVIDER_IDS from providers.deepseek import DeepSeekProvider +from providers.exceptions import UnknownProviderTypeError from providers.llamacpp import LlamaCppProvider from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NvidiaNimProvider @@ -41,14 +45,24 @@ def _make_settings(**overrides): return mock +def test_importing_registry_does_not_eager_load_other_adapters() -> None: + """Registry metadata must not import every provider adapter up front.""" + code = ( + "import sys\n" + "import providers.registry\n" + "assert 'providers.open_router' not in sys.modules\n" + ) + proc = subprocess.run( + [sys.executable, "-c", code], + check=False, + capture_output=True, + text=True, + ) + assert proc.returncode == 0, proc.stderr or proc.stdout + + def test_descriptors_cover_advertised_provider_ids(): - assert set(PROVIDER_DESCRIPTORS) == { - "nvidia_nim", - "open_router", - "deepseek", - "lmstudio", - "llamacpp", - } + assert set(PROVIDER_DESCRIPTORS) == set(SUPPORTED_PROVIDER_IDS) for descriptor in PROVIDER_DESCRIPTORS.values(): assert descriptor.provider_id assert descriptor.transport_type in {"openai_chat", "anthropic_messages"} @@ -90,6 +104,38 @@ def test_provider_registry_caches_by_provider_id(): assert first is second -def test_unknown_provider_raises_value_error(): - with pytest.raises(ValueError, match="Unknown provider_type"): +def test_unknown_provider_raises_unknown_provider_type_error(): + with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"): create_provider("unknown", _make_settings()) + + +@pytest.mark.asyncio +async def test_provider_registry_cleanup_runs_all_even_if_one_fails() -> None: + """Every provider gets cleanup; cache is cleared even when one raises.""" + reg = ProviderRegistry() + p1 = MagicMock() + p1.cleanup = AsyncMock(side_effect=RuntimeError("first")) + p2 = MagicMock() + p2.cleanup = AsyncMock() + reg._providers["a"] = p1 + reg._providers["b"] = p2 + with pytest.raises(RuntimeError, match="first"): + await reg.cleanup() + p1.cleanup.assert_awaited_once() + p2.cleanup.assert_awaited_once() + assert reg._providers == {} + + +@pytest.mark.asyncio +async def test_provider_registry_cleanup_exceptiongroup_on_multiple_failures() -> None: + reg = ProviderRegistry() + p1 = MagicMock() + p1.cleanup = AsyncMock(side_effect=RuntimeError("a")) + p2 = MagicMock() + p2.cleanup = AsyncMock(side_effect=RuntimeError("b")) + reg._providers["x"] = p1 + reg._providers["y"] = p2 + with pytest.raises(ExceptionGroup) as exc_info: + await reg.cleanup() + assert len(exc_info.value.exceptions) == 2 + assert reg._providers == {}