mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-26 10:31:07 +00:00
feat: Anthropic web server tools, provider metadata, messaging hardening
- Add local web_search/web_fetch SSE handling and optional tool schemas - Extend HeuristicToolParser for JSON-style WebFetch/WebSearch text - Consolidate provider defaults, ids, and exception typing; stream contracts - Messaging: typed options, voice config injection, platform contract cleanup - Tests for web server tools, converters, parsers, contracts; ignore debug-*.log
This commit is contained in:
parent
4b89183ba0
commit
b926f60f64
50 changed files with 1658 additions and 439 deletions
|
|
@ -66,6 +66,7 @@ VOICE_NOTE_ENABLED=false
|
|||
# WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim"
|
||||
# - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local)
|
||||
# - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice)
|
||||
# (Independent of MODEL=nvidia_nim/...: that selects the *chat* provider; this selects voice STT only.)
|
||||
WHISPER_DEVICE="nvidia_nim"
|
||||
# WHISPER_MODEL:
|
||||
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,6 +8,7 @@ __pycache__
|
|||
agent_workspace
|
||||
.env
|
||||
server.log
|
||||
debug-*.log
|
||||
.coverage
|
||||
llama_cache
|
||||
.smoke-results
|
||||
|
|
|
|||
27
PLAN.md
27
PLAN.md
|
|
@ -33,20 +33,43 @@ flowchart TD
|
|||
core --> providers
|
||||
core --> messaging
|
||||
providers --> api
|
||||
api --> cli[cli]
|
||||
api --> messaging
|
||||
cli --> messaging
|
||||
messaging --> api
|
||||
```
|
||||
|
||||
Runtime note: `api.runtime` imports `cli` and `messaging` to wire the optional
|
||||
messaging stack; `messaging` does not import `cli` (session/CLI access is passed
|
||||
in from `api.runtime`).
|
||||
|
||||
The practical rule is simpler than the graph: shared protocol helpers belong in
|
||||
neutral core modules, not under a provider package. Provider adapters may depend
|
||||
on the neutral protocol layer, but API and messaging code should not import
|
||||
provider internals.
|
||||
|
||||
The diagram above mixes **Python import direction** (e.g. `config` → `providers`)
|
||||
with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging`).
|
||||
`PLAN.md` remains the product map; **encoded** rules (including root imports like
|
||||
`import api`, relative imports, and `api` → `providers` facade allowlists) live in
|
||||
`tests/contracts/test_import_boundaries.py`.
|
||||
|
||||
**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`,
|
||||
and `providers.registry` from the providers package (not per-adapter modules).
|
||||
`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`.
|
||||
`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract
|
||||
assertions for default CI live under `core/anthropic/stream_contracts.py`;
|
||||
`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers
|
||||
(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and
|
||||
unit tests; production HTTP handlers must use `resolve_provider` with
|
||||
`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package
|
||||
`__all__` exposes HTTP models and `create_app` only (not those helpers).
|
||||
|
||||
## Target Boundaries
|
||||
|
||||
- `core/anthropic/`: Anthropic protocol helpers, stream primitives, content
|
||||
extraction, token estimation, user-facing error strings, request conversion,
|
||||
thinking, and tool helpers shared across API, providers, messaging, and tests.
|
||||
thinking, tool helpers, and stream contract assertions
|
||||
(`stream_contracts.py`) shared across API, providers, messaging, and tests.
|
||||
- `api/runtime.py`: application composition, optional messaging startup,
|
||||
session store restoration, and cleanup ownership.
|
||||
- `providers/`: provider descriptors, credential resolution, transport
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""API layer for Claude Code Proxy."""
|
||||
|
||||
from .app import app, create_app
|
||||
from .dependencies import get_provider, get_provider_for_type
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
|
|
@ -16,6 +15,4 @@ __all__ = [
|
|||
"TokenCountResponse",
|
||||
"app",
|
||||
"create_app",
|
||||
"get_provider",
|
||||
"get_provider_for_type",
|
||||
]
|
||||
|
|
|
|||
62
api/app.py
62
api/app.py
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exception_handlers import request_validation_exception_handler
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -11,7 +14,6 @@ from config.logging_config import configure_logging
|
|||
from config.settings import get_settings
|
||||
from providers.exceptions import ProviderError
|
||||
|
||||
from .dependencies import cleanup_provider
|
||||
from .routes import router
|
||||
from .runtime import AppRuntime
|
||||
|
||||
|
|
@ -26,9 +28,7 @@ configure_logging(_settings.log_file)
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
runtime = AppRuntime.for_app(
|
||||
app, settings=get_settings(), provider_cleanup=cleanup_provider
|
||||
)
|
||||
runtime = AppRuntime.for_app(app, settings=get_settings())
|
||||
await runtime.startup()
|
||||
|
||||
yield
|
||||
|
|
@ -48,6 +48,60 @@ def create_app() -> FastAPI:
|
|||
app.include_router(router)
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(request: Request, exc: RequestValidationError):
|
||||
"""Log request shape for 422 debugging without content values."""
|
||||
body: Any
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
body = {"_json_error": type(e).__name__}
|
||||
|
||||
messages = body.get("messages") if isinstance(body, dict) else None
|
||||
message_summary: list[dict[str, Any]] = []
|
||||
if isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
message_summary.append({"message_kind": type(msg).__name__})
|
||||
continue
|
||||
content = msg.get("content")
|
||||
item: dict[str, Any] = {
|
||||
"role": msg.get("role"),
|
||||
"content_kind": type(content).__name__,
|
||||
}
|
||||
if isinstance(content, list):
|
||||
item["block_types"] = [
|
||||
block.get("type", "dict")
|
||||
if isinstance(block, dict)
|
||||
else type(block).__name__
|
||||
for block in content[:12]
|
||||
]
|
||||
item["block_keys"] = [
|
||||
sorted(str(key) for key in block)[:12]
|
||||
for block in content[:5]
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
elif isinstance(content, str):
|
||||
item["content_length"] = len(content)
|
||||
message_summary.append(item)
|
||||
|
||||
logger.debug(
|
||||
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
|
||||
request.url.path,
|
||||
str(request.url.query),
|
||||
[list(error.get("loc", ())) for error in exc.errors()],
|
||||
[str(error.get("type", "")) for error in exc.errors()],
|
||||
message_summary,
|
||||
[
|
||||
str(tool.get("name", ""))
|
||||
for tool in body.get("tools", [])
|
||||
if isinstance(body, dict)
|
||||
and isinstance(body.get("tools"), list)
|
||||
and isinstance(tool, dict)
|
||||
],
|
||||
)
|
||||
return await request_validation_exception_handler(request, exc)
|
||||
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
"""Handle provider-specific errors and return Anthropic format."""
|
||||
|
|
|
|||
|
|
@ -2,15 +2,18 @@
|
|||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from config.settings import Settings
|
||||
from config.settings import get_settings as _get_settings
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from providers.base import BaseProvider
|
||||
from providers.exceptions import AuthenticationError
|
||||
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
||||
|
||||
# Provider registry: keyed by provider type string, lazily populated
|
||||
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
|
||||
# when there is no ``Request``/``app`` (unit tests, scripts). HTTP handlers must pass
|
||||
# ``app`` to :func:`resolve_provider` so the app-scoped registry is used.
|
||||
_providers: dict[str, BaseProvider] = {}
|
||||
|
||||
|
||||
|
|
@ -19,19 +22,43 @@ def get_settings() -> Settings:
|
|||
return _get_settings()
|
||||
|
||||
|
||||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider for the given provider type.
|
||||
def resolve_provider(
|
||||
provider_type: str,
|
||||
*,
|
||||
app: Starlette | None,
|
||||
settings: Settings,
|
||||
) -> BaseProvider:
|
||||
"""Resolve a provider using the app-scoped registry when ``app`` is set.
|
||||
|
||||
Providers are cached in the registry and reused across requests.
|
||||
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
|
||||
is always used. If the registry is missing (e.g. a test app without
|
||||
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
|
||||
is installed on ``app.state`` so the process cache is never mixed with
|
||||
per-request app identity.
|
||||
|
||||
When ``app`` is ``None`` (no HTTP context), uses the process-level
|
||||
:data:`_providers` cache only.
|
||||
"""
|
||||
should_log_init = provider_type not in _providers
|
||||
if app is not None:
|
||||
reg = getattr(app.state, "provider_registry", None)
|
||||
if reg is None:
|
||||
reg = ProviderRegistry()
|
||||
app.state.provider_registry = reg
|
||||
return _resolve_with_registry(reg, provider_type, settings)
|
||||
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
|
||||
|
||||
|
||||
def _resolve_with_registry(
|
||||
registry: ProviderRegistry, provider_type: str, settings: Settings
|
||||
) -> BaseProvider:
|
||||
should_log_init = not registry.is_cached(provider_type)
|
||||
try:
|
||||
provider = ProviderRegistry(_providers).get(provider_type, get_settings())
|
||||
provider = registry.get(provider_type, settings)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
except ValueError:
|
||||
except UnknownProviderTypeError:
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: {}",
|
||||
provider_type,
|
||||
|
|
@ -43,6 +70,15 @@ def get_provider_for_type(provider_type: str) -> BaseProvider:
|
|||
return provider
|
||||
|
||||
|
||||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider in the process-level cache (no ``app``/Request).
|
||||
|
||||
For server requests, use :func:`resolve_provider` with the active
|
||||
:attr:`request.app` so the app-scoped provider registry is used.
|
||||
"""
|
||||
return resolve_provider(provider_type, app=None, settings=get_settings())
|
||||
|
||||
|
||||
def require_api_key(
|
||||
request: Request, settings: Settings = Depends(get_settings)
|
||||
) -> None:
|
||||
|
|
@ -78,9 +114,11 @@ def require_api_key(
|
|||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the default provider (based on MODEL env var).
|
||||
"""Get or create the default provider (``MODEL`` / ``provider_type``).
|
||||
|
||||
Backward-compatible convenience for health/root endpoints and tests.
|
||||
Process-cache helper for scripts, unit tests, and non-FastAPI callers. HTTP
|
||||
handlers must use :func:`resolve_provider` with :attr:`request.app` so the
|
||||
app-scoped :class:`~providers.registry.ProviderRegistry` is used.
|
||||
"""
|
||||
return get_provider_for_type(get_settings().provider_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,8 +75,11 @@ class Message(BaseModel):
|
|||
|
||||
class Tool(BaseModel):
|
||||
name: str
|
||||
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
|
||||
# may omit ``input_schema`` because the provider owns the schema.
|
||||
type: str | None = None
|
||||
description: str | None = None
|
||||
input_schema: dict[str, Any]
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ThinkingConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from loguru import logger
|
|||
from config.settings import Settings
|
||||
from core.anthropic import get_token_count
|
||||
|
||||
from .dependencies import get_provider_for_type, get_settings, require_api_key
|
||||
from . import dependencies
|
||||
from .dependencies import get_settings, require_api_key
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import ModelResponse, ModelsListResponse
|
||||
from .services import ClaudeProxyService
|
||||
|
|
@ -54,12 +55,15 @@ SUPPORTED_CLAUDE_MODELS = [
|
|||
|
||||
|
||||
def get_proxy_service(
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
) -> ClaudeProxyService:
|
||||
"""Build the request service for route handlers."""
|
||||
return ClaudeProxyService(
|
||||
settings,
|
||||
provider_getter=get_provider_for_type,
|
||||
provider_getter=lambda provider_type: dependencies.resolve_provider(
|
||||
provider_type, app=request.app, settings=settings
|
||||
),
|
||||
token_counter=get_token_count,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,16 +4,20 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from providers.registry import ProviderRegistry
|
||||
|
||||
from .dependencies import cleanup_provider
|
||||
if TYPE_CHECKING:
|
||||
from cli.manager import CLISessionManager
|
||||
from messaging.handler import ClaudeMessageHandler
|
||||
from messaging.platforms.base import MessagingPlatform
|
||||
from messaging.session import SessionStore
|
||||
|
||||
_SHUTDOWN_TIMEOUT_S = 5.0
|
||||
|
||||
|
|
@ -32,8 +36,7 @@ async def best_effort(
|
|||
|
||||
def warn_if_process_auth_token(settings: Settings) -> None:
|
||||
"""Warn when server auth was implicitly inherited from the shell."""
|
||||
uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None)
|
||||
if callable(uses_process_token) and uses_process_token():
|
||||
if settings.uses_process_anthropic_auth_token():
|
||||
logger.warning(
|
||||
"ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
|
||||
"a configured .env file. The proxy will require that token. Add "
|
||||
|
|
@ -48,32 +51,29 @@ class AppRuntime:
|
|||
|
||||
app: FastAPI
|
||||
settings: Settings
|
||||
provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider
|
||||
messaging_platform: Any = None
|
||||
message_handler: Any = None
|
||||
cli_manager: Any = None
|
||||
_provider_registry: ProviderRegistry | None = field(default=None, init=False)
|
||||
messaging_platform: MessagingPlatform | None = None
|
||||
message_handler: ClaudeMessageHandler | None = None
|
||||
cli_manager: CLISessionManager | None = None
|
||||
|
||||
@classmethod
|
||||
def for_app(
|
||||
cls,
|
||||
app: FastAPI,
|
||||
settings: Settings | None = None,
|
||||
provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider,
|
||||
) -> AppRuntime:
|
||||
return cls(
|
||||
app=app,
|
||||
settings=settings or get_settings(),
|
||||
provider_cleanup=provider_cleanup,
|
||||
)
|
||||
return cls(app=app, settings=settings or get_settings())
|
||||
|
||||
async def startup(self) -> None:
|
||||
logger.info("Starting Claude Code Proxy...")
|
||||
self._provider_registry = ProviderRegistry()
|
||||
self.app.state.provider_registry = self._provider_registry
|
||||
warn_if_process_auth_token(self.settings)
|
||||
await self._start_messaging_if_configured()
|
||||
self._publish_state()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.message_handler and hasattr(self.message_handler, "session_store"):
|
||||
if self.message_handler is not None:
|
||||
try:
|
||||
self.message_handler.session_store.flush_pending_save()
|
||||
except Exception as e:
|
||||
|
|
@ -84,20 +84,33 @@ class AppRuntime:
|
|||
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
|
||||
if self.cli_manager:
|
||||
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
|
||||
await best_effort("cleanup_provider", self.provider_cleanup())
|
||||
if self._provider_registry is not None:
|
||||
await best_effort(
|
||||
"provider_registry.cleanup", self._provider_registry.cleanup()
|
||||
)
|
||||
await self._shutdown_limiter()
|
||||
logger.info("Server shut down cleanly")
|
||||
|
||||
async def _start_messaging_if_configured(self) -> None:
|
||||
try:
|
||||
from messaging.platforms.factory import create_messaging_platform
|
||||
from messaging.platforms.factory import (
|
||||
MessagingPlatformOptions,
|
||||
create_messaging_platform,
|
||||
)
|
||||
|
||||
self.messaging_platform = create_messaging_platform(
|
||||
platform_type=self.settings.messaging_platform,
|
||||
bot_token=self.settings.telegram_bot_token,
|
||||
allowed_user_id=self.settings.allowed_telegram_user_id,
|
||||
self.settings.messaging_platform,
|
||||
MessagingPlatformOptions(
|
||||
telegram_bot_token=self.settings.telegram_bot_token,
|
||||
allowed_telegram_user_id=self.settings.allowed_telegram_user_id,
|
||||
discord_bot_token=self.settings.discord_bot_token,
|
||||
allowed_discord_channels=self.settings.allowed_discord_channels,
|
||||
voice_note_enabled=self.settings.voice_note_enabled,
|
||||
whisper_model=self.settings.whisper_model,
|
||||
whisper_device=self.settings.whisper_device,
|
||||
hf_token=self.settings.hf_token,
|
||||
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
|
||||
),
|
||||
)
|
||||
|
||||
if self.messaging_platform:
|
||||
|
|
@ -137,29 +150,31 @@ class AppRuntime:
|
|||
api_url=api_url,
|
||||
allowed_dirs=allowed_dirs,
|
||||
plans_directory=plans_directory,
|
||||
claude_bin=getattr(self.settings, "claude_cli_bin", "claude"),
|
||||
claude_bin=self.settings.claude_cli_bin,
|
||||
)
|
||||
|
||||
session_store = SessionStore(
|
||||
storage_path=os.path.join(data_path, "sessions.json")
|
||||
)
|
||||
platform = self.messaging_platform
|
||||
assert platform is not None
|
||||
self.message_handler = ClaudeMessageHandler(
|
||||
platform=self.messaging_platform,
|
||||
platform=platform,
|
||||
cli_manager=self.cli_manager,
|
||||
session_store=session_store,
|
||||
)
|
||||
self._restore_tree_state(session_store)
|
||||
|
||||
self.messaging_platform.on_message(self.message_handler.handle_message)
|
||||
await self.messaging_platform.start()
|
||||
logger.info(
|
||||
f"{self.messaging_platform.name} platform started with message handler"
|
||||
)
|
||||
platform.on_message(self.message_handler.handle_message)
|
||||
await platform.start()
|
||||
logger.info(f"{platform.name} platform started with message handler")
|
||||
|
||||
def _restore_tree_state(self, session_store: Any) -> None:
|
||||
def _restore_tree_state(self, session_store: SessionStore) -> None:
|
||||
saved_trees = session_store.get_all_trees()
|
||||
if not saved_trees:
|
||||
return
|
||||
if self.message_handler is None:
|
||||
return
|
||||
|
||||
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
|
||||
from messaging.trees.queue_manager import TreeQueueManager
|
||||
|
|
@ -188,11 +203,16 @@ class AppRuntime:
|
|||
async def _shutdown_limiter(self) -> None:
|
||||
try:
|
||||
from messaging.limiter import MessagingRateLimiter
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Rate limiter shutdown skipped (import failed): {}: {}",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
await best_effort(
|
||||
"MessagingRateLimiter.shutdown_instance",
|
||||
MessagingRateLimiter.shutdown_instance(),
|
||||
timeout_s=2.0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ from .model_router import ModelRouter
|
|||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
from .web_server_tools import (
|
||||
is_web_server_tool_request,
|
||||
stream_web_server_tool_response,
|
||||
)
|
||||
|
||||
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
||||
|
||||
|
|
@ -48,6 +52,22 @@ class ClaudeProxyService:
|
|||
raise InvalidRequestError("messages cannot be empty")
|
||||
|
||||
routed = self._model_router.resolve_messages_request(request_data)
|
||||
if is_web_server_tool_request(routed.request):
|
||||
input_tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
)
|
||||
logger.info("Optimization: Handling Anthropic web server tool")
|
||||
return StreamingResponse(
|
||||
stream_web_server_tool_response(
|
||||
routed.request, input_tokens=input_tokens
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
optimized = try_optimizations(routed.request, self._settings)
|
||||
if optimized is not None:
|
||||
|
|
|
|||
331
api/web_server_tools.py
Normal file
331
api/web_server_tools.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
"""Local handlers for Anthropic web server tools.
|
||||
|
||||
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
|
||||
web tools are server-side: the API response itself must include the tool result.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from .models.anthropic import MessagesRequest
|
||||
|
||||
_REQUEST_TIMEOUT_S = 20.0
|
||||
_MAX_SEARCH_RESULTS = 10
|
||||
_MAX_FETCH_CHARS = 24_000
|
||||
|
||||
|
||||
class _SearchResultParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results: list[dict[str, str]] = []
|
||||
self._href: str | None = None
|
||||
self._title_parts: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag != "a":
|
||||
return
|
||||
href = dict(attrs).get("href")
|
||||
if not href or "uddg=" not in href:
|
||||
return
|
||||
parsed = urlparse(href)
|
||||
query = parse_qs(parsed.query)
|
||||
uddg = query.get("uddg", [""])[0]
|
||||
if not uddg:
|
||||
return
|
||||
self._href = unquote(uddg)
|
||||
self._title_parts = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._href is not None:
|
||||
self._title_parts.append(data)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag != "a" or self._href is None:
|
||||
return
|
||||
title = " ".join("".join(self._title_parts).split())
|
||||
if title and not any(result["url"] == self._href for result in self.results):
|
||||
self.results.append({"title": html.unescape(title), "url": self._href})
|
||||
self._href = None
|
||||
self._title_parts = []
|
||||
|
||||
|
||||
class _HTMLTextParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.title = ""
|
||||
self.text_parts: list[str] = []
|
||||
self._in_title = False
|
||||
self._skip_depth = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag in {"script", "style", "noscript"}:
|
||||
self._skip_depth += 1
|
||||
elif tag == "title":
|
||||
self._in_title = True
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
||||
self._skip_depth -= 1
|
||||
elif tag == "title":
|
||||
self._in_title = False
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
text = " ".join(data.split())
|
||||
if not text:
|
||||
return
|
||||
if self._in_title:
|
||||
self.title = f"{self.title} {text}".strip()
|
||||
elif not self._skip_depth:
|
||||
self.text_parts.append(text)
|
||||
|
||||
|
||||
def _format_event(event_type: str, data: dict[str, Any]) -> str:
|
||||
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _content_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(str(item.get("text", "")))
|
||||
else:
|
||||
parts.append(str(getattr(item, "text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _request_text(request: MessagesRequest) -> str:
|
||||
return "\n".join(_content_text(message.content) for message in request.messages)
|
||||
|
||||
|
||||
def _web_tool_name(request: MessagesRequest) -> str | None:
|
||||
for tool in request.tools or []:
|
||||
name = tool.name
|
||||
tool_type = tool.type or ""
|
||||
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
||||
return _web_tool_name(request) in {"web_search", "web_fetch"}
|
||||
|
||||
|
||||
def _extract_query(text: str) -> str:
|
||||
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip().strip("\"'")
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _extract_url(text: str) -> str:
|
||||
match = re.search(r"https?://\S+", text)
|
||||
return match.group(0).rstrip(").,]") if match else text.strip()
|
||||
|
||||
|
||||
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"https://lite.duckduckgo.com/lite/",
|
||||
params={"q": query},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
parser = _SearchResultParser()
|
||||
parser.feed(response.text)
|
||||
return parser.results[:_MAX_SEARCH_RESULTS]
|
||||
|
||||
|
||||
async def _run_web_fetch(url: str) -> dict[str, str]:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "text/plain")
|
||||
title = url
|
||||
data = response.text
|
||||
if "html" in content_type.lower():
|
||||
parser = _HTMLTextParser()
|
||||
parser.feed(response.text)
|
||||
title = parser.title or url
|
||||
data = "\n".join(parser.text_parts)
|
||||
return {
|
||||
"url": str(response.url),
|
||||
"title": title,
|
||||
"media_type": "text/plain",
|
||||
"data": data[:_MAX_FETCH_CHARS],
|
||||
}
|
||||
|
||||
|
||||
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
||||
if not results:
|
||||
return f"No web search results found for: {query}"
|
||||
lines = [f"Search results for: {query}"]
|
||||
for index, result in enumerate(results, start=1):
|
||||
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def stream_web_server_tool_response(
|
||||
request: MessagesRequest, input_tokens: int
|
||||
) -> AsyncIterator[str]:
|
||||
tool_name = _web_tool_name(request)
|
||||
if tool_name is None:
|
||||
return
|
||||
|
||||
text = _request_text(request)
|
||||
message_id = f"msg_{uuid.uuid4()}"
|
||||
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
||||
output_tokens = 1
|
||||
usage_key = (
|
||||
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
||||
)
|
||||
tool_input = (
|
||||
{"query": _extract_query(text)}
|
||||
if tool_name == "web_search"
|
||||
else {"url": _extract_url(text)}
|
||||
)
|
||||
|
||||
yield _format_event(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": request.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "server_tool_use",
|
||||
"id": tool_id,
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "web_search":
|
||||
query = str(tool_input["query"])
|
||||
results = await _run_web_search(query)
|
||||
result_content: Any = [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
summary = _search_summary(query, results)
|
||||
result_block_type = "web_search_tool_result"
|
||||
else:
|
||||
fetched = await _run_web_fetch(str(tool_input["url"]))
|
||||
result_content = {
|
||||
"type": "web_fetch_result",
|
||||
"url": fetched["url"],
|
||||
"content": {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": fetched["media_type"],
|
||||
"data": fetched["data"],
|
||||
},
|
||||
"title": fetched["title"],
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
"retrieved_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
||||
result_block_type = "web_fetch_tool_result"
|
||||
except Exception as error:
|
||||
result_block_type = (
|
||||
"web_search_tool_result"
|
||||
if tool_name == "web_search"
|
||||
else "web_fetch_tool_result"
|
||||
)
|
||||
error_type = (
|
||||
"web_search_tool_result_error"
|
||||
if tool_name == "web_search"
|
||||
else "web_fetch_tool_error"
|
||||
)
|
||||
result_content = {"type": error_type, "error_code": "unavailable"}
|
||||
summary = f"{tool_name} failed: {type(error).__name__}"
|
||||
|
||||
output_tokens = max(1, len(summary) // 4)
|
||||
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": result_block_type,
|
||||
"tool_use_id": tool_id,
|
||||
"content": result_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": {"type": "text", "text": summary},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
||||
)
|
||||
yield _format_event(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"server_tool_use": {usage_key: 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event("message_stop", {"type": "message_stop"})
|
||||
17
config/provider_ids.py
Normal file
17
config/provider_ids.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Canonical list of model provider type prefixes (provider_id values).
|
||||
|
||||
`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this
|
||||
module holds the id set for config validation and must stay in sync
|
||||
(registries assert in `providers.registry`).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys.
|
||||
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = (
|
||||
"nvidia_nim",
|
||||
"open_router",
|
||||
"deepseek",
|
||||
"lmstudio",
|
||||
"llamacpp",
|
||||
)
|
||||
|
|
@ -11,6 +11,7 @@ from pydantic import Field, field_validator, model_validator
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .nim import NimSettings
|
||||
from .provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
|
||||
|
||||
def _env_files() -> tuple[Path, ...]:
|
||||
|
|
@ -252,25 +253,16 @@ class Settings(BaseSettings):
|
|||
def validate_model_format(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
valid_providers = (
|
||||
"nvidia_nim",
|
||||
"open_router",
|
||||
"deepseek",
|
||||
"lmstudio",
|
||||
"llamacpp",
|
||||
)
|
||||
if "/" not in v:
|
||||
raise ValueError(
|
||||
f"Model must be prefixed with provider type. "
|
||||
f"Valid providers: {', '.join(valid_providers)}. "
|
||||
f"Valid providers: {', '.join(SUPPORTED_PROVIDER_IDS)}. "
|
||||
f"Format: provider_type/model/name"
|
||||
)
|
||||
provider = v.split("/", 1)[0]
|
||||
if provider not in valid_providers:
|
||||
raise ValueError(
|
||||
f"Invalid provider: '{provider}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'"
|
||||
)
|
||||
if provider not in SUPPORTED_PROVIDER_IDS:
|
||||
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDER_IDS)
|
||||
raise ValueError(f"Invalid provider: '{provider}'. Supported: {supported}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,17 @@ from .content import get_block_attr, get_block_type
|
|||
from .utils import set_if_not_none
|
||||
|
||||
|
||||
def _tool_name(tool: Any) -> str:
|
||||
return str(getattr(tool, "name", "") or "")
|
||||
|
||||
|
||||
def _tool_input_schema(tool: Any) -> dict[str, Any]:
|
||||
schema = getattr(tool, "input_schema", None)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
class AnthropicToOpenAIConverter:
|
||||
"""Convert Anthropic message format to OpenAI-compatible format."""
|
||||
|
||||
|
|
@ -140,7 +151,7 @@ class AnthropicToOpenAIConverter:
|
|||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": tool.input_schema,
|
||||
"parameters": _tool_input_schema(tool),
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
|
|
|
|||
158
core/anthropic/stream_contracts.py
Normal file
158
core/anthropic/stream_contracts.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Neutral SSE parsing and Anthropic stream shape assertions.
|
||||
|
||||
Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for
|
||||
opt-in smoke scenarios.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SSEEvent:
|
||||
event: str
|
||||
data: dict[str, Any]
|
||||
raw: str
|
||||
|
||||
|
||||
def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]:
|
||||
events: list[SSEEvent] = []
|
||||
current_event = ""
|
||||
data_parts: list[str] = []
|
||||
raw_parts: list[str] = []
|
||||
|
||||
for line in lines:
|
||||
stripped = line.rstrip("\r\n")
|
||||
if stripped == "":
|
||||
_append_event(events, current_event, data_parts, raw_parts)
|
||||
current_event = ""
|
||||
data_parts = []
|
||||
raw_parts = []
|
||||
continue
|
||||
raw_parts.append(stripped)
|
||||
if stripped.startswith("event:"):
|
||||
current_event = stripped.split(":", 1)[1].strip()
|
||||
elif stripped.startswith("data:"):
|
||||
data_parts.append(stripped.split(":", 1)[1].strip())
|
||||
|
||||
_append_event(events, current_event, data_parts, raw_parts)
|
||||
return events
|
||||
|
||||
|
||||
def parse_sse_text(text: str) -> list[SSEEvent]:
|
||||
return parse_sse_lines(text.splitlines())
|
||||
|
||||
|
||||
def _append_event(
|
||||
events: list[SSEEvent],
|
||||
current_event: str,
|
||||
data_parts: list[str],
|
||||
raw_parts: list[str],
|
||||
) -> None:
|
||||
if not current_event and not data_parts:
|
||||
return
|
||||
data_text = "\n".join(data_parts)
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed = json.loads(data_text) if data_text else {}
|
||||
data = parsed if isinstance(parsed, dict) else {"value": parsed}
|
||||
except json.JSONDecodeError:
|
||||
data = {"raw": data_text}
|
||||
events.append(SSEEvent(current_event, data, "\n".join(raw_parts)))
|
||||
|
||||
|
||||
def assert_anthropic_stream_contract(
|
||||
events: list[SSEEvent], *, allow_error: bool = False
|
||||
) -> None:
|
||||
"""Check minimal Anthropic-style SSE invariants: start/stop, block nesting.
|
||||
|
||||
Does *not* assert strict event ordering (e.g. :class:`message_delta` vs
|
||||
content blocks) beyond presence of a final ``message_stop``; stricter
|
||||
ordering can be tested in product or transport-specific suites.
|
||||
"""
|
||||
assert events, "stream produced no SSE events"
|
||||
event_names = [event.event for event in events]
|
||||
assert "message_start" in event_names, event_names
|
||||
assert event_names[-1] == "message_stop", event_names
|
||||
|
||||
open_blocks: dict[int, str] = {}
|
||||
seen_blocks: set[int] = set()
|
||||
for event in events:
|
||||
if event.event == "error" and not allow_error:
|
||||
raise AssertionError(f"unexpected SSE error event: {event.data}")
|
||||
|
||||
if event.event == "content_block_start":
|
||||
index = _event_index(event)
|
||||
block = event.data.get("content_block", {})
|
||||
assert isinstance(block, dict), event.data
|
||||
block_type = str(block.get("type", ""))
|
||||
assert block_type in {"text", "thinking", "tool_use"}, event.data
|
||||
assert index not in open_blocks, f"block {index} started twice"
|
||||
assert index not in seen_blocks, f"block {index} reused after stop"
|
||||
open_blocks[index] = block_type
|
||||
seen_blocks.add(index)
|
||||
continue
|
||||
|
||||
if event.event == "content_block_delta":
|
||||
index = _event_index(event)
|
||||
assert index in open_blocks, f"delta for unopened block {index}"
|
||||
delta = event.data.get("delta", {})
|
||||
assert isinstance(delta, dict), event.data
|
||||
delta_type = str(delta.get("type", ""))
|
||||
expected = {
|
||||
"text": "text_delta",
|
||||
"thinking": "thinking_delta",
|
||||
"tool_use": "input_json_delta",
|
||||
}[open_blocks[index]]
|
||||
assert delta_type == expected, (
|
||||
f"block {index} is {open_blocks[index]}, got {delta_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event == "content_block_stop":
|
||||
index = _event_index(event)
|
||||
assert index in open_blocks, f"stop for unopened block {index}"
|
||||
open_blocks.pop(index)
|
||||
|
||||
assert not open_blocks, f"unclosed blocks: {open_blocks}"
|
||||
assert seen_blocks, "stream did not emit any content blocks"
|
||||
|
||||
|
||||
def event_names(events: list[SSEEvent]) -> list[str]:
|
||||
return [event.event for event in events]
|
||||
|
||||
|
||||
def text_content(events: list[SSEEvent]) -> str:
|
||||
parts: list[str] = []
|
||||
for event in events:
|
||||
delta = event.data.get("delta", {})
|
||||
if isinstance(delta, dict) and delta.get("type") == "text_delta":
|
||||
parts.append(str(delta.get("text", "")))
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def thinking_content(events: list[SSEEvent]) -> str:
|
||||
parts: list[str] = []
|
||||
for event in events:
|
||||
delta = event.data.get("delta", {})
|
||||
if isinstance(delta, dict) and delta.get("type") == "thinking_delta":
|
||||
parts.append(str(delta.get("thinking", "")))
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def has_tool_use(events: list[SSEEvent]) -> bool:
|
||||
for event in events:
|
||||
block = event.data.get("content_block", {})
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _event_index(event: SSEEvent) -> int:
|
||||
value = event.data.get("index")
|
||||
assert isinstance(value, int), event.data
|
||||
return value
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
"""Heuristic parser for text-emitted tool calls."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
|
@ -31,6 +32,9 @@ class HeuristicToolParser:
|
|||
_PARAM_PATTERN = re.compile(
|
||||
r"<parameter=([^>]+)>(.*?)(?:</parameter>|$)", re.DOTALL
|
||||
)
|
||||
_WEB_TOOL_JSON_PATTERN = re.compile(
|
||||
r"(?is)\b(?:use\s+)?(?P<tool>WebFetch|WebSearch)\b.*?(?P<json>\{.*?\})"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self._state = ParserState.TEXT
|
||||
|
|
@ -39,6 +43,41 @@ class HeuristicToolParser:
|
|||
self._current_function_name = None
|
||||
self._current_parameters = {}
|
||||
|
||||
def _extract_web_tool_json_calls(self) -> tuple[str, list[dict[str, Any]]]:
|
||||
detected_tools: list[dict[str, Any]] = []
|
||||
|
||||
for match in self._WEB_TOOL_JSON_PATTERN.finditer(self._buffer):
|
||||
try:
|
||||
tool_input = json.loads(match.group("json"))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(tool_input, dict):
|
||||
continue
|
||||
|
||||
tool_name = match.group("tool")
|
||||
if tool_name == "WebFetch" and "url" not in tool_input:
|
||||
continue
|
||||
if tool_name == "WebSearch" and "query" not in tool_input:
|
||||
continue
|
||||
|
||||
detected_tools.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": f"toolu_heuristic_{uuid.uuid4().hex[:8]}",
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"Heuristic bypass: Detected JSON-style tool call '{}'",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
if not detected_tools:
|
||||
return self._buffer, []
|
||||
|
||||
return "", detected_tools
|
||||
|
||||
def _strip_control_tokens(self, text: str) -> str:
|
||||
return _CONTROL_TOKEN_RE.sub("", text)
|
||||
|
||||
|
|
@ -58,7 +97,7 @@ class HeuristicToolParser:
|
|||
"""Feed text and return safe text plus detected tool calls."""
|
||||
self._buffer += text
|
||||
self._buffer = self._strip_control_tokens(self._buffer)
|
||||
detected_tools = []
|
||||
self._buffer, detected_tools = self._extract_web_tool_json_calls()
|
||||
filtered_output_parts: list[str] = []
|
||||
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -114,23 +114,15 @@ async def _delete_message_ids(
|
|||
numeric.sort(reverse=True)
|
||||
ordered = [mid for _, mid in numeric] + non_numeric
|
||||
|
||||
batch_fn = getattr(handler.platform, "queue_delete_messages", None)
|
||||
if callable(batch_fn):
|
||||
try:
|
||||
CHUNK = 100
|
||||
for i in range(0, len(ordered), CHUNK):
|
||||
chunk = ordered[i : i + CHUNK]
|
||||
await batch_fn(chat_id, chunk, fire_and_forget=False)
|
||||
except Exception as e:
|
||||
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
|
||||
else:
|
||||
for mid in ordered:
|
||||
try:
|
||||
await handler.platform.queue_delete_message(
|
||||
chat_id, mid, fire_and_forget=False
|
||||
await handler.platform.queue_delete_messages(
|
||||
chat_id, chunk, fire_and_forget=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Delete failed for msg {mid}: {type(e).__name__}: {e}")
|
||||
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
|
||||
|
||||
|
||||
async def _handle_clear_branch(
|
||||
|
|
|
|||
|
|
@ -467,7 +467,7 @@ class ClaudeMessageHandler:
|
|||
status,
|
||||
len(display),
|
||||
)
|
||||
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
|
||||
if os.getenv("DEBUG_PLATFORM_EDITS") == "1":
|
||||
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
|
||||
else:
|
||||
head = display[:500]
|
||||
|
|
|
|||
|
|
@ -17,15 +17,10 @@ class CLISession(Protocol):
|
|||
|
||||
def start_task(
|
||||
self, prompt: str, session_id: str | None = None, fork_session: bool = False
|
||||
) -> AsyncGenerator[dict, Any]:
|
||||
"""Start a task in the CLI session."""
|
||||
...
|
||||
) -> AsyncGenerator[dict, Any]: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if session is busy."""
|
||||
pass
|
||||
def is_busy(self) -> bool: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -101,7 +96,8 @@ class MessagingPlatform(ABC):
|
|||
text: Message content
|
||||
reply_to: Optional message ID to reply to
|
||||
parse_mode: Optional formatting mode ("markdown", "html")
|
||||
message_thread_id: Optional forum topic ID (Telegram)
|
||||
message_thread_id: Optional thread or topic id for threaded channels
|
||||
(e.g. forum topics); unused on platforms that do not support it.
|
||||
|
||||
Returns:
|
||||
The message ID of the sent message
|
||||
|
|
@ -192,6 +188,22 @@ class MessagingPlatform(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def queue_delete_messages(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_ids: list[str],
|
||||
*,
|
||||
fire_and_forget: bool = True,
|
||||
) -> None:
|
||||
"""Delete many messages; default loops :meth:`queue_delete_message`.
|
||||
|
||||
Adapters with native bulk delete should override.
|
||||
"""
|
||||
for mid in message_ids:
|
||||
await self.queue_delete_message(
|
||||
chat_id, mid, fire_and_forget=fire_and_forget
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def on_message(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -91,6 +91,12 @@ class DiscordPlatform(MessagingPlatform):
|
|||
self,
|
||||
bot_token: str | None = None,
|
||||
allowed_channel_ids: str | None = None,
|
||||
*,
|
||||
voice_note_enabled: bool = True,
|
||||
whisper_model: str = "base",
|
||||
whisper_device: str = "cpu",
|
||||
hf_token: str = "",
|
||||
nvidia_nim_api_key: str = "",
|
||||
):
|
||||
if not DISCORD_AVAILABLE:
|
||||
raise ImportError(
|
||||
|
|
@ -117,7 +123,13 @@ class DiscordPlatform(MessagingPlatform):
|
|||
self._limiter: Any | None = None
|
||||
self._start_task: asyncio.Task | None = None
|
||||
self._pending_voice = PendingVoiceRegistry()
|
||||
self._voice_transcription = VoiceTranscriptionService()
|
||||
self._voice_transcription = VoiceTranscriptionService(
|
||||
hf_token=hf_token,
|
||||
nvidia_nim_api_key=nvidia_nim_api_key,
|
||||
)
|
||||
self._voice_note_enabled = voice_note_enabled
|
||||
self._whisper_model = whisper_model
|
||||
self._whisper_device = whisper_device
|
||||
|
||||
async def _handle_client_message(self, message: Any) -> None:
|
||||
"""Adapter entry point used by the internal discord client."""
|
||||
|
|
@ -154,10 +166,7 @@ class DiscordPlatform(MessagingPlatform):
|
|||
self, message: Any, attachment: Any, channel_id: str
|
||||
) -> bool:
|
||||
"""Handle voice/audio attachment. Returns True if handled."""
|
||||
from config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
if not settings.voice_note_enabled:
|
||||
if not self._voice_note_enabled:
|
||||
await message.reply("Voice notes are disabled.")
|
||||
return True
|
||||
|
||||
|
|
@ -201,8 +210,8 @@ class DiscordPlatform(MessagingPlatform):
|
|||
transcribed = await self._voice_transcription.transcribe(
|
||||
tmp_path,
|
||||
ct,
|
||||
whisper_model=settings.whisper_model,
|
||||
whisper_device=settings.whisper_device,
|
||||
whisper_model=self._whisper_model,
|
||||
whisper_device=self._whisper_device,
|
||||
)
|
||||
|
||||
if not await self._is_voice_still_pending(channel_id, message_id):
|
||||
|
|
|
|||
|
|
@ -6,30 +6,50 @@ To add a new platform (e.g. Discord, Slack):
|
|||
2. Add a case to create_messaging_platform() below
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .base import MessagingPlatform
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MessagingPlatformOptions:
|
||||
"""Typed wiring from :class:`~api.runtime.AppRuntime` into platform adapters."""
|
||||
|
||||
telegram_bot_token: str | None = None
|
||||
allowed_telegram_user_id: str | None = None
|
||||
discord_bot_token: str | None = None
|
||||
allowed_discord_channels: str | None = None
|
||||
voice_note_enabled: bool = True
|
||||
whisper_model: str = "base"
|
||||
whisper_device: str = "cpu"
|
||||
hf_token: str = ""
|
||||
nvidia_nim_api_key: str = ""
|
||||
|
||||
|
||||
def create_messaging_platform(
|
||||
platform_type: str,
|
||||
**kwargs,
|
||||
options: MessagingPlatformOptions | None = None,
|
||||
) -> MessagingPlatform | None:
|
||||
"""Create a messaging platform instance based on type.
|
||||
|
||||
Args:
|
||||
platform_type: Platform identifier ("telegram", "discord", etc.)
|
||||
**kwargs: Platform-specific configuration passed to the constructor.
|
||||
platform_type: Platform identifier (``telegram``, ``discord``, ``none``).
|
||||
options: Token, allowlist, and voice / transcription settings.
|
||||
|
||||
Returns:
|
||||
Configured MessagingPlatform instance, or None if not configured.
|
||||
Configured :class:`MessagingPlatform` instance, or None if not configured.
|
||||
"""
|
||||
opts = options or MessagingPlatformOptions()
|
||||
if platform_type == "none":
|
||||
logger.info("Messaging platform disabled by configuration")
|
||||
return None
|
||||
|
||||
if platform_type == "telegram":
|
||||
bot_token = kwargs.get("bot_token")
|
||||
bot_token = opts.telegram_bot_token
|
||||
if not bot_token:
|
||||
logger.info("No Telegram bot token configured, skipping platform setup")
|
||||
return None
|
||||
|
|
@ -38,11 +58,16 @@ def create_messaging_platform(
|
|||
|
||||
return TelegramPlatform(
|
||||
bot_token=bot_token,
|
||||
allowed_user_id=kwargs.get("allowed_user_id"),
|
||||
allowed_user_id=opts.allowed_telegram_user_id,
|
||||
voice_note_enabled=opts.voice_note_enabled,
|
||||
whisper_model=opts.whisper_model,
|
||||
whisper_device=opts.whisper_device,
|
||||
hf_token=opts.hf_token,
|
||||
nvidia_nim_api_key=opts.nvidia_nim_api_key,
|
||||
)
|
||||
|
||||
if platform_type == "discord":
|
||||
bot_token = kwargs.get("discord_bot_token")
|
||||
bot_token = opts.discord_bot_token
|
||||
if not bot_token:
|
||||
logger.info("No Discord bot token configured, skipping platform setup")
|
||||
return None
|
||||
|
|
@ -51,7 +76,12 @@ def create_messaging_platform(
|
|||
|
||||
return DiscordPlatform(
|
||||
bot_token=bot_token,
|
||||
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
|
||||
allowed_channel_ids=opts.allowed_discord_channels,
|
||||
voice_note_enabled=opts.voice_note_enabled,
|
||||
whisper_model=opts.whisper_model,
|
||||
whisper_device=opts.whisper_device,
|
||||
hf_token=opts.hf_token,
|
||||
nvidia_nim_api_key=opts.nvidia_nim_api_key,
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -62,6 +62,12 @@ class TelegramPlatform(MessagingPlatform):
|
|||
self,
|
||||
bot_token: str | None = None,
|
||||
allowed_user_id: str | None = None,
|
||||
*,
|
||||
voice_note_enabled: bool = True,
|
||||
whisper_model: str = "base",
|
||||
whisper_device: str = "cpu",
|
||||
hf_token: str = "",
|
||||
nvidia_nim_api_key: str = "",
|
||||
):
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
raise ImportError(
|
||||
|
|
@ -84,7 +90,13 @@ class TelegramPlatform(MessagingPlatform):
|
|||
self._limiter: Any | None = None # Will be MessagingRateLimiter
|
||||
# Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id)
|
||||
self._pending_voice = PendingVoiceRegistry()
|
||||
self._voice_transcription = VoiceTranscriptionService()
|
||||
self._voice_transcription = VoiceTranscriptionService(
|
||||
hf_token=hf_token,
|
||||
nvidia_nim_api_key=nvidia_nim_api_key,
|
||||
)
|
||||
self._voice_note_enabled = voice_note_enabled
|
||||
self._whisper_model = whisper_model
|
||||
self._whisper_device = whisper_device
|
||||
|
||||
async def _register_pending_voice(
|
||||
self, chat_id: str, voice_msg_id: str, status_msg_id: str
|
||||
|
|
@ -544,10 +556,7 @@ class TelegramPlatform(MessagingPlatform):
|
|||
):
|
||||
return
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
if not settings.voice_note_enabled:
|
||||
if not self._voice_note_enabled:
|
||||
await update.message.reply_text("Voice notes are disabled.")
|
||||
return
|
||||
|
||||
|
|
@ -600,8 +609,8 @@ class TelegramPlatform(MessagingPlatform):
|
|||
transcribed = await self._voice_transcription.transcribe(
|
||||
tmp_path,
|
||||
voice.mime_type or "audio/ogg",
|
||||
whisper_model=settings.whisper_model,
|
||||
whisper_device=settings.whisper_device,
|
||||
whisper_model=self._whisper_model,
|
||||
whisper_device=self._whisper_device,
|
||||
)
|
||||
|
||||
if not await self._is_voice_still_pending(chat_id, message_id):
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ _MODEL_MAP: dict[str, str] = {
|
|||
"large-v3-turbo": "openai/whisper-large-v3-turbo",
|
||||
}
|
||||
|
||||
# Lazy-loaded pipelines: (model_id, device) -> pipeline
|
||||
_pipeline_cache: dict[tuple[str, str], Any] = {}
|
||||
# Lazy-loaded pipelines: (model_id, device, hf_token_fingerprint) -> pipeline
|
||||
_pipeline_cache: dict[tuple[str, str, str], Any] = {}
|
||||
|
||||
|
||||
def _resolve_model_id(whisper_model: str) -> str:
|
||||
|
|
@ -51,20 +51,22 @@ def _resolve_model_id(whisper_model: str) -> str:
|
|||
return _MODEL_MAP.get(whisper_model, whisper_model)
|
||||
|
||||
|
||||
def _get_pipeline(model_id: str, device: str) -> Any:
|
||||
def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any:
|
||||
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
|
||||
global _pipeline_cache
|
||||
if device not in ("cpu", "cuda"):
|
||||
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
|
||||
cache_key = (model_id, device)
|
||||
resolved_token = (
|
||||
hf_token if hf_token is not None else get_settings().hf_token
|
||||
) or ""
|
||||
cache_key = (model_id, device, resolved_token)
|
||||
if cache_key not in _pipeline_cache:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
|
||||
token = get_settings().hf_token
|
||||
if token:
|
||||
os.environ["HF_TOKEN"] = token
|
||||
if resolved_token:
|
||||
os.environ["HF_TOKEN"] = resolved_token
|
||||
|
||||
use_cuda = device == "cuda" and torch.cuda.is_available()
|
||||
pipe_device = "cuda:0" if use_cuda else "cpu"
|
||||
|
|
@ -103,6 +105,8 @@ def transcribe_audio(
|
|||
*,
|
||||
whisper_model: str = "base",
|
||||
whisper_device: str = "cpu",
|
||||
hf_token: str = "",
|
||||
nvidia_nim_api_key: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio file to text.
|
||||
|
|
@ -136,9 +140,12 @@ def transcribe_audio(
|
|||
)
|
||||
|
||||
if whisper_device == "nvidia_nim":
|
||||
return _transcribe_nim(file_path, whisper_model)
|
||||
else:
|
||||
return _transcribe_local(file_path, whisper_model, whisper_device)
|
||||
return _transcribe_nim(
|
||||
file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key
|
||||
)
|
||||
return _transcribe_local(
|
||||
file_path, whisper_model, whisper_device, hf_token=hf_token
|
||||
)
|
||||
|
||||
|
||||
# Whisper expects 16 kHz sample rate
|
||||
|
|
@ -153,10 +160,17 @@ def _load_audio(file_path: Path) -> dict[str, Any]:
|
|||
return {"array": waveform, "sampling_rate": sr}
|
||||
|
||||
|
||||
def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str) -> str:
|
||||
def _transcribe_local(
|
||||
file_path: Path,
|
||||
whisper_model: str,
|
||||
whisper_device: str,
|
||||
*,
|
||||
hf_token: str = "",
|
||||
) -> str:
|
||||
"""Transcribe using transformers Whisper pipeline."""
|
||||
model_id = _resolve_model_id(whisper_model)
|
||||
pipe = _get_pipeline(model_id, whisper_device)
|
||||
token: str | None = hf_token if hf_token else None
|
||||
pipe = _get_pipeline(model_id, whisper_device, hf_token=token)
|
||||
audio = _load_audio(file_path)
|
||||
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
|
||||
text = result.get("text", "") or ""
|
||||
|
|
@ -167,7 +181,9 @@ def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str)
|
|||
return result_text or "(no speech detected)"
|
||||
|
||||
|
||||
def _transcribe_nim(file_path: Path, model: str) -> str:
|
||||
def _transcribe_nim(
|
||||
file_path: Path, model: str, *, nvidia_nim_api_key: str = ""
|
||||
) -> str:
|
||||
"""Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
|
||||
try:
|
||||
import riva.client
|
||||
|
|
@ -177,8 +193,7 @@ def _transcribe_nim(file_path: Path, model: str) -> str:
|
|||
"Install with: uv sync --extra voice"
|
||||
) from e
|
||||
|
||||
settings = get_settings()
|
||||
api_key = settings.nvidia_nim_api_key
|
||||
api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key
|
||||
|
||||
# Look up function ID and language code from model mapping
|
||||
model_config = _NIM_MODEL_MAP.get(model)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,15 @@ class PendingVoiceRegistry:
|
|||
class VoiceTranscriptionService:
|
||||
"""Run configured transcription backends off the event loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hf_token: str = "",
|
||||
nvidia_nim_api_key: str = "",
|
||||
) -> None:
|
||||
self._hf_token = hf_token
|
||||
self._nvidia_nim_api_key = nvidia_nim_api_key
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
file_path: Path,
|
||||
|
|
@ -62,4 +71,6 @@ class VoiceTranscriptionService:
|
|||
mime_type,
|
||||
whisper_model=whisper_model,
|
||||
whisper_device=whisper_device,
|
||||
hf_token=self._hf_token,
|
||||
nvidia_nim_api_key=self._nvidia_nim_api_key,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
"""Providers package - implement your own provider by extending BaseProvider."""
|
||||
"""Providers package - implement your own provider by extending BaseProvider.
|
||||
|
||||
Concrete adapters (e.g. ``NvidiaNimProvider``) live in subpackages; import them
|
||||
from ``providers.nvidia_nim`` etc. to avoid loading every adapter when the
|
||||
``providers`` package is imported.
|
||||
"""
|
||||
|
||||
from .anthropic_messages import AnthropicMessagesTransport
|
||||
from .base import BaseProvider, ProviderConfig
|
||||
from .deepseek import DeepSeekProvider
|
||||
from .exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
|
|
@ -10,27 +13,17 @@ from .exceptions import (
|
|||
OverloadedError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
UnknownProviderTypeError,
|
||||
)
|
||||
from .llamacpp import LlamaCppProvider
|
||||
from .lmstudio import LMStudioProvider
|
||||
from .nvidia_nim import NvidiaNimProvider
|
||||
from .open_router import OpenRouterProvider
|
||||
from .openai_compat import OpenAIChatTransport
|
||||
|
||||
__all__ = [
|
||||
"APIError",
|
||||
"AnthropicMessagesTransport",
|
||||
"AuthenticationError",
|
||||
"BaseProvider",
|
||||
"DeepSeekProvider",
|
||||
"InvalidRequestError",
|
||||
"LMStudioProvider",
|
||||
"LlamaCppProvider",
|
||||
"NvidiaNimProvider",
|
||||
"OpenAIChatTransport",
|
||||
"OpenRouterProvider",
|
||||
"OverloadedError",
|
||||
"ProviderConfig",
|
||||
"ProviderError",
|
||||
"RateLimitError",
|
||||
"UnknownProviderTypeError",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@
|
|||
from typing import Any
|
||||
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import DEEPSEEK_BASE_URL
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
||||
|
||||
|
||||
class DeepSeekProvider(OpenAIChatTransport):
|
||||
"""DeepSeek provider using OpenAI-compatible chat completions."""
|
||||
|
|
|
|||
19
providers/defaults.py
Normal file
19
providers/defaults.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""Default upstream base URLs and shared provider constants.
|
||||
|
||||
Adapters and :mod:`providers.registry` import from here to avoid duplicating
|
||||
literals and to keep ``providers.registry`` free of per-adapter eager imports.
|
||||
"""
|
||||
|
||||
# OpenAI-compatible chat (NIM, DeepSeek) and local OpenAI-shaped endpoints
|
||||
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
|
||||
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
|
||||
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
|
||||
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
|
||||
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
|
||||
|
||||
# Backward-compatible names used by existing adapter modules
|
||||
NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE
|
||||
DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE
|
||||
OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE
|
||||
LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE
|
||||
LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE
|
||||
|
|
@ -88,3 +88,10 @@ class APIError(ProviderError):
|
|||
error_type="api_error",
|
||||
raw_error=raw_error,
|
||||
)
|
||||
|
||||
|
||||
class UnknownProviderTypeError(ValueError):
|
||||
"""Raised when ``provider_id`` is not registered in the provider map."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
|
||||
LLAMACPP_DEFAULT_BASE_URL = "http://localhost:8080/v1"
|
||||
from providers.defaults import LLAMACPP_DEFAULT_BASE_URL
|
||||
|
||||
|
||||
class LlamaCppProvider(AnthropicMessagesTransport):
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
|
||||
LMSTUDIO_DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
||||
from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL
|
||||
|
||||
|
||||
class LMStudioProvider(AnthropicMessagesTransport):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from loguru import logger
|
|||
|
||||
from config.nim import NimSettings
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import NVIDIA_NIM_BASE_URL
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import (
|
||||
|
|
@ -16,8 +17,6 @@ from .request import (
|
|||
clone_body_without_reasoning_budget,
|
||||
)
|
||||
|
||||
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
|
||||
|
||||
|
||||
class NvidiaNimProvider(OpenAIChatTransport):
|
||||
"""NVIDIA NIM provider using official OpenAI client."""
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from typing import Any
|
|||
from core.anthropic import SSEBuilder, append_request_id
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import OPENROUTER_BASE_URL
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
_ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
"""Shared base class for OpenAI-compatible providers (NIM, OpenRouter, LM Studio)."""
|
||||
"""OpenAI-style chat base for :class:`OpenAIChatTransport` (NIM, DeepSeek, etc.).
|
||||
|
||||
``AnthropicMessagesTransport``-based providers (OpenRouter, LM Studio, …) live
|
||||
in separate modules; do not list them as subclasses of this class.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
|
|
@ -25,7 +30,7 @@ from providers.rate_limit import GlobalRateLimiter
|
|||
|
||||
|
||||
class OpenAIChatTransport(BaseProvider):
|
||||
"""Base class for providers using OpenAI-compatible chat completions API."""
|
||||
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -114,6 +119,7 @@ class OpenAIChatTransport(BaseProvider):
|
|||
|
||||
fn_delta = tc.get("function", {})
|
||||
incoming_name = fn_delta.get("name")
|
||||
arguments = fn_delta.get("arguments", "")
|
||||
if incoming_name is not None:
|
||||
sse.blocks.register_tool_name(tc_index, incoming_name)
|
||||
|
||||
|
|
@ -124,7 +130,7 @@ class OpenAIChatTransport(BaseProvider):
|
|||
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
|
||||
yield sse.start_tool_block(tc_index, tool_id, name)
|
||||
|
||||
args = fn_delta.get("arguments", "")
|
||||
args = arguments
|
||||
if args:
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if state is None or not state.started:
|
||||
|
|
@ -285,6 +291,8 @@ class OpenAIChatTransport(BaseProvider):
|
|||
for event in self._process_tool_call(tc_info, sse):
|
||||
yield event
|
||||
|
||||
except asyncio.CancelledError, GeneratorExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)
|
||||
|
|
|
|||
|
|
@ -6,17 +6,17 @@ from collections.abc import Callable, MutableMapping
|
|||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider
|
||||
from providers.exceptions import AuthenticationError
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
from providers.open_router import (
|
||||
OPENROUTER_BASE_URL,
|
||||
OpenRouterProvider,
|
||||
from providers.defaults import (
|
||||
DEEPSEEK_DEFAULT_BASE,
|
||||
LLAMACPP_DEFAULT_BASE,
|
||||
LMSTUDIO_DEFAULT_BASE,
|
||||
NVIDIA_NIM_DEFAULT_BASE,
|
||||
OPENROUTER_DEFAULT_BASE,
|
||||
)
|
||||
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||
|
||||
TransportType = Literal["openai_chat", "anthropic_messages"]
|
||||
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
||||
|
|
@ -24,11 +24,17 @@ ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
|||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProviderDescriptor:
|
||||
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
|
||||
|
||||
provider_id: str
|
||||
transport_type: TransportType
|
||||
capabilities: tuple[str, ...]
|
||||
credential_env: str | None = None
|
||||
credential_url: str | None = None
|
||||
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
|
||||
credential_attr: str | None = None
|
||||
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
|
||||
static_credential: str | None = None
|
||||
default_base_url: str | None = None
|
||||
base_url_attr: str | None = None
|
||||
proxy_attr: str | None = None
|
||||
|
|
@ -40,7 +46,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
|||
transport_type="openai_chat",
|
||||
credential_env="NVIDIA_NIM_API_KEY",
|
||||
credential_url="https://build.nvidia.com/settings/api-keys",
|
||||
default_base_url=NVIDIA_NIM_BASE_URL,
|
||||
credential_attr="nvidia_nim_api_key",
|
||||
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
|
||||
proxy_attr="nvidia_nim_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
),
|
||||
|
|
@ -49,7 +56,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
|||
transport_type="anthropic_messages",
|
||||
credential_env="OPENROUTER_API_KEY",
|
||||
credential_url="https://openrouter.ai/keys",
|
||||
default_base_url=OPENROUTER_BASE_URL,
|
||||
credential_attr="open_router_api_key",
|
||||
default_base_url=OPENROUTER_DEFAULT_BASE,
|
||||
proxy_attr="open_router_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
|
||||
),
|
||||
|
|
@ -58,13 +66,15 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
|||
transport_type="openai_chat",
|
||||
credential_env="DEEPSEEK_API_KEY",
|
||||
credential_url="https://platform.deepseek.com/api_keys",
|
||||
default_base_url=DEEPSEEK_BASE_URL,
|
||||
credential_attr="deepseek_api_key",
|
||||
default_base_url=DEEPSEEK_DEFAULT_BASE,
|
||||
capabilities=("chat", "streaming", "thinking"),
|
||||
),
|
||||
"lmstudio": ProviderDescriptor(
|
||||
provider_id="lmstudio",
|
||||
transport_type="anthropic_messages",
|
||||
default_base_url="http://localhost:1234/v1",
|
||||
static_credential="lm-studio",
|
||||
default_base_url=LMSTUDIO_DEFAULT_BASE,
|
||||
base_url_attr="lm_studio_base_url",
|
||||
proxy_attr="lmstudio_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||
|
|
@ -72,7 +82,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
|||
"llamacpp": ProviderDescriptor(
|
||||
provider_id="llamacpp",
|
||||
transport_type="anthropic_messages",
|
||||
default_base_url="http://localhost:8080/v1",
|
||||
static_credential="llamacpp",
|
||||
default_base_url=LLAMACPP_DEFAULT_BASE,
|
||||
base_url_attr="llamacpp_base_url",
|
||||
proxy_attr="llamacpp_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||
|
|
@ -81,22 +92,32 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
|||
|
||||
|
||||
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
|
||||
|
||||
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
from providers.open_router import OpenRouterProvider
|
||||
|
||||
return OpenRouterProvider(config)
|
||||
|
||||
|
||||
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
|
||||
return DeepSeekProvider(config)
|
||||
|
||||
|
||||
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
return LMStudioProvider(config)
|
||||
|
||||
|
||||
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
|
||||
return LlamaCppProvider(config)
|
||||
|
||||
|
||||
|
|
@ -108,6 +129,15 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
|
|||
"llamacpp": _create_llamacpp,
|
||||
}
|
||||
|
||||
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
|
||||
PROVIDER_FACTORIES
|
||||
) != set(SUPPORTED_PROVIDER_IDS):
|
||||
raise AssertionError(
|
||||
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
|
||||
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
|
||||
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
|
||||
)
|
||||
|
||||
|
||||
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
|
||||
if attr_name is None:
|
||||
|
|
@ -116,17 +146,11 @@ def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -
|
|||
return value if isinstance(value, str) else default
|
||||
|
||||
|
||||
def _credential_for(provider_id: str, settings: Settings) -> str:
|
||||
if provider_id == "nvidia_nim":
|
||||
return settings.nvidia_nim_api_key
|
||||
if provider_id == "open_router":
|
||||
return settings.open_router_api_key
|
||||
if provider_id == "deepseek":
|
||||
return settings.deepseek_api_key
|
||||
if provider_id == "lmstudio":
|
||||
return "lm-studio"
|
||||
if provider_id == "llamacpp":
|
||||
return "llamacpp"
|
||||
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
|
||||
if descriptor.static_credential is not None:
|
||||
return descriptor.static_credential
|
||||
if descriptor.credential_attr:
|
||||
return _string_attr(settings, descriptor.credential_attr)
|
||||
return ""
|
||||
|
||||
|
||||
|
|
@ -144,7 +168,7 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None
|
|||
def build_provider_config(
|
||||
descriptor: ProviderDescriptor, settings: Settings
|
||||
) -> ProviderConfig:
|
||||
credential = _credential_for(descriptor.provider_id, settings)
|
||||
credential = _credential_for(descriptor, settings)
|
||||
_require_credential(descriptor, credential)
|
||||
base_url = _string_attr(
|
||||
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
|
||||
|
|
@ -168,7 +192,7 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
|
|||
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
|
||||
if descriptor is None:
|
||||
supported = "', '".join(PROVIDER_DESCRIPTORS)
|
||||
raise ValueError(
|
||||
raise UnknownProviderTypeError(
|
||||
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
|
||||
)
|
||||
|
||||
|
|
@ -185,12 +209,33 @@ class ProviderRegistry:
|
|||
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
|
||||
self._providers = providers if providers is not None else {}
|
||||
|
||||
def is_cached(self, provider_id: str) -> bool:
|
||||
"""Return whether a provider for this id is already in the cache."""
|
||||
return provider_id in self._providers
|
||||
|
||||
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
|
||||
if provider_id not in self._providers:
|
||||
self._providers[provider_id] = create_provider(provider_id, settings)
|
||||
return self._providers[provider_id]
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
for provider in self._providers.values():
|
||||
"""Call ``cleanup`` on every cached provider, then clear the cache.
|
||||
|
||||
Attempts all providers even if one fails. A single failure is re-raised
|
||||
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
|
||||
"""
|
||||
items = list(self._providers.items())
|
||||
errors: list[Exception] = []
|
||||
try:
|
||||
for _pid, provider in items:
|
||||
try:
|
||||
await provider.cleanup()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
finally:
|
||||
self._providers.clear()
|
||||
if len(errors) == 1:
|
||||
raise errors[0]
|
||||
if len(errors) > 1:
|
||||
msg = "One or more provider cleanups failed"
|
||||
raise ExceptionGroup(msg, errors)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from typing import Any
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
from messaging.handler import ClaudeMessageHandler
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.platforms.base import MessagingPlatform
|
||||
|
|
@ -153,7 +154,7 @@ class ConversationDriver:
|
|||
class ProviderMatrixDriver:
|
||||
"""Resolve provider models and enforce matrix semantics for product smoke."""
|
||||
|
||||
ALL_PROVIDERS = ("nvidia_nim", "open_router", "deepseek", "lmstudio", "llamacpp")
|
||||
ALL_PROVIDERS: tuple[str, ...] = SUPPORTED_PROVIDER_IDS
|
||||
|
||||
def __init__(self, config: SmokeConfig) -> None:
|
||||
self.config = config
|
||||
|
|
|
|||
167
smoke/lib/sse.py
167
smoke/lib/sse.py
|
|
@ -1,148 +1,29 @@
|
|||
"""SSE parsing and Anthropic stream assertions for smoke tests."""
|
||||
"""SSE parsing and Anthropic stream assertions for smoke tests.
|
||||
|
||||
Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this
|
||||
module re-exports it for smoke and historical import paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SSEEvent:
|
||||
event: str
|
||||
data: dict[str, Any]
|
||||
raw: str
|
||||
|
||||
|
||||
def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]:
|
||||
events: list[SSEEvent] = []
|
||||
current_event = ""
|
||||
data_parts: list[str] = []
|
||||
raw_parts: list[str] = []
|
||||
|
||||
for line in lines:
|
||||
stripped = line.rstrip("\r\n")
|
||||
if stripped == "":
|
||||
_append_event(events, current_event, data_parts, raw_parts)
|
||||
current_event = ""
|
||||
data_parts = []
|
||||
raw_parts = []
|
||||
continue
|
||||
raw_parts.append(stripped)
|
||||
if stripped.startswith("event:"):
|
||||
current_event = stripped.split(":", 1)[1].strip()
|
||||
elif stripped.startswith("data:"):
|
||||
data_parts.append(stripped.split(":", 1)[1].strip())
|
||||
|
||||
_append_event(events, current_event, data_parts, raw_parts)
|
||||
return events
|
||||
|
||||
|
||||
def parse_sse_text(text: str) -> list[SSEEvent]:
|
||||
return parse_sse_lines(text.splitlines())
|
||||
|
||||
|
||||
def _append_event(
|
||||
events: list[SSEEvent],
|
||||
current_event: str,
|
||||
data_parts: list[str],
|
||||
raw_parts: list[str],
|
||||
) -> None:
|
||||
if not current_event and not data_parts:
|
||||
return
|
||||
data_text = "\n".join(data_parts)
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed = json.loads(data_text) if data_text else {}
|
||||
data = parsed if isinstance(parsed, dict) else {"value": parsed}
|
||||
except json.JSONDecodeError:
|
||||
data = {"raw": data_text}
|
||||
events.append(SSEEvent(current_event, data, "\n".join(raw_parts)))
|
||||
|
||||
|
||||
def assert_anthropic_stream_contract(
|
||||
events: list[SSEEvent], *, allow_error: bool = False
|
||||
) -> None:
|
||||
assert events, "stream produced no SSE events"
|
||||
event_names = [event.event for event in events]
|
||||
assert "message_start" in event_names, event_names
|
||||
assert event_names[-1] == "message_stop", event_names
|
||||
|
||||
open_blocks: dict[int, str] = {}
|
||||
seen_blocks: set[int] = set()
|
||||
for event in events:
|
||||
if event.event == "error" and not allow_error:
|
||||
raise AssertionError(f"unexpected SSE error event: {event.data}")
|
||||
|
||||
if event.event == "content_block_start":
|
||||
index = _event_index(event)
|
||||
block = event.data.get("content_block", {})
|
||||
assert isinstance(block, dict), event.data
|
||||
block_type = str(block.get("type", ""))
|
||||
assert block_type in {"text", "thinking", "tool_use"}, event.data
|
||||
assert index not in open_blocks, f"block {index} started twice"
|
||||
assert index not in seen_blocks, f"block {index} reused after stop"
|
||||
open_blocks[index] = block_type
|
||||
seen_blocks.add(index)
|
||||
continue
|
||||
|
||||
if event.event == "content_block_delta":
|
||||
index = _event_index(event)
|
||||
assert index in open_blocks, f"delta for unopened block {index}"
|
||||
delta = event.data.get("delta", {})
|
||||
assert isinstance(delta, dict), event.data
|
||||
delta_type = str(delta.get("type", ""))
|
||||
expected = {
|
||||
"text": "text_delta",
|
||||
"thinking": "thinking_delta",
|
||||
"tool_use": "input_json_delta",
|
||||
}[open_blocks[index]]
|
||||
assert delta_type == expected, (
|
||||
f"block {index} is {open_blocks[index]}, got {delta_type}"
|
||||
from core.anthropic.stream_contracts import (
|
||||
SSEEvent,
|
||||
assert_anthropic_stream_contract,
|
||||
event_names,
|
||||
has_tool_use,
|
||||
parse_sse_lines,
|
||||
parse_sse_text,
|
||||
text_content,
|
||||
thinking_content,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event == "content_block_stop":
|
||||
index = _event_index(event)
|
||||
assert index in open_blocks, f"stop for unopened block {index}"
|
||||
open_blocks.pop(index)
|
||||
|
||||
assert not open_blocks, f"unclosed blocks: {open_blocks}"
|
||||
assert seen_blocks, "stream did not emit any content blocks"
|
||||
|
||||
|
||||
def event_names(events: list[SSEEvent]) -> list[str]:
|
||||
return [event.event for event in events]
|
||||
|
||||
|
||||
def text_content(events: list[SSEEvent]) -> str:
|
||||
parts: list[str] = []
|
||||
for event in events:
|
||||
delta = event.data.get("delta", {})
|
||||
if isinstance(delta, dict) and delta.get("type") == "text_delta":
|
||||
parts.append(str(delta.get("text", "")))
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def thinking_content(events: list[SSEEvent]) -> str:
|
||||
parts: list[str] = []
|
||||
for event in events:
|
||||
delta = event.data.get("delta", {})
|
||||
if isinstance(delta, dict) and delta.get("type") == "thinking_delta":
|
||||
parts.append(str(delta.get("thinking", "")))
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def has_tool_use(events: list[SSEEvent]) -> bool:
|
||||
for event in events:
|
||||
block = event.data.get("content_block", {})
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _event_index(event: SSEEvent) -> int:
|
||||
value = event.data.get("index")
|
||||
assert isinstance(value, int), event.data
|
||||
return value
|
||||
__all__ = [
|
||||
"SSEEvent",
|
||||
"assert_anthropic_stream_contract",
|
||||
"event_names",
|
||||
"has_tool_use",
|
||||
"parse_sse_lines",
|
||||
"parse_sse_text",
|
||||
"text_content",
|
||||
"thinking_content",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.app import app
|
||||
|
|
@ -9,7 +10,7 @@ from providers.nvidia_nim import NvidiaNimProvider
|
|||
mock_provider = MagicMock(spec=NvidiaNimProvider)
|
||||
|
||||
# Track stream_response calls for test_model_mapping
|
||||
_stream_response_calls = []
|
||||
_stream_response_calls: list = []
|
||||
|
||||
|
||||
async def _mock_stream_response(*args, **kwargs):
|
||||
|
|
@ -21,26 +22,30 @@ async def _mock_stream_response(*args, **kwargs):
|
|||
|
||||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
# Patch get_provider_for_type to always return mock_provider
|
||||
_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider)
|
||||
_patcher.start()
|
||||
|
||||
client = TestClient(app)
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
"""HTTP client with provider resolution stubbed; patch only for this file."""
|
||||
with (
|
||||
patch("api.dependencies.resolve_provider", return_value=mock_provider),
|
||||
TestClient(app) as test_client,
|
||||
):
|
||||
yield test_client
|
||||
|
||||
|
||||
def test_root():
|
||||
def test_root(client: TestClient):
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
|
||||
def test_health():
|
||||
def test_health(client: TestClient):
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "healthy"
|
||||
|
||||
|
||||
def test_models_list():
|
||||
def test_models_list(client: TestClient):
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -51,7 +56,7 @@ def test_models_list():
|
|||
assert data["last_id"] == ids[-1]
|
||||
|
||||
|
||||
def test_probe_endpoints_return_204_with_allow_headers():
|
||||
def test_probe_endpoints_return_204_with_allow_headers(client: TestClient):
|
||||
responses = [
|
||||
client.head("/"),
|
||||
client.options("/"),
|
||||
|
|
@ -68,7 +73,7 @@ def test_probe_endpoints_return_204_with_allow_headers():
|
|||
assert "Allow" in response.headers
|
||||
|
||||
|
||||
def test_create_message_stream():
|
||||
def test_create_message_stream(client: TestClient):
|
||||
"""Create message returns streaming response."""
|
||||
payload = {
|
||||
"model": "claude-3-sonnet",
|
||||
|
|
@ -83,7 +88,7 @@ def test_create_message_stream():
|
|||
assert b"message_start" in content or b"event:" in content
|
||||
|
||||
|
||||
def test_model_mapping():
|
||||
def test_model_mapping(client: TestClient):
|
||||
# Test Haiku mapping
|
||||
_stream_response_calls.clear()
|
||||
payload_haiku = {
|
||||
|
|
@ -98,7 +103,7 @@ def test_model_mapping():
|
|||
assert args[0].model != "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
def test_error_fallbacks():
|
||||
def test_error_fallbacks(client: TestClient):
|
||||
from providers.exceptions import (
|
||||
AuthenticationError,
|
||||
OverloadedError,
|
||||
|
|
@ -143,7 +148,7 @@ def test_error_fallbacks():
|
|||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
def test_generic_exception_returns_500():
|
||||
def test_generic_exception_returns_500(client: TestClient):
|
||||
"""Non-ProviderError exceptions are caught and returned as HTTPException(500)."""
|
||||
|
||||
def _raise_runtime(*args, **kwargs):
|
||||
|
|
@ -163,7 +168,7 @@ def test_generic_exception_returns_500():
|
|||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
def test_generic_exception_with_status_code():
|
||||
def test_generic_exception_with_status_code(client: TestClient):
|
||||
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
|
||||
|
||||
class ExceptionWithStatus(RuntimeError):
|
||||
|
|
@ -188,7 +193,7 @@ def test_generic_exception_with_status_code():
|
|||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
def test_generic_exception_empty_message_returns_non_empty_detail():
|
||||
def test_generic_exception_empty_message_returns_non_empty_detail(client: TestClient):
|
||||
"""Exceptions with empty __str__ still return a readable HTTP detail."""
|
||||
|
||||
class SilentError(RuntimeError):
|
||||
|
|
@ -213,7 +218,7 @@ def test_generic_exception_empty_message_returns_non_empty_detail():
|
|||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
def test_count_tokens_endpoint():
|
||||
def test_count_tokens_endpoint(client: TestClient):
|
||||
"""count_tokens endpoint returns token count."""
|
||||
response = client.post(
|
||||
"/v1/messages/count_tokens",
|
||||
|
|
@ -223,7 +228,7 @@ def test_count_tokens_endpoint():
|
|||
assert "input_tokens" in response.json()
|
||||
|
||||
|
||||
def test_stop_endpoint_no_handler_no_cli_503():
|
||||
def test_stop_endpoint_no_handler_no_cli_503(client: TestClient):
|
||||
"""POST /stop without handler or cli_manager returns 503."""
|
||||
# Ensure no handler or cli_manager on app state
|
||||
if hasattr(app.state, "message_handler"):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,23 @@ import pytest
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from config.settings import Settings
|
||||
from providers.registry import ProviderRegistry
|
||||
|
||||
_RUNTIME_EXTRAS = {
|
||||
"voice_note_enabled": True,
|
||||
"whisper_model": "base",
|
||||
"whisper_device": "cpu",
|
||||
"hf_token": "",
|
||||
"nvidia_nim_api_key": "",
|
||||
"claude_cli_bin": "claude",
|
||||
"uses_process_anthropic_auth_token": lambda: False,
|
||||
}
|
||||
|
||||
|
||||
def _app_settings(**kwargs):
|
||||
"""Minimal settings namespace for AppRuntime (matches typed :class:`Settings` fields used)."""
|
||||
data = {**_RUNTIME_EXTRAS, **kwargs}
|
||||
return SimpleNamespace(**data)
|
||||
|
||||
|
||||
def test_warn_if_process_auth_token_logs_warning():
|
||||
|
|
@ -45,7 +62,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
|
|||
raise AuthenticationError("bad key")
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token=None,
|
||||
allowed_telegram_user_id=None,
|
||||
|
|
@ -59,7 +76,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
|
|||
)
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()),
|
||||
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
|
||||
):
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/raise_provider")
|
||||
|
|
@ -79,7 +96,7 @@ def test_create_app_general_exception_handler_returns_500():
|
|||
raise RuntimeError("boom")
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token=None,
|
||||
allowed_telegram_user_id=None,
|
||||
|
|
@ -93,7 +110,7 @@ def test_create_app_general_exception_handler_returns_500():
|
|||
)
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()),
|
||||
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
|
||||
):
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/raise_general")
|
||||
|
|
@ -111,7 +128,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
|||
|
||||
app = create_app()
|
||||
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="token" if messaging_enabled else None,
|
||||
allowed_telegram_user_id="123",
|
||||
|
|
@ -147,10 +164,10 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
|||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
|
||||
cleanup_provider = AsyncMock()
|
||||
registry_cleanup = AsyncMock()
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform if messaging_enabled else None,
|
||||
|
|
@ -182,7 +199,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
|||
cli_manager.stop_all.assert_not_awaited()
|
||||
assert getattr(app.state, "messaging_platform", "missing") is None
|
||||
|
||||
cleanup_provider.assert_awaited_once()
|
||||
registry_cleanup.assert_awaited_once()
|
||||
|
||||
|
||||
def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||
|
|
@ -190,7 +207,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
|||
|
||||
app = create_app()
|
||||
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="token",
|
||||
allowed_telegram_user_id="123",
|
||||
|
|
@ -218,10 +235,10 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
|||
cli_manager.stop_all = AsyncMock()
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
cleanup_provider = AsyncMock()
|
||||
registry_cleanup = AsyncMock()
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform,
|
||||
|
|
@ -234,7 +251,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
|||
|
||||
fake_platform.stop.assert_awaited_once()
|
||||
cli_manager.stop_all.assert_awaited_once()
|
||||
cleanup_provider.assert_awaited_once()
|
||||
registry_cleanup.assert_awaited_once()
|
||||
|
||||
|
||||
def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||
|
|
@ -243,7 +260,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
|||
|
||||
app = create_app()
|
||||
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="token",
|
||||
allowed_telegram_user_id="123",
|
||||
|
|
@ -257,10 +274,10 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
|||
)
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
cleanup_provider = AsyncMock()
|
||||
registry_cleanup = AsyncMock()
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
side_effect=ImportError("discord not installed"),
|
||||
|
|
@ -270,7 +287,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
|||
pass
|
||||
|
||||
assert getattr(app.state, "messaging_platform", None) is None
|
||||
cleanup_provider.assert_awaited_once()
|
||||
registry_cleanup.assert_awaited_once()
|
||||
|
||||
|
||||
def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
||||
|
|
@ -279,7 +296,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
|||
|
||||
app = create_app()
|
||||
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="token",
|
||||
allowed_telegram_user_id="123",
|
||||
|
|
@ -307,10 +324,10 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
|||
cli_manager.stop_all = AsyncMock()
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
cleanup_provider = AsyncMock()
|
||||
registry_cleanup = AsyncMock()
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform,
|
||||
|
|
@ -321,7 +338,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
|||
):
|
||||
pass
|
||||
|
||||
cleanup_provider.assert_awaited_once()
|
||||
registry_cleanup.assert_awaited_once()
|
||||
|
||||
|
||||
def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
||||
|
|
@ -330,7 +347,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
|||
|
||||
app = create_app()
|
||||
|
||||
settings = SimpleNamespace(
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="token",
|
||||
allowed_telegram_user_id="123",
|
||||
|
|
@ -359,10 +376,10 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
|||
cli_manager.stop_all = AsyncMock()
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
cleanup_provider = AsyncMock()
|
||||
registry_cleanup = AsyncMock()
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform,
|
||||
|
|
@ -374,4 +391,4 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
|||
pass
|
||||
|
||||
session_store.flush_pending_save.assert_called_once()
|
||||
cleanup_provider.assert_awaited_once()
|
||||
registry_cleanup.assert_awaited_once()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,26 @@
|
|||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import State
|
||||
|
||||
from api.dependencies import (
|
||||
cleanup_provider,
|
||||
get_provider,
|
||||
get_provider_for_type,
|
||||
get_settings,
|
||||
resolve_provider,
|
||||
)
|
||||
from config.nim import NimSettings
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
from providers.exceptions import UnknownProviderTypeError
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from providers.open_router import OpenRouterProvider
|
||||
from providers.registry import ProviderRegistry
|
||||
|
||||
|
||||
def _make_mock_settings(**overrides):
|
||||
|
|
@ -304,11 +311,11 @@ async def test_get_provider_deepseek_missing_api_key():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_unknown_type():
|
||||
"""Test that unknown provider_type raises ValueError."""
|
||||
"""Unknown ``provider_type`` raises :exc:`~providers.exceptions.UnknownProviderTypeError`."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings(provider_type="unknown")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider_type"):
|
||||
with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"):
|
||||
get_provider()
|
||||
|
||||
|
||||
|
|
@ -390,3 +397,55 @@ async def test_cleanup_provider_cleans_all():
|
|||
|
||||
nim._client.aclose.assert_called_once()
|
||||
lmstudio._client.aclose.assert_called_once()
|
||||
|
||||
|
||||
def test_resolve_provider_per_app_uses_separate_registries() -> None:
|
||||
"""With app set, each app gets its own provider cache (not process _providers)."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings()
|
||||
settings = _make_mock_settings()
|
||||
app1 = SimpleNamespace(state=State())
|
||||
app2 = SimpleNamespace(state=State())
|
||||
app1.state.provider_registry = ProviderRegistry()
|
||||
app2.state.provider_registry = ProviderRegistry()
|
||||
p1 = resolve_provider(
|
||||
"nvidia_nim", app=cast(Starlette, app1), settings=settings
|
||||
)
|
||||
p2 = resolve_provider(
|
||||
"nvidia_nim", app=cast(Starlette, app2), settings=settings
|
||||
)
|
||||
assert isinstance(p1, NvidiaNimProvider)
|
||||
assert isinstance(p2, NvidiaNimProvider)
|
||||
assert p1 is not p2
|
||||
|
||||
|
||||
def test_resolve_provider_lazily_installs_registry() -> None:
|
||||
"""If app has no provider_registry, one is created on app.state."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings()
|
||||
settings = _make_mock_settings()
|
||||
app = SimpleNamespace(state=State())
|
||||
assert getattr(app.state, "provider_registry", None) is None
|
||||
resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
|
||||
reg = app.state.provider_registry
|
||||
assert reg is not None
|
||||
p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
|
||||
assert p2 is reg.get("nvidia_nim", settings) # same registry instance
|
||||
|
||||
|
||||
def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None:
|
||||
"""Only :exc:`~providers.exceptions.UnknownProviderTypeError` logs unknown provider."""
|
||||
import api.dependencies as deps
|
||||
|
||||
with (
|
||||
patch.object(deps, "get_settings", return_value=_make_mock_settings()),
|
||||
patch.object(
|
||||
ProviderRegistry,
|
||||
"get",
|
||||
side_effect=ValueError("unrelated config"),
|
||||
),
|
||||
patch.object(deps.logger, "error") as log_err,
|
||||
pytest.raises(ValueError, match="unrelated config"),
|
||||
):
|
||||
deps.resolve_provider("nvidia_nim", app=None, settings=_make_mock_settings())
|
||||
log_err.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -91,6 +91,21 @@ def test_messages_request_accepts_adaptive_thinking_type():
|
|||
assert dumped["thinking"]["type"] == "adaptive"
|
||||
|
||||
|
||||
def test_messages_request_accepts_anthropic_server_tool_without_input_schema():
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-opus-4-7",
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "search"}],
|
||||
"tools": [{"type": "web_search_20250305", "name": "web_search"}],
|
||||
}
|
||||
)
|
||||
|
||||
dumped = request.model_dump(exclude_none=True)
|
||||
|
||||
assert dumped["tools"] == [{"name": "web_search", "type": "web_search_20250305"}]
|
||||
|
||||
|
||||
def test_messages_request_accepts_redacted_thinking_blocks():
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
|
|
|
|||
96
tests/api/test_web_server_tools.py
Normal file
96
tests/api/test_web_server_tools.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from api.models.anthropic import Message, MessagesRequest, Tool
|
||||
from api.web_server_tools import (
|
||||
is_web_server_tool_request,
|
||||
stream_web_server_tool_response,
|
||||
)
|
||||
|
||||
|
||||
def _event_data(event: str) -> dict:
|
||||
data_line = next(line for line in event.splitlines() if line.startswith("data: "))
|
||||
return json.loads(data_line.removeprefix("data: "))
|
||||
|
||||
|
||||
def test_detects_web_search_server_tool_request():
|
||||
request = MessagesRequest(
|
||||
model="claude-haiku-4-5-20251001",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="search")],
|
||||
tools=[Tool(name="web_search", type="web_search_20250305")],
|
||||
)
|
||||
|
||||
assert is_web_server_tool_request(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streams_web_search_server_tool_result(monkeypatch):
|
||||
async def fake_search(query: str) -> list[dict[str, str]]:
|
||||
assert query == "DeepSeek V4 model release 2026"
|
||||
return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}]
|
||||
|
||||
monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search)
|
||||
request = MessagesRequest(
|
||||
model="claude-haiku-4-5-20251001",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
Message(
|
||||
role="user",
|
||||
content=(
|
||||
"Perform a web search for the query: DeepSeek V4 model release 2026"
|
||||
),
|
||||
)
|
||||
],
|
||||
tools=[Tool(name="web_search", type="web_search_20250305")],
|
||||
tool_choice={"type": "tool", "name": "web_search"},
|
||||
)
|
||||
|
||||
events = [
|
||||
event
|
||||
async for event in stream_web_server_tool_response(request, input_tokens=42)
|
||||
]
|
||||
payloads = [_event_data(event) for event in events]
|
||||
|
||||
assert payloads[1]["content_block"]["type"] == "server_tool_use"
|
||||
assert payloads[1]["content_block"]["name"] == "web_search"
|
||||
assert payloads[3]["content_block"]["type"] == "web_search_tool_result"
|
||||
assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4"
|
||||
assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streams_web_fetch_server_tool_result(monkeypatch):
|
||||
async def fake_fetch(url: str) -> dict[str, str]:
|
||||
assert url == "https://example.com/article"
|
||||
return {
|
||||
"url": url,
|
||||
"title": "Example Article",
|
||||
"media_type": "text/plain",
|
||||
"data": "Article body",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch)
|
||||
request = MessagesRequest(
|
||||
model="claude-haiku-4-5-20251001",
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
Message(role="user", content="Fetch https://example.com/article please")
|
||||
],
|
||||
tools=[Tool(name="web_fetch", type="web_fetch_20250910")],
|
||||
tool_choice={"type": "tool", "name": "web_fetch"},
|
||||
)
|
||||
|
||||
events = [
|
||||
event
|
||||
async for event in stream_web_server_tool_response(request, input_tokens=42)
|
||||
]
|
||||
payloads = [_event_data(event) for event in events]
|
||||
|
||||
assert payloads[1]["content_block"]["type"] == "server_tool_use"
|
||||
assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result"
|
||||
assert payloads[3]["content_block"]["content"]["content"]["title"] == (
|
||||
"Example Article"
|
||||
)
|
||||
assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1}
|
||||
|
|
@ -118,6 +118,15 @@ def mock_platform():
|
|||
platform.queue_edit_message = AsyncMock()
|
||||
platform.queue_delete_message = AsyncMock()
|
||||
|
||||
async def _queue_delete_messages(
|
||||
chat_id: str, message_ids: list[str], *, fire_and_forget: bool = True
|
||||
) -> None:
|
||||
qdm = platform.queue_delete_message
|
||||
for mid in message_ids:
|
||||
await qdm(chat_id, mid, fire_and_forget=fire_and_forget)
|
||||
|
||||
platform.queue_delete_messages = AsyncMock(side_effect=_queue_delete_messages)
|
||||
|
||||
def _fire_and_forget(task):
|
||||
if asyncio.iscoroutine(task):
|
||||
# Create a task to avoid "coroutine was never awaited" warning
|
||||
|
|
|
|||
|
|
@ -1,8 +1,20 @@
|
|||
"""Package import contract tests (static AST; dynamic ``importlib`` loads are not scanned)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
# `api` may only import this narrow ``providers`` surface (AGENTS/PLAN).
|
||||
_API_ALLOWED_PROVIDER_MODULES = frozenset(
|
||||
{
|
||||
"providers",
|
||||
"providers.base",
|
||||
"providers.exceptions",
|
||||
"providers.registry",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_api_and_messaging_do_not_import_provider_common() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
|
|
@ -25,6 +37,66 @@ def test_provider_adapters_do_not_import_runtime_layers() -> None:
|
|||
assert offenders == []
|
||||
|
||||
|
||||
def test_core_does_not_import_product_packages() -> None:
|
||||
"""Neutral ``core`` must stay independent of API, workers, and providers."""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
offenders = _imports_matching(
|
||||
[repo_root / "core"],
|
||||
forbidden_prefixes=(
|
||||
"api.",
|
||||
"messaging.",
|
||||
"cli.",
|
||||
"smoke.",
|
||||
"providers.",
|
||||
"config.",
|
||||
),
|
||||
)
|
||||
assert offenders == []
|
||||
|
||||
|
||||
def test_config_does_not_import_non_config_packages() -> None:
|
||||
"""Settings and env handling must not depend on transport or protocol layers."""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
offenders = _imports_matching(
|
||||
[repo_root / "config"],
|
||||
forbidden_prefixes=(
|
||||
"api.",
|
||||
"messaging.",
|
||||
"cli.",
|
||||
"smoke.",
|
||||
"providers.",
|
||||
"core.",
|
||||
),
|
||||
)
|
||||
assert offenders == []
|
||||
|
||||
|
||||
def test_messaging_does_not_import_api_or_cli_or_providers() -> None:
|
||||
"""Messaging is wired by ``api.runtime``; must not import server or provider adapters."""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
offenders = _imports_matching(
|
||||
[repo_root / "messaging"],
|
||||
forbidden_prefixes=("api.", "cli.", "providers.", "smoke."),
|
||||
)
|
||||
assert offenders == []
|
||||
|
||||
|
||||
def test_api_may_only_import_narrow_provider_facade() -> None:
|
||||
"""HTTP layer must not depend on per-adapter provider subpackages."""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
offenders: list[str] = []
|
||||
for path in (repo_root / "api").rglob("*.py"):
|
||||
for imported in _imports_from(path, repo_root):
|
||||
if imported is None or not imported.startswith("providers"):
|
||||
continue
|
||||
if imported in _API_ALLOWED_PROVIDER_MODULES:
|
||||
continue
|
||||
if imported.startswith("providers."):
|
||||
rel = path.relative_to(repo_root)
|
||||
offenders.append(f"{rel}: {imported}")
|
||||
assert sorted(offenders) == []
|
||||
|
||||
|
||||
def test_removed_openrouter_rollback_transport_stays_removed() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
|
@ -35,6 +107,11 @@ def test_removed_openrouter_rollback_transport_stays_removed() -> None:
|
|||
|
||||
def test_architecture_doc_names_enforced_boundaries() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
contract_test = repo_root / "tests" / "contracts" / "test_import_boundaries.py"
|
||||
assert contract_test.is_file()
|
||||
stream_contracts = repo_root / "core" / "anthropic" / "stream_contracts.py"
|
||||
assert stream_contracts.is_file()
|
||||
|
||||
text = (repo_root / "PLAN.md").read_text(encoding="utf-8")
|
||||
|
||||
assert "core/anthropic/" in text
|
||||
|
|
@ -46,26 +123,89 @@ def _imports_matching(
|
|||
roots: list[Path], *, forbidden_prefixes: tuple[str, ...]
|
||||
) -> list[str]:
|
||||
offenders: list[str] = []
|
||||
repo_root = roots[0].parent
|
||||
for root in roots:
|
||||
for path in root.rglob("*.py"):
|
||||
rel = path.relative_to(root.parent)
|
||||
offenders.extend(
|
||||
f"{rel}: {imported}"
|
||||
for imported in _imports_from(path)
|
||||
if imported in forbidden_prefixes
|
||||
or imported.startswith(forbidden_prefixes)
|
||||
for imported in _imports_from(path, repo_root)
|
||||
if imported is not None and _is_forbidden(imported, forbidden_prefixes)
|
||||
)
|
||||
return sorted(offenders)
|
||||
|
||||
|
||||
def _imports_from(path: Path) -> list[str]:
|
||||
def _is_forbidden(name: str, forbidden: tuple[str, ...]) -> bool:
|
||||
"""Match root modules (``import api``) and submodules (``import api.x``)."""
|
||||
for token in forbidden:
|
||||
if not token:
|
||||
continue
|
||||
root = token.rstrip(".")
|
||||
if name == root or name.startswith(f"{root}."):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _module_fqn_from_path(repo_root: Path, path: Path) -> str:
|
||||
rel = path.relative_to(repo_root)
|
||||
if rel.name == "__init__.py":
|
||||
return ".".join(rel.parent.parts) if rel.parent != Path() else rel.parent.name
|
||||
return ".".join(rel.with_suffix("").parts)
|
||||
|
||||
|
||||
def _importing_package_parts(repo_root: Path, path: Path) -> list[str]:
|
||||
"""Package in which this file's module lives (for relative imports)."""
|
||||
rel = path.relative_to(repo_root)
|
||||
if rel.name == "__init__.py":
|
||||
return list(rel.parent.parts)
|
||||
fqn = _module_fqn_from_path(repo_root, path)
|
||||
parts = fqn.split(".")
|
||||
if len(parts) <= 1:
|
||||
return []
|
||||
return parts[:-1]
|
||||
|
||||
|
||||
def _resolve_relative_import(
|
||||
repo_root: Path, path: Path, node: ast.ImportFrom
|
||||
) -> str | None:
|
||||
"""Best-effort absolute name for ``from .x`` / ``from ..y`` (level >= 1)."""
|
||||
if node.level == 0 and node.module:
|
||||
return node.module
|
||||
base = _importing_package_parts(repo_root, path)
|
||||
for _ in range(node.level - 1):
|
||||
if not base:
|
||||
return None
|
||||
base.pop()
|
||||
if not node.module:
|
||||
return ".".join(base) if base else None
|
||||
return ".".join(base + node.module.split("."))
|
||||
|
||||
|
||||
def _imports_from(path: Path, repo_root: Path) -> list[str]:
|
||||
tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
||||
imports: list[str] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
imports.extend(alias.name for alias in node.names)
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.level == 0:
|
||||
if node.module:
|
||||
imports.append(node.module)
|
||||
continue
|
||||
if node.module is not None:
|
||||
resolved = _resolve_relative_import(repo_root, path, node)
|
||||
if resolved:
|
||||
imports.append(resolved)
|
||||
else:
|
||||
base = _importing_package_parts(repo_root, path).copy()
|
||||
for _ in range(node.level - 1):
|
||||
if base:
|
||||
base.pop()
|
||||
for alias in node.names:
|
||||
if base:
|
||||
imports.append(".".join([*base, alias.name]))
|
||||
else:
|
||||
imports.append(alias.name)
|
||||
return imports
|
||||
|
||||
|
||||
|
|
|
|||
11
tests/contracts/test_smoke_sse_reexport.py
Normal file
11
tests/contracts/test_smoke_sse_reexport.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import core.anthropic.stream_contracts as core_sc
|
||||
import smoke.lib.sse as smoke_sse
|
||||
|
||||
|
||||
def test_smoke_lib_sse_reexports_core_stream_contracts() -> None:
|
||||
for name in smoke_sse.__all__:
|
||||
assert getattr(smoke_sse, name) is getattr(core_sc, name)
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
"""Stream/SSE contract tests. Strict transcript *ordering* is covered here for
|
||||
``SSEBuilder`` output; for transport-integrated ordering, add messaging or API
|
||||
integration tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser
|
||||
from messaging.event_parser import parse_cli_event
|
||||
from messaging.transcript import RenderCtx, TranscriptBuffer
|
||||
from smoke.lib.sse import (
|
||||
from core.anthropic.stream_contracts import (
|
||||
assert_anthropic_stream_contract,
|
||||
event_names,
|
||||
has_tool_use,
|
||||
|
|
@ -13,6 +16,8 @@ from smoke.lib.sse import (
|
|||
text_content,
|
||||
thinking_content,
|
||||
)
|
||||
from messaging.event_parser import parse_cli_event
|
||||
from messaging.transcript import RenderCtx, TranscriptBuffer
|
||||
|
||||
|
||||
def test_interleaved_thinking_text_blocks_are_valid() -> None:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@
|
|||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from messaging.platforms.factory import create_messaging_platform
|
||||
from messaging.platforms.factory import (
|
||||
MessagingPlatformOptions,
|
||||
create_messaging_platform,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateMessagingPlatform:
|
||||
|
|
@ -16,15 +19,29 @@ class TestCreateMessagingPlatform:
|
|||
patch(
|
||||
"messaging.platforms.telegram.TelegramPlatform",
|
||||
return_value=mock_platform,
|
||||
),
|
||||
) as platform_cls,
|
||||
):
|
||||
result = create_messaging_platform(
|
||||
"telegram",
|
||||
bot_token="test_token",
|
||||
allowed_user_id="12345",
|
||||
MessagingPlatformOptions(
|
||||
telegram_bot_token="test_token",
|
||||
allowed_telegram_user_id="12345",
|
||||
voice_note_enabled=False,
|
||||
whisper_model="large-v3",
|
||||
whisper_device="cuda",
|
||||
),
|
||||
)
|
||||
|
||||
assert result is mock_platform
|
||||
platform_cls.assert_called_once_with(
|
||||
bot_token="test_token",
|
||||
allowed_user_id="12345",
|
||||
voice_note_enabled=False,
|
||||
whisper_model="large-v3",
|
||||
whisper_device="cuda",
|
||||
hf_token="",
|
||||
nvidia_nim_api_key="",
|
||||
)
|
||||
|
||||
def test_telegram_without_token(self):
|
||||
"""Return None when no bot_token for Telegram."""
|
||||
|
|
@ -33,7 +50,9 @@ class TestCreateMessagingPlatform:
|
|||
|
||||
def test_telegram_empty_token(self):
|
||||
"""Return None when bot_token is empty string."""
|
||||
result = create_messaging_platform("telegram", bot_token="")
|
||||
result = create_messaging_platform(
|
||||
"telegram", MessagingPlatformOptions(telegram_bot_token="")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_discord_with_token(self):
|
||||
|
|
@ -44,15 +63,29 @@ class TestCreateMessagingPlatform:
|
|||
patch(
|
||||
"messaging.platforms.discord.DiscordPlatform",
|
||||
return_value=mock_platform,
|
||||
),
|
||||
) as platform_cls,
|
||||
):
|
||||
result = create_messaging_platform(
|
||||
"discord",
|
||||
MessagingPlatformOptions(
|
||||
discord_bot_token="test_token",
|
||||
allowed_discord_channels="123,456",
|
||||
voice_note_enabled=False,
|
||||
whisper_model="small",
|
||||
whisper_device="nvidia_nim",
|
||||
),
|
||||
)
|
||||
|
||||
assert result is mock_platform
|
||||
platform_cls.assert_called_once_with(
|
||||
bot_token="test_token",
|
||||
allowed_channel_ids="123,456",
|
||||
voice_note_enabled=False,
|
||||
whisper_model="small",
|
||||
whisper_device="nvidia_nim",
|
||||
hf_token="",
|
||||
nvidia_nim_api_key="",
|
||||
)
|
||||
|
||||
def test_discord_without_token(self):
|
||||
"""Return None when no discord_bot_token for Discord."""
|
||||
|
|
@ -62,7 +95,11 @@ class TestCreateMessagingPlatform:
|
|||
def test_discord_empty_token(self):
|
||||
"""Return None when discord_bot_token is empty string."""
|
||||
result = create_messaging_platform(
|
||||
"discord", discord_bot_token="", allowed_discord_channels="123"
|
||||
"discord",
|
||||
MessagingPlatformOptions(
|
||||
discord_bot_token="",
|
||||
allowed_discord_channels="123",
|
||||
),
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
|
@ -73,5 +110,7 @@ class TestCreateMessagingPlatform:
|
|||
|
||||
def test_unknown_platform_with_kwargs(self):
|
||||
"""Return None for unknown platform even with kwargs."""
|
||||
result = create_messaging_platform("slack", bot_token="token")
|
||||
result = create_messaging_platform(
|
||||
"slack", MessagingPlatformOptions(telegram_bot_token="token")
|
||||
)
|
||||
assert result is None
|
||||
|
|
|
|||
|
|
@ -17,18 +17,20 @@ def telegram_platform():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telegram_voice_disabled_sends_reply(telegram_platform):
|
||||
async def test_telegram_voice_disabled_sends_reply():
|
||||
"""When voice_note_enabled is False, reply with disabled message."""
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
telegram_platform = TelegramPlatform(
|
||||
bot_token="test_token",
|
||||
allowed_user_id="12345",
|
||||
voice_note_enabled=False,
|
||||
)
|
||||
mock_update = MagicMock()
|
||||
mock_update.message.voice = MagicMock(file_id="f1", mime_type="audio/ogg")
|
||||
mock_update.effective_user.id = 12345
|
||||
mock_update.effective_chat.id = 6789
|
||||
mock_update.message.reply_text = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"config.settings.get_settings",
|
||||
return_value=MagicMock(voice_note_enabled=False),
|
||||
):
|
||||
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
|
||||
|
||||
mock_update.message.reply_text.assert_called_once_with("Voice notes are disabled.")
|
||||
|
|
@ -42,10 +44,6 @@ async def test_telegram_voice_unauthorized_ignored(telegram_platform):
|
|||
mock_update.effective_user.id = 99999 # Not 12345
|
||||
mock_update.message.reply_text = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"config.settings.get_settings",
|
||||
return_value=MagicMock(voice_note_enabled=True),
|
||||
):
|
||||
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
|
||||
|
||||
mock_update.message.reply_text.assert_not_called()
|
||||
|
|
@ -82,17 +80,8 @@ async def test_telegram_voice_success_invokes_handler(telegram_platform):
|
|||
|
||||
mock_file.download_to_drive = fake_download
|
||||
|
||||
mock_settings = MagicMock(
|
||||
voice_note_enabled=True,
|
||||
whisper_model="base",
|
||||
)
|
||||
|
||||
mock_queue_send = AsyncMock(return_value="999")
|
||||
with (
|
||||
patch(
|
||||
"config.settings.get_settings",
|
||||
return_value=mock_settings,
|
||||
),
|
||||
patch(
|
||||
"messaging.transcription.transcribe_audio",
|
||||
return_value="Hello from voice",
|
||||
|
|
@ -164,7 +153,11 @@ class TestDiscordGetAudioAttachment:
|
|||
@pytest.mark.asyncio
|
||||
async def test_discord_voice_disabled_sends_reply():
|
||||
"""When voice_note_enabled is False, reply with disabled message."""
|
||||
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
|
||||
platform = DiscordPlatform(
|
||||
bot_token="token",
|
||||
allowed_channel_ids="123",
|
||||
voice_note_enabled=False,
|
||||
)
|
||||
platform._message_handler = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -178,10 +171,6 @@ async def test_discord_voice_disabled_sends_reply():
|
|||
mock_att.filename = "voice.ogg"
|
||||
mock_message.attachments = [mock_att]
|
||||
|
||||
with patch(
|
||||
"config.settings.get_settings",
|
||||
return_value=MagicMock(voice_note_enabled=False),
|
||||
):
|
||||
await platform._on_discord_message(mock_message)
|
||||
|
||||
mock_message.reply.assert_called_once_with("Voice notes are disabled.")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class MockBlock:
|
|||
|
||||
|
||||
class MockTool:
|
||||
def __init__(self, name, description, input_schema):
|
||||
def __init__(self, name, description, input_schema=None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.input_schema = input_schema
|
||||
|
|
@ -79,6 +79,23 @@ def test_convert_tools():
|
|||
assert result[1]["function"]["description"] == "" # Check default empty string
|
||||
|
||||
|
||||
def test_convert_tool_without_input_schema_uses_empty_object_schema():
|
||||
tools = [MockTool("web_search", None)]
|
||||
|
||||
result = AnthropicToOpenAIConverter.convert_tools(tools)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_choice,expected",
|
||||
[
|
||||
|
|
|
|||
|
|
@ -449,6 +449,40 @@ def test_heuristic_tool_parser_flush_no_tool():
|
|||
assert tools == []
|
||||
|
||||
|
||||
def test_heuristic_tool_parser_json_style_web_fetch_tool_call():
|
||||
parser = HeuristicToolParser()
|
||||
text = (
|
||||
"Use WebFetch on the article.\n\n"
|
||||
"{\n"
|
||||
' "url": "https://example.com/article",\n'
|
||||
' "prompt": "Summarize it."\n'
|
||||
"}\n"
|
||||
)
|
||||
|
||||
filtered, tools = parser.feed(text)
|
||||
tools.extend(parser.flush())
|
||||
|
||||
assert filtered == ""
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "WebFetch"
|
||||
assert tools[0]["input"] == {
|
||||
"url": "https://example.com/article",
|
||||
"prompt": "Summarize it.",
|
||||
}
|
||||
|
||||
|
||||
def test_heuristic_tool_parser_json_style_web_search_tool_call():
|
||||
parser = HeuristicToolParser()
|
||||
|
||||
filtered, tools = parser.feed('Use WebSearch {"query": "DeepSeek V4"}')
|
||||
tools.extend(parser.flush())
|
||||
|
||||
assert filtered == ""
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "WebSearch"
|
||||
assert tools[0]["input"] == {"query": "DeepSeek V4"}
|
||||
|
||||
|
||||
def test_heuristic_tool_parser_unicode_function_name():
|
||||
"""Unicode characters in function parameters."""
|
||||
parser = HeuristicToolParser()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from config.nim import NimSettings
|
||||
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
from providers.exceptions import UnknownProviderTypeError
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
|
@ -41,14 +45,24 @@ def _make_settings(**overrides):
|
|||
return mock
|
||||
|
||||
|
||||
def test_importing_registry_does_not_eager_load_other_adapters() -> None:
|
||||
"""Registry metadata must not import every provider adapter up front."""
|
||||
code = (
|
||||
"import sys\n"
|
||||
"import providers.registry\n"
|
||||
"assert 'providers.open_router' not in sys.modules\n"
|
||||
)
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-c", code],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
assert proc.returncode == 0, proc.stderr or proc.stdout
|
||||
|
||||
|
||||
def test_descriptors_cover_advertised_provider_ids():
|
||||
assert set(PROVIDER_DESCRIPTORS) == {
|
||||
"nvidia_nim",
|
||||
"open_router",
|
||||
"deepseek",
|
||||
"lmstudio",
|
||||
"llamacpp",
|
||||
}
|
||||
assert set(PROVIDER_DESCRIPTORS) == set(SUPPORTED_PROVIDER_IDS)
|
||||
for descriptor in PROVIDER_DESCRIPTORS.values():
|
||||
assert descriptor.provider_id
|
||||
assert descriptor.transport_type in {"openai_chat", "anthropic_messages"}
|
||||
|
|
@ -90,6 +104,38 @@ def test_provider_registry_caches_by_provider_id():
|
|||
assert first is second
|
||||
|
||||
|
||||
def test_unknown_provider_raises_value_error():
|
||||
with pytest.raises(ValueError, match="Unknown provider_type"):
|
||||
def test_unknown_provider_raises_unknown_provider_type_error():
|
||||
with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"):
|
||||
create_provider("unknown", _make_settings())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_registry_cleanup_runs_all_even_if_one_fails() -> None:
|
||||
"""Every provider gets cleanup; cache is cleared even when one raises."""
|
||||
reg = ProviderRegistry()
|
||||
p1 = MagicMock()
|
||||
p1.cleanup = AsyncMock(side_effect=RuntimeError("first"))
|
||||
p2 = MagicMock()
|
||||
p2.cleanup = AsyncMock()
|
||||
reg._providers["a"] = p1
|
||||
reg._providers["b"] = p2
|
||||
with pytest.raises(RuntimeError, match="first"):
|
||||
await reg.cleanup()
|
||||
p1.cleanup.assert_awaited_once()
|
||||
p2.cleanup.assert_awaited_once()
|
||||
assert reg._providers == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_registry_cleanup_exceptiongroup_on_multiple_failures() -> None:
|
||||
reg = ProviderRegistry()
|
||||
p1 = MagicMock()
|
||||
p1.cleanup = AsyncMock(side_effect=RuntimeError("a"))
|
||||
p2 = MagicMock()
|
||||
p2.cleanup = AsyncMock(side_effect=RuntimeError("b"))
|
||||
reg._providers["x"] = p1
|
||||
reg._providers["y"] = p2
|
||||
with pytest.raises(ExceptionGroup) as exc_info:
|
||||
await reg.cleanup()
|
||||
assert len(exc_info.value.exceptions) == 2
|
||||
assert reg._providers == {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue