feat: Anthropic web server tools, provider metadata, messaging hardening

- Add local web_search/web_fetch SSE handling and optional tool schemas
- Extend HeuristicToolParser for JSON-style WebFetch/WebSearch text
- Consolidate provider defaults, ids, and exception typing; stream contracts
- Messaging: typed options, voice config injection, platform contract cleanup
- Tests for web server tools, converters, parsers, contracts; ignore debug-*.log
This commit is contained in:
Alishahryar1 2026-04-24 23:01:14 -07:00
parent 4b89183ba0
commit b926f60f64
50 changed files with 1658 additions and 439 deletions

View file

@ -66,6 +66,7 @@ VOICE_NOTE_ENABLED=false
# WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim"
# - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local)
# - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice)
# (Independent of MODEL=nvidia_nim/...: that selects the *chat* provider; this selects voice STT only.)
WHISPER_DEVICE="nvidia_nim"
# WHISPER_MODEL:
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)

1
.gitignore vendored
View file

@ -8,6 +8,7 @@ __pycache__
agent_workspace
.env
server.log
debug-*.log
.coverage
llama_cache
.smoke-results

27
PLAN.md
View file

@ -33,20 +33,43 @@ flowchart TD
core --> providers
core --> messaging
providers --> api
api --> cli[cli]
api --> messaging
cli --> messaging
messaging --> api
```
Runtime note: `api.runtime` imports `cli` and `messaging` to wire the optional
messaging stack; `messaging` does not import `cli` (session/CLI access is passed
in from `api.runtime`).
The practical rule is simpler than the graph: shared protocol helpers belong in
neutral core modules, not under a provider package. Provider adapters may depend
on the neutral protocol layer, but API and messaging code should not import
provider internals.
The diagram above mixes **Python import direction** (e.g. `config``providers`)
with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging`).
`PLAN.md` remains the product map; **encoded** rules (including root imports like
`import api`, relative imports, and `api``providers` facade allowlists) live in
`tests/contracts/test_import_boundaries.py`.
**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`,
and `providers.registry` from the providers package (not per-adapter modules).
`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`.
`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract
assertions for default CI live under `core/anthropic/stream_contracts.py`;
`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers
(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and
unit tests; production HTTP handlers must use `resolve_provider` with
`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package
`__all__` exposes HTTP models and `create_app` only (not those helpers).
## Target Boundaries
- `core/anthropic/`: Anthropic protocol helpers, stream primitives, content
extraction, token estimation, user-facing error strings, request conversion,
thinking, and tool helpers shared across API, providers, messaging, and tests.
thinking, tool helpers, and stream contract assertions
(`stream_contracts.py`) shared across API, providers, messaging, and tests.
- `api/runtime.py`: application composition, optional messaging startup,
session store restoration, and cleanup ownership.
- `providers/`: provider descriptors, credential resolution, transport

View file

@ -1,7 +1,6 @@
"""API layer for Claude Code Proxy."""
from .app import app, create_app
from .dependencies import get_provider, get_provider_for_type
from .models import (
MessagesRequest,
MessagesResponse,
@ -16,6 +15,4 @@ __all__ = [
"TokenCountResponse",
"app",
"create_app",
"get_provider",
"get_provider_for_type",
]

View file

@ -2,8 +2,11 @@
import os
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, Request
from fastapi.exception_handlers import request_validation_exception_handler
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
@ -11,7 +14,6 @@ from config.logging_config import configure_logging
from config.settings import get_settings
from providers.exceptions import ProviderError
from .dependencies import cleanup_provider
from .routes import router
from .runtime import AppRuntime
@ -26,9 +28,7 @@ configure_logging(_settings.log_file)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
runtime = AppRuntime.for_app(
app, settings=get_settings(), provider_cleanup=cleanup_provider
)
runtime = AppRuntime.for_app(app, settings=get_settings())
await runtime.startup()
yield
@ -48,6 +48,60 @@ def create_app() -> FastAPI:
app.include_router(router)
# Exception handlers
@app.exception_handler(RequestValidationError)
async def validation_error_handler(request: Request, exc: RequestValidationError):
"""Log request shape for 422 debugging without content values."""
body: Any
try:
body = await request.json()
except Exception as e:
body = {"_json_error": type(e).__name__}
messages = body.get("messages") if isinstance(body, dict) else None
message_summary: list[dict[str, Any]] = []
if isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
message_summary.append({"message_kind": type(msg).__name__})
continue
content = msg.get("content")
item: dict[str, Any] = {
"role": msg.get("role"),
"content_kind": type(content).__name__,
}
if isinstance(content, list):
item["block_types"] = [
block.get("type", "dict")
if isinstance(block, dict)
else type(block).__name__
for block in content[:12]
]
item["block_keys"] = [
sorted(str(key) for key in block)[:12]
for block in content[:5]
if isinstance(block, dict)
]
elif isinstance(content, str):
item["content_length"] = len(content)
message_summary.append(item)
logger.debug(
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
request.url.path,
str(request.url.query),
[list(error.get("loc", ())) for error in exc.errors()],
[str(error.get("type", "")) for error in exc.errors()],
message_summary,
[
str(tool.get("name", ""))
for tool in body.get("tools", [])
if isinstance(body, dict)
and isinstance(body.get("tools"), list)
and isinstance(tool, dict)
],
)
return await request_validation_exception_handler(request, exc)
@app.exception_handler(ProviderError)
async def provider_error_handler(request: Request, exc: ProviderError):
"""Handle provider-specific errors and return Anthropic format."""

View file

@ -2,15 +2,18 @@
from fastapi import Depends, HTTPException, Request
from loguru import logger
from starlette.applications import Starlette
from config.settings import Settings
from config.settings import get_settings as _get_settings
from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider
from providers.exceptions import AuthenticationError
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
# Provider registry: keyed by provider type string, lazily populated
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
# when there is no ``Request``/``app`` (unit tests, scripts). HTTP handlers must pass
# ``app`` to :func:`resolve_provider` so the app-scoped registry is used.
_providers: dict[str, BaseProvider] = {}
@ -19,19 +22,43 @@ def get_settings() -> Settings:
return _get_settings()
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider for the given provider type.
def resolve_provider(
provider_type: str,
*,
app: Starlette | None,
settings: Settings,
) -> BaseProvider:
"""Resolve a provider using the app-scoped registry when ``app`` is set.
Providers are cached in the registry and reused across requests.
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
is always used. If the registry is missing (e.g. a test app without
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
is installed on ``app.state`` so the process cache is never mixed with
per-request app identity.
When ``app`` is ``None`` (no HTTP context), uses the process-level
:data:`_providers` cache only.
"""
should_log_init = provider_type not in _providers
if app is not None:
reg = getattr(app.state, "provider_registry", None)
if reg is None:
reg = ProviderRegistry()
app.state.provider_registry = reg
return _resolve_with_registry(reg, provider_type, settings)
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
def _resolve_with_registry(
registry: ProviderRegistry, provider_type: str, settings: Settings
) -> BaseProvider:
should_log_init = not registry.is_cached(provider_type)
try:
provider = ProviderRegistry(_providers).get(provider_type, get_settings())
provider = registry.get(provider_type, settings)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
except ValueError:
except UnknownProviderTypeError:
logger.error(
"Unknown provider_type: '{}'. Supported: {}",
provider_type,
@ -43,6 +70,15 @@ def get_provider_for_type(provider_type: str) -> BaseProvider:
return provider
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider in the process-level cache (no ``app``/Request).
For server requests, use :func:`resolve_provider` with the active
:attr:`request.app` so the app-scoped provider registry is used.
"""
return resolve_provider(provider_type, app=None, settings=get_settings())
def require_api_key(
request: Request, settings: Settings = Depends(get_settings)
) -> None:
@ -78,9 +114,11 @@ def require_api_key(
def get_provider() -> BaseProvider:
"""Get or create the default provider (based on MODEL env var).
"""Get or create the default provider (``MODEL`` / ``provider_type``).
Backward-compatible convenience for health/root endpoints and tests.
Process-cache helper for scripts, unit tests, and non-FastAPI callers. HTTP
handlers must use :func:`resolve_provider` with :attr:`request.app` so the
app-scoped :class:`~providers.registry.ProviderRegistry` is used.
"""
return get_provider_for_type(get_settings().provider_type)

View file

@ -75,8 +75,11 @@ class Message(BaseModel):
class Tool(BaseModel):
name: str
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
# may omit ``input_schema`` because the provider owns the schema.
type: str | None = None
description: str | None = None
input_schema: dict[str, Any]
input_schema: dict[str, Any] | None = None
class ThinkingConfig(BaseModel):

View file

@ -6,7 +6,8 @@ from loguru import logger
from config.settings import Settings
from core.anthropic import get_token_count
from .dependencies import get_provider_for_type, get_settings, require_api_key
from . import dependencies
from .dependencies import get_settings, require_api_key
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import ModelResponse, ModelsListResponse
from .services import ClaudeProxyService
@ -54,12 +55,15 @@ SUPPORTED_CLAUDE_MODELS = [
def get_proxy_service(
request: Request,
settings: Settings = Depends(get_settings),
) -> ClaudeProxyService:
"""Build the request service for route handlers."""
return ClaudeProxyService(
settings,
provider_getter=get_provider_for_type,
provider_getter=lambda provider_type: dependencies.resolve_provider(
provider_type, app=request.app, settings=settings
),
token_counter=get_token_count,
)

View file

@ -4,16 +4,20 @@ from __future__ import annotations
import asyncio
import os
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from fastapi import FastAPI
from loguru import logger
from config.settings import Settings, get_settings
from providers.registry import ProviderRegistry
from .dependencies import cleanup_provider
if TYPE_CHECKING:
from cli.manager import CLISessionManager
from messaging.handler import ClaudeMessageHandler
from messaging.platforms.base import MessagingPlatform
from messaging.session import SessionStore
_SHUTDOWN_TIMEOUT_S = 5.0
@ -32,8 +36,7 @@ async def best_effort(
def warn_if_process_auth_token(settings: Settings) -> None:
"""Warn when server auth was implicitly inherited from the shell."""
uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None)
if callable(uses_process_token) and uses_process_token():
if settings.uses_process_anthropic_auth_token():
logger.warning(
"ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
"a configured .env file. The proxy will require that token. Add "
@ -48,32 +51,29 @@ class AppRuntime:
app: FastAPI
settings: Settings
provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider
messaging_platform: Any = None
message_handler: Any = None
cli_manager: Any = None
_provider_registry: ProviderRegistry | None = field(default=None, init=False)
messaging_platform: MessagingPlatform | None = None
message_handler: ClaudeMessageHandler | None = None
cli_manager: CLISessionManager | None = None
@classmethod
def for_app(
cls,
app: FastAPI,
settings: Settings | None = None,
provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider,
) -> AppRuntime:
return cls(
app=app,
settings=settings or get_settings(),
provider_cleanup=provider_cleanup,
)
return cls(app=app, settings=settings or get_settings())
async def startup(self) -> None:
logger.info("Starting Claude Code Proxy...")
self._provider_registry = ProviderRegistry()
self.app.state.provider_registry = self._provider_registry
warn_if_process_auth_token(self.settings)
await self._start_messaging_if_configured()
self._publish_state()
async def shutdown(self) -> None:
if self.message_handler and hasattr(self.message_handler, "session_store"):
if self.message_handler is not None:
try:
self.message_handler.session_store.flush_pending_save()
except Exception as e:
@ -84,20 +84,33 @@ class AppRuntime:
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
if self.cli_manager:
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
await best_effort("cleanup_provider", self.provider_cleanup())
if self._provider_registry is not None:
await best_effort(
"provider_registry.cleanup", self._provider_registry.cleanup()
)
await self._shutdown_limiter()
logger.info("Server shut down cleanly")
async def _start_messaging_if_configured(self) -> None:
try:
from messaging.platforms.factory import create_messaging_platform
from messaging.platforms.factory import (
MessagingPlatformOptions,
create_messaging_platform,
)
self.messaging_platform = create_messaging_platform(
platform_type=self.settings.messaging_platform,
bot_token=self.settings.telegram_bot_token,
allowed_user_id=self.settings.allowed_telegram_user_id,
self.settings.messaging_platform,
MessagingPlatformOptions(
telegram_bot_token=self.settings.telegram_bot_token,
allowed_telegram_user_id=self.settings.allowed_telegram_user_id,
discord_bot_token=self.settings.discord_bot_token,
allowed_discord_channels=self.settings.allowed_discord_channels,
voice_note_enabled=self.settings.voice_note_enabled,
whisper_model=self.settings.whisper_model,
whisper_device=self.settings.whisper_device,
hf_token=self.settings.hf_token,
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
),
)
if self.messaging_platform:
@ -137,29 +150,31 @@ class AppRuntime:
api_url=api_url,
allowed_dirs=allowed_dirs,
plans_directory=plans_directory,
claude_bin=getattr(self.settings, "claude_cli_bin", "claude"),
claude_bin=self.settings.claude_cli_bin,
)
session_store = SessionStore(
storage_path=os.path.join(data_path, "sessions.json")
)
platform = self.messaging_platform
assert platform is not None
self.message_handler = ClaudeMessageHandler(
platform=self.messaging_platform,
platform=platform,
cli_manager=self.cli_manager,
session_store=session_store,
)
self._restore_tree_state(session_store)
self.messaging_platform.on_message(self.message_handler.handle_message)
await self.messaging_platform.start()
logger.info(
f"{self.messaging_platform.name} platform started with message handler"
)
platform.on_message(self.message_handler.handle_message)
await platform.start()
logger.info(f"{platform.name} platform started with message handler")
def _restore_tree_state(self, session_store: Any) -> None:
def _restore_tree_state(self, session_store: SessionStore) -> None:
saved_trees = session_store.get_all_trees()
if not saved_trees:
return
if self.message_handler is None:
return
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
from messaging.trees.queue_manager import TreeQueueManager
@ -188,11 +203,16 @@ class AppRuntime:
async def _shutdown_limiter(self) -> None:
try:
from messaging.limiter import MessagingRateLimiter
except Exception as e:
logger.debug(
"Rate limiter shutdown skipped (import failed): {}: {}",
type(e).__name__,
e,
)
return
await best_effort(
"MessagingRateLimiter.shutdown_instance",
MessagingRateLimiter.shutdown_instance(),
timeout_s=2.0,
)
except Exception:
pass

View file

@ -20,6 +20,10 @@ from .model_router import ModelRouter
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
from .web_server_tools import (
is_web_server_tool_request,
stream_web_server_tool_response,
)
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
@ -48,6 +52,22 @@ class ClaudeProxyService:
raise InvalidRequestError("messages cannot be empty")
routed = self._model_router.resolve_messages_request(request_data)
if is_web_server_tool_request(routed.request):
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
logger.info("Optimization: Handling Anthropic web server tool")
return StreamingResponse(
stream_web_server_tool_response(
routed.request, input_tokens=input_tokens
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
optimized = try_optimizations(routed.request, self._settings)
if optimized is not None:

331
api/web_server_tools.py Normal file
View file

@ -0,0 +1,331 @@
"""Local handlers for Anthropic web server tools.
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
web tools are server-side: the API response itself must include the tool result.
"""
from __future__ import annotations
import html
import json
import re
import uuid
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from html.parser import HTMLParser
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse
import httpx
from .models.anthropic import MessagesRequest
_REQUEST_TIMEOUT_S = 20.0
_MAX_SEARCH_RESULTS = 10
_MAX_FETCH_CHARS = 24_000
class _SearchResultParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.results: list[dict[str, str]] = []
self._href: str | None = None
self._title_parts: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag != "a":
return
href = dict(attrs).get("href")
if not href or "uddg=" not in href:
return
parsed = urlparse(href)
query = parse_qs(parsed.query)
uddg = query.get("uddg", [""])[0]
if not uddg:
return
self._href = unquote(uddg)
self._title_parts = []
def handle_data(self, data: str) -> None:
if self._href is not None:
self._title_parts.append(data)
def handle_endtag(self, tag: str) -> None:
if tag != "a" or self._href is None:
return
title = " ".join("".join(self._title_parts).split())
if title and not any(result["url"] == self._href for result in self.results):
self.results.append({"title": html.unescape(title), "url": self._href})
self._href = None
self._title_parts = []
class _HTMLTextParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.title = ""
self.text_parts: list[str] = []
self._in_title = False
self._skip_depth = 0
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag in {"script", "style", "noscript"}:
self._skip_depth += 1
elif tag == "title":
self._in_title = True
def handle_endtag(self, tag: str) -> None:
if tag in {"script", "style", "noscript"} and self._skip_depth:
self._skip_depth -= 1
elif tag == "title":
self._in_title = False
def handle_data(self, data: str) -> None:
text = " ".join(data.split())
if not text:
return
if self._in_title:
self.title = f"{self.title} {text}".strip()
elif not self._skip_depth:
self.text_parts.append(text)
def _format_event(event_type: str, data: dict[str, Any]) -> str:
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
def _content_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(str(item.get("text", "")))
else:
parts.append(str(getattr(item, "text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def _request_text(request: MessagesRequest) -> str:
return "\n".join(_content_text(message.content) for message in request.messages)
def _web_tool_name(request: MessagesRequest) -> str | None:
for tool in request.tools or []:
name = tool.name
tool_type = tool.type or ""
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
return name
return None
def is_web_server_tool_request(request: MessagesRequest) -> bool:
return _web_tool_name(request) in {"web_search", "web_fetch"}
def _extract_query(text: str) -> str:
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip().strip("\"'")
return text.strip()
def _extract_url(text: str) -> str:
match = re.search(r"https?://\S+", text)
return match.group(0).rstrip(").,]") if match else text.strip()
async def _run_web_search(query: str) -> list[dict[str, str]]:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
) as client:
response = await client.get(
"https://lite.duckduckgo.com/lite/",
params={"q": query},
)
response.raise_for_status()
parser = _SearchResultParser()
parser.feed(response.text)
return parser.results[:_MAX_SEARCH_RESULTS]
async def _run_web_fetch(url: str) -> dict[str, str]:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
) as client:
response = await client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "text/plain")
title = url
data = response.text
if "html" in content_type.lower():
parser = _HTMLTextParser()
parser.feed(response.text)
title = parser.title or url
data = "\n".join(parser.text_parts)
return {
"url": str(response.url),
"title": title,
"media_type": "text/plain",
"data": data[:_MAX_FETCH_CHARS],
}
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
if not results:
return f"No web search results found for: {query}"
lines = [f"Search results for: {query}"]
for index, result in enumerate(results, start=1):
lines.append(f"{index}. {result['title']}\n{result['url']}")
return "\n\n".join(lines)
async def stream_web_server_tool_response(
request: MessagesRequest, input_tokens: int
) -> AsyncIterator[str]:
tool_name = _web_tool_name(request)
if tool_name is None:
return
text = _request_text(request)
message_id = f"msg_{uuid.uuid4()}"
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
output_tokens = 1
usage_key = (
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
)
tool_input = (
{"query": _extract_query(text)}
if tool_name == "web_search"
else {"url": _extract_url(text)}
)
yield _format_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": request.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
},
},
)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": "server_tool_use",
"id": tool_id,
"name": tool_name,
"input": tool_input,
},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 0}
)
try:
if tool_name == "web_search":
query = str(tool_input["query"])
results = await _run_web_search(query)
result_content: Any = [
{
"type": "web_search_result",
"title": result["title"],
"url": result["url"],
}
for result in results
]
summary = _search_summary(query, results)
result_block_type = "web_search_tool_result"
else:
fetched = await _run_web_fetch(str(tool_input["url"]))
result_content = {
"type": "web_fetch_result",
"url": fetched["url"],
"content": {
"type": "document",
"source": {
"type": "text",
"media_type": fetched["media_type"],
"data": fetched["data"],
},
"title": fetched["title"],
"citations": {"enabled": True},
},
"retrieved_at": datetime.now(UTC).isoformat(),
}
summary = fetched["data"][:_MAX_FETCH_CHARS]
result_block_type = "web_fetch_tool_result"
except Exception as error:
result_block_type = (
"web_search_tool_result"
if tool_name == "web_search"
else "web_fetch_tool_result"
)
error_type = (
"web_search_tool_result_error"
if tool_name == "web_search"
else "web_fetch_tool_error"
)
result_content = {"type": error_type, "error_code": "unavailable"}
summary = f"{tool_name} failed: {type(error).__name__}"
output_tokens = max(1, len(summary) // 4)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 1,
"content_block": {
"type": result_block_type,
"tool_use_id": tool_id,
"content": result_content,
},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 1}
)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 2,
"content_block": {"type": "text", "text": summary},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 2}
)
yield _format_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"server_tool_use": {usage_key: 1},
},
},
)
yield _format_event("message_stop", {"type": "message_stop"})

17
config/provider_ids.py Normal file
View file

@ -0,0 +1,17 @@
"""Canonical list of model provider type prefixes (provider_id values).
`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this
module holds the id set for config validation and must stay in sync
(registries assert in `providers.registry`).
"""
from __future__ import annotations
# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys.
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = (
"nvidia_nim",
"open_router",
"deepseek",
"lmstudio",
"llamacpp",
)

View file

@ -11,6 +11,7 @@ from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from .nim import NimSettings
from .provider_ids import SUPPORTED_PROVIDER_IDS
def _env_files() -> tuple[Path, ...]:
@ -252,25 +253,16 @@ class Settings(BaseSettings):
def validate_model_format(cls, v: str | None) -> str | None:
if v is None:
return None
valid_providers = (
"nvidia_nim",
"open_router",
"deepseek",
"lmstudio",
"llamacpp",
)
if "/" not in v:
raise ValueError(
f"Model must be prefixed with provider type. "
f"Valid providers: {', '.join(valid_providers)}. "
f"Valid providers: {', '.join(SUPPORTED_PROVIDER_IDS)}. "
f"Format: provider_type/model/name"
)
provider = v.split("/", 1)[0]
if provider not in valid_providers:
raise ValueError(
f"Invalid provider: '{provider}'. "
f"Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'"
)
if provider not in SUPPORTED_PROVIDER_IDS:
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDER_IDS)
raise ValueError(f"Invalid provider: '{provider}'. Supported: {supported}")
return v
@model_validator(mode="after")

View file

@ -7,6 +7,17 @@ from .content import get_block_attr, get_block_type
from .utils import set_if_not_none
def _tool_name(tool: Any) -> str:
return str(getattr(tool, "name", "") or "")
def _tool_input_schema(tool: Any) -> dict[str, Any]:
schema = getattr(tool, "input_schema", None)
if isinstance(schema, dict):
return schema
return {"type": "object", "properties": {}}
class AnthropicToOpenAIConverter:
"""Convert Anthropic message format to OpenAI-compatible format."""
@ -140,7 +151,7 @@ class AnthropicToOpenAIConverter:
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.input_schema,
"parameters": _tool_input_schema(tool),
},
}
for tool in tools

View file

@ -0,0 +1,158 @@
"""Neutral SSE parsing and Anthropic stream shape assertions.
Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for
opt-in smoke scenarios.
"""
from __future__ import annotations
import json
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class SSEEvent:
event: str
data: dict[str, Any]
raw: str
def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]:
events: list[SSEEvent] = []
current_event = ""
data_parts: list[str] = []
raw_parts: list[str] = []
for line in lines:
stripped = line.rstrip("\r\n")
if stripped == "":
_append_event(events, current_event, data_parts, raw_parts)
current_event = ""
data_parts = []
raw_parts = []
continue
raw_parts.append(stripped)
if stripped.startswith("event:"):
current_event = stripped.split(":", 1)[1].strip()
elif stripped.startswith("data:"):
data_parts.append(stripped.split(":", 1)[1].strip())
_append_event(events, current_event, data_parts, raw_parts)
return events
def parse_sse_text(text: str) -> list[SSEEvent]:
return parse_sse_lines(text.splitlines())
def _append_event(
events: list[SSEEvent],
current_event: str,
data_parts: list[str],
raw_parts: list[str],
) -> None:
if not current_event and not data_parts:
return
data_text = "\n".join(data_parts)
data: dict[str, Any]
try:
parsed = json.loads(data_text) if data_text else {}
data = parsed if isinstance(parsed, dict) else {"value": parsed}
except json.JSONDecodeError:
data = {"raw": data_text}
events.append(SSEEvent(current_event, data, "\n".join(raw_parts)))
def assert_anthropic_stream_contract(
events: list[SSEEvent], *, allow_error: bool = False
) -> None:
"""Check minimal Anthropic-style SSE invariants: start/stop, block nesting.
Does *not* assert strict event ordering (e.g. :class:`message_delta` vs
content blocks) beyond presence of a final ``message_stop``; stricter
ordering can be tested in product or transport-specific suites.
"""
assert events, "stream produced no SSE events"
event_names = [event.event for event in events]
assert "message_start" in event_names, event_names
assert event_names[-1] == "message_stop", event_names
open_blocks: dict[int, str] = {}
seen_blocks: set[int] = set()
for event in events:
if event.event == "error" and not allow_error:
raise AssertionError(f"unexpected SSE error event: {event.data}")
if event.event == "content_block_start":
index = _event_index(event)
block = event.data.get("content_block", {})
assert isinstance(block, dict), event.data
block_type = str(block.get("type", ""))
assert block_type in {"text", "thinking", "tool_use"}, event.data
assert index not in open_blocks, f"block {index} started twice"
assert index not in seen_blocks, f"block {index} reused after stop"
open_blocks[index] = block_type
seen_blocks.add(index)
continue
if event.event == "content_block_delta":
index = _event_index(event)
assert index in open_blocks, f"delta for unopened block {index}"
delta = event.data.get("delta", {})
assert isinstance(delta, dict), event.data
delta_type = str(delta.get("type", ""))
expected = {
"text": "text_delta",
"thinking": "thinking_delta",
"tool_use": "input_json_delta",
}[open_blocks[index]]
assert delta_type == expected, (
f"block {index} is {open_blocks[index]}, got {delta_type}"
)
continue
if event.event == "content_block_stop":
index = _event_index(event)
assert index in open_blocks, f"stop for unopened block {index}"
open_blocks.pop(index)
assert not open_blocks, f"unclosed blocks: {open_blocks}"
assert seen_blocks, "stream did not emit any content blocks"
def event_names(events: list[SSEEvent]) -> list[str]:
return [event.event for event in events]
def text_content(events: list[SSEEvent]) -> str:
parts: list[str] = []
for event in events:
delta = event.data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "text_delta":
parts.append(str(delta.get("text", "")))
return "".join(parts)
def thinking_content(events: list[SSEEvent]) -> str:
parts: list[str] = []
for event in events:
delta = event.data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "thinking_delta":
parts.append(str(delta.get("thinking", "")))
return "".join(parts)
def has_tool_use(events: list[SSEEvent]) -> bool:
for event in events:
block = event.data.get("content_block", {})
if isinstance(block, dict) and block.get("type") == "tool_use":
return True
return False
def _event_index(event: SSEEvent) -> int:
value = event.data.get("index")
assert isinstance(value, int), event.data
return value

View file

@ -1,5 +1,6 @@
"""Heuristic parser for text-emitted tool calls."""
import json
import re
import uuid
from enum import Enum
@ -31,6 +32,9 @@ class HeuristicToolParser:
_PARAM_PATTERN = re.compile(
r"<parameter=([^>]+)>(.*?)(?:</parameter>|$)", re.DOTALL
)
_WEB_TOOL_JSON_PATTERN = re.compile(
r"(?is)\b(?:use\s+)?(?P<tool>WebFetch|WebSearch)\b.*?(?P<json>\{.*?\})"
)
def __init__(self):
self._state = ParserState.TEXT
@ -39,6 +43,41 @@ class HeuristicToolParser:
self._current_function_name = None
self._current_parameters = {}
def _extract_web_tool_json_calls(self) -> tuple[str, list[dict[str, Any]]]:
detected_tools: list[dict[str, Any]] = []
for match in self._WEB_TOOL_JSON_PATTERN.finditer(self._buffer):
try:
tool_input = json.loads(match.group("json"))
except json.JSONDecodeError:
continue
if not isinstance(tool_input, dict):
continue
tool_name = match.group("tool")
if tool_name == "WebFetch" and "url" not in tool_input:
continue
if tool_name == "WebSearch" and "query" not in tool_input:
continue
detected_tools.append(
{
"type": "tool_use",
"id": f"toolu_heuristic_{uuid.uuid4().hex[:8]}",
"name": tool_name,
"input": tool_input,
}
)
logger.debug(
"Heuristic bypass: Detected JSON-style tool call '{}'",
tool_name,
)
if not detected_tools:
return self._buffer, []
return "", detected_tools
def _strip_control_tokens(self, text: str) -> str:
return _CONTROL_TOKEN_RE.sub("", text)
@ -58,7 +97,7 @@ class HeuristicToolParser:
"""Feed text and return safe text plus detected tool calls."""
self._buffer += text
self._buffer = self._strip_control_tokens(self._buffer)
detected_tools = []
self._buffer, detected_tools = self._extract_web_tool_json_calls()
filtered_output_parts: list[str] = []
while True:

View file

@ -114,23 +114,15 @@ async def _delete_message_ids(
numeric.sort(reverse=True)
ordered = [mid for _, mid in numeric] + non_numeric
batch_fn = getattr(handler.platform, "queue_delete_messages", None)
if callable(batch_fn):
try:
CHUNK = 100
for i in range(0, len(ordered), CHUNK):
chunk = ordered[i : i + CHUNK]
await batch_fn(chat_id, chunk, fire_and_forget=False)
except Exception as e:
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
else:
for mid in ordered:
try:
await handler.platform.queue_delete_message(
chat_id, mid, fire_and_forget=False
await handler.platform.queue_delete_messages(
chat_id, chunk, fire_and_forget=False
)
except Exception as e:
logger.debug(f"Delete failed for msg {mid}: {type(e).__name__}: {e}")
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
async def _handle_clear_branch(

View file

@ -467,7 +467,7 @@ class ClaudeMessageHandler:
status,
len(display),
)
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
if os.getenv("DEBUG_PLATFORM_EDITS") == "1":
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
else:
head = display[:500]

View file

@ -17,15 +17,10 @@ class CLISession(Protocol):
def start_task(
self, prompt: str, session_id: str | None = None, fork_session: bool = False
) -> AsyncGenerator[dict, Any]:
"""Start a task in the CLI session."""
...
) -> AsyncGenerator[dict, Any]: ...
@property
@abstractmethod
def is_busy(self) -> bool:
"""Check if session is busy."""
pass
def is_busy(self) -> bool: ...
@runtime_checkable
@ -101,7 +96,8 @@ class MessagingPlatform(ABC):
text: Message content
reply_to: Optional message ID to reply to
parse_mode: Optional formatting mode ("markdown", "html")
message_thread_id: Optional forum topic ID (Telegram)
message_thread_id: Optional thread or topic id for threaded channels
(e.g. forum topics); unused on platforms that do not support it.
Returns:
The message ID of the sent message
@ -192,6 +188,22 @@ class MessagingPlatform(ABC):
"""
pass
async def queue_delete_messages(
self,
chat_id: str,
message_ids: list[str],
*,
fire_and_forget: bool = True,
) -> None:
"""Delete many messages; default loops :meth:`queue_delete_message`.
Adapters with native bulk delete should override.
"""
for mid in message_ids:
await self.queue_delete_message(
chat_id, mid, fire_and_forget=fire_and_forget
)
@abstractmethod
def on_message(
self,

View file

@ -91,6 +91,12 @@ class DiscordPlatform(MessagingPlatform):
self,
bot_token: str | None = None,
allowed_channel_ids: str | None = None,
*,
voice_note_enabled: bool = True,
whisper_model: str = "base",
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
):
if not DISCORD_AVAILABLE:
raise ImportError(
@ -117,7 +123,13 @@ class DiscordPlatform(MessagingPlatform):
self._limiter: Any | None = None
self._start_task: asyncio.Task | None = None
self._pending_voice = PendingVoiceRegistry()
self._voice_transcription = VoiceTranscriptionService()
self._voice_transcription = VoiceTranscriptionService(
hf_token=hf_token,
nvidia_nim_api_key=nvidia_nim_api_key,
)
self._voice_note_enabled = voice_note_enabled
self._whisper_model = whisper_model
self._whisper_device = whisper_device
async def _handle_client_message(self, message: Any) -> None:
"""Adapter entry point used by the internal discord client."""
@ -154,10 +166,7 @@ class DiscordPlatform(MessagingPlatform):
self, message: Any, attachment: Any, channel_id: str
) -> bool:
"""Handle voice/audio attachment. Returns True if handled."""
from config.settings import get_settings
settings = get_settings()
if not settings.voice_note_enabled:
if not self._voice_note_enabled:
await message.reply("Voice notes are disabled.")
return True
@ -201,8 +210,8 @@ class DiscordPlatform(MessagingPlatform):
transcribed = await self._voice_transcription.transcribe(
tmp_path,
ct,
whisper_model=settings.whisper_model,
whisper_device=settings.whisper_device,
whisper_model=self._whisper_model,
whisper_device=self._whisper_device,
)
if not await self._is_voice_still_pending(channel_id, message_id):

View file

@ -6,30 +6,50 @@ To add a new platform (e.g. Discord, Slack):
2. Add a case to create_messaging_platform() below
"""
from __future__ import annotations
from dataclasses import dataclass
from loguru import logger
from .base import MessagingPlatform
@dataclass(frozen=True, slots=True)
class MessagingPlatformOptions:
"""Typed wiring from :class:`~api.runtime.AppRuntime` into platform adapters."""
telegram_bot_token: str | None = None
allowed_telegram_user_id: str | None = None
discord_bot_token: str | None = None
allowed_discord_channels: str | None = None
voice_note_enabled: bool = True
whisper_model: str = "base"
whisper_device: str = "cpu"
hf_token: str = ""
nvidia_nim_api_key: str = ""
def create_messaging_platform(
platform_type: str,
**kwargs,
options: MessagingPlatformOptions | None = None,
) -> MessagingPlatform | None:
"""Create a messaging platform instance based on type.
Args:
platform_type: Platform identifier ("telegram", "discord", etc.)
**kwargs: Platform-specific configuration passed to the constructor.
platform_type: Platform identifier (``telegram``, ``discord``, ``none``).
options: Token, allowlist, and voice / transcription settings.
Returns:
Configured MessagingPlatform instance, or None if not configured.
Configured :class:`MessagingPlatform` instance, or None if not configured.
"""
opts = options or MessagingPlatformOptions()
if platform_type == "none":
logger.info("Messaging platform disabled by configuration")
return None
if platform_type == "telegram":
bot_token = kwargs.get("bot_token")
bot_token = opts.telegram_bot_token
if not bot_token:
logger.info("No Telegram bot token configured, skipping platform setup")
return None
@ -38,11 +58,16 @@ def create_messaging_platform(
return TelegramPlatform(
bot_token=bot_token,
allowed_user_id=kwargs.get("allowed_user_id"),
allowed_user_id=opts.allowed_telegram_user_id,
voice_note_enabled=opts.voice_note_enabled,
whisper_model=opts.whisper_model,
whisper_device=opts.whisper_device,
hf_token=opts.hf_token,
nvidia_nim_api_key=opts.nvidia_nim_api_key,
)
if platform_type == "discord":
bot_token = kwargs.get("discord_bot_token")
bot_token = opts.discord_bot_token
if not bot_token:
logger.info("No Discord bot token configured, skipping platform setup")
return None
@ -51,7 +76,12 @@ def create_messaging_platform(
return DiscordPlatform(
bot_token=bot_token,
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
allowed_channel_ids=opts.allowed_discord_channels,
voice_note_enabled=opts.voice_note_enabled,
whisper_model=opts.whisper_model,
whisper_device=opts.whisper_device,
hf_token=opts.hf_token,
nvidia_nim_api_key=opts.nvidia_nim_api_key,
)
logger.warning(

View file

@ -62,6 +62,12 @@ class TelegramPlatform(MessagingPlatform):
self,
bot_token: str | None = None,
allowed_user_id: str | None = None,
*,
voice_note_enabled: bool = True,
whisper_model: str = "base",
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
):
if not TELEGRAM_AVAILABLE:
raise ImportError(
@ -84,7 +90,13 @@ class TelegramPlatform(MessagingPlatform):
self._limiter: Any | None = None # Will be MessagingRateLimiter
# Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id)
self._pending_voice = PendingVoiceRegistry()
self._voice_transcription = VoiceTranscriptionService()
self._voice_transcription = VoiceTranscriptionService(
hf_token=hf_token,
nvidia_nim_api_key=nvidia_nim_api_key,
)
self._voice_note_enabled = voice_note_enabled
self._whisper_model = whisper_model
self._whisper_device = whisper_device
async def _register_pending_voice(
self, chat_id: str, voice_msg_id: str, status_msg_id: str
@ -544,10 +556,7 @@ class TelegramPlatform(MessagingPlatform):
):
return
from config.settings import get_settings
settings = get_settings()
if not settings.voice_note_enabled:
if not self._voice_note_enabled:
await update.message.reply_text("Voice notes are disabled.")
return
@ -600,8 +609,8 @@ class TelegramPlatform(MessagingPlatform):
transcribed = await self._voice_transcription.transcribe(
tmp_path,
voice.mime_type or "audio/ogg",
whisper_model=settings.whisper_model,
whisper_device=settings.whisper_device,
whisper_model=self._whisper_model,
whisper_device=self._whisper_device,
)
if not await self._is_voice_still_pending(chat_id, message_id):

View file

@ -42,8 +42,8 @@ _MODEL_MAP: dict[str, str] = {
"large-v3-turbo": "openai/whisper-large-v3-turbo",
}
# Lazy-loaded pipelines: (model_id, device) -> pipeline
_pipeline_cache: dict[tuple[str, str], Any] = {}
# Lazy-loaded pipelines: (model_id, device, hf_token_fingerprint) -> pipeline
_pipeline_cache: dict[tuple[str, str, str], Any] = {}
def _resolve_model_id(whisper_model: str) -> str:
@ -51,20 +51,22 @@ def _resolve_model_id(whisper_model: str) -> str:
return _MODEL_MAP.get(whisper_model, whisper_model)
def _get_pipeline(model_id: str, device: str) -> Any:
def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any:
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
global _pipeline_cache
if device not in ("cpu", "cuda"):
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
cache_key = (model_id, device)
resolved_token = (
hf_token if hf_token is not None else get_settings().hf_token
) or ""
cache_key = (model_id, device, resolved_token)
if cache_key not in _pipeline_cache:
try:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
token = get_settings().hf_token
if token:
os.environ["HF_TOKEN"] = token
if resolved_token:
os.environ["HF_TOKEN"] = resolved_token
use_cuda = device == "cuda" and torch.cuda.is_available()
pipe_device = "cuda:0" if use_cuda else "cpu"
@ -103,6 +105,8 @@ def transcribe_audio(
*,
whisper_model: str = "base",
whisper_device: str = "cpu",
hf_token: str = "",
nvidia_nim_api_key: str = "",
) -> str:
"""
Transcribe audio file to text.
@ -136,9 +140,12 @@ def transcribe_audio(
)
if whisper_device == "nvidia_nim":
return _transcribe_nim(file_path, whisper_model)
else:
return _transcribe_local(file_path, whisper_model, whisper_device)
return _transcribe_nim(
file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key
)
return _transcribe_local(
file_path, whisper_model, whisper_device, hf_token=hf_token
)
# Whisper expects 16 kHz sample rate
@ -153,10 +160,17 @@ def _load_audio(file_path: Path) -> dict[str, Any]:
return {"array": waveform, "sampling_rate": sr}
def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str) -> str:
def _transcribe_local(
file_path: Path,
whisper_model: str,
whisper_device: str,
*,
hf_token: str = "",
) -> str:
"""Transcribe using transformers Whisper pipeline."""
model_id = _resolve_model_id(whisper_model)
pipe = _get_pipeline(model_id, whisper_device)
token: str | None = hf_token if hf_token else None
pipe = _get_pipeline(model_id, whisper_device, hf_token=token)
audio = _load_audio(file_path)
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
text = result.get("text", "") or ""
@ -167,7 +181,9 @@ def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str)
return result_text or "(no speech detected)"
def _transcribe_nim(file_path: Path, model: str) -> str:
def _transcribe_nim(
file_path: Path, model: str, *, nvidia_nim_api_key: str = ""
) -> str:
"""Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
try:
import riva.client
@ -177,8 +193,7 @@ def _transcribe_nim(file_path: Path, model: str) -> str:
"Install with: uv sync --extra voice"
) from e
settings = get_settings()
api_key = settings.nvidia_nim_api_key
api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key
# Look up function ID and language code from model mapping
model_config = _NIM_MODEL_MAP.get(model)

View file

@ -46,6 +46,15 @@ class PendingVoiceRegistry:
class VoiceTranscriptionService:
"""Run configured transcription backends off the event loop."""
def __init__(
self,
*,
hf_token: str = "",
nvidia_nim_api_key: str = "",
) -> None:
self._hf_token = hf_token
self._nvidia_nim_api_key = nvidia_nim_api_key
async def transcribe(
self,
file_path: Path,
@ -62,4 +71,6 @@ class VoiceTranscriptionService:
mime_type,
whisper_model=whisper_model,
whisper_device=whisper_device,
hf_token=self._hf_token,
nvidia_nim_api_key=self._nvidia_nim_api_key,
)

View file

@ -1,8 +1,11 @@
"""Providers package - implement your own provider by extending BaseProvider."""
"""Providers package - implement your own provider by extending BaseProvider.
Concrete adapters (e.g. ``NvidiaNimProvider``) live in subpackages; import them
from ``providers.nvidia_nim`` etc. to avoid loading every adapter when the
``providers`` package is imported.
"""
from .anthropic_messages import AnthropicMessagesTransport
from .base import BaseProvider, ProviderConfig
from .deepseek import DeepSeekProvider
from .exceptions import (
APIError,
AuthenticationError,
@ -10,27 +13,17 @@ from .exceptions import (
OverloadedError,
ProviderError,
RateLimitError,
UnknownProviderTypeError,
)
from .llamacpp import LlamaCppProvider
from .lmstudio import LMStudioProvider
from .nvidia_nim import NvidiaNimProvider
from .open_router import OpenRouterProvider
from .openai_compat import OpenAIChatTransport
__all__ = [
"APIError",
"AnthropicMessagesTransport",
"AuthenticationError",
"BaseProvider",
"DeepSeekProvider",
"InvalidRequestError",
"LMStudioProvider",
"LlamaCppProvider",
"NvidiaNimProvider",
"OpenAIChatTransport",
"OpenRouterProvider",
"OverloadedError",
"ProviderConfig",
"ProviderError",
"RateLimitError",
"UnknownProviderTypeError",
]

View file

@ -3,12 +3,11 @@
from typing import Any
from providers.base import ProviderConfig
from providers.defaults import DEEPSEEK_BASE_URL
from providers.openai_compat import OpenAIChatTransport
from .request import build_request_body
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
class DeepSeekProvider(OpenAIChatTransport):
"""DeepSeek provider using OpenAI-compatible chat completions."""

19
providers/defaults.py Normal file
View file

@ -0,0 +1,19 @@
"""Default upstream base URLs and shared provider constants.
Adapters and :mod:`providers.registry` import from here to avoid duplicating
literals and to keep ``providers.registry`` free of per-adapter eager imports.
"""
# OpenAI-compatible chat (NIM, DeepSeek) and local OpenAI-shaped endpoints
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
# Backward-compatible names used by existing adapter modules
NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE
DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE
OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE
LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE
LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE

View file

@ -88,3 +88,10 @@ class APIError(ProviderError):
error_type="api_error",
raw_error=raw_error,
)
class UnknownProviderTypeError(ValueError):
"""Raised when ``provider_id`` is not registered in the provider map."""
def __init__(self, message: str) -> None:
super().__init__(message)

View file

@ -2,8 +2,7 @@
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
LLAMACPP_DEFAULT_BASE_URL = "http://localhost:8080/v1"
from providers.defaults import LLAMACPP_DEFAULT_BASE_URL
class LlamaCppProvider(AnthropicMessagesTransport):

View file

@ -2,8 +2,7 @@
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
LMSTUDIO_DEFAULT_BASE_URL = "http://localhost:1234/v1"
from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL
class LMStudioProvider(AnthropicMessagesTransport):

View file

@ -8,6 +8,7 @@ from loguru import logger
from config.nim import NimSettings
from providers.base import ProviderConfig
from providers.defaults import NVIDIA_NIM_BASE_URL
from providers.openai_compat import OpenAIChatTransport
from .request import (
@ -16,8 +17,6 @@ from .request import (
clone_body_without_reasoning_budget,
)
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
class NvidiaNimProvider(OpenAIChatTransport):
"""NVIDIA NIM provider using official OpenAI client."""

View file

@ -11,10 +11,10 @@ from typing import Any
from core.anthropic import SSEBuilder, append_request_id
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
from providers.base import ProviderConfig
from providers.defaults import OPENROUTER_BASE_URL
from .request import build_request_body
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
_ANTHROPIC_VERSION = "2023-06-01"

View file

@ -1,5 +1,10 @@
"""Shared base class for OpenAI-compatible providers (NIM, OpenRouter, LM Studio)."""
"""OpenAI-style chat base for :class:`OpenAIChatTransport` (NIM, DeepSeek, etc.).
``AnthropicMessagesTransport``-based providers (OpenRouter, LM Studio, ) live
in separate modules; do not list them as subclasses of this class.
"""
import asyncio
import json
import uuid
from abc import abstractmethod
@ -25,7 +30,7 @@ from providers.rate_limit import GlobalRateLimiter
class OpenAIChatTransport(BaseProvider):
"""Base class for providers using OpenAI-compatible chat completions API."""
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …)."""
def __init__(
self,
@ -114,6 +119,7 @@ class OpenAIChatTransport(BaseProvider):
fn_delta = tc.get("function", {})
incoming_name = fn_delta.get("name")
arguments = fn_delta.get("arguments", "")
if incoming_name is not None:
sse.blocks.register_tool_name(tc_index, incoming_name)
@ -124,7 +130,7 @@ class OpenAIChatTransport(BaseProvider):
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
yield sse.start_tool_block(tc_index, tool_id, name)
args = fn_delta.get("arguments", "")
args = arguments
if args:
state = sse.blocks.tool_states.get(tc_index)
if state is None or not state.started:
@ -285,6 +291,8 @@ class OpenAIChatTransport(BaseProvider):
for event in self._process_tool_call(tc_info, sse):
yield event
except asyncio.CancelledError, GeneratorExit:
raise
except Exception as e:
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)

View file

@ -6,17 +6,17 @@ from collections.abc import Callable, MutableMapping
from dataclasses import dataclass
from typing import Literal
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from config.settings import Settings
from providers.base import BaseProvider, ProviderConfig
from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider
from providers.exceptions import AuthenticationError
from providers.llamacpp import LlamaCppProvider
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import (
OPENROUTER_BASE_URL,
OpenRouterProvider,
from providers.defaults import (
DEEPSEEK_DEFAULT_BASE,
LLAMACPP_DEFAULT_BASE,
LMSTUDIO_DEFAULT_BASE,
NVIDIA_NIM_DEFAULT_BASE,
OPENROUTER_DEFAULT_BASE,
)
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
TransportType = Literal["openai_chat", "anthropic_messages"]
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
@ -24,11 +24,17 @@ ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
@dataclass(frozen=True, slots=True)
class ProviderDescriptor:
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
provider_id: str
transport_type: TransportType
capabilities: tuple[str, ...]
credential_env: str | None = None
credential_url: str | None = None
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
credential_attr: str | None = None
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
static_credential: str | None = None
default_base_url: str | None = None
base_url_attr: str | None = None
proxy_attr: str | None = None
@ -40,7 +46,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
transport_type="openai_chat",
credential_env="NVIDIA_NIM_API_KEY",
credential_url="https://build.nvidia.com/settings/api-keys",
default_base_url=NVIDIA_NIM_BASE_URL,
credential_attr="nvidia_nim_api_key",
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
proxy_attr="nvidia_nim_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
@ -49,7 +56,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
transport_type="anthropic_messages",
credential_env="OPENROUTER_API_KEY",
credential_url="https://openrouter.ai/keys",
default_base_url=OPENROUTER_BASE_URL,
credential_attr="open_router_api_key",
default_base_url=OPENROUTER_DEFAULT_BASE,
proxy_attr="open_router_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
),
@ -58,13 +66,15 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
transport_type="openai_chat",
credential_env="DEEPSEEK_API_KEY",
credential_url="https://platform.deepseek.com/api_keys",
default_base_url=DEEPSEEK_BASE_URL,
credential_attr="deepseek_api_key",
default_base_url=DEEPSEEK_DEFAULT_BASE,
capabilities=("chat", "streaming", "thinking"),
),
"lmstudio": ProviderDescriptor(
provider_id="lmstudio",
transport_type="anthropic_messages",
default_base_url="http://localhost:1234/v1",
static_credential="lm-studio",
default_base_url=LMSTUDIO_DEFAULT_BASE,
base_url_attr="lm_studio_base_url",
proxy_attr="lmstudio_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
@ -72,7 +82,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
"llamacpp": ProviderDescriptor(
provider_id="llamacpp",
transport_type="anthropic_messages",
default_base_url="http://localhost:8080/v1",
static_credential="llamacpp",
default_base_url=LLAMACPP_DEFAULT_BASE,
base_url_attr="llamacpp_base_url",
proxy_attr="llamacpp_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
@ -81,22 +92,32 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.nvidia_nim import NvidiaNimProvider
return NvidiaNimProvider(config, nim_settings=settings.nim)
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.open_router import OpenRouterProvider
return OpenRouterProvider(config)
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.deepseek import DeepSeekProvider
return DeepSeekProvider(config)
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.lmstudio import LMStudioProvider
return LMStudioProvider(config)
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.llamacpp import LlamaCppProvider
return LlamaCppProvider(config)
@ -108,6 +129,15 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"llamacpp": _create_llamacpp,
}
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
PROVIDER_FACTORIES
) != set(SUPPORTED_PROVIDER_IDS):
raise AssertionError(
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
)
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
if attr_name is None:
@ -116,17 +146,11 @@ def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -
return value if isinstance(value, str) else default
def _credential_for(provider_id: str, settings: Settings) -> str:
if provider_id == "nvidia_nim":
return settings.nvidia_nim_api_key
if provider_id == "open_router":
return settings.open_router_api_key
if provider_id == "deepseek":
return settings.deepseek_api_key
if provider_id == "lmstudio":
return "lm-studio"
if provider_id == "llamacpp":
return "llamacpp"
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
if descriptor.static_credential is not None:
return descriptor.static_credential
if descriptor.credential_attr:
return _string_attr(settings, descriptor.credential_attr)
return ""
@ -144,7 +168,7 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None
def build_provider_config(
descriptor: ProviderDescriptor, settings: Settings
) -> ProviderConfig:
credential = _credential_for(descriptor.provider_id, settings)
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential)
base_url = _string_attr(
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
@ -168,7 +192,7 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
if descriptor is None:
supported = "', '".join(PROVIDER_DESCRIPTORS)
raise ValueError(
raise UnknownProviderTypeError(
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
)
@ -185,12 +209,33 @@ class ProviderRegistry:
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
self._providers = providers if providers is not None else {}
def is_cached(self, provider_id: str) -> bool:
"""Return whether a provider for this id is already in the cache."""
return provider_id in self._providers
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
if provider_id not in self._providers:
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
async def cleanup(self) -> None:
for provider in self._providers.values():
"""Call ``cleanup`` on every cached provider, then clear the cache.
Attempts all providers even if one fails. A single failure is re-raised
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
"""
items = list(self._providers.items())
errors: list[Exception] = []
try:
for _pid, provider in items:
try:
await provider.cleanup()
except Exception as e:
errors.append(e)
finally:
self._providers.clear()
if len(errors) == 1:
raise errors[0]
if len(errors) > 1:
msg = "One or more provider cleanups failed"
raise ExceptionGroup(msg, errors)

View file

@ -18,6 +18,7 @@ from typing import Any
import httpx
import pytest
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from messaging.handler import ClaudeMessageHandler
from messaging.models import IncomingMessage
from messaging.platforms.base import MessagingPlatform
@ -153,7 +154,7 @@ class ConversationDriver:
class ProviderMatrixDriver:
"""Resolve provider models and enforce matrix semantics for product smoke."""
ALL_PROVIDERS = ("nvidia_nim", "open_router", "deepseek", "lmstudio", "llamacpp")
ALL_PROVIDERS: tuple[str, ...] = SUPPORTED_PROVIDER_IDS
def __init__(self, config: SmokeConfig) -> None:
self.config = config

View file

@ -1,148 +1,29 @@
"""SSE parsing and Anthropic stream assertions for smoke tests."""
"""SSE parsing and Anthropic stream assertions for smoke tests.
Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this
module re-exports it for smoke and historical import paths.
"""
from __future__ import annotations
import json
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class SSEEvent:
event: str
data: dict[str, Any]
raw: str
def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]:
events: list[SSEEvent] = []
current_event = ""
data_parts: list[str] = []
raw_parts: list[str] = []
for line in lines:
stripped = line.rstrip("\r\n")
if stripped == "":
_append_event(events, current_event, data_parts, raw_parts)
current_event = ""
data_parts = []
raw_parts = []
continue
raw_parts.append(stripped)
if stripped.startswith("event:"):
current_event = stripped.split(":", 1)[1].strip()
elif stripped.startswith("data:"):
data_parts.append(stripped.split(":", 1)[1].strip())
_append_event(events, current_event, data_parts, raw_parts)
return events
def parse_sse_text(text: str) -> list[SSEEvent]:
return parse_sse_lines(text.splitlines())
def _append_event(
events: list[SSEEvent],
current_event: str,
data_parts: list[str],
raw_parts: list[str],
) -> None:
if not current_event and not data_parts:
return
data_text = "\n".join(data_parts)
data: dict[str, Any]
try:
parsed = json.loads(data_text) if data_text else {}
data = parsed if isinstance(parsed, dict) else {"value": parsed}
except json.JSONDecodeError:
data = {"raw": data_text}
events.append(SSEEvent(current_event, data, "\n".join(raw_parts)))
def assert_anthropic_stream_contract(
events: list[SSEEvent], *, allow_error: bool = False
) -> None:
assert events, "stream produced no SSE events"
event_names = [event.event for event in events]
assert "message_start" in event_names, event_names
assert event_names[-1] == "message_stop", event_names
open_blocks: dict[int, str] = {}
seen_blocks: set[int] = set()
for event in events:
if event.event == "error" and not allow_error:
raise AssertionError(f"unexpected SSE error event: {event.data}")
if event.event == "content_block_start":
index = _event_index(event)
block = event.data.get("content_block", {})
assert isinstance(block, dict), event.data
block_type = str(block.get("type", ""))
assert block_type in {"text", "thinking", "tool_use"}, event.data
assert index not in open_blocks, f"block {index} started twice"
assert index not in seen_blocks, f"block {index} reused after stop"
open_blocks[index] = block_type
seen_blocks.add(index)
continue
if event.event == "content_block_delta":
index = _event_index(event)
assert index in open_blocks, f"delta for unopened block {index}"
delta = event.data.get("delta", {})
assert isinstance(delta, dict), event.data
delta_type = str(delta.get("type", ""))
expected = {
"text": "text_delta",
"thinking": "thinking_delta",
"tool_use": "input_json_delta",
}[open_blocks[index]]
assert delta_type == expected, (
f"block {index} is {open_blocks[index]}, got {delta_type}"
from core.anthropic.stream_contracts import (
SSEEvent,
assert_anthropic_stream_contract,
event_names,
has_tool_use,
parse_sse_lines,
parse_sse_text,
text_content,
thinking_content,
)
continue
if event.event == "content_block_stop":
index = _event_index(event)
assert index in open_blocks, f"stop for unopened block {index}"
open_blocks.pop(index)
assert not open_blocks, f"unclosed blocks: {open_blocks}"
assert seen_blocks, "stream did not emit any content blocks"
def event_names(events: list[SSEEvent]) -> list[str]:
return [event.event for event in events]
def text_content(events: list[SSEEvent]) -> str:
parts: list[str] = []
for event in events:
delta = event.data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "text_delta":
parts.append(str(delta.get("text", "")))
return "".join(parts)
def thinking_content(events: list[SSEEvent]) -> str:
parts: list[str] = []
for event in events:
delta = event.data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "thinking_delta":
parts.append(str(delta.get("thinking", "")))
return "".join(parts)
def has_tool_use(events: list[SSEEvent]) -> bool:
for event in events:
block = event.data.get("content_block", {})
if isinstance(block, dict) and block.get("type") == "tool_use":
return True
return False
def _event_index(event: SSEEvent) -> int:
value = event.data.get("index")
assert isinstance(value, int), event.data
return value
__all__ = [
"SSEEvent",
"assert_anthropic_stream_contract",
"event_names",
"has_tool_use",
"parse_sse_lines",
"parse_sse_text",
"text_content",
"thinking_content",
]

View file

@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from api.app import app
@ -9,7 +10,7 @@ from providers.nvidia_nim import NvidiaNimProvider
mock_provider = MagicMock(spec=NvidiaNimProvider)
# Track stream_response calls for test_model_mapping
_stream_response_calls = []
_stream_response_calls: list = []
async def _mock_stream_response(*args, **kwargs):
@ -21,26 +22,30 @@ async def _mock_stream_response(*args, **kwargs):
mock_provider.stream_response = _mock_stream_response
# Patch get_provider_for_type to always return mock_provider
_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider)
_patcher.start()
client = TestClient(app)
@pytest.fixture(scope="module")
def client():
"""HTTP client with provider resolution stubbed; patch only for this file."""
with (
patch("api.dependencies.resolve_provider", return_value=mock_provider),
TestClient(app) as test_client,
):
yield test_client
def test_root():
def test_root(client: TestClient):
response = client.get("/")
assert response.status_code == 200
assert response.json()["status"] == "ok"
def test_health():
def test_health(client: TestClient):
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_models_list():
def test_models_list(client: TestClient):
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
@ -51,7 +56,7 @@ def test_models_list():
assert data["last_id"] == ids[-1]
def test_probe_endpoints_return_204_with_allow_headers():
def test_probe_endpoints_return_204_with_allow_headers(client: TestClient):
responses = [
client.head("/"),
client.options("/"),
@ -68,7 +73,7 @@ def test_probe_endpoints_return_204_with_allow_headers():
assert "Allow" in response.headers
def test_create_message_stream():
def test_create_message_stream(client: TestClient):
"""Create message returns streaming response."""
payload = {
"model": "claude-3-sonnet",
@ -83,7 +88,7 @@ def test_create_message_stream():
assert b"message_start" in content or b"event:" in content
def test_model_mapping():
def test_model_mapping(client: TestClient):
# Test Haiku mapping
_stream_response_calls.clear()
payload_haiku = {
@ -98,7 +103,7 @@ def test_model_mapping():
assert args[0].model != "claude-3-haiku-20240307"
def test_error_fallbacks():
def test_error_fallbacks(client: TestClient):
from providers.exceptions import (
AuthenticationError,
OverloadedError,
@ -143,7 +148,7 @@ def test_error_fallbacks():
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_returns_500():
def test_generic_exception_returns_500(client: TestClient):
"""Non-ProviderError exceptions are caught and returned as HTTPException(500)."""
def _raise_runtime(*args, **kwargs):
@ -163,7 +168,7 @@ def test_generic_exception_returns_500():
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_with_status_code():
def test_generic_exception_with_status_code(client: TestClient):
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
class ExceptionWithStatus(RuntimeError):
@ -188,7 +193,7 @@ def test_generic_exception_with_status_code():
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_empty_message_returns_non_empty_detail():
def test_generic_exception_empty_message_returns_non_empty_detail(client: TestClient):
"""Exceptions with empty __str__ still return a readable HTTP detail."""
class SilentError(RuntimeError):
@ -213,7 +218,7 @@ def test_generic_exception_empty_message_returns_non_empty_detail():
mock_provider.stream_response = _mock_stream_response
def test_count_tokens_endpoint():
def test_count_tokens_endpoint(client: TestClient):
"""count_tokens endpoint returns token count."""
response = client.post(
"/v1/messages/count_tokens",
@ -223,7 +228,7 @@ def test_count_tokens_endpoint():
assert "input_tokens" in response.json()
def test_stop_endpoint_no_handler_no_cli_503():
def test_stop_endpoint_no_handler_no_cli_503(client: TestClient):
"""POST /stop without handler or cli_manager returns 503."""
# Ensure no handler or cli_manager on app state
if hasattr(app.state, "message_handler"):

View file

@ -7,6 +7,23 @@ import pytest
from fastapi.testclient import TestClient
from config.settings import Settings
from providers.registry import ProviderRegistry
_RUNTIME_EXTRAS = {
"voice_note_enabled": True,
"whisper_model": "base",
"whisper_device": "cpu",
"hf_token": "",
"nvidia_nim_api_key": "",
"claude_cli_bin": "claude",
"uses_process_anthropic_auth_token": lambda: False,
}
def _app_settings(**kwargs):
"""Minimal settings namespace for AppRuntime (matches typed :class:`Settings` fields used)."""
data = {**_RUNTIME_EXTRAS, **kwargs}
return SimpleNamespace(**data)
def test_warn_if_process_auth_token_logs_warning():
@ -45,7 +62,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
raise AuthenticationError("bad key")
api_app_mod = importlib.import_module("api.app")
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token=None,
allowed_telegram_user_id=None,
@ -59,7 +76,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
)
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()),
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
):
with TestClient(app) as client:
resp = client.get("/raise_provider")
@ -79,7 +96,7 @@ def test_create_app_general_exception_handler_returns_500():
raise RuntimeError("boom")
api_app_mod = importlib.import_module("api.app")
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token=None,
allowed_telegram_user_id=None,
@ -93,7 +110,7 @@ def test_create_app_general_exception_handler_returns_500():
)
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()),
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
):
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/raise_general")
@ -111,7 +128,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
app = create_app()
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="token" if messaging_enabled else None,
allowed_telegram_user_id="123",
@ -147,10 +164,10 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
api_app_mod = importlib.import_module("api.app")
cleanup_provider = AsyncMock()
registry_cleanup = AsyncMock()
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
patch(
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform if messaging_enabled else None,
@ -182,7 +199,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
cli_manager.stop_all.assert_not_awaited()
assert getattr(app.state, "messaging_platform", "missing") is None
cleanup_provider.assert_awaited_once()
registry_cleanup.assert_awaited_once()
def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
@ -190,7 +207,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
app = create_app()
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="token",
allowed_telegram_user_id="123",
@ -218,10 +235,10 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
cli_manager.stop_all = AsyncMock()
api_app_mod = importlib.import_module("api.app")
cleanup_provider = AsyncMock()
registry_cleanup = AsyncMock()
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
patch(
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform,
@ -234,7 +251,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
fake_platform.stop.assert_awaited_once()
cli_manager.stop_all.assert_awaited_once()
cleanup_provider.assert_awaited_once()
registry_cleanup.assert_awaited_once()
def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
@ -243,7 +260,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
app = create_app()
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="token",
allowed_telegram_user_id="123",
@ -257,10 +274,10 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
)
api_app_mod = importlib.import_module("api.app")
cleanup_provider = AsyncMock()
registry_cleanup = AsyncMock()
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
patch(
"messaging.platforms.factory.create_messaging_platform",
side_effect=ImportError("discord not installed"),
@ -270,7 +287,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
pass
assert getattr(app.state, "messaging_platform", None) is None
cleanup_provider.assert_awaited_once()
registry_cleanup.assert_awaited_once()
def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
@ -279,7 +296,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
app = create_app()
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="token",
allowed_telegram_user_id="123",
@ -307,10 +324,10 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
cli_manager.stop_all = AsyncMock()
api_app_mod = importlib.import_module("api.app")
cleanup_provider = AsyncMock()
registry_cleanup = AsyncMock()
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
patch(
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform,
@ -321,7 +338,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
):
pass
cleanup_provider.assert_awaited_once()
registry_cleanup.assert_awaited_once()
def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
@ -330,7 +347,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
app = create_app()
settings = SimpleNamespace(
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="token",
allowed_telegram_user_id="123",
@ -359,10 +376,10 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
cli_manager.stop_all = AsyncMock()
api_app_mod = importlib.import_module("api.app")
cleanup_provider = AsyncMock()
registry_cleanup = AsyncMock()
with (
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
patch(
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform,
@ -374,4 +391,4 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
pass
session_store.flush_pending_save.assert_called_once()
cleanup_provider.assert_awaited_once()
registry_cleanup.assert_awaited_once()

View file

@ -1,19 +1,26 @@
from types import SimpleNamespace
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from starlette.applications import Starlette
from starlette.datastructures import State
from api.dependencies import (
cleanup_provider,
get_provider,
get_provider_for_type,
get_settings,
resolve_provider,
)
from config.nim import NimSettings
from providers.deepseek import DeepSeekProvider
from providers.exceptions import UnknownProviderTypeError
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NvidiaNimProvider
from providers.open_router import OpenRouterProvider
from providers.registry import ProviderRegistry
def _make_mock_settings(**overrides):
@ -304,11 +311,11 @@ async def test_get_provider_deepseek_missing_api_key():
@pytest.mark.asyncio
async def test_get_provider_unknown_type():
"""Test that unknown provider_type raises ValueError."""
"""Unknown ``provider_type`` raises :exc:`~providers.exceptions.UnknownProviderTypeError`."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings(provider_type="unknown")
with pytest.raises(ValueError, match="Unknown provider_type"):
with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"):
get_provider()
@ -390,3 +397,55 @@ async def test_cleanup_provider_cleans_all():
nim._client.aclose.assert_called_once()
lmstudio._client.aclose.assert_called_once()
def test_resolve_provider_per_app_uses_separate_registries() -> None:
"""With app set, each app gets its own provider cache (not process _providers)."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
settings = _make_mock_settings()
app1 = SimpleNamespace(state=State())
app2 = SimpleNamespace(state=State())
app1.state.provider_registry = ProviderRegistry()
app2.state.provider_registry = ProviderRegistry()
p1 = resolve_provider(
"nvidia_nim", app=cast(Starlette, app1), settings=settings
)
p2 = resolve_provider(
"nvidia_nim", app=cast(Starlette, app2), settings=settings
)
assert isinstance(p1, NvidiaNimProvider)
assert isinstance(p2, NvidiaNimProvider)
assert p1 is not p2
def test_resolve_provider_lazily_installs_registry() -> None:
"""If app has no provider_registry, one is created on app.state."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
settings = _make_mock_settings()
app = SimpleNamespace(state=State())
assert getattr(app.state, "provider_registry", None) is None
resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
reg = app.state.provider_registry
assert reg is not None
p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
assert p2 is reg.get("nvidia_nim", settings) # same registry instance
def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None:
"""Only :exc:`~providers.exceptions.UnknownProviderTypeError` logs unknown provider."""
import api.dependencies as deps
with (
patch.object(deps, "get_settings", return_value=_make_mock_settings()),
patch.object(
ProviderRegistry,
"get",
side_effect=ValueError("unrelated config"),
),
patch.object(deps.logger, "error") as log_err,
pytest.raises(ValueError, match="unrelated config"),
):
deps.resolve_provider("nvidia_nim", app=None, settings=_make_mock_settings())
log_err.assert_not_called()

View file

@ -91,6 +91,21 @@ def test_messages_request_accepts_adaptive_thinking_type():
assert dumped["thinking"]["type"] == "adaptive"
def test_messages_request_accepts_anthropic_server_tool_without_input_schema():
request = MessagesRequest.model_validate(
{
"model": "claude-opus-4-7",
"max_tokens": 100,
"messages": [{"role": "user", "content": "search"}],
"tools": [{"type": "web_search_20250305", "name": "web_search"}],
}
)
dumped = request.model_dump(exclude_none=True)
assert dumped["tools"] == [{"name": "web_search", "type": "web_search_20250305"}]
def test_messages_request_accepts_redacted_thinking_blocks():
request = MessagesRequest.model_validate(
{

View file

@ -0,0 +1,96 @@
import json
import pytest
from api.models.anthropic import Message, MessagesRequest, Tool
from api.web_server_tools import (
is_web_server_tool_request,
stream_web_server_tool_response,
)
def _event_data(event: str) -> dict:
data_line = next(line for line in event.splitlines() if line.startswith("data: "))
return json.loads(data_line.removeprefix("data: "))
def test_detects_web_search_server_tool_request():
request = MessagesRequest(
model="claude-haiku-4-5-20251001",
max_tokens=100,
messages=[Message(role="user", content="search")],
tools=[Tool(name="web_search", type="web_search_20250305")],
)
assert is_web_server_tool_request(request)
@pytest.mark.asyncio
async def test_streams_web_search_server_tool_result(monkeypatch):
async def fake_search(query: str) -> list[dict[str, str]]:
assert query == "DeepSeek V4 model release 2026"
return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}]
monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search)
request = MessagesRequest(
model="claude-haiku-4-5-20251001",
max_tokens=100,
messages=[
Message(
role="user",
content=(
"Perform a web search for the query: DeepSeek V4 model release 2026"
),
)
],
tools=[Tool(name="web_search", type="web_search_20250305")],
tool_choice={"type": "tool", "name": "web_search"},
)
events = [
event
async for event in stream_web_server_tool_response(request, input_tokens=42)
]
payloads = [_event_data(event) for event in events]
assert payloads[1]["content_block"]["type"] == "server_tool_use"
assert payloads[1]["content_block"]["name"] == "web_search"
assert payloads[3]["content_block"]["type"] == "web_search_tool_result"
assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4"
assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1}
@pytest.mark.asyncio
async def test_streams_web_fetch_server_tool_result(monkeypatch):
async def fake_fetch(url: str) -> dict[str, str]:
assert url == "https://example.com/article"
return {
"url": url,
"title": "Example Article",
"media_type": "text/plain",
"data": "Article body",
}
monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch)
request = MessagesRequest(
model="claude-haiku-4-5-20251001",
max_tokens=100,
messages=[
Message(role="user", content="Fetch https://example.com/article please")
],
tools=[Tool(name="web_fetch", type="web_fetch_20250910")],
tool_choice={"type": "tool", "name": "web_fetch"},
)
events = [
event
async for event in stream_web_server_tool_response(request, input_tokens=42)
]
payloads = [_event_data(event) for event in events]
assert payloads[1]["content_block"]["type"] == "server_tool_use"
assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result"
assert payloads[3]["content_block"]["content"]["content"]["title"] == (
"Example Article"
)
assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1}

View file

@ -118,6 +118,15 @@ def mock_platform():
platform.queue_edit_message = AsyncMock()
platform.queue_delete_message = AsyncMock()
async def _queue_delete_messages(
chat_id: str, message_ids: list[str], *, fire_and_forget: bool = True
) -> None:
qdm = platform.queue_delete_message
for mid in message_ids:
await qdm(chat_id, mid, fire_and_forget=fire_and_forget)
platform.queue_delete_messages = AsyncMock(side_effect=_queue_delete_messages)
def _fire_and_forget(task):
if asyncio.iscoroutine(task):
# Create a task to avoid "coroutine was never awaited" warning

View file

@ -1,8 +1,20 @@
"""Package import contract tests (static AST; dynamic ``importlib`` loads are not scanned)."""
from __future__ import annotations
import ast
from pathlib import Path
# `api` may only import this narrow ``providers`` surface (AGENTS/PLAN).
_API_ALLOWED_PROVIDER_MODULES = frozenset(
{
"providers",
"providers.base",
"providers.exceptions",
"providers.registry",
}
)
def test_api_and_messaging_do_not_import_provider_common() -> None:
repo_root = Path(__file__).resolve().parents[2]
@ -25,6 +37,66 @@ def test_provider_adapters_do_not_import_runtime_layers() -> None:
assert offenders == []
def test_core_does_not_import_product_packages() -> None:
"""Neutral ``core`` must stay independent of API, workers, and providers."""
repo_root = Path(__file__).resolve().parents[2]
offenders = _imports_matching(
[repo_root / "core"],
forbidden_prefixes=(
"api.",
"messaging.",
"cli.",
"smoke.",
"providers.",
"config.",
),
)
assert offenders == []
def test_config_does_not_import_non_config_packages() -> None:
"""Settings and env handling must not depend on transport or protocol layers."""
repo_root = Path(__file__).resolve().parents[2]
offenders = _imports_matching(
[repo_root / "config"],
forbidden_prefixes=(
"api.",
"messaging.",
"cli.",
"smoke.",
"providers.",
"core.",
),
)
assert offenders == []
def test_messaging_does_not_import_api_or_cli_or_providers() -> None:
"""Messaging is wired by ``api.runtime``; must not import server or provider adapters."""
repo_root = Path(__file__).resolve().parents[2]
offenders = _imports_matching(
[repo_root / "messaging"],
forbidden_prefixes=("api.", "cli.", "providers.", "smoke."),
)
assert offenders == []
def test_api_may_only_import_narrow_provider_facade() -> None:
"""HTTP layer must not depend on per-adapter provider subpackages."""
repo_root = Path(__file__).resolve().parents[2]
offenders: list[str] = []
for path in (repo_root / "api").rglob("*.py"):
for imported in _imports_from(path, repo_root):
if imported is None or not imported.startswith("providers"):
continue
if imported in _API_ALLOWED_PROVIDER_MODULES:
continue
if imported.startswith("providers."):
rel = path.relative_to(repo_root)
offenders.append(f"{rel}: {imported}")
assert sorted(offenders) == []
def test_removed_openrouter_rollback_transport_stays_removed() -> None:
repo_root = Path(__file__).resolve().parents[2]
@ -35,6 +107,11 @@ def test_removed_openrouter_rollback_transport_stays_removed() -> None:
def test_architecture_doc_names_enforced_boundaries() -> None:
repo_root = Path(__file__).resolve().parents[2]
contract_test = repo_root / "tests" / "contracts" / "test_import_boundaries.py"
assert contract_test.is_file()
stream_contracts = repo_root / "core" / "anthropic" / "stream_contracts.py"
assert stream_contracts.is_file()
text = (repo_root / "PLAN.md").read_text(encoding="utf-8")
assert "core/anthropic/" in text
@ -46,26 +123,89 @@ def _imports_matching(
roots: list[Path], *, forbidden_prefixes: tuple[str, ...]
) -> list[str]:
offenders: list[str] = []
repo_root = roots[0].parent
for root in roots:
for path in root.rglob("*.py"):
rel = path.relative_to(root.parent)
offenders.extend(
f"{rel}: {imported}"
for imported in _imports_from(path)
if imported in forbidden_prefixes
or imported.startswith(forbidden_prefixes)
for imported in _imports_from(path, repo_root)
if imported is not None and _is_forbidden(imported, forbidden_prefixes)
)
return sorted(offenders)
def _imports_from(path: Path) -> list[str]:
def _is_forbidden(name: str, forbidden: tuple[str, ...]) -> bool:
"""Match root modules (``import api``) and submodules (``import api.x``)."""
for token in forbidden:
if not token:
continue
root = token.rstrip(".")
if name == root or name.startswith(f"{root}."):
return True
return False
def _module_fqn_from_path(repo_root: Path, path: Path) -> str:
rel = path.relative_to(repo_root)
if rel.name == "__init__.py":
return ".".join(rel.parent.parts) if rel.parent != Path() else rel.parent.name
return ".".join(rel.with_suffix("").parts)
def _importing_package_parts(repo_root: Path, path: Path) -> list[str]:
"""Package in which this file's module lives (for relative imports)."""
rel = path.relative_to(repo_root)
if rel.name == "__init__.py":
return list(rel.parent.parts)
fqn = _module_fqn_from_path(repo_root, path)
parts = fqn.split(".")
if len(parts) <= 1:
return []
return parts[:-1]
def _resolve_relative_import(
repo_root: Path, path: Path, node: ast.ImportFrom
) -> str | None:
"""Best-effort absolute name for ``from .x`` / ``from ..y`` (level >= 1)."""
if node.level == 0 and node.module:
return node.module
base = _importing_package_parts(repo_root, path)
for _ in range(node.level - 1):
if not base:
return None
base.pop()
if not node.module:
return ".".join(base) if base else None
return ".".join(base + node.module.split("."))
def _imports_from(path: Path, repo_root: Path) -> list[str]:
tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
imports: list[str] = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
imports.extend(alias.name for alias in node.names)
elif isinstance(node, ast.ImportFrom) and node.module:
elif isinstance(node, ast.ImportFrom):
if node.level == 0:
if node.module:
imports.append(node.module)
continue
if node.module is not None:
resolved = _resolve_relative_import(repo_root, path, node)
if resolved:
imports.append(resolved)
else:
base = _importing_package_parts(repo_root, path).copy()
for _ in range(node.level - 1):
if base:
base.pop()
for alias in node.names:
if base:
imports.append(".".join([*base, alias.name]))
else:
imports.append(alias.name)
return imports

View file

@ -0,0 +1,11 @@
"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`."""
from __future__ import annotations
import core.anthropic.stream_contracts as core_sc
import smoke.lib.sse as smoke_sse
def test_smoke_lib_sse_reexports_core_stream_contracts() -> None:
for name in smoke_sse.__all__:
assert getattr(smoke_sse, name) is getattr(core_sc, name)

View file

@ -1,11 +1,14 @@
"""Stream/SSE contract tests. Strict transcript *ordering* is covered here for
``SSEBuilder`` output; for transport-integrated ordering, add messaging or API
integration tests.
"""
from __future__ import annotations
from collections.abc import Iterable
from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser
from messaging.event_parser import parse_cli_event
from messaging.transcript import RenderCtx, TranscriptBuffer
from smoke.lib.sse import (
from core.anthropic.stream_contracts import (
assert_anthropic_stream_contract,
event_names,
has_tool_use,
@ -13,6 +16,8 @@ from smoke.lib.sse import (
text_content,
thinking_content,
)
from messaging.event_parser import parse_cli_event
from messaging.transcript import RenderCtx, TranscriptBuffer
def test_interleaved_thinking_text_blocks_are_valid() -> None:

View file

@ -2,7 +2,10 @@
from unittest.mock import MagicMock, patch
from messaging.platforms.factory import create_messaging_platform
from messaging.platforms.factory import (
MessagingPlatformOptions,
create_messaging_platform,
)
class TestCreateMessagingPlatform:
@ -16,15 +19,29 @@ class TestCreateMessagingPlatform:
patch(
"messaging.platforms.telegram.TelegramPlatform",
return_value=mock_platform,
),
) as platform_cls,
):
result = create_messaging_platform(
"telegram",
bot_token="test_token",
allowed_user_id="12345",
MessagingPlatformOptions(
telegram_bot_token="test_token",
allowed_telegram_user_id="12345",
voice_note_enabled=False,
whisper_model="large-v3",
whisper_device="cuda",
),
)
assert result is mock_platform
platform_cls.assert_called_once_with(
bot_token="test_token",
allowed_user_id="12345",
voice_note_enabled=False,
whisper_model="large-v3",
whisper_device="cuda",
hf_token="",
nvidia_nim_api_key="",
)
def test_telegram_without_token(self):
"""Return None when no bot_token for Telegram."""
@ -33,7 +50,9 @@ class TestCreateMessagingPlatform:
def test_telegram_empty_token(self):
"""Return None when bot_token is empty string."""
result = create_messaging_platform("telegram", bot_token="")
result = create_messaging_platform(
"telegram", MessagingPlatformOptions(telegram_bot_token="")
)
assert result is None
def test_discord_with_token(self):
@ -44,15 +63,29 @@ class TestCreateMessagingPlatform:
patch(
"messaging.platforms.discord.DiscordPlatform",
return_value=mock_platform,
),
) as platform_cls,
):
result = create_messaging_platform(
"discord",
MessagingPlatformOptions(
discord_bot_token="test_token",
allowed_discord_channels="123,456",
voice_note_enabled=False,
whisper_model="small",
whisper_device="nvidia_nim",
),
)
assert result is mock_platform
platform_cls.assert_called_once_with(
bot_token="test_token",
allowed_channel_ids="123,456",
voice_note_enabled=False,
whisper_model="small",
whisper_device="nvidia_nim",
hf_token="",
nvidia_nim_api_key="",
)
def test_discord_without_token(self):
"""Return None when no discord_bot_token for Discord."""
@ -62,7 +95,11 @@ class TestCreateMessagingPlatform:
def test_discord_empty_token(self):
"""Return None when discord_bot_token is empty string."""
result = create_messaging_platform(
"discord", discord_bot_token="", allowed_discord_channels="123"
"discord",
MessagingPlatformOptions(
discord_bot_token="",
allowed_discord_channels="123",
),
)
assert result is None
@ -73,5 +110,7 @@ class TestCreateMessagingPlatform:
def test_unknown_platform_with_kwargs(self):
"""Return None for unknown platform even with kwargs."""
result = create_messaging_platform("slack", bot_token="token")
result = create_messaging_platform(
"slack", MessagingPlatformOptions(telegram_bot_token="token")
)
assert result is None

View file

@ -17,18 +17,20 @@ def telegram_platform():
@pytest.mark.asyncio
async def test_telegram_voice_disabled_sends_reply(telegram_platform):
async def test_telegram_voice_disabled_sends_reply():
"""When voice_note_enabled is False, reply with disabled message."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
telegram_platform = TelegramPlatform(
bot_token="test_token",
allowed_user_id="12345",
voice_note_enabled=False,
)
mock_update = MagicMock()
mock_update.message.voice = MagicMock(file_id="f1", mime_type="audio/ogg")
mock_update.effective_user.id = 12345
mock_update.effective_chat.id = 6789
mock_update.message.reply_text = AsyncMock()
with patch(
"config.settings.get_settings",
return_value=MagicMock(voice_note_enabled=False),
):
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
mock_update.message.reply_text.assert_called_once_with("Voice notes are disabled.")
@ -42,10 +44,6 @@ async def test_telegram_voice_unauthorized_ignored(telegram_platform):
mock_update.effective_user.id = 99999 # Not 12345
mock_update.message.reply_text = AsyncMock()
with patch(
"config.settings.get_settings",
return_value=MagicMock(voice_note_enabled=True),
):
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
mock_update.message.reply_text.assert_not_called()
@ -82,17 +80,8 @@ async def test_telegram_voice_success_invokes_handler(telegram_platform):
mock_file.download_to_drive = fake_download
mock_settings = MagicMock(
voice_note_enabled=True,
whisper_model="base",
)
mock_queue_send = AsyncMock(return_value="999")
with (
patch(
"config.settings.get_settings",
return_value=mock_settings,
),
patch(
"messaging.transcription.transcribe_audio",
return_value="Hello from voice",
@ -164,7 +153,11 @@ class TestDiscordGetAudioAttachment:
@pytest.mark.asyncio
async def test_discord_voice_disabled_sends_reply():
"""When voice_note_enabled is False, reply with disabled message."""
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
platform = DiscordPlatform(
bot_token="token",
allowed_channel_ids="123",
voice_note_enabled=False,
)
platform._message_handler = None
mock_message = MagicMock()
@ -178,10 +171,6 @@ async def test_discord_voice_disabled_sends_reply():
mock_att.filename = "voice.ogg"
mock_message.attachments = [mock_att]
with patch(
"config.settings.get_settings",
return_value=MagicMock(voice_note_enabled=False),
):
await platform._on_discord_message(mock_message)
mock_message.reply.assert_called_once_with("Voice notes are disabled.")

View file

@ -24,7 +24,7 @@ class MockBlock:
class MockTool:
def __init__(self, name, description, input_schema):
def __init__(self, name, description, input_schema=None):
self.name = name
self.description = description
self.input_schema = input_schema
@ -79,6 +79,23 @@ def test_convert_tools():
assert result[1]["function"]["description"] == "" # Check default empty string
def test_convert_tool_without_input_schema_uses_empty_object_schema():
tools = [MockTool("web_search", None)]
result = AnthropicToOpenAIConverter.convert_tools(tools)
assert result == [
{
"type": "function",
"function": {
"name": "web_search",
"description": "",
"parameters": {"type": "object", "properties": {}},
},
}
]
@pytest.mark.parametrize(
"tool_choice,expected",
[

View file

@ -449,6 +449,40 @@ def test_heuristic_tool_parser_flush_no_tool():
assert tools == []
def test_heuristic_tool_parser_json_style_web_fetch_tool_call():
parser = HeuristicToolParser()
text = (
"Use WebFetch on the article.\n\n"
"{\n"
' "url": "https://example.com/article",\n'
' "prompt": "Summarize it."\n'
"}\n"
)
filtered, tools = parser.feed(text)
tools.extend(parser.flush())
assert filtered == ""
assert len(tools) == 1
assert tools[0]["name"] == "WebFetch"
assert tools[0]["input"] == {
"url": "https://example.com/article",
"prompt": "Summarize it.",
}
def test_heuristic_tool_parser_json_style_web_search_tool_call():
parser = HeuristicToolParser()
filtered, tools = parser.feed('Use WebSearch {"query": "DeepSeek V4"}')
tools.extend(parser.flush())
assert filtered == ""
assert len(tools) == 1
assert tools[0]["name"] == "WebSearch"
assert tools[0]["input"] == {"query": "DeepSeek V4"}
def test_heuristic_tool_parser_unicode_function_name():
"""Unicode characters in function parameters."""
parser = HeuristicToolParser()

View file

@ -1,9 +1,13 @@
from unittest.mock import MagicMock, patch
import subprocess
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from config.nim import NimSettings
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from providers.deepseek import DeepSeekProvider
from providers.exceptions import UnknownProviderTypeError
from providers.llamacpp import LlamaCppProvider
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NvidiaNimProvider
@ -41,14 +45,24 @@ def _make_settings(**overrides):
return mock
def test_importing_registry_does_not_eager_load_other_adapters() -> None:
"""Registry metadata must not import every provider adapter up front."""
code = (
"import sys\n"
"import providers.registry\n"
"assert 'providers.open_router' not in sys.modules\n"
)
proc = subprocess.run(
[sys.executable, "-c", code],
check=False,
capture_output=True,
text=True,
)
assert proc.returncode == 0, proc.stderr or proc.stdout
def test_descriptors_cover_advertised_provider_ids():
assert set(PROVIDER_DESCRIPTORS) == {
"nvidia_nim",
"open_router",
"deepseek",
"lmstudio",
"llamacpp",
}
assert set(PROVIDER_DESCRIPTORS) == set(SUPPORTED_PROVIDER_IDS)
for descriptor in PROVIDER_DESCRIPTORS.values():
assert descriptor.provider_id
assert descriptor.transport_type in {"openai_chat", "anthropic_messages"}
@ -90,6 +104,38 @@ def test_provider_registry_caches_by_provider_id():
assert first is second
def test_unknown_provider_raises_value_error():
with pytest.raises(ValueError, match="Unknown provider_type"):
def test_unknown_provider_raises_unknown_provider_type_error():
with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"):
create_provider("unknown", _make_settings())
@pytest.mark.asyncio
async def test_provider_registry_cleanup_runs_all_even_if_one_fails() -> None:
"""Every provider gets cleanup; cache is cleared even when one raises."""
reg = ProviderRegistry()
p1 = MagicMock()
p1.cleanup = AsyncMock(side_effect=RuntimeError("first"))
p2 = MagicMock()
p2.cleanup = AsyncMock()
reg._providers["a"] = p1
reg._providers["b"] = p2
with pytest.raises(RuntimeError, match="first"):
await reg.cleanup()
p1.cleanup.assert_awaited_once()
p2.cleanup.assert_awaited_once()
assert reg._providers == {}
@pytest.mark.asyncio
async def test_provider_registry_cleanup_exceptiongroup_on_multiple_failures() -> None:
reg = ProviderRegistry()
p1 = MagicMock()
p1.cleanup = AsyncMock(side_effect=RuntimeError("a"))
p2 = MagicMock()
p2.cleanup = AsyncMock(side_effect=RuntimeError("b"))
reg._providers["x"] = p1
reg._providers["y"] = p2
with pytest.raises(ExceptionGroup) as exc_info:
await reg.cleanup()
assert len(exc_info.value.exceptions) == 2
assert reg._providers == {}