mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
feat: Anthropic web server tools, provider metadata, messaging hardening
- Add local web_search/web_fetch SSE handling and optional tool schemas - Extend HeuristicToolParser for JSON-style WebFetch/WebSearch text - Consolidate provider defaults, ids, and exception typing; stream contracts - Messaging: typed options, voice config injection, platform contract cleanup - Tests for web server tools, converters, parsers, contracts; ignore debug-*.log
This commit is contained in:
parent
4b89183ba0
commit
b926f60f64
50 changed files with 1658 additions and 439 deletions
|
|
@ -66,6 +66,7 @@ VOICE_NOTE_ENABLED=false
|
||||||
# WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim"
|
# WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim"
|
||||||
# - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local)
|
# - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local)
|
||||||
# - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice)
|
# - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice)
|
||||||
|
# (Independent of MODEL=nvidia_nim/...: that selects the *chat* provider; this selects voice STT only.)
|
||||||
WHISPER_DEVICE="nvidia_nim"
|
WHISPER_DEVICE="nvidia_nim"
|
||||||
# WHISPER_MODEL:
|
# WHISPER_MODEL:
|
||||||
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
|
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
|
||||||
|
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,6 +8,7 @@ __pycache__
|
||||||
agent_workspace
|
agent_workspace
|
||||||
.env
|
.env
|
||||||
server.log
|
server.log
|
||||||
|
debug-*.log
|
||||||
.coverage
|
.coverage
|
||||||
llama_cache
|
llama_cache
|
||||||
.smoke-results
|
.smoke-results
|
||||||
|
|
|
||||||
27
PLAN.md
27
PLAN.md
|
|
@ -33,20 +33,43 @@ flowchart TD
|
||||||
core --> providers
|
core --> providers
|
||||||
core --> messaging
|
core --> messaging
|
||||||
providers --> api
|
providers --> api
|
||||||
|
api --> cli[cli]
|
||||||
|
api --> messaging
|
||||||
cli --> messaging
|
cli --> messaging
|
||||||
messaging --> api
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Runtime note: `api.runtime` imports `cli` and `messaging` to wire the optional
|
||||||
|
messaging stack; `messaging` does not import `cli` (session/CLI access is passed
|
||||||
|
in from `api.runtime`).
|
||||||
|
|
||||||
The practical rule is simpler than the graph: shared protocol helpers belong in
|
The practical rule is simpler than the graph: shared protocol helpers belong in
|
||||||
neutral core modules, not under a provider package. Provider adapters may depend
|
neutral core modules, not under a provider package. Provider adapters may depend
|
||||||
on the neutral protocol layer, but API and messaging code should not import
|
on the neutral protocol layer, but API and messaging code should not import
|
||||||
provider internals.
|
provider internals.
|
||||||
|
|
||||||
|
The diagram above mixes **Python import direction** (e.g. `config` → `providers`)
|
||||||
|
with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging`).
|
||||||
|
`PLAN.md` remains the product map; **encoded** rules (including root imports like
|
||||||
|
`import api`, relative imports, and `api` → `providers` facade allowlists) live in
|
||||||
|
`tests/contracts/test_import_boundaries.py`.
|
||||||
|
|
||||||
|
**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`,
|
||||||
|
and `providers.registry` from the providers package (not per-adapter modules).
|
||||||
|
`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`.
|
||||||
|
`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract
|
||||||
|
assertions for default CI live under `core/anthropic/stream_contracts.py`;
|
||||||
|
`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers
|
||||||
|
(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and
|
||||||
|
unit tests; production HTTP handlers must use `resolve_provider` with
|
||||||
|
`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package
|
||||||
|
`__all__` exposes HTTP models and `create_app` only (not those helpers).
|
||||||
|
|
||||||
## Target Boundaries
|
## Target Boundaries
|
||||||
|
|
||||||
- `core/anthropic/`: Anthropic protocol helpers, stream primitives, content
|
- `core/anthropic/`: Anthropic protocol helpers, stream primitives, content
|
||||||
extraction, token estimation, user-facing error strings, request conversion,
|
extraction, token estimation, user-facing error strings, request conversion,
|
||||||
thinking, and tool helpers shared across API, providers, messaging, and tests.
|
thinking, tool helpers, and stream contract assertions
|
||||||
|
(`stream_contracts.py`) shared across API, providers, messaging, and tests.
|
||||||
- `api/runtime.py`: application composition, optional messaging startup,
|
- `api/runtime.py`: application composition, optional messaging startup,
|
||||||
session store restoration, and cleanup ownership.
|
session store restoration, and cleanup ownership.
|
||||||
- `providers/`: provider descriptors, credential resolution, transport
|
- `providers/`: provider descriptors, credential resolution, transport
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""API layer for Claude Code Proxy."""
|
"""API layer for Claude Code Proxy."""
|
||||||
|
|
||||||
from .app import app, create_app
|
from .app import app, create_app
|
||||||
from .dependencies import get_provider, get_provider_for_type
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
MessagesResponse,
|
MessagesResponse,
|
||||||
|
|
@ -16,6 +15,4 @@ __all__ = [
|
||||||
"TokenCountResponse",
|
"TokenCountResponse",
|
||||||
"app",
|
"app",
|
||||||
"create_app",
|
"create_app",
|
||||||
"get_provider",
|
|
||||||
"get_provider_for_type",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
62
api/app.py
62
api/app.py
|
|
@ -2,8 +2,11 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.exception_handlers import request_validation_exception_handler
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
@ -11,7 +14,6 @@ from config.logging_config import configure_logging
|
||||||
from config.settings import get_settings
|
from config.settings import get_settings
|
||||||
from providers.exceptions import ProviderError
|
from providers.exceptions import ProviderError
|
||||||
|
|
||||||
from .dependencies import cleanup_provider
|
|
||||||
from .routes import router
|
from .routes import router
|
||||||
from .runtime import AppRuntime
|
from .runtime import AppRuntime
|
||||||
|
|
||||||
|
|
@ -26,9 +28,7 @@ configure_logging(_settings.log_file)
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Application lifespan manager."""
|
"""Application lifespan manager."""
|
||||||
runtime = AppRuntime.for_app(
|
runtime = AppRuntime.for_app(app, settings=get_settings())
|
||||||
app, settings=get_settings(), provider_cleanup=cleanup_provider
|
|
||||||
)
|
|
||||||
await runtime.startup()
|
await runtime.startup()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
@ -48,6 +48,60 @@ def create_app() -> FastAPI:
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
# Exception handlers
|
# Exception handlers
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_error_handler(request: Request, exc: RequestValidationError):
|
||||||
|
"""Log request shape for 422 debugging without content values."""
|
||||||
|
body: Any
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception as e:
|
||||||
|
body = {"_json_error": type(e).__name__}
|
||||||
|
|
||||||
|
messages = body.get("messages") if isinstance(body, dict) else None
|
||||||
|
message_summary: list[dict[str, Any]] = []
|
||||||
|
if isinstance(messages, list):
|
||||||
|
for msg in messages:
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
message_summary.append({"message_kind": type(msg).__name__})
|
||||||
|
continue
|
||||||
|
content = msg.get("content")
|
||||||
|
item: dict[str, Any] = {
|
||||||
|
"role": msg.get("role"),
|
||||||
|
"content_kind": type(content).__name__,
|
||||||
|
}
|
||||||
|
if isinstance(content, list):
|
||||||
|
item["block_types"] = [
|
||||||
|
block.get("type", "dict")
|
||||||
|
if isinstance(block, dict)
|
||||||
|
else type(block).__name__
|
||||||
|
for block in content[:12]
|
||||||
|
]
|
||||||
|
item["block_keys"] = [
|
||||||
|
sorted(str(key) for key in block)[:12]
|
||||||
|
for block in content[:5]
|
||||||
|
if isinstance(block, dict)
|
||||||
|
]
|
||||||
|
elif isinstance(content, str):
|
||||||
|
item["content_length"] = len(content)
|
||||||
|
message_summary.append(item)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
|
||||||
|
request.url.path,
|
||||||
|
str(request.url.query),
|
||||||
|
[list(error.get("loc", ())) for error in exc.errors()],
|
||||||
|
[str(error.get("type", "")) for error in exc.errors()],
|
||||||
|
message_summary,
|
||||||
|
[
|
||||||
|
str(tool.get("name", ""))
|
||||||
|
for tool in body.get("tools", [])
|
||||||
|
if isinstance(body, dict)
|
||||||
|
and isinstance(body.get("tools"), list)
|
||||||
|
and isinstance(tool, dict)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return await request_validation_exception_handler(request, exc)
|
||||||
|
|
||||||
@app.exception_handler(ProviderError)
|
@app.exception_handler(ProviderError)
|
||||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||||
"""Handle provider-specific errors and return Anthropic format."""
|
"""Handle provider-specific errors and return Anthropic format."""
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,18 @@
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Request
|
from fastapi import Depends, HTTPException, Request
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
|
||||||
from config.settings import Settings
|
from config.settings import Settings
|
||||||
from config.settings import get_settings as _get_settings
|
from config.settings import get_settings as _get_settings
|
||||||
from core.anthropic import get_user_facing_error_message
|
from core.anthropic import get_user_facing_error_message
|
||||||
from providers.base import BaseProvider
|
from providers.base import BaseProvider
|
||||||
from providers.exceptions import AuthenticationError
|
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||||
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
||||||
|
|
||||||
# Provider registry: keyed by provider type string, lazily populated
|
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
|
||||||
|
# when there is no ``Request``/``app`` (unit tests, scripts). HTTP handlers must pass
|
||||||
|
# ``app`` to :func:`resolve_provider` so the app-scoped registry is used.
|
||||||
_providers: dict[str, BaseProvider] = {}
|
_providers: dict[str, BaseProvider] = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,19 +22,43 @@ def get_settings() -> Settings:
|
||||||
return _get_settings()
|
return _get_settings()
|
||||||
|
|
||||||
|
|
||||||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
def resolve_provider(
|
||||||
"""Get or create a provider for the given provider type.
|
provider_type: str,
|
||||||
|
*,
|
||||||
|
app: Starlette | None,
|
||||||
|
settings: Settings,
|
||||||
|
) -> BaseProvider:
|
||||||
|
"""Resolve a provider using the app-scoped registry when ``app`` is set.
|
||||||
|
|
||||||
Providers are cached in the registry and reused across requests.
|
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
|
||||||
|
is always used. If the registry is missing (e.g. a test app without
|
||||||
|
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
|
||||||
|
is installed on ``app.state`` so the process cache is never mixed with
|
||||||
|
per-request app identity.
|
||||||
|
|
||||||
|
When ``app`` is ``None`` (no HTTP context), uses the process-level
|
||||||
|
:data:`_providers` cache only.
|
||||||
"""
|
"""
|
||||||
should_log_init = provider_type not in _providers
|
if app is not None:
|
||||||
|
reg = getattr(app.state, "provider_registry", None)
|
||||||
|
if reg is None:
|
||||||
|
reg = ProviderRegistry()
|
||||||
|
app.state.provider_registry = reg
|
||||||
|
return _resolve_with_registry(reg, provider_type, settings)
|
||||||
|
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_with_registry(
|
||||||
|
registry: ProviderRegistry, provider_type: str, settings: Settings
|
||||||
|
) -> BaseProvider:
|
||||||
|
should_log_init = not registry.is_cached(provider_type)
|
||||||
try:
|
try:
|
||||||
provider = ProviderRegistry(_providers).get(provider_type, get_settings())
|
provider = registry.get(provider_type, settings)
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=503, detail=get_user_facing_error_message(e)
|
status_code=503, detail=get_user_facing_error_message(e)
|
||||||
) from e
|
) from e
|
||||||
except ValueError:
|
except UnknownProviderTypeError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Unknown provider_type: '{}'. Supported: {}",
|
"Unknown provider_type: '{}'. Supported: {}",
|
||||||
provider_type,
|
provider_type,
|
||||||
|
|
@ -43,6 +70,15 @@ def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||||
|
"""Get or create a provider in the process-level cache (no ``app``/Request).
|
||||||
|
|
||||||
|
For server requests, use :func:`resolve_provider` with the active
|
||||||
|
:attr:`request.app` so the app-scoped provider registry is used.
|
||||||
|
"""
|
||||||
|
return resolve_provider(provider_type, app=None, settings=get_settings())
|
||||||
|
|
||||||
|
|
||||||
def require_api_key(
|
def require_api_key(
|
||||||
request: Request, settings: Settings = Depends(get_settings)
|
request: Request, settings: Settings = Depends(get_settings)
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -78,9 +114,11 @@ def require_api_key(
|
||||||
|
|
||||||
|
|
||||||
def get_provider() -> BaseProvider:
|
def get_provider() -> BaseProvider:
|
||||||
"""Get or create the default provider (based on MODEL env var).
|
"""Get or create the default provider (``MODEL`` / ``provider_type``).
|
||||||
|
|
||||||
Backward-compatible convenience for health/root endpoints and tests.
|
Process-cache helper for scripts, unit tests, and non-FastAPI callers. HTTP
|
||||||
|
handlers must use :func:`resolve_provider` with :attr:`request.app` so the
|
||||||
|
app-scoped :class:`~providers.registry.ProviderRegistry` is used.
|
||||||
"""
|
"""
|
||||||
return get_provider_for_type(get_settings().provider_type)
|
return get_provider_for_type(get_settings().provider_type)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,8 +75,11 @@ class Message(BaseModel):
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
|
||||||
|
# may omit ``input_schema`` because the provider owns the schema.
|
||||||
|
type: str | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
input_schema: dict[str, Any]
|
input_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ThinkingConfig(BaseModel):
|
class ThinkingConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ from loguru import logger
|
||||||
from config.settings import Settings
|
from config.settings import Settings
|
||||||
from core.anthropic import get_token_count
|
from core.anthropic import get_token_count
|
||||||
|
|
||||||
from .dependencies import get_provider_for_type, get_settings, require_api_key
|
from . import dependencies
|
||||||
|
from .dependencies import get_settings, require_api_key
|
||||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||||
from .models.responses import ModelResponse, ModelsListResponse
|
from .models.responses import ModelResponse, ModelsListResponse
|
||||||
from .services import ClaudeProxyService
|
from .services import ClaudeProxyService
|
||||||
|
|
@ -54,12 +55,15 @@ SUPPORTED_CLAUDE_MODELS = [
|
||||||
|
|
||||||
|
|
||||||
def get_proxy_service(
|
def get_proxy_service(
|
||||||
|
request: Request,
|
||||||
settings: Settings = Depends(get_settings),
|
settings: Settings = Depends(get_settings),
|
||||||
) -> ClaudeProxyService:
|
) -> ClaudeProxyService:
|
||||||
"""Build the request service for route handlers."""
|
"""Build the request service for route handlers."""
|
||||||
return ClaudeProxyService(
|
return ClaudeProxyService(
|
||||||
settings,
|
settings,
|
||||||
provider_getter=get_provider_for_type,
|
provider_getter=lambda provider_type: dependencies.resolve_provider(
|
||||||
|
provider_type, app=request.app, settings=settings
|
||||||
|
),
|
||||||
token_counter=get_token_count,
|
token_counter=get_token_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,20 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections.abc import Awaitable, Callable
|
from dataclasses import dataclass, field
|
||||||
from dataclasses import dataclass
|
from typing import TYPE_CHECKING, Any
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from config.settings import Settings, get_settings
|
from config.settings import Settings, get_settings
|
||||||
|
from providers.registry import ProviderRegistry
|
||||||
|
|
||||||
from .dependencies import cleanup_provider
|
if TYPE_CHECKING:
|
||||||
|
from cli.manager import CLISessionManager
|
||||||
|
from messaging.handler import ClaudeMessageHandler
|
||||||
|
from messaging.platforms.base import MessagingPlatform
|
||||||
|
from messaging.session import SessionStore
|
||||||
|
|
||||||
_SHUTDOWN_TIMEOUT_S = 5.0
|
_SHUTDOWN_TIMEOUT_S = 5.0
|
||||||
|
|
||||||
|
|
@ -32,8 +36,7 @@ async def best_effort(
|
||||||
|
|
||||||
def warn_if_process_auth_token(settings: Settings) -> None:
|
def warn_if_process_auth_token(settings: Settings) -> None:
|
||||||
"""Warn when server auth was implicitly inherited from the shell."""
|
"""Warn when server auth was implicitly inherited from the shell."""
|
||||||
uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None)
|
if settings.uses_process_anthropic_auth_token():
|
||||||
if callable(uses_process_token) and uses_process_token():
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
|
"ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
|
||||||
"a configured .env file. The proxy will require that token. Add "
|
"a configured .env file. The proxy will require that token. Add "
|
||||||
|
|
@ -48,32 +51,29 @@ class AppRuntime:
|
||||||
|
|
||||||
app: FastAPI
|
app: FastAPI
|
||||||
settings: Settings
|
settings: Settings
|
||||||
provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider
|
_provider_registry: ProviderRegistry | None = field(default=None, init=False)
|
||||||
messaging_platform: Any = None
|
messaging_platform: MessagingPlatform | None = None
|
||||||
message_handler: Any = None
|
message_handler: ClaudeMessageHandler | None = None
|
||||||
cli_manager: Any = None
|
cli_manager: CLISessionManager | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_app(
|
def for_app(
|
||||||
cls,
|
cls,
|
||||||
app: FastAPI,
|
app: FastAPI,
|
||||||
settings: Settings | None = None,
|
settings: Settings | None = None,
|
||||||
provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider,
|
|
||||||
) -> AppRuntime:
|
) -> AppRuntime:
|
||||||
return cls(
|
return cls(app=app, settings=settings or get_settings())
|
||||||
app=app,
|
|
||||||
settings=settings or get_settings(),
|
|
||||||
provider_cleanup=provider_cleanup,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def startup(self) -> None:
|
async def startup(self) -> None:
|
||||||
logger.info("Starting Claude Code Proxy...")
|
logger.info("Starting Claude Code Proxy...")
|
||||||
|
self._provider_registry = ProviderRegistry()
|
||||||
|
self.app.state.provider_registry = self._provider_registry
|
||||||
warn_if_process_auth_token(self.settings)
|
warn_if_process_auth_token(self.settings)
|
||||||
await self._start_messaging_if_configured()
|
await self._start_messaging_if_configured()
|
||||||
self._publish_state()
|
self._publish_state()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.message_handler and hasattr(self.message_handler, "session_store"):
|
if self.message_handler is not None:
|
||||||
try:
|
try:
|
||||||
self.message_handler.session_store.flush_pending_save()
|
self.message_handler.session_store.flush_pending_save()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -84,20 +84,33 @@ class AppRuntime:
|
||||||
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
|
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
|
||||||
if self.cli_manager:
|
if self.cli_manager:
|
||||||
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
|
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
|
||||||
await best_effort("cleanup_provider", self.provider_cleanup())
|
if self._provider_registry is not None:
|
||||||
|
await best_effort(
|
||||||
|
"provider_registry.cleanup", self._provider_registry.cleanup()
|
||||||
|
)
|
||||||
await self._shutdown_limiter()
|
await self._shutdown_limiter()
|
||||||
logger.info("Server shut down cleanly")
|
logger.info("Server shut down cleanly")
|
||||||
|
|
||||||
async def _start_messaging_if_configured(self) -> None:
|
async def _start_messaging_if_configured(self) -> None:
|
||||||
try:
|
try:
|
||||||
from messaging.platforms.factory import create_messaging_platform
|
from messaging.platforms.factory import (
|
||||||
|
MessagingPlatformOptions,
|
||||||
|
create_messaging_platform,
|
||||||
|
)
|
||||||
|
|
||||||
self.messaging_platform = create_messaging_platform(
|
self.messaging_platform = create_messaging_platform(
|
||||||
platform_type=self.settings.messaging_platform,
|
self.settings.messaging_platform,
|
||||||
bot_token=self.settings.telegram_bot_token,
|
MessagingPlatformOptions(
|
||||||
allowed_user_id=self.settings.allowed_telegram_user_id,
|
telegram_bot_token=self.settings.telegram_bot_token,
|
||||||
|
allowed_telegram_user_id=self.settings.allowed_telegram_user_id,
|
||||||
discord_bot_token=self.settings.discord_bot_token,
|
discord_bot_token=self.settings.discord_bot_token,
|
||||||
allowed_discord_channels=self.settings.allowed_discord_channels,
|
allowed_discord_channels=self.settings.allowed_discord_channels,
|
||||||
|
voice_note_enabled=self.settings.voice_note_enabled,
|
||||||
|
whisper_model=self.settings.whisper_model,
|
||||||
|
whisper_device=self.settings.whisper_device,
|
||||||
|
hf_token=self.settings.hf_token,
|
||||||
|
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.messaging_platform:
|
if self.messaging_platform:
|
||||||
|
|
@ -137,29 +150,31 @@ class AppRuntime:
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
allowed_dirs=allowed_dirs,
|
allowed_dirs=allowed_dirs,
|
||||||
plans_directory=plans_directory,
|
plans_directory=plans_directory,
|
||||||
claude_bin=getattr(self.settings, "claude_cli_bin", "claude"),
|
claude_bin=self.settings.claude_cli_bin,
|
||||||
)
|
)
|
||||||
|
|
||||||
session_store = SessionStore(
|
session_store = SessionStore(
|
||||||
storage_path=os.path.join(data_path, "sessions.json")
|
storage_path=os.path.join(data_path, "sessions.json")
|
||||||
)
|
)
|
||||||
|
platform = self.messaging_platform
|
||||||
|
assert platform is not None
|
||||||
self.message_handler = ClaudeMessageHandler(
|
self.message_handler = ClaudeMessageHandler(
|
||||||
platform=self.messaging_platform,
|
platform=platform,
|
||||||
cli_manager=self.cli_manager,
|
cli_manager=self.cli_manager,
|
||||||
session_store=session_store,
|
session_store=session_store,
|
||||||
)
|
)
|
||||||
self._restore_tree_state(session_store)
|
self._restore_tree_state(session_store)
|
||||||
|
|
||||||
self.messaging_platform.on_message(self.message_handler.handle_message)
|
platform.on_message(self.message_handler.handle_message)
|
||||||
await self.messaging_platform.start()
|
await platform.start()
|
||||||
logger.info(
|
logger.info(f"{platform.name} platform started with message handler")
|
||||||
f"{self.messaging_platform.name} platform started with message handler"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _restore_tree_state(self, session_store: Any) -> None:
|
def _restore_tree_state(self, session_store: SessionStore) -> None:
|
||||||
saved_trees = session_store.get_all_trees()
|
saved_trees = session_store.get_all_trees()
|
||||||
if not saved_trees:
|
if not saved_trees:
|
||||||
return
|
return
|
||||||
|
if self.message_handler is None:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
|
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
|
||||||
from messaging.trees.queue_manager import TreeQueueManager
|
from messaging.trees.queue_manager import TreeQueueManager
|
||||||
|
|
@ -188,11 +203,16 @@ class AppRuntime:
|
||||||
async def _shutdown_limiter(self) -> None:
|
async def _shutdown_limiter(self) -> None:
|
||||||
try:
|
try:
|
||||||
from messaging.limiter import MessagingRateLimiter
|
from messaging.limiter import MessagingRateLimiter
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Rate limiter shutdown skipped (import failed): {}: {}",
|
||||||
|
type(e).__name__,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
await best_effort(
|
await best_effort(
|
||||||
"MessagingRateLimiter.shutdown_instance",
|
"MessagingRateLimiter.shutdown_instance",
|
||||||
MessagingRateLimiter.shutdown_instance(),
|
MessagingRateLimiter.shutdown_instance(),
|
||||||
timeout_s=2.0,
|
timeout_s=2.0,
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,10 @@ from .model_router import ModelRouter
|
||||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||||
from .models.responses import TokenCountResponse
|
from .models.responses import TokenCountResponse
|
||||||
from .optimization_handlers import try_optimizations
|
from .optimization_handlers import try_optimizations
|
||||||
|
from .web_server_tools import (
|
||||||
|
is_web_server_tool_request,
|
||||||
|
stream_web_server_tool_response,
|
||||||
|
)
|
||||||
|
|
||||||
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
||||||
|
|
||||||
|
|
@ -48,6 +52,22 @@ class ClaudeProxyService:
|
||||||
raise InvalidRequestError("messages cannot be empty")
|
raise InvalidRequestError("messages cannot be empty")
|
||||||
|
|
||||||
routed = self._model_router.resolve_messages_request(request_data)
|
routed = self._model_router.resolve_messages_request(request_data)
|
||||||
|
if is_web_server_tool_request(routed.request):
|
||||||
|
input_tokens = self._token_counter(
|
||||||
|
routed.request.messages, routed.request.system, routed.request.tools
|
||||||
|
)
|
||||||
|
logger.info("Optimization: Handling Anthropic web server tool")
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_web_server_tool_response(
|
||||||
|
routed.request, input_tokens=input_tokens
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
optimized = try_optimizations(routed.request, self._settings)
|
optimized = try_optimizations(routed.request, self._settings)
|
||||||
if optimized is not None:
|
if optimized is not None:
|
||||||
|
|
|
||||||
331
api/web_server_tools.py
Normal file
331
api/web_server_tools.py
Normal file
|
|
@ -0,0 +1,331 @@
|
||||||
|
"""Local handlers for Anthropic web server tools.
|
||||||
|
|
||||||
|
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
|
||||||
|
web tools are server-side: the API response itself must include the tool result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import html
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from html.parser import HTMLParser
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import parse_qs, unquote, urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .models.anthropic import MessagesRequest
|
||||||
|
|
||||||
|
_REQUEST_TIMEOUT_S = 20.0
|
||||||
|
_MAX_SEARCH_RESULTS = 10
|
||||||
|
_MAX_FETCH_CHARS = 24_000
|
||||||
|
|
||||||
|
|
||||||
|
class _SearchResultParser(HTMLParser):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.results: list[dict[str, str]] = []
|
||||||
|
self._href: str | None = None
|
||||||
|
self._title_parts: list[str] = []
|
||||||
|
|
||||||
|
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||||
|
if tag != "a":
|
||||||
|
return
|
||||||
|
href = dict(attrs).get("href")
|
||||||
|
if not href or "uddg=" not in href:
|
||||||
|
return
|
||||||
|
parsed = urlparse(href)
|
||||||
|
query = parse_qs(parsed.query)
|
||||||
|
uddg = query.get("uddg", [""])[0]
|
||||||
|
if not uddg:
|
||||||
|
return
|
||||||
|
self._href = unquote(uddg)
|
||||||
|
self._title_parts = []
|
||||||
|
|
||||||
|
def handle_data(self, data: str) -> None:
|
||||||
|
if self._href is not None:
|
||||||
|
self._title_parts.append(data)
|
||||||
|
|
||||||
|
def handle_endtag(self, tag: str) -> None:
|
||||||
|
if tag != "a" or self._href is None:
|
||||||
|
return
|
||||||
|
title = " ".join("".join(self._title_parts).split())
|
||||||
|
if title and not any(result["url"] == self._href for result in self.results):
|
||||||
|
self.results.append({"title": html.unescape(title), "url": self._href})
|
||||||
|
self._href = None
|
||||||
|
self._title_parts = []
|
||||||
|
|
||||||
|
|
||||||
|
class _HTMLTextParser(HTMLParser):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.title = ""
|
||||||
|
self.text_parts: list[str] = []
|
||||||
|
self._in_title = False
|
||||||
|
self._skip_depth = 0
|
||||||
|
|
||||||
|
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||||
|
if tag in {"script", "style", "noscript"}:
|
||||||
|
self._skip_depth += 1
|
||||||
|
elif tag == "title":
|
||||||
|
self._in_title = True
|
||||||
|
|
||||||
|
def handle_endtag(self, tag: str) -> None:
|
||||||
|
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
||||||
|
self._skip_depth -= 1
|
||||||
|
elif tag == "title":
|
||||||
|
self._in_title = False
|
||||||
|
|
||||||
|
def handle_data(self, data: str) -> None:
|
||||||
|
text = " ".join(data.split())
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
if self._in_title:
|
||||||
|
self.title = f"{self.title} {text}".strip()
|
||||||
|
elif not self._skip_depth:
|
||||||
|
self.text_parts.append(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_event(event_type: str, data: dict[str, Any]) -> str:
|
||||||
|
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _content_text(content: Any) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
parts.append(str(item.get("text", "")))
|
||||||
|
else:
|
||||||
|
parts.append(str(getattr(item, "text", "")))
|
||||||
|
return "\n".join(part for part in parts if part)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_text(request: MessagesRequest) -> str:
|
||||||
|
return "\n".join(_content_text(message.content) for message in request.messages)
|
||||||
|
|
||||||
|
|
||||||
|
def _web_tool_name(request: MessagesRequest) -> str | None:
|
||||||
|
for tool in request.tools or []:
|
||||||
|
name = tool.name
|
||||||
|
tool_type = tool.type or ""
|
||||||
|
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
||||||
|
return _web_tool_name(request) in {"web_search", "web_fetch"}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_query(text: str) -> str:
|
||||||
|
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip().strip("\"'")
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_url(text: str) -> str:
|
||||||
|
match = re.search(r"https?://\S+", text)
|
||||||
|
return match.group(0).rstrip(").,]") if match else text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=_REQUEST_TIMEOUT_S,
|
||||||
|
follow_redirects=True,
|
||||||
|
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||||
|
) as client:
|
||||||
|
response = await client.get(
|
||||||
|
"https://lite.duckduckgo.com/lite/",
|
||||||
|
params={"q": query},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
parser = _SearchResultParser()
|
||||||
|
parser.feed(response.text)
|
||||||
|
return parser.results[:_MAX_SEARCH_RESULTS]
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_web_fetch(url: str) -> dict[str, str]:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=_REQUEST_TIMEOUT_S,
|
||||||
|
follow_redirects=True,
|
||||||
|
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||||
|
) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "text/plain")
|
||||||
|
title = url
|
||||||
|
data = response.text
|
||||||
|
if "html" in content_type.lower():
|
||||||
|
parser = _HTMLTextParser()
|
||||||
|
parser.feed(response.text)
|
||||||
|
title = parser.title or url
|
||||||
|
data = "\n".join(parser.text_parts)
|
||||||
|
return {
|
||||||
|
"url": str(response.url),
|
||||||
|
"title": title,
|
||||||
|
"media_type": "text/plain",
|
||||||
|
"data": data[:_MAX_FETCH_CHARS],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
||||||
|
if not results:
|
||||||
|
return f"No web search results found for: {query}"
|
||||||
|
lines = [f"Search results for: {query}"]
|
||||||
|
for index, result in enumerate(results, start=1):
|
||||||
|
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
||||||
|
return "\n\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_web_server_tool_response(
|
||||||
|
request: MessagesRequest, input_tokens: int
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
tool_name = _web_tool_name(request)
|
||||||
|
if tool_name is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
text = _request_text(request)
|
||||||
|
message_id = f"msg_{uuid.uuid4()}"
|
||||||
|
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
||||||
|
output_tokens = 1
|
||||||
|
usage_key = (
|
||||||
|
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
||||||
|
)
|
||||||
|
tool_input = (
|
||||||
|
{"query": _extract_query(text)}
|
||||||
|
if tool_name == "web_search"
|
||||||
|
else {"url": _extract_url(text)}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield _format_event(
|
||||||
|
"message_start",
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": message_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [],
|
||||||
|
"model": request.model,
|
||||||
|
"stop_reason": None,
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _format_event(
|
||||||
|
"content_block_start",
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {
|
||||||
|
"type": "server_tool_use",
|
||||||
|
"id": tool_id,
|
||||||
|
"name": tool_name,
|
||||||
|
"input": tool_input,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _format_event(
|
||||||
|
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if tool_name == "web_search":
|
||||||
|
query = str(tool_input["query"])
|
||||||
|
results = await _run_web_search(query)
|
||||||
|
result_content: Any = [
|
||||||
|
{
|
||||||
|
"type": "web_search_result",
|
||||||
|
"title": result["title"],
|
||||||
|
"url": result["url"],
|
||||||
|
}
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
summary = _search_summary(query, results)
|
||||||
|
result_block_type = "web_search_tool_result"
|
||||||
|
else:
|
||||||
|
fetched = await _run_web_fetch(str(tool_input["url"]))
|
||||||
|
result_content = {
|
||||||
|
"type": "web_fetch_result",
|
||||||
|
"url": fetched["url"],
|
||||||
|
"content": {
|
||||||
|
"type": "document",
|
||||||
|
"source": {
|
||||||
|
"type": "text",
|
||||||
|
"media_type": fetched["media_type"],
|
||||||
|
"data": fetched["data"],
|
||||||
|
},
|
||||||
|
"title": fetched["title"],
|
||||||
|
"citations": {"enabled": True},
|
||||||
|
},
|
||||||
|
"retrieved_at": datetime.now(UTC).isoformat(),
|
||||||
|
}
|
||||||
|
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
||||||
|
result_block_type = "web_fetch_tool_result"
|
||||||
|
except Exception as error:
|
||||||
|
result_block_type = (
|
||||||
|
"web_search_tool_result"
|
||||||
|
if tool_name == "web_search"
|
||||||
|
else "web_fetch_tool_result"
|
||||||
|
)
|
||||||
|
error_type = (
|
||||||
|
"web_search_tool_result_error"
|
||||||
|
if tool_name == "web_search"
|
||||||
|
else "web_fetch_tool_error"
|
||||||
|
)
|
||||||
|
result_content = {"type": error_type, "error_code": "unavailable"}
|
||||||
|
summary = f"{tool_name} failed: {type(error).__name__}"
|
||||||
|
|
||||||
|
output_tokens = max(1, len(summary) // 4)
|
||||||
|
|
||||||
|
yield _format_event(
|
||||||
|
"content_block_start",
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 1,
|
||||||
|
"content_block": {
|
||||||
|
"type": result_block_type,
|
||||||
|
"tool_use_id": tool_id,
|
||||||
|
"content": result_content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _format_event(
|
||||||
|
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
||||||
|
)
|
||||||
|
yield _format_event(
|
||||||
|
"content_block_start",
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 2,
|
||||||
|
"content_block": {"type": "text", "text": summary},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _format_event(
|
||||||
|
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
||||||
|
)
|
||||||
|
yield _format_event(
|
||||||
|
"message_delta",
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"server_tool_use": {usage_key: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _format_event("message_stop", {"type": "message_stop"})
|
||||||
17
config/provider_ids.py
Normal file
17
config/provider_ids.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""Canonical list of model provider type prefixes (provider_id values).
|
||||||
|
|
||||||
|
`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this
|
||||||
|
module holds the id set for config validation and must stay in sync
|
||||||
|
(registries assert in `providers.registry`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys.
|
||||||
|
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = (
|
||||||
|
"nvidia_nim",
|
||||||
|
"open_router",
|
||||||
|
"deepseek",
|
||||||
|
"lmstudio",
|
||||||
|
"llamacpp",
|
||||||
|
)
|
||||||
|
|
@ -11,6 +11,7 @@ from pydantic import Field, field_validator, model_validator
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from .nim import NimSettings
|
from .nim import NimSettings
|
||||||
|
from .provider_ids import SUPPORTED_PROVIDER_IDS
|
||||||
|
|
||||||
|
|
||||||
def _env_files() -> tuple[Path, ...]:
|
def _env_files() -> tuple[Path, ...]:
|
||||||
|
|
@ -252,25 +253,16 @@ class Settings(BaseSettings):
|
||||||
def validate_model_format(cls, v: str | None) -> str | None:
|
def validate_model_format(cls, v: str | None) -> str | None:
|
||||||
if v is None:
|
if v is None:
|
||||||
return None
|
return None
|
||||||
valid_providers = (
|
|
||||||
"nvidia_nim",
|
|
||||||
"open_router",
|
|
||||||
"deepseek",
|
|
||||||
"lmstudio",
|
|
||||||
"llamacpp",
|
|
||||||
)
|
|
||||||
if "/" not in v:
|
if "/" not in v:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model must be prefixed with provider type. "
|
f"Model must be prefixed with provider type. "
|
||||||
f"Valid providers: {', '.join(valid_providers)}. "
|
f"Valid providers: {', '.join(SUPPORTED_PROVIDER_IDS)}. "
|
||||||
f"Format: provider_type/model/name"
|
f"Format: provider_type/model/name"
|
||||||
)
|
)
|
||||||
provider = v.split("/", 1)[0]
|
provider = v.split("/", 1)[0]
|
||||||
if provider not in valid_providers:
|
if provider not in SUPPORTED_PROVIDER_IDS:
|
||||||
raise ValueError(
|
supported = ", ".join(f"'{p}'" for p in SUPPORTED_PROVIDER_IDS)
|
||||||
f"Invalid provider: '{provider}'. "
|
raise ValueError(f"Invalid provider: '{provider}'. Supported: {supported}")
|
||||||
f"Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'"
|
|
||||||
)
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,17 @@ from .content import get_block_attr, get_block_type
|
||||||
from .utils import set_if_not_none
|
from .utils import set_if_not_none
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_name(tool: Any) -> str:
|
||||||
|
return str(getattr(tool, "name", "") or "")
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_input_schema(tool: Any) -> dict[str, Any]:
|
||||||
|
schema = getattr(tool, "input_schema", None)
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
return schema
|
||||||
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
|
||||||
class AnthropicToOpenAIConverter:
|
class AnthropicToOpenAIConverter:
|
||||||
"""Convert Anthropic message format to OpenAI-compatible format."""
|
"""Convert Anthropic message format to OpenAI-compatible format."""
|
||||||
|
|
||||||
|
|
@ -140,7 +151,7 @@ class AnthropicToOpenAIConverter:
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description or "",
|
"description": tool.description or "",
|
||||||
"parameters": tool.input_schema,
|
"parameters": _tool_input_schema(tool),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in tools
|
||||||
|
|
|
||||||
158
core/anthropic/stream_contracts.py
Normal file
158
core/anthropic/stream_contracts.py
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
"""Neutral SSE parsing and Anthropic stream shape assertions.
|
||||||
|
|
||||||
|
Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for
|
||||||
|
opt-in smoke scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class SSEEvent:
|
||||||
|
event: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
raw: str
|
||||||
|
|
||||||
|
|
||||||
|
def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]:
|
||||||
|
events: list[SSEEvent] = []
|
||||||
|
current_event = ""
|
||||||
|
data_parts: list[str] = []
|
||||||
|
raw_parts: list[str] = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
stripped = line.rstrip("\r\n")
|
||||||
|
if stripped == "":
|
||||||
|
_append_event(events, current_event, data_parts, raw_parts)
|
||||||
|
current_event = ""
|
||||||
|
data_parts = []
|
||||||
|
raw_parts = []
|
||||||
|
continue
|
||||||
|
raw_parts.append(stripped)
|
||||||
|
if stripped.startswith("event:"):
|
||||||
|
current_event = stripped.split(":", 1)[1].strip()
|
||||||
|
elif stripped.startswith("data:"):
|
||||||
|
data_parts.append(stripped.split(":", 1)[1].strip())
|
||||||
|
|
||||||
|
_append_event(events, current_event, data_parts, raw_parts)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def parse_sse_text(text: str) -> list[SSEEvent]:
|
||||||
|
return parse_sse_lines(text.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
def _append_event(
|
||||||
|
events: list[SSEEvent],
|
||||||
|
current_event: str,
|
||||||
|
data_parts: list[str],
|
||||||
|
raw_parts: list[str],
|
||||||
|
) -> None:
|
||||||
|
if not current_event and not data_parts:
|
||||||
|
return
|
||||||
|
data_text = "\n".join(data_parts)
|
||||||
|
data: dict[str, Any]
|
||||||
|
try:
|
||||||
|
parsed = json.loads(data_text) if data_text else {}
|
||||||
|
data = parsed if isinstance(parsed, dict) else {"value": parsed}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
data = {"raw": data_text}
|
||||||
|
events.append(SSEEvent(current_event, data, "\n".join(raw_parts)))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_anthropic_stream_contract(
|
||||||
|
events: list[SSEEvent], *, allow_error: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Check minimal Anthropic-style SSE invariants: start/stop, block nesting.
|
||||||
|
|
||||||
|
Does *not* assert strict event ordering (e.g. :class:`message_delta` vs
|
||||||
|
content blocks) beyond presence of a final ``message_stop``; stricter
|
||||||
|
ordering can be tested in product or transport-specific suites.
|
||||||
|
"""
|
||||||
|
assert events, "stream produced no SSE events"
|
||||||
|
event_names = [event.event for event in events]
|
||||||
|
assert "message_start" in event_names, event_names
|
||||||
|
assert event_names[-1] == "message_stop", event_names
|
||||||
|
|
||||||
|
open_blocks: dict[int, str] = {}
|
||||||
|
seen_blocks: set[int] = set()
|
||||||
|
for event in events:
|
||||||
|
if event.event == "error" and not allow_error:
|
||||||
|
raise AssertionError(f"unexpected SSE error event: {event.data}")
|
||||||
|
|
||||||
|
if event.event == "content_block_start":
|
||||||
|
index = _event_index(event)
|
||||||
|
block = event.data.get("content_block", {})
|
||||||
|
assert isinstance(block, dict), event.data
|
||||||
|
block_type = str(block.get("type", ""))
|
||||||
|
assert block_type in {"text", "thinking", "tool_use"}, event.data
|
||||||
|
assert index not in open_blocks, f"block {index} started twice"
|
||||||
|
assert index not in seen_blocks, f"block {index} reused after stop"
|
||||||
|
open_blocks[index] = block_type
|
||||||
|
seen_blocks.add(index)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.event == "content_block_delta":
|
||||||
|
index = _event_index(event)
|
||||||
|
assert index in open_blocks, f"delta for unopened block {index}"
|
||||||
|
delta = event.data.get("delta", {})
|
||||||
|
assert isinstance(delta, dict), event.data
|
||||||
|
delta_type = str(delta.get("type", ""))
|
||||||
|
expected = {
|
||||||
|
"text": "text_delta",
|
||||||
|
"thinking": "thinking_delta",
|
||||||
|
"tool_use": "input_json_delta",
|
||||||
|
}[open_blocks[index]]
|
||||||
|
assert delta_type == expected, (
|
||||||
|
f"block {index} is {open_blocks[index]}, got {delta_type}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.event == "content_block_stop":
|
||||||
|
index = _event_index(event)
|
||||||
|
assert index in open_blocks, f"stop for unopened block {index}"
|
||||||
|
open_blocks.pop(index)
|
||||||
|
|
||||||
|
assert not open_blocks, f"unclosed blocks: {open_blocks}"
|
||||||
|
assert seen_blocks, "stream did not emit any content blocks"
|
||||||
|
|
||||||
|
|
||||||
|
def event_names(events: list[SSEEvent]) -> list[str]:
|
||||||
|
return [event.event for event in events]
|
||||||
|
|
||||||
|
|
||||||
|
def text_content(events: list[SSEEvent]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for event in events:
|
||||||
|
delta = event.data.get("delta", {})
|
||||||
|
if isinstance(delta, dict) and delta.get("type") == "text_delta":
|
||||||
|
parts.append(str(delta.get("text", "")))
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def thinking_content(events: list[SSEEvent]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for event in events:
|
||||||
|
delta = event.data.get("delta", {})
|
||||||
|
if isinstance(delta, dict) and delta.get("type") == "thinking_delta":
|
||||||
|
parts.append(str(delta.get("thinking", "")))
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def has_tool_use(events: list[SSEEvent]) -> bool:
|
||||||
|
for event in events:
|
||||||
|
block = event.data.get("content_block", {})
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _event_index(event: SSEEvent) -> int:
|
||||||
|
value = event.data.get("index")
|
||||||
|
assert isinstance(value, int), event.data
|
||||||
|
return value
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Heuristic parser for text-emitted tool calls."""
|
"""Heuristic parser for text-emitted tool calls."""
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
@ -31,6 +32,9 @@ class HeuristicToolParser:
|
||||||
_PARAM_PATTERN = re.compile(
|
_PARAM_PATTERN = re.compile(
|
||||||
r"<parameter=([^>]+)>(.*?)(?:</parameter>|$)", re.DOTALL
|
r"<parameter=([^>]+)>(.*?)(?:</parameter>|$)", re.DOTALL
|
||||||
)
|
)
|
||||||
|
_WEB_TOOL_JSON_PATTERN = re.compile(
|
||||||
|
r"(?is)\b(?:use\s+)?(?P<tool>WebFetch|WebSearch)\b.*?(?P<json>\{.*?\})"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._state = ParserState.TEXT
|
self._state = ParserState.TEXT
|
||||||
|
|
@ -39,6 +43,41 @@ class HeuristicToolParser:
|
||||||
self._current_function_name = None
|
self._current_function_name = None
|
||||||
self._current_parameters = {}
|
self._current_parameters = {}
|
||||||
|
|
||||||
|
def _extract_web_tool_json_calls(self) -> tuple[str, list[dict[str, Any]]]:
|
||||||
|
detected_tools: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for match in self._WEB_TOOL_JSON_PATTERN.finditer(self._buffer):
|
||||||
|
try:
|
||||||
|
tool_input = json.loads(match.group("json"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if not isinstance(tool_input, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_name = match.group("tool")
|
||||||
|
if tool_name == "WebFetch" and "url" not in tool_input:
|
||||||
|
continue
|
||||||
|
if tool_name == "WebSearch" and "query" not in tool_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
detected_tools.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": f"toolu_heuristic_{uuid.uuid4().hex[:8]}",
|
||||||
|
"name": tool_name,
|
||||||
|
"input": tool_input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Heuristic bypass: Detected JSON-style tool call '{}'",
|
||||||
|
tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not detected_tools:
|
||||||
|
return self._buffer, []
|
||||||
|
|
||||||
|
return "", detected_tools
|
||||||
|
|
||||||
def _strip_control_tokens(self, text: str) -> str:
|
def _strip_control_tokens(self, text: str) -> str:
|
||||||
return _CONTROL_TOKEN_RE.sub("", text)
|
return _CONTROL_TOKEN_RE.sub("", text)
|
||||||
|
|
||||||
|
|
@ -58,7 +97,7 @@ class HeuristicToolParser:
|
||||||
"""Feed text and return safe text plus detected tool calls."""
|
"""Feed text and return safe text plus detected tool calls."""
|
||||||
self._buffer += text
|
self._buffer += text
|
||||||
self._buffer = self._strip_control_tokens(self._buffer)
|
self._buffer = self._strip_control_tokens(self._buffer)
|
||||||
detected_tools = []
|
self._buffer, detected_tools = self._extract_web_tool_json_calls()
|
||||||
filtered_output_parts: list[str] = []
|
filtered_output_parts: list[str] = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
|
||||||
|
|
@ -114,23 +114,15 @@ async def _delete_message_ids(
|
||||||
numeric.sort(reverse=True)
|
numeric.sort(reverse=True)
|
||||||
ordered = [mid for _, mid in numeric] + non_numeric
|
ordered = [mid for _, mid in numeric] + non_numeric
|
||||||
|
|
||||||
batch_fn = getattr(handler.platform, "queue_delete_messages", None)
|
|
||||||
if callable(batch_fn):
|
|
||||||
try:
|
try:
|
||||||
CHUNK = 100
|
CHUNK = 100
|
||||||
for i in range(0, len(ordered), CHUNK):
|
for i in range(0, len(ordered), CHUNK):
|
||||||
chunk = ordered[i : i + CHUNK]
|
chunk = ordered[i : i + CHUNK]
|
||||||
await batch_fn(chat_id, chunk, fire_and_forget=False)
|
await handler.platform.queue_delete_messages(
|
||||||
except Exception as e:
|
chat_id, chunk, fire_and_forget=False
|
||||||
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
|
|
||||||
else:
|
|
||||||
for mid in ordered:
|
|
||||||
try:
|
|
||||||
await handler.platform.queue_delete_message(
|
|
||||||
chat_id, mid, fire_and_forget=False
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Delete failed for msg {mid}: {type(e).__name__}: {e}")
|
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _handle_clear_branch(
|
async def _handle_clear_branch(
|
||||||
|
|
|
||||||
|
|
@ -467,7 +467,7 @@ class ClaudeMessageHandler:
|
||||||
status,
|
status,
|
||||||
len(display),
|
len(display),
|
||||||
)
|
)
|
||||||
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
|
if os.getenv("DEBUG_PLATFORM_EDITS") == "1":
|
||||||
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
|
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
|
||||||
else:
|
else:
|
||||||
head = display[:500]
|
head = display[:500]
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,10 @@ class CLISession(Protocol):
|
||||||
|
|
||||||
def start_task(
|
def start_task(
|
||||||
self, prompt: str, session_id: str | None = None, fork_session: bool = False
|
self, prompt: str, session_id: str | None = None, fork_session: bool = False
|
||||||
) -> AsyncGenerator[dict, Any]:
|
) -> AsyncGenerator[dict, Any]: ...
|
||||||
"""Start a task in the CLI session."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
def is_busy(self) -> bool: ...
|
||||||
def is_busy(self) -> bool:
|
|
||||||
"""Check if session is busy."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
@ -101,7 +96,8 @@ class MessagingPlatform(ABC):
|
||||||
text: Message content
|
text: Message content
|
||||||
reply_to: Optional message ID to reply to
|
reply_to: Optional message ID to reply to
|
||||||
parse_mode: Optional formatting mode ("markdown", "html")
|
parse_mode: Optional formatting mode ("markdown", "html")
|
||||||
message_thread_id: Optional forum topic ID (Telegram)
|
message_thread_id: Optional thread or topic id for threaded channels
|
||||||
|
(e.g. forum topics); unused on platforms that do not support it.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The message ID of the sent message
|
The message ID of the sent message
|
||||||
|
|
@ -192,6 +188,22 @@ class MessagingPlatform(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def queue_delete_messages(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
message_ids: list[str],
|
||||||
|
*,
|
||||||
|
fire_and_forget: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Delete many messages; default loops :meth:`queue_delete_message`.
|
||||||
|
|
||||||
|
Adapters with native bulk delete should override.
|
||||||
|
"""
|
||||||
|
for mid in message_ids:
|
||||||
|
await self.queue_delete_message(
|
||||||
|
chat_id, mid, fire_and_forget=fire_and_forget
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_message(
|
def on_message(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,12 @@ class DiscordPlatform(MessagingPlatform):
|
||||||
self,
|
self,
|
||||||
bot_token: str | None = None,
|
bot_token: str | None = None,
|
||||||
allowed_channel_ids: str | None = None,
|
allowed_channel_ids: str | None = None,
|
||||||
|
*,
|
||||||
|
voice_note_enabled: bool = True,
|
||||||
|
whisper_model: str = "base",
|
||||||
|
whisper_device: str = "cpu",
|
||||||
|
hf_token: str = "",
|
||||||
|
nvidia_nim_api_key: str = "",
|
||||||
):
|
):
|
||||||
if not DISCORD_AVAILABLE:
|
if not DISCORD_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|
@ -117,7 +123,13 @@ class DiscordPlatform(MessagingPlatform):
|
||||||
self._limiter: Any | None = None
|
self._limiter: Any | None = None
|
||||||
self._start_task: asyncio.Task | None = None
|
self._start_task: asyncio.Task | None = None
|
||||||
self._pending_voice = PendingVoiceRegistry()
|
self._pending_voice = PendingVoiceRegistry()
|
||||||
self._voice_transcription = VoiceTranscriptionService()
|
self._voice_transcription = VoiceTranscriptionService(
|
||||||
|
hf_token=hf_token,
|
||||||
|
nvidia_nim_api_key=nvidia_nim_api_key,
|
||||||
|
)
|
||||||
|
self._voice_note_enabled = voice_note_enabled
|
||||||
|
self._whisper_model = whisper_model
|
||||||
|
self._whisper_device = whisper_device
|
||||||
|
|
||||||
async def _handle_client_message(self, message: Any) -> None:
|
async def _handle_client_message(self, message: Any) -> None:
|
||||||
"""Adapter entry point used by the internal discord client."""
|
"""Adapter entry point used by the internal discord client."""
|
||||||
|
|
@ -154,10 +166,7 @@ class DiscordPlatform(MessagingPlatform):
|
||||||
self, message: Any, attachment: Any, channel_id: str
|
self, message: Any, attachment: Any, channel_id: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Handle voice/audio attachment. Returns True if handled."""
|
"""Handle voice/audio attachment. Returns True if handled."""
|
||||||
from config.settings import get_settings
|
if not self._voice_note_enabled:
|
||||||
|
|
||||||
settings = get_settings()
|
|
||||||
if not settings.voice_note_enabled:
|
|
||||||
await message.reply("Voice notes are disabled.")
|
await message.reply("Voice notes are disabled.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -201,8 +210,8 @@ class DiscordPlatform(MessagingPlatform):
|
||||||
transcribed = await self._voice_transcription.transcribe(
|
transcribed = await self._voice_transcription.transcribe(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
ct,
|
ct,
|
||||||
whisper_model=settings.whisper_model,
|
whisper_model=self._whisper_model,
|
||||||
whisper_device=settings.whisper_device,
|
whisper_device=self._whisper_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await self._is_voice_still_pending(channel_id, message_id):
|
if not await self._is_voice_still_pending(channel_id, message_id):
|
||||||
|
|
|
||||||
|
|
@ -6,30 +6,50 @@ To add a new platform (e.g. Discord, Slack):
|
||||||
2. Add a case to create_messaging_platform() below
|
2. Add a case to create_messaging_platform() below
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .base import MessagingPlatform
|
from .base import MessagingPlatform
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MessagingPlatformOptions:
|
||||||
|
"""Typed wiring from :class:`~api.runtime.AppRuntime` into platform adapters."""
|
||||||
|
|
||||||
|
telegram_bot_token: str | None = None
|
||||||
|
allowed_telegram_user_id: str | None = None
|
||||||
|
discord_bot_token: str | None = None
|
||||||
|
allowed_discord_channels: str | None = None
|
||||||
|
voice_note_enabled: bool = True
|
||||||
|
whisper_model: str = "base"
|
||||||
|
whisper_device: str = "cpu"
|
||||||
|
hf_token: str = ""
|
||||||
|
nvidia_nim_api_key: str = ""
|
||||||
|
|
||||||
|
|
||||||
def create_messaging_platform(
|
def create_messaging_platform(
|
||||||
platform_type: str,
|
platform_type: str,
|
||||||
**kwargs,
|
options: MessagingPlatformOptions | None = None,
|
||||||
) -> MessagingPlatform | None:
|
) -> MessagingPlatform | None:
|
||||||
"""Create a messaging platform instance based on type.
|
"""Create a messaging platform instance based on type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
platform_type: Platform identifier ("telegram", "discord", etc.)
|
platform_type: Platform identifier (``telegram``, ``discord``, ``none``).
|
||||||
**kwargs: Platform-specific configuration passed to the constructor.
|
options: Token, allowlist, and voice / transcription settings.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured MessagingPlatform instance, or None if not configured.
|
Configured :class:`MessagingPlatform` instance, or None if not configured.
|
||||||
"""
|
"""
|
||||||
|
opts = options or MessagingPlatformOptions()
|
||||||
if platform_type == "none":
|
if platform_type == "none":
|
||||||
logger.info("Messaging platform disabled by configuration")
|
logger.info("Messaging platform disabled by configuration")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if platform_type == "telegram":
|
if platform_type == "telegram":
|
||||||
bot_token = kwargs.get("bot_token")
|
bot_token = opts.telegram_bot_token
|
||||||
if not bot_token:
|
if not bot_token:
|
||||||
logger.info("No Telegram bot token configured, skipping platform setup")
|
logger.info("No Telegram bot token configured, skipping platform setup")
|
||||||
return None
|
return None
|
||||||
|
|
@ -38,11 +58,16 @@ def create_messaging_platform(
|
||||||
|
|
||||||
return TelegramPlatform(
|
return TelegramPlatform(
|
||||||
bot_token=bot_token,
|
bot_token=bot_token,
|
||||||
allowed_user_id=kwargs.get("allowed_user_id"),
|
allowed_user_id=opts.allowed_telegram_user_id,
|
||||||
|
voice_note_enabled=opts.voice_note_enabled,
|
||||||
|
whisper_model=opts.whisper_model,
|
||||||
|
whisper_device=opts.whisper_device,
|
||||||
|
hf_token=opts.hf_token,
|
||||||
|
nvidia_nim_api_key=opts.nvidia_nim_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
if platform_type == "discord":
|
if platform_type == "discord":
|
||||||
bot_token = kwargs.get("discord_bot_token")
|
bot_token = opts.discord_bot_token
|
||||||
if not bot_token:
|
if not bot_token:
|
||||||
logger.info("No Discord bot token configured, skipping platform setup")
|
logger.info("No Discord bot token configured, skipping platform setup")
|
||||||
return None
|
return None
|
||||||
|
|
@ -51,7 +76,12 @@ def create_messaging_platform(
|
||||||
|
|
||||||
return DiscordPlatform(
|
return DiscordPlatform(
|
||||||
bot_token=bot_token,
|
bot_token=bot_token,
|
||||||
allowed_channel_ids=kwargs.get("allowed_discord_channels"),
|
allowed_channel_ids=opts.allowed_discord_channels,
|
||||||
|
voice_note_enabled=opts.voice_note_enabled,
|
||||||
|
whisper_model=opts.whisper_model,
|
||||||
|
whisper_device=opts.whisper_device,
|
||||||
|
hf_token=opts.hf_token,
|
||||||
|
nvidia_nim_api_key=opts.nvidia_nim_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,12 @@ class TelegramPlatform(MessagingPlatform):
|
||||||
self,
|
self,
|
||||||
bot_token: str | None = None,
|
bot_token: str | None = None,
|
||||||
allowed_user_id: str | None = None,
|
allowed_user_id: str | None = None,
|
||||||
|
*,
|
||||||
|
voice_note_enabled: bool = True,
|
||||||
|
whisper_model: str = "base",
|
||||||
|
whisper_device: str = "cpu",
|
||||||
|
hf_token: str = "",
|
||||||
|
nvidia_nim_api_key: str = "",
|
||||||
):
|
):
|
||||||
if not TELEGRAM_AVAILABLE:
|
if not TELEGRAM_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|
@ -84,7 +90,13 @@ class TelegramPlatform(MessagingPlatform):
|
||||||
self._limiter: Any | None = None # Will be MessagingRateLimiter
|
self._limiter: Any | None = None # Will be MessagingRateLimiter
|
||||||
# Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id)
|
# Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id)
|
||||||
self._pending_voice = PendingVoiceRegistry()
|
self._pending_voice = PendingVoiceRegistry()
|
||||||
self._voice_transcription = VoiceTranscriptionService()
|
self._voice_transcription = VoiceTranscriptionService(
|
||||||
|
hf_token=hf_token,
|
||||||
|
nvidia_nim_api_key=nvidia_nim_api_key,
|
||||||
|
)
|
||||||
|
self._voice_note_enabled = voice_note_enabled
|
||||||
|
self._whisper_model = whisper_model
|
||||||
|
self._whisper_device = whisper_device
|
||||||
|
|
||||||
async def _register_pending_voice(
|
async def _register_pending_voice(
|
||||||
self, chat_id: str, voice_msg_id: str, status_msg_id: str
|
self, chat_id: str, voice_msg_id: str, status_msg_id: str
|
||||||
|
|
@ -544,10 +556,7 @@ class TelegramPlatform(MessagingPlatform):
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
from config.settings import get_settings
|
if not self._voice_note_enabled:
|
||||||
|
|
||||||
settings = get_settings()
|
|
||||||
if not settings.voice_note_enabled:
|
|
||||||
await update.message.reply_text("Voice notes are disabled.")
|
await update.message.reply_text("Voice notes are disabled.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -600,8 +609,8 @@ class TelegramPlatform(MessagingPlatform):
|
||||||
transcribed = await self._voice_transcription.transcribe(
|
transcribed = await self._voice_transcription.transcribe(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
voice.mime_type or "audio/ogg",
|
voice.mime_type or "audio/ogg",
|
||||||
whisper_model=settings.whisper_model,
|
whisper_model=self._whisper_model,
|
||||||
whisper_device=settings.whisper_device,
|
whisper_device=self._whisper_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await self._is_voice_still_pending(chat_id, message_id):
|
if not await self._is_voice_still_pending(chat_id, message_id):
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,8 @@ _MODEL_MAP: dict[str, str] = {
|
||||||
"large-v3-turbo": "openai/whisper-large-v3-turbo",
|
"large-v3-turbo": "openai/whisper-large-v3-turbo",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Lazy-loaded pipelines: (model_id, device) -> pipeline
|
# Lazy-loaded pipelines: (model_id, device, hf_token_fingerprint) -> pipeline
|
||||||
_pipeline_cache: dict[tuple[str, str], Any] = {}
|
_pipeline_cache: dict[tuple[str, str, str], Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model_id(whisper_model: str) -> str:
|
def _resolve_model_id(whisper_model: str) -> str:
|
||||||
|
|
@ -51,20 +51,22 @@ def _resolve_model_id(whisper_model: str) -> str:
|
||||||
return _MODEL_MAP.get(whisper_model, whisper_model)
|
return _MODEL_MAP.get(whisper_model, whisper_model)
|
||||||
|
|
||||||
|
|
||||||
def _get_pipeline(model_id: str, device: str) -> Any:
|
def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any:
|
||||||
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
|
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
|
||||||
global _pipeline_cache
|
global _pipeline_cache
|
||||||
if device not in ("cpu", "cuda"):
|
if device not in ("cpu", "cuda"):
|
||||||
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
|
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
|
||||||
cache_key = (model_id, device)
|
resolved_token = (
|
||||||
|
hf_token if hf_token is not None else get_settings().hf_token
|
||||||
|
) or ""
|
||||||
|
cache_key = (model_id, device, resolved_token)
|
||||||
if cache_key not in _pipeline_cache:
|
if cache_key not in _pipeline_cache:
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||||
|
|
||||||
token = get_settings().hf_token
|
if resolved_token:
|
||||||
if token:
|
os.environ["HF_TOKEN"] = resolved_token
|
||||||
os.environ["HF_TOKEN"] = token
|
|
||||||
|
|
||||||
use_cuda = device == "cuda" and torch.cuda.is_available()
|
use_cuda = device == "cuda" and torch.cuda.is_available()
|
||||||
pipe_device = "cuda:0" if use_cuda else "cpu"
|
pipe_device = "cuda:0" if use_cuda else "cpu"
|
||||||
|
|
@ -103,6 +105,8 @@ def transcribe_audio(
|
||||||
*,
|
*,
|
||||||
whisper_model: str = "base",
|
whisper_model: str = "base",
|
||||||
whisper_device: str = "cpu",
|
whisper_device: str = "cpu",
|
||||||
|
hf_token: str = "",
|
||||||
|
nvidia_nim_api_key: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe audio file to text.
|
Transcribe audio file to text.
|
||||||
|
|
@ -136,9 +140,12 @@ def transcribe_audio(
|
||||||
)
|
)
|
||||||
|
|
||||||
if whisper_device == "nvidia_nim":
|
if whisper_device == "nvidia_nim":
|
||||||
return _transcribe_nim(file_path, whisper_model)
|
return _transcribe_nim(
|
||||||
else:
|
file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key
|
||||||
return _transcribe_local(file_path, whisper_model, whisper_device)
|
)
|
||||||
|
return _transcribe_local(
|
||||||
|
file_path, whisper_model, whisper_device, hf_token=hf_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Whisper expects 16 kHz sample rate
|
# Whisper expects 16 kHz sample rate
|
||||||
|
|
@ -153,10 +160,17 @@ def _load_audio(file_path: Path) -> dict[str, Any]:
|
||||||
return {"array": waveform, "sampling_rate": sr}
|
return {"array": waveform, "sampling_rate": sr}
|
||||||
|
|
||||||
|
|
||||||
def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str) -> str:
|
def _transcribe_local(
|
||||||
|
file_path: Path,
|
||||||
|
whisper_model: str,
|
||||||
|
whisper_device: str,
|
||||||
|
*,
|
||||||
|
hf_token: str = "",
|
||||||
|
) -> str:
|
||||||
"""Transcribe using transformers Whisper pipeline."""
|
"""Transcribe using transformers Whisper pipeline."""
|
||||||
model_id = _resolve_model_id(whisper_model)
|
model_id = _resolve_model_id(whisper_model)
|
||||||
pipe = _get_pipeline(model_id, whisper_device)
|
token: str | None = hf_token if hf_token else None
|
||||||
|
pipe = _get_pipeline(model_id, whisper_device, hf_token=token)
|
||||||
audio = _load_audio(file_path)
|
audio = _load_audio(file_path)
|
||||||
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
|
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
|
||||||
text = result.get("text", "") or ""
|
text = result.get("text", "") or ""
|
||||||
|
|
@ -167,7 +181,9 @@ def _transcribe_local(file_path: Path, whisper_model: str, whisper_device: str)
|
||||||
return result_text or "(no speech detected)"
|
return result_text or "(no speech detected)"
|
||||||
|
|
||||||
|
|
||||||
def _transcribe_nim(file_path: Path, model: str) -> str:
|
def _transcribe_nim(
|
||||||
|
file_path: Path, model: str, *, nvidia_nim_api_key: str = ""
|
||||||
|
) -> str:
|
||||||
"""Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
|
"""Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
|
||||||
try:
|
try:
|
||||||
import riva.client
|
import riva.client
|
||||||
|
|
@ -177,8 +193,7 @@ def _transcribe_nim(file_path: Path, model: str) -> str:
|
||||||
"Install with: uv sync --extra voice"
|
"Install with: uv sync --extra voice"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
settings = get_settings()
|
api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key
|
||||||
api_key = settings.nvidia_nim_api_key
|
|
||||||
|
|
||||||
# Look up function ID and language code from model mapping
|
# Look up function ID and language code from model mapping
|
||||||
model_config = _NIM_MODEL_MAP.get(model)
|
model_config = _NIM_MODEL_MAP.get(model)
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,15 @@ class PendingVoiceRegistry:
|
||||||
class VoiceTranscriptionService:
|
class VoiceTranscriptionService:
|
||||||
"""Run configured transcription backends off the event loop."""
|
"""Run configured transcription backends off the event loop."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
hf_token: str = "",
|
||||||
|
nvidia_nim_api_key: str = "",
|
||||||
|
) -> None:
|
||||||
|
self._hf_token = hf_token
|
||||||
|
self._nvidia_nim_api_key = nvidia_nim_api_key
|
||||||
|
|
||||||
async def transcribe(
|
async def transcribe(
|
||||||
self,
|
self,
|
||||||
file_path: Path,
|
file_path: Path,
|
||||||
|
|
@ -62,4 +71,6 @@ class VoiceTranscriptionService:
|
||||||
mime_type,
|
mime_type,
|
||||||
whisper_model=whisper_model,
|
whisper_model=whisper_model,
|
||||||
whisper_device=whisper_device,
|
whisper_device=whisper_device,
|
||||||
|
hf_token=self._hf_token,
|
||||||
|
nvidia_nim_api_key=self._nvidia_nim_api_key,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
"""Providers package - implement your own provider by extending BaseProvider."""
|
"""Providers package - implement your own provider by extending BaseProvider.
|
||||||
|
|
||||||
|
Concrete adapters (e.g. ``NvidiaNimProvider``) live in subpackages; import them
|
||||||
|
from ``providers.nvidia_nim`` etc. to avoid loading every adapter when the
|
||||||
|
``providers`` package is imported.
|
||||||
|
"""
|
||||||
|
|
||||||
from .anthropic_messages import AnthropicMessagesTransport
|
|
||||||
from .base import BaseProvider, ProviderConfig
|
from .base import BaseProvider, ProviderConfig
|
||||||
from .deepseek import DeepSeekProvider
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
APIError,
|
APIError,
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
|
|
@ -10,27 +13,17 @@ from .exceptions import (
|
||||||
OverloadedError,
|
OverloadedError,
|
||||||
ProviderError,
|
ProviderError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
|
UnknownProviderTypeError,
|
||||||
)
|
)
|
||||||
from .llamacpp import LlamaCppProvider
|
|
||||||
from .lmstudio import LMStudioProvider
|
|
||||||
from .nvidia_nim import NvidiaNimProvider
|
|
||||||
from .open_router import OpenRouterProvider
|
|
||||||
from .openai_compat import OpenAIChatTransport
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"APIError",
|
"APIError",
|
||||||
"AnthropicMessagesTransport",
|
|
||||||
"AuthenticationError",
|
"AuthenticationError",
|
||||||
"BaseProvider",
|
"BaseProvider",
|
||||||
"DeepSeekProvider",
|
|
||||||
"InvalidRequestError",
|
"InvalidRequestError",
|
||||||
"LMStudioProvider",
|
|
||||||
"LlamaCppProvider",
|
|
||||||
"NvidiaNimProvider",
|
|
||||||
"OpenAIChatTransport",
|
|
||||||
"OpenRouterProvider",
|
|
||||||
"OverloadedError",
|
"OverloadedError",
|
||||||
"ProviderConfig",
|
"ProviderConfig",
|
||||||
"ProviderError",
|
"ProviderError",
|
||||||
"RateLimitError",
|
"RateLimitError",
|
||||||
|
"UnknownProviderTypeError",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,11 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from providers.base import ProviderConfig
|
from providers.base import ProviderConfig
|
||||||
|
from providers.defaults import DEEPSEEK_BASE_URL
|
||||||
from providers.openai_compat import OpenAIChatTransport
|
from providers.openai_compat import OpenAIChatTransport
|
||||||
|
|
||||||
from .request import build_request_body
|
from .request import build_request_body
|
||||||
|
|
||||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekProvider(OpenAIChatTransport):
|
class DeepSeekProvider(OpenAIChatTransport):
|
||||||
"""DeepSeek provider using OpenAI-compatible chat completions."""
|
"""DeepSeek provider using OpenAI-compatible chat completions."""
|
||||||
|
|
|
||||||
19
providers/defaults.py
Normal file
19
providers/defaults.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""Default upstream base URLs and shared provider constants.
|
||||||
|
|
||||||
|
Adapters and :mod:`providers.registry` import from here to avoid duplicating
|
||||||
|
literals and to keep ``providers.registry`` free of per-adapter eager imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# OpenAI-compatible chat (NIM, DeepSeek) and local OpenAI-shaped endpoints
|
||||||
|
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
|
||||||
|
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
|
||||||
|
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
|
||||||
|
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
|
||||||
|
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
|
||||||
|
|
||||||
|
# Backward-compatible names used by existing adapter modules
|
||||||
|
NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE
|
||||||
|
DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE
|
||||||
|
OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE
|
||||||
|
LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE
|
||||||
|
LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE
|
||||||
|
|
@ -88,3 +88,10 @@ class APIError(ProviderError):
|
||||||
error_type="api_error",
|
error_type="api_error",
|
||||||
raw_error=raw_error,
|
raw_error=raw_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownProviderTypeError(ValueError):
|
||||||
|
"""Raised when ``provider_id`` is not registered in the provider map."""
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@
|
||||||
|
|
||||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||||
from providers.base import ProviderConfig
|
from providers.base import ProviderConfig
|
||||||
|
from providers.defaults import LLAMACPP_DEFAULT_BASE_URL
|
||||||
LLAMACPP_DEFAULT_BASE_URL = "http://localhost:8080/v1"
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppProvider(AnthropicMessagesTransport):
|
class LlamaCppProvider(AnthropicMessagesTransport):
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@
|
||||||
|
|
||||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||||
from providers.base import ProviderConfig
|
from providers.base import ProviderConfig
|
||||||
|
from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL
|
||||||
LMSTUDIO_DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
|
||||||
|
|
||||||
|
|
||||||
class LMStudioProvider(AnthropicMessagesTransport):
|
class LMStudioProvider(AnthropicMessagesTransport):
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from loguru import logger
|
||||||
|
|
||||||
from config.nim import NimSettings
|
from config.nim import NimSettings
|
||||||
from providers.base import ProviderConfig
|
from providers.base import ProviderConfig
|
||||||
|
from providers.defaults import NVIDIA_NIM_BASE_URL
|
||||||
from providers.openai_compat import OpenAIChatTransport
|
from providers.openai_compat import OpenAIChatTransport
|
||||||
|
|
||||||
from .request import (
|
from .request import (
|
||||||
|
|
@ -16,8 +17,6 @@ from .request import (
|
||||||
clone_body_without_reasoning_budget,
|
clone_body_without_reasoning_budget,
|
||||||
)
|
)
|
||||||
|
|
||||||
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaNimProvider(OpenAIChatTransport):
|
class NvidiaNimProvider(OpenAIChatTransport):
|
||||||
"""NVIDIA NIM provider using official OpenAI client."""
|
"""NVIDIA NIM provider using official OpenAI client."""
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,10 @@ from typing import Any
|
||||||
from core.anthropic import SSEBuilder, append_request_id
|
from core.anthropic import SSEBuilder, append_request_id
|
||||||
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
|
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
|
||||||
from providers.base import ProviderConfig
|
from providers.base import ProviderConfig
|
||||||
|
from providers.defaults import OPENROUTER_BASE_URL
|
||||||
|
|
||||||
from .request import build_request_body
|
from .request import build_request_body
|
||||||
|
|
||||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
|
||||||
_ANTHROPIC_VERSION = "2023-06-01"
|
_ANTHROPIC_VERSION = "2023-06-01"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
"""Shared base class for OpenAI-compatible providers (NIM, OpenRouter, LM Studio)."""
|
"""OpenAI-style chat base for :class:`OpenAIChatTransport` (NIM, DeepSeek, etc.).
|
||||||
|
|
||||||
|
``AnthropicMessagesTransport``-based providers (OpenRouter, LM Studio, …) live
|
||||||
|
in separate modules; do not list them as subclasses of this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
@ -25,7 +30,7 @@ from providers.rate_limit import GlobalRateLimiter
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatTransport(BaseProvider):
|
class OpenAIChatTransport(BaseProvider):
|
||||||
"""Base class for providers using OpenAI-compatible chat completions API."""
|
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …)."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -114,6 +119,7 @@ class OpenAIChatTransport(BaseProvider):
|
||||||
|
|
||||||
fn_delta = tc.get("function", {})
|
fn_delta = tc.get("function", {})
|
||||||
incoming_name = fn_delta.get("name")
|
incoming_name = fn_delta.get("name")
|
||||||
|
arguments = fn_delta.get("arguments", "")
|
||||||
if incoming_name is not None:
|
if incoming_name is not None:
|
||||||
sse.blocks.register_tool_name(tc_index, incoming_name)
|
sse.blocks.register_tool_name(tc_index, incoming_name)
|
||||||
|
|
||||||
|
|
@ -124,7 +130,7 @@ class OpenAIChatTransport(BaseProvider):
|
||||||
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
|
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
|
||||||
yield sse.start_tool_block(tc_index, tool_id, name)
|
yield sse.start_tool_block(tc_index, tool_id, name)
|
||||||
|
|
||||||
args = fn_delta.get("arguments", "")
|
args = arguments
|
||||||
if args:
|
if args:
|
||||||
state = sse.blocks.tool_states.get(tc_index)
|
state = sse.blocks.tool_states.get(tc_index)
|
||||||
if state is None or not state.started:
|
if state is None or not state.started:
|
||||||
|
|
@ -285,6 +291,8 @@ class OpenAIChatTransport(BaseProvider):
|
||||||
for event in self._process_tool_call(tc_info, sse):
|
for event in self._process_tool_call(tc_info, sse):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
except asyncio.CancelledError, GeneratorExit:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||||
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)
|
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)
|
||||||
|
|
|
||||||
|
|
@ -6,17 +6,17 @@ from collections.abc import Callable, MutableMapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||||
from config.settings import Settings
|
from config.settings import Settings
|
||||||
from providers.base import BaseProvider, ProviderConfig
|
from providers.base import BaseProvider, ProviderConfig
|
||||||
from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider
|
from providers.defaults import (
|
||||||
from providers.exceptions import AuthenticationError
|
DEEPSEEK_DEFAULT_BASE,
|
||||||
from providers.llamacpp import LlamaCppProvider
|
LLAMACPP_DEFAULT_BASE,
|
||||||
from providers.lmstudio import LMStudioProvider
|
LMSTUDIO_DEFAULT_BASE,
|
||||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
NVIDIA_NIM_DEFAULT_BASE,
|
||||||
from providers.open_router import (
|
OPENROUTER_DEFAULT_BASE,
|
||||||
OPENROUTER_BASE_URL,
|
|
||||||
OpenRouterProvider,
|
|
||||||
)
|
)
|
||||||
|
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||||
|
|
||||||
TransportType = Literal["openai_chat", "anthropic_messages"]
|
TransportType = Literal["openai_chat", "anthropic_messages"]
|
||||||
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
||||||
|
|
@ -24,11 +24,17 @@ ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class ProviderDescriptor:
|
class ProviderDescriptor:
|
||||||
|
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
|
||||||
|
|
||||||
provider_id: str
|
provider_id: str
|
||||||
transport_type: TransportType
|
transport_type: TransportType
|
||||||
capabilities: tuple[str, ...]
|
capabilities: tuple[str, ...]
|
||||||
credential_env: str | None = None
|
credential_env: str | None = None
|
||||||
credential_url: str | None = None
|
credential_url: str | None = None
|
||||||
|
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
|
||||||
|
credential_attr: str | None = None
|
||||||
|
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
|
||||||
|
static_credential: str | None = None
|
||||||
default_base_url: str | None = None
|
default_base_url: str | None = None
|
||||||
base_url_attr: str | None = None
|
base_url_attr: str | None = None
|
||||||
proxy_attr: str | None = None
|
proxy_attr: str | None = None
|
||||||
|
|
@ -40,7 +46,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
||||||
transport_type="openai_chat",
|
transport_type="openai_chat",
|
||||||
credential_env="NVIDIA_NIM_API_KEY",
|
credential_env="NVIDIA_NIM_API_KEY",
|
||||||
credential_url="https://build.nvidia.com/settings/api-keys",
|
credential_url="https://build.nvidia.com/settings/api-keys",
|
||||||
default_base_url=NVIDIA_NIM_BASE_URL,
|
credential_attr="nvidia_nim_api_key",
|
||||||
|
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
|
||||||
proxy_attr="nvidia_nim_proxy",
|
proxy_attr="nvidia_nim_proxy",
|
||||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||||
),
|
),
|
||||||
|
|
@ -49,7 +56,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
||||||
transport_type="anthropic_messages",
|
transport_type="anthropic_messages",
|
||||||
credential_env="OPENROUTER_API_KEY",
|
credential_env="OPENROUTER_API_KEY",
|
||||||
credential_url="https://openrouter.ai/keys",
|
credential_url="https://openrouter.ai/keys",
|
||||||
default_base_url=OPENROUTER_BASE_URL,
|
credential_attr="open_router_api_key",
|
||||||
|
default_base_url=OPENROUTER_DEFAULT_BASE,
|
||||||
proxy_attr="open_router_proxy",
|
proxy_attr="open_router_proxy",
|
||||||
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
|
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
|
||||||
),
|
),
|
||||||
|
|
@ -58,13 +66,15 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
||||||
transport_type="openai_chat",
|
transport_type="openai_chat",
|
||||||
credential_env="DEEPSEEK_API_KEY",
|
credential_env="DEEPSEEK_API_KEY",
|
||||||
credential_url="https://platform.deepseek.com/api_keys",
|
credential_url="https://platform.deepseek.com/api_keys",
|
||||||
default_base_url=DEEPSEEK_BASE_URL,
|
credential_attr="deepseek_api_key",
|
||||||
|
default_base_url=DEEPSEEK_DEFAULT_BASE,
|
||||||
capabilities=("chat", "streaming", "thinking"),
|
capabilities=("chat", "streaming", "thinking"),
|
||||||
),
|
),
|
||||||
"lmstudio": ProviderDescriptor(
|
"lmstudio": ProviderDescriptor(
|
||||||
provider_id="lmstudio",
|
provider_id="lmstudio",
|
||||||
transport_type="anthropic_messages",
|
transport_type="anthropic_messages",
|
||||||
default_base_url="http://localhost:1234/v1",
|
static_credential="lm-studio",
|
||||||
|
default_base_url=LMSTUDIO_DEFAULT_BASE,
|
||||||
base_url_attr="lm_studio_base_url",
|
base_url_attr="lm_studio_base_url",
|
||||||
proxy_attr="lmstudio_proxy",
|
proxy_attr="lmstudio_proxy",
|
||||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||||
|
|
@ -72,7 +82,8 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
||||||
"llamacpp": ProviderDescriptor(
|
"llamacpp": ProviderDescriptor(
|
||||||
provider_id="llamacpp",
|
provider_id="llamacpp",
|
||||||
transport_type="anthropic_messages",
|
transport_type="anthropic_messages",
|
||||||
default_base_url="http://localhost:8080/v1",
|
static_credential="llamacpp",
|
||||||
|
default_base_url=LLAMACPP_DEFAULT_BASE,
|
||||||
base_url_attr="llamacpp_base_url",
|
base_url_attr="llamacpp_base_url",
|
||||||
proxy_attr="llamacpp_proxy",
|
proxy_attr="llamacpp_proxy",
|
||||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||||
|
|
@ -81,22 +92,32 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
||||||
|
|
||||||
|
|
||||||
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||||
|
from providers.nvidia_nim import NvidiaNimProvider
|
||||||
|
|
||||||
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||||
|
|
||||||
|
|
||||||
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||||
|
from providers.open_router import OpenRouterProvider
|
||||||
|
|
||||||
return OpenRouterProvider(config)
|
return OpenRouterProvider(config)
|
||||||
|
|
||||||
|
|
||||||
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||||
|
from providers.deepseek import DeepSeekProvider
|
||||||
|
|
||||||
return DeepSeekProvider(config)
|
return DeepSeekProvider(config)
|
||||||
|
|
||||||
|
|
||||||
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||||
|
from providers.lmstudio import LMStudioProvider
|
||||||
|
|
||||||
return LMStudioProvider(config)
|
return LMStudioProvider(config)
|
||||||
|
|
||||||
|
|
||||||
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||||
|
from providers.llamacpp import LlamaCppProvider
|
||||||
|
|
||||||
return LlamaCppProvider(config)
|
return LlamaCppProvider(config)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,6 +129,15 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
|
||||||
"llamacpp": _create_llamacpp,
|
"llamacpp": _create_llamacpp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
|
||||||
|
PROVIDER_FACTORIES
|
||||||
|
) != set(SUPPORTED_PROVIDER_IDS):
|
||||||
|
raise AssertionError(
|
||||||
|
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
|
||||||
|
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
|
||||||
|
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
|
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
|
||||||
if attr_name is None:
|
if attr_name is None:
|
||||||
|
|
@ -116,17 +146,11 @@ def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -
|
||||||
return value if isinstance(value, str) else default
|
return value if isinstance(value, str) else default
|
||||||
|
|
||||||
|
|
||||||
def _credential_for(provider_id: str, settings: Settings) -> str:
|
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
|
||||||
if provider_id == "nvidia_nim":
|
if descriptor.static_credential is not None:
|
||||||
return settings.nvidia_nim_api_key
|
return descriptor.static_credential
|
||||||
if provider_id == "open_router":
|
if descriptor.credential_attr:
|
||||||
return settings.open_router_api_key
|
return _string_attr(settings, descriptor.credential_attr)
|
||||||
if provider_id == "deepseek":
|
|
||||||
return settings.deepseek_api_key
|
|
||||||
if provider_id == "lmstudio":
|
|
||||||
return "lm-studio"
|
|
||||||
if provider_id == "llamacpp":
|
|
||||||
return "llamacpp"
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -144,7 +168,7 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None
|
||||||
def build_provider_config(
|
def build_provider_config(
|
||||||
descriptor: ProviderDescriptor, settings: Settings
|
descriptor: ProviderDescriptor, settings: Settings
|
||||||
) -> ProviderConfig:
|
) -> ProviderConfig:
|
||||||
credential = _credential_for(descriptor.provider_id, settings)
|
credential = _credential_for(descriptor, settings)
|
||||||
_require_credential(descriptor, credential)
|
_require_credential(descriptor, credential)
|
||||||
base_url = _string_attr(
|
base_url = _string_attr(
|
||||||
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
|
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
|
||||||
|
|
@ -168,7 +192,7 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
|
||||||
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
|
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
|
||||||
if descriptor is None:
|
if descriptor is None:
|
||||||
supported = "', '".join(PROVIDER_DESCRIPTORS)
|
supported = "', '".join(PROVIDER_DESCRIPTORS)
|
||||||
raise ValueError(
|
raise UnknownProviderTypeError(
|
||||||
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
|
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -185,12 +209,33 @@ class ProviderRegistry:
|
||||||
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
|
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
|
||||||
self._providers = providers if providers is not None else {}
|
self._providers = providers if providers is not None else {}
|
||||||
|
|
||||||
|
def is_cached(self, provider_id: str) -> bool:
|
||||||
|
"""Return whether a provider for this id is already in the cache."""
|
||||||
|
return provider_id in self._providers
|
||||||
|
|
||||||
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
|
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
|
||||||
if provider_id not in self._providers:
|
if provider_id not in self._providers:
|
||||||
self._providers[provider_id] = create_provider(provider_id, settings)
|
self._providers[provider_id] = create_provider(provider_id, settings)
|
||||||
return self._providers[provider_id]
|
return self._providers[provider_id]
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
for provider in self._providers.values():
|
"""Call ``cleanup`` on every cached provider, then clear the cache.
|
||||||
|
|
||||||
|
Attempts all providers even if one fails. A single failure is re-raised
|
||||||
|
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
|
||||||
|
"""
|
||||||
|
items = list(self._providers.items())
|
||||||
|
errors: list[Exception] = []
|
||||||
|
try:
|
||||||
|
for _pid, provider in items:
|
||||||
|
try:
|
||||||
await provider.cleanup()
|
await provider.cleanup()
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
finally:
|
||||||
self._providers.clear()
|
self._providers.clear()
|
||||||
|
if len(errors) == 1:
|
||||||
|
raise errors[0]
|
||||||
|
if len(errors) > 1:
|
||||||
|
msg = "One or more provider cleanups failed"
|
||||||
|
raise ExceptionGroup(msg, errors)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from typing import Any
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||||
from messaging.handler import ClaudeMessageHandler
|
from messaging.handler import ClaudeMessageHandler
|
||||||
from messaging.models import IncomingMessage
|
from messaging.models import IncomingMessage
|
||||||
from messaging.platforms.base import MessagingPlatform
|
from messaging.platforms.base import MessagingPlatform
|
||||||
|
|
@ -153,7 +154,7 @@ class ConversationDriver:
|
||||||
class ProviderMatrixDriver:
|
class ProviderMatrixDriver:
|
||||||
"""Resolve provider models and enforce matrix semantics for product smoke."""
|
"""Resolve provider models and enforce matrix semantics for product smoke."""
|
||||||
|
|
||||||
ALL_PROVIDERS = ("nvidia_nim", "open_router", "deepseek", "lmstudio", "llamacpp")
|
ALL_PROVIDERS: tuple[str, ...] = SUPPORTED_PROVIDER_IDS
|
||||||
|
|
||||||
def __init__(self, config: SmokeConfig) -> None:
|
def __init__(self, config: SmokeConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
||||||
167
smoke/lib/sse.py
167
smoke/lib/sse.py
|
|
@ -1,148 +1,29 @@
|
||||||
"""SSE parsing and Anthropic stream assertions for smoke tests."""
|
"""SSE parsing and Anthropic stream assertions for smoke tests.
|
||||||
|
|
||||||
|
Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this
|
||||||
|
module re-exports it for smoke and historical import paths.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
from core.anthropic.stream_contracts import (
|
||||||
from collections.abc import Iterable
|
SSEEvent,
|
||||||
from dataclasses import dataclass
|
assert_anthropic_stream_contract,
|
||||||
from typing import Any
|
event_names,
|
||||||
|
has_tool_use,
|
||||||
|
parse_sse_lines,
|
||||||
@dataclass(frozen=True, slots=True)
|
parse_sse_text,
|
||||||
class SSEEvent:
|
text_content,
|
||||||
event: str
|
thinking_content,
|
||||||
data: dict[str, Any]
|
|
||||||
raw: str
|
|
||||||
|
|
||||||
|
|
||||||
def parse_sse_lines(lines: Iterable[str]) -> list[SSEEvent]:
|
|
||||||
events: list[SSEEvent] = []
|
|
||||||
current_event = ""
|
|
||||||
data_parts: list[str] = []
|
|
||||||
raw_parts: list[str] = []
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
stripped = line.rstrip("\r\n")
|
|
||||||
if stripped == "":
|
|
||||||
_append_event(events, current_event, data_parts, raw_parts)
|
|
||||||
current_event = ""
|
|
||||||
data_parts = []
|
|
||||||
raw_parts = []
|
|
||||||
continue
|
|
||||||
raw_parts.append(stripped)
|
|
||||||
if stripped.startswith("event:"):
|
|
||||||
current_event = stripped.split(":", 1)[1].strip()
|
|
||||||
elif stripped.startswith("data:"):
|
|
||||||
data_parts.append(stripped.split(":", 1)[1].strip())
|
|
||||||
|
|
||||||
_append_event(events, current_event, data_parts, raw_parts)
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
def parse_sse_text(text: str) -> list[SSEEvent]:
|
|
||||||
return parse_sse_lines(text.splitlines())
|
|
||||||
|
|
||||||
|
|
||||||
def _append_event(
|
|
||||||
events: list[SSEEvent],
|
|
||||||
current_event: str,
|
|
||||||
data_parts: list[str],
|
|
||||||
raw_parts: list[str],
|
|
||||||
) -> None:
|
|
||||||
if not current_event and not data_parts:
|
|
||||||
return
|
|
||||||
data_text = "\n".join(data_parts)
|
|
||||||
data: dict[str, Any]
|
|
||||||
try:
|
|
||||||
parsed = json.loads(data_text) if data_text else {}
|
|
||||||
data = parsed if isinstance(parsed, dict) else {"value": parsed}
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
data = {"raw": data_text}
|
|
||||||
events.append(SSEEvent(current_event, data, "\n".join(raw_parts)))
|
|
||||||
|
|
||||||
|
|
||||||
def assert_anthropic_stream_contract(
|
|
||||||
events: list[SSEEvent], *, allow_error: bool = False
|
|
||||||
) -> None:
|
|
||||||
assert events, "stream produced no SSE events"
|
|
||||||
event_names = [event.event for event in events]
|
|
||||||
assert "message_start" in event_names, event_names
|
|
||||||
assert event_names[-1] == "message_stop", event_names
|
|
||||||
|
|
||||||
open_blocks: dict[int, str] = {}
|
|
||||||
seen_blocks: set[int] = set()
|
|
||||||
for event in events:
|
|
||||||
if event.event == "error" and not allow_error:
|
|
||||||
raise AssertionError(f"unexpected SSE error event: {event.data}")
|
|
||||||
|
|
||||||
if event.event == "content_block_start":
|
|
||||||
index = _event_index(event)
|
|
||||||
block = event.data.get("content_block", {})
|
|
||||||
assert isinstance(block, dict), event.data
|
|
||||||
block_type = str(block.get("type", ""))
|
|
||||||
assert block_type in {"text", "thinking", "tool_use"}, event.data
|
|
||||||
assert index not in open_blocks, f"block {index} started twice"
|
|
||||||
assert index not in seen_blocks, f"block {index} reused after stop"
|
|
||||||
open_blocks[index] = block_type
|
|
||||||
seen_blocks.add(index)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if event.event == "content_block_delta":
|
|
||||||
index = _event_index(event)
|
|
||||||
assert index in open_blocks, f"delta for unopened block {index}"
|
|
||||||
delta = event.data.get("delta", {})
|
|
||||||
assert isinstance(delta, dict), event.data
|
|
||||||
delta_type = str(delta.get("type", ""))
|
|
||||||
expected = {
|
|
||||||
"text": "text_delta",
|
|
||||||
"thinking": "thinking_delta",
|
|
||||||
"tool_use": "input_json_delta",
|
|
||||||
}[open_blocks[index]]
|
|
||||||
assert delta_type == expected, (
|
|
||||||
f"block {index} is {open_blocks[index]}, got {delta_type}"
|
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if event.event == "content_block_stop":
|
__all__ = [
|
||||||
index = _event_index(event)
|
"SSEEvent",
|
||||||
assert index in open_blocks, f"stop for unopened block {index}"
|
"assert_anthropic_stream_contract",
|
||||||
open_blocks.pop(index)
|
"event_names",
|
||||||
|
"has_tool_use",
|
||||||
assert not open_blocks, f"unclosed blocks: {open_blocks}"
|
"parse_sse_lines",
|
||||||
assert seen_blocks, "stream did not emit any content blocks"
|
"parse_sse_text",
|
||||||
|
"text_content",
|
||||||
|
"thinking_content",
|
||||||
def event_names(events: list[SSEEvent]) -> list[str]:
|
]
|
||||||
return [event.event for event in events]
|
|
||||||
|
|
||||||
|
|
||||||
def text_content(events: list[SSEEvent]) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
for event in events:
|
|
||||||
delta = event.data.get("delta", {})
|
|
||||||
if isinstance(delta, dict) and delta.get("type") == "text_delta":
|
|
||||||
parts.append(str(delta.get("text", "")))
|
|
||||||
return "".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def thinking_content(events: list[SSEEvent]) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
for event in events:
|
|
||||||
delta = event.data.get("delta", {})
|
|
||||||
if isinstance(delta, dict) and delta.get("type") == "thinking_delta":
|
|
||||||
parts.append(str(delta.get("thinking", "")))
|
|
||||||
return "".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def has_tool_use(events: list[SSEEvent]) -> bool:
|
|
||||||
for event in events:
|
|
||||||
block = event.data.get("content_block", {})
|
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _event_index(event: SSEEvent) -> int:
|
|
||||||
value = event.data.get("index")
|
|
||||||
assert isinstance(value, int), event.data
|
|
||||||
return value
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from api.app import app
|
from api.app import app
|
||||||
|
|
@ -9,7 +10,7 @@ from providers.nvidia_nim import NvidiaNimProvider
|
||||||
mock_provider = MagicMock(spec=NvidiaNimProvider)
|
mock_provider = MagicMock(spec=NvidiaNimProvider)
|
||||||
|
|
||||||
# Track stream_response calls for test_model_mapping
|
# Track stream_response calls for test_model_mapping
|
||||||
_stream_response_calls = []
|
_stream_response_calls: list = []
|
||||||
|
|
||||||
|
|
||||||
async def _mock_stream_response(*args, **kwargs):
|
async def _mock_stream_response(*args, **kwargs):
|
||||||
|
|
@ -21,26 +22,30 @@ async def _mock_stream_response(*args, **kwargs):
|
||||||
|
|
||||||
mock_provider.stream_response = _mock_stream_response
|
mock_provider.stream_response = _mock_stream_response
|
||||||
|
|
||||||
# Patch get_provider_for_type to always return mock_provider
|
|
||||||
_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider)
|
|
||||||
_patcher.start()
|
|
||||||
|
|
||||||
client = TestClient(app)
|
@pytest.fixture(scope="module")
|
||||||
|
def client():
|
||||||
|
"""HTTP client with provider resolution stubbed; patch only for this file."""
|
||||||
|
with (
|
||||||
|
patch("api.dependencies.resolve_provider", return_value=mock_provider),
|
||||||
|
TestClient(app) as test_client,
|
||||||
|
):
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
|
||||||
def test_root():
|
def test_root(client: TestClient):
|
||||||
response = client.get("/")
|
response = client.get("/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
def test_health():
|
def test_health(client: TestClient):
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "healthy"
|
assert response.json()["status"] == "healthy"
|
||||||
|
|
||||||
|
|
||||||
def test_models_list():
|
def test_models_list(client: TestClient):
|
||||||
response = client.get("/v1/models")
|
response = client.get("/v1/models")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -51,7 +56,7 @@ def test_models_list():
|
||||||
assert data["last_id"] == ids[-1]
|
assert data["last_id"] == ids[-1]
|
||||||
|
|
||||||
|
|
||||||
def test_probe_endpoints_return_204_with_allow_headers():
|
def test_probe_endpoints_return_204_with_allow_headers(client: TestClient):
|
||||||
responses = [
|
responses = [
|
||||||
client.head("/"),
|
client.head("/"),
|
||||||
client.options("/"),
|
client.options("/"),
|
||||||
|
|
@ -68,7 +73,7 @@ def test_probe_endpoints_return_204_with_allow_headers():
|
||||||
assert "Allow" in response.headers
|
assert "Allow" in response.headers
|
||||||
|
|
||||||
|
|
||||||
def test_create_message_stream():
|
def test_create_message_stream(client: TestClient):
|
||||||
"""Create message returns streaming response."""
|
"""Create message returns streaming response."""
|
||||||
payload = {
|
payload = {
|
||||||
"model": "claude-3-sonnet",
|
"model": "claude-3-sonnet",
|
||||||
|
|
@ -83,7 +88,7 @@ def test_create_message_stream():
|
||||||
assert b"message_start" in content or b"event:" in content
|
assert b"message_start" in content or b"event:" in content
|
||||||
|
|
||||||
|
|
||||||
def test_model_mapping():
|
def test_model_mapping(client: TestClient):
|
||||||
# Test Haiku mapping
|
# Test Haiku mapping
|
||||||
_stream_response_calls.clear()
|
_stream_response_calls.clear()
|
||||||
payload_haiku = {
|
payload_haiku = {
|
||||||
|
|
@ -98,7 +103,7 @@ def test_model_mapping():
|
||||||
assert args[0].model != "claude-3-haiku-20240307"
|
assert args[0].model != "claude-3-haiku-20240307"
|
||||||
|
|
||||||
|
|
||||||
def test_error_fallbacks():
|
def test_error_fallbacks(client: TestClient):
|
||||||
from providers.exceptions import (
|
from providers.exceptions import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
OverloadedError,
|
OverloadedError,
|
||||||
|
|
@ -143,7 +148,7 @@ def test_error_fallbacks():
|
||||||
mock_provider.stream_response = _mock_stream_response
|
mock_provider.stream_response = _mock_stream_response
|
||||||
|
|
||||||
|
|
||||||
def test_generic_exception_returns_500():
|
def test_generic_exception_returns_500(client: TestClient):
|
||||||
"""Non-ProviderError exceptions are caught and returned as HTTPException(500)."""
|
"""Non-ProviderError exceptions are caught and returned as HTTPException(500)."""
|
||||||
|
|
||||||
def _raise_runtime(*args, **kwargs):
|
def _raise_runtime(*args, **kwargs):
|
||||||
|
|
@ -163,7 +168,7 @@ def test_generic_exception_returns_500():
|
||||||
mock_provider.stream_response = _mock_stream_response
|
mock_provider.stream_response = _mock_stream_response
|
||||||
|
|
||||||
|
|
||||||
def test_generic_exception_with_status_code():
|
def test_generic_exception_with_status_code(client: TestClient):
|
||||||
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
|
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
|
||||||
|
|
||||||
class ExceptionWithStatus(RuntimeError):
|
class ExceptionWithStatus(RuntimeError):
|
||||||
|
|
@ -188,7 +193,7 @@ def test_generic_exception_with_status_code():
|
||||||
mock_provider.stream_response = _mock_stream_response
|
mock_provider.stream_response = _mock_stream_response
|
||||||
|
|
||||||
|
|
||||||
def test_generic_exception_empty_message_returns_non_empty_detail():
|
def test_generic_exception_empty_message_returns_non_empty_detail(client: TestClient):
|
||||||
"""Exceptions with empty __str__ still return a readable HTTP detail."""
|
"""Exceptions with empty __str__ still return a readable HTTP detail."""
|
||||||
|
|
||||||
class SilentError(RuntimeError):
|
class SilentError(RuntimeError):
|
||||||
|
|
@ -213,7 +218,7 @@ def test_generic_exception_empty_message_returns_non_empty_detail():
|
||||||
mock_provider.stream_response = _mock_stream_response
|
mock_provider.stream_response = _mock_stream_response
|
||||||
|
|
||||||
|
|
||||||
def test_count_tokens_endpoint():
|
def test_count_tokens_endpoint(client: TestClient):
|
||||||
"""count_tokens endpoint returns token count."""
|
"""count_tokens endpoint returns token count."""
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages/count_tokens",
|
"/v1/messages/count_tokens",
|
||||||
|
|
@ -223,7 +228,7 @@ def test_count_tokens_endpoint():
|
||||||
assert "input_tokens" in response.json()
|
assert "input_tokens" in response.json()
|
||||||
|
|
||||||
|
|
||||||
def test_stop_endpoint_no_handler_no_cli_503():
|
def test_stop_endpoint_no_handler_no_cli_503(client: TestClient):
|
||||||
"""POST /stop without handler or cli_manager returns 503."""
|
"""POST /stop without handler or cli_manager returns 503."""
|
||||||
# Ensure no handler or cli_manager on app state
|
# Ensure no handler or cli_manager on app state
|
||||||
if hasattr(app.state, "message_handler"):
|
if hasattr(app.state, "message_handler"):
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,23 @@ import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from config.settings import Settings
|
from config.settings import Settings
|
||||||
|
from providers.registry import ProviderRegistry
|
||||||
|
|
||||||
|
_RUNTIME_EXTRAS = {
|
||||||
|
"voice_note_enabled": True,
|
||||||
|
"whisper_model": "base",
|
||||||
|
"whisper_device": "cpu",
|
||||||
|
"hf_token": "",
|
||||||
|
"nvidia_nim_api_key": "",
|
||||||
|
"claude_cli_bin": "claude",
|
||||||
|
"uses_process_anthropic_auth_token": lambda: False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _app_settings(**kwargs):
|
||||||
|
"""Minimal settings namespace for AppRuntime (matches typed :class:`Settings` fields used)."""
|
||||||
|
data = {**_RUNTIME_EXTRAS, **kwargs}
|
||||||
|
return SimpleNamespace(**data)
|
||||||
|
|
||||||
|
|
||||||
def test_warn_if_process_auth_token_logs_warning():
|
def test_warn_if_process_auth_token_logs_warning():
|
||||||
|
|
@ -45,7 +62,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
|
||||||
raise AuthenticationError("bad key")
|
raise AuthenticationError("bad key")
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token=None,
|
telegram_bot_token=None,
|
||||||
allowed_telegram_user_id=None,
|
allowed_telegram_user_id=None,
|
||||||
|
|
@ -59,7 +76,7 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()),
|
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
|
||||||
):
|
):
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/raise_provider")
|
resp = client.get("/raise_provider")
|
||||||
|
|
@ -79,7 +96,7 @@ def test_create_app_general_exception_handler_returns_500():
|
||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token=None,
|
telegram_bot_token=None,
|
||||||
allowed_telegram_user_id=None,
|
allowed_telegram_user_id=None,
|
||||||
|
|
@ -93,7 +110,7 @@ def test_create_app_general_exception_handler_returns_500():
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=AsyncMock()),
|
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
|
||||||
):
|
):
|
||||||
with TestClient(app, raise_server_exceptions=False) as client:
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
resp = client.get("/raise_general")
|
resp = client.get("/raise_general")
|
||||||
|
|
@ -111,7 +128,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token="token" if messaging_enabled else None,
|
telegram_bot_token="token" if messaging_enabled else None,
|
||||||
allowed_telegram_user_id="123",
|
allowed_telegram_user_id="123",
|
||||||
|
|
@ -147,10 +164,10 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
|
|
||||||
cleanup_provider = AsyncMock()
|
registry_cleanup = AsyncMock()
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform if messaging_enabled else None,
|
return_value=fake_platform if messaging_enabled else None,
|
||||||
|
|
@ -182,7 +199,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
||||||
cli_manager.stop_all.assert_not_awaited()
|
cli_manager.stop_all.assert_not_awaited()
|
||||||
assert getattr(app.state, "messaging_platform", "missing") is None
|
assert getattr(app.state, "messaging_platform", "missing") is None
|
||||||
|
|
||||||
cleanup_provider.assert_awaited_once()
|
registry_cleanup.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||||
|
|
@ -190,7 +207,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token="token",
|
telegram_bot_token="token",
|
||||||
allowed_telegram_user_id="123",
|
allowed_telegram_user_id="123",
|
||||||
|
|
@ -218,10 +235,10 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||||
cli_manager.stop_all = AsyncMock()
|
cli_manager.stop_all = AsyncMock()
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
cleanup_provider = AsyncMock()
|
registry_cleanup = AsyncMock()
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform,
|
return_value=fake_platform,
|
||||||
|
|
@ -234,7 +251,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||||
|
|
||||||
fake_platform.stop.assert_awaited_once()
|
fake_platform.stop.assert_awaited_once()
|
||||||
cli_manager.stop_all.assert_awaited_once()
|
cli_manager.stop_all.assert_awaited_once()
|
||||||
cleanup_provider.assert_awaited_once()
|
registry_cleanup.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||||
|
|
@ -243,7 +260,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token="token",
|
telegram_bot_token="token",
|
||||||
allowed_telegram_user_id="123",
|
allowed_telegram_user_id="123",
|
||||||
|
|
@ -257,10 +274,10 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||||
)
|
)
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
cleanup_provider = AsyncMock()
|
registry_cleanup = AsyncMock()
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
side_effect=ImportError("discord not installed"),
|
side_effect=ImportError("discord not installed"),
|
||||||
|
|
@ -270,7 +287,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert getattr(app.state, "messaging_platform", None) is None
|
assert getattr(app.state, "messaging_platform", None) is None
|
||||||
cleanup_provider.assert_awaited_once()
|
registry_cleanup.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
||||||
|
|
@ -279,7 +296,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token="token",
|
telegram_bot_token="token",
|
||||||
allowed_telegram_user_id="123",
|
allowed_telegram_user_id="123",
|
||||||
|
|
@ -307,10 +324,10 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
||||||
cli_manager.stop_all = AsyncMock()
|
cli_manager.stop_all = AsyncMock()
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
cleanup_provider = AsyncMock()
|
registry_cleanup = AsyncMock()
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform,
|
return_value=fake_platform,
|
||||||
|
|
@ -321,7 +338,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cleanup_provider.assert_awaited_once()
|
registry_cleanup.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
||||||
|
|
@ -330,7 +347,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
settings = SimpleNamespace(
|
settings = _app_settings(
|
||||||
messaging_platform="telegram",
|
messaging_platform="telegram",
|
||||||
telegram_bot_token="token",
|
telegram_bot_token="token",
|
||||||
allowed_telegram_user_id="123",
|
allowed_telegram_user_id="123",
|
||||||
|
|
@ -359,10 +376,10 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
||||||
cli_manager.stop_all = AsyncMock()
|
cli_manager.stop_all = AsyncMock()
|
||||||
|
|
||||||
api_app_mod = importlib.import_module("api.app")
|
api_app_mod = importlib.import_module("api.app")
|
||||||
cleanup_provider = AsyncMock()
|
registry_cleanup = AsyncMock()
|
||||||
with (
|
with (
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(ProviderRegistry, "cleanup", new=registry_cleanup),
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform,
|
return_value=fake_platform,
|
||||||
|
|
@ -374,4 +391,4 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
session_store.flush_pending_save.assert_called_once()
|
session_store.flush_pending_save.assert_called_once()
|
||||||
cleanup_provider.assert_awaited_once()
|
registry_cleanup.assert_awaited_once()
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,26 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import cast
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.datastructures import State
|
||||||
|
|
||||||
from api.dependencies import (
|
from api.dependencies import (
|
||||||
cleanup_provider,
|
cleanup_provider,
|
||||||
get_provider,
|
get_provider,
|
||||||
get_provider_for_type,
|
get_provider_for_type,
|
||||||
get_settings,
|
get_settings,
|
||||||
|
resolve_provider,
|
||||||
)
|
)
|
||||||
from config.nim import NimSettings
|
from config.nim import NimSettings
|
||||||
from providers.deepseek import DeepSeekProvider
|
from providers.deepseek import DeepSeekProvider
|
||||||
|
from providers.exceptions import UnknownProviderTypeError
|
||||||
from providers.lmstudio import LMStudioProvider
|
from providers.lmstudio import LMStudioProvider
|
||||||
from providers.nvidia_nim import NvidiaNimProvider
|
from providers.nvidia_nim import NvidiaNimProvider
|
||||||
from providers.open_router import OpenRouterProvider
|
from providers.open_router import OpenRouterProvider
|
||||||
|
from providers.registry import ProviderRegistry
|
||||||
|
|
||||||
|
|
||||||
def _make_mock_settings(**overrides):
|
def _make_mock_settings(**overrides):
|
||||||
|
|
@ -304,11 +311,11 @@ async def test_get_provider_deepseek_missing_api_key():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_provider_unknown_type():
|
async def test_get_provider_unknown_type():
|
||||||
"""Test that unknown provider_type raises ValueError."""
|
"""Unknown ``provider_type`` raises :exc:`~providers.exceptions.UnknownProviderTypeError`."""
|
||||||
with patch("api.dependencies.get_settings") as mock_settings:
|
with patch("api.dependencies.get_settings") as mock_settings:
|
||||||
mock_settings.return_value = _make_mock_settings(provider_type="unknown")
|
mock_settings.return_value = _make_mock_settings(provider_type="unknown")
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unknown provider_type"):
|
with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"):
|
||||||
get_provider()
|
get_provider()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -390,3 +397,55 @@ async def test_cleanup_provider_cleans_all():
|
||||||
|
|
||||||
nim._client.aclose.assert_called_once()
|
nim._client.aclose.assert_called_once()
|
||||||
lmstudio._client.aclose.assert_called_once()
|
lmstudio._client.aclose.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_provider_per_app_uses_separate_registries() -> None:
|
||||||
|
"""With app set, each app gets its own provider cache (not process _providers)."""
|
||||||
|
with patch("api.dependencies.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value = _make_mock_settings()
|
||||||
|
settings = _make_mock_settings()
|
||||||
|
app1 = SimpleNamespace(state=State())
|
||||||
|
app2 = SimpleNamespace(state=State())
|
||||||
|
app1.state.provider_registry = ProviderRegistry()
|
||||||
|
app2.state.provider_registry = ProviderRegistry()
|
||||||
|
p1 = resolve_provider(
|
||||||
|
"nvidia_nim", app=cast(Starlette, app1), settings=settings
|
||||||
|
)
|
||||||
|
p2 = resolve_provider(
|
||||||
|
"nvidia_nim", app=cast(Starlette, app2), settings=settings
|
||||||
|
)
|
||||||
|
assert isinstance(p1, NvidiaNimProvider)
|
||||||
|
assert isinstance(p2, NvidiaNimProvider)
|
||||||
|
assert p1 is not p2
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_provider_lazily_installs_registry() -> None:
|
||||||
|
"""If app has no provider_registry, one is created on app.state."""
|
||||||
|
with patch("api.dependencies.get_settings") as mock_settings:
|
||||||
|
mock_settings.return_value = _make_mock_settings()
|
||||||
|
settings = _make_mock_settings()
|
||||||
|
app = SimpleNamespace(state=State())
|
||||||
|
assert getattr(app.state, "provider_registry", None) is None
|
||||||
|
resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
|
||||||
|
reg = app.state.provider_registry
|
||||||
|
assert reg is not None
|
||||||
|
p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
|
||||||
|
assert p2 is reg.get("nvidia_nim", settings) # same registry instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None:
|
||||||
|
"""Only :exc:`~providers.exceptions.UnknownProviderTypeError` logs unknown provider."""
|
||||||
|
import api.dependencies as deps
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(deps, "get_settings", return_value=_make_mock_settings()),
|
||||||
|
patch.object(
|
||||||
|
ProviderRegistry,
|
||||||
|
"get",
|
||||||
|
side_effect=ValueError("unrelated config"),
|
||||||
|
),
|
||||||
|
patch.object(deps.logger, "error") as log_err,
|
||||||
|
pytest.raises(ValueError, match="unrelated config"),
|
||||||
|
):
|
||||||
|
deps.resolve_provider("nvidia_nim", app=None, settings=_make_mock_settings())
|
||||||
|
log_err.assert_not_called()
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,21 @@ def test_messages_request_accepts_adaptive_thinking_type():
|
||||||
assert dumped["thinking"]["type"] == "adaptive"
|
assert dumped["thinking"]["type"] == "adaptive"
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_request_accepts_anthropic_server_tool_without_input_schema():
|
||||||
|
request = MessagesRequest.model_validate(
|
||||||
|
{
|
||||||
|
"model": "claude-opus-4-7",
|
||||||
|
"max_tokens": 100,
|
||||||
|
"messages": [{"role": "user", "content": "search"}],
|
||||||
|
"tools": [{"type": "web_search_20250305", "name": "web_search"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dumped = request.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
assert dumped["tools"] == [{"name": "web_search", "type": "web_search_20250305"}]
|
||||||
|
|
||||||
|
|
||||||
def test_messages_request_accepts_redacted_thinking_blocks():
|
def test_messages_request_accepts_redacted_thinking_blocks():
|
||||||
request = MessagesRequest.model_validate(
|
request = MessagesRequest.model_validate(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
96
tests/api/test_web_server_tools.py
Normal file
96
tests/api/test_web_server_tools.py
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from api.models.anthropic import Message, MessagesRequest, Tool
|
||||||
|
from api.web_server_tools import (
|
||||||
|
is_web_server_tool_request,
|
||||||
|
stream_web_server_tool_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _event_data(event: str) -> dict:
|
||||||
|
data_line = next(line for line in event.splitlines() if line.startswith("data: "))
|
||||||
|
return json.loads(data_line.removeprefix("data: "))
|
||||||
|
|
||||||
|
|
||||||
|
def test_detects_web_search_server_tool_request():
|
||||||
|
request = MessagesRequest(
|
||||||
|
model="claude-haiku-4-5-20251001",
|
||||||
|
max_tokens=100,
|
||||||
|
messages=[Message(role="user", content="search")],
|
||||||
|
tools=[Tool(name="web_search", type="web_search_20250305")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_web_server_tool_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streams_web_search_server_tool_result(monkeypatch):
|
||||||
|
async def fake_search(query: str) -> list[dict[str, str]]:
|
||||||
|
assert query == "DeepSeek V4 model release 2026"
|
||||||
|
return [{"title": "DeepSeek V4 Released", "url": "https://example.com/v4"}]
|
||||||
|
|
||||||
|
monkeypatch.setattr("api.web_server_tools._run_web_search", fake_search)
|
||||||
|
request = MessagesRequest(
|
||||||
|
model="claude-haiku-4-5-20251001",
|
||||||
|
max_tokens=100,
|
||||||
|
messages=[
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=(
|
||||||
|
"Perform a web search for the query: DeepSeek V4 model release 2026"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tools=[Tool(name="web_search", type="web_search_20250305")],
|
||||||
|
tool_choice={"type": "tool", "name": "web_search"},
|
||||||
|
)
|
||||||
|
|
||||||
|
events = [
|
||||||
|
event
|
||||||
|
async for event in stream_web_server_tool_response(request, input_tokens=42)
|
||||||
|
]
|
||||||
|
payloads = [_event_data(event) for event in events]
|
||||||
|
|
||||||
|
assert payloads[1]["content_block"]["type"] == "server_tool_use"
|
||||||
|
assert payloads[1]["content_block"]["name"] == "web_search"
|
||||||
|
assert payloads[3]["content_block"]["type"] == "web_search_tool_result"
|
||||||
|
assert payloads[3]["content_block"]["content"][0]["url"] == "https://example.com/v4"
|
||||||
|
assert payloads[-2]["usage"]["server_tool_use"] == {"web_search_requests": 1}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streams_web_fetch_server_tool_result(monkeypatch):
|
||||||
|
async def fake_fetch(url: str) -> dict[str, str]:
|
||||||
|
assert url == "https://example.com/article"
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"title": "Example Article",
|
||||||
|
"media_type": "text/plain",
|
||||||
|
"data": "Article body",
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("api.web_server_tools._run_web_fetch", fake_fetch)
|
||||||
|
request = MessagesRequest(
|
||||||
|
model="claude-haiku-4-5-20251001",
|
||||||
|
max_tokens=100,
|
||||||
|
messages=[
|
||||||
|
Message(role="user", content="Fetch https://example.com/article please")
|
||||||
|
],
|
||||||
|
tools=[Tool(name="web_fetch", type="web_fetch_20250910")],
|
||||||
|
tool_choice={"type": "tool", "name": "web_fetch"},
|
||||||
|
)
|
||||||
|
|
||||||
|
events = [
|
||||||
|
event
|
||||||
|
async for event in stream_web_server_tool_response(request, input_tokens=42)
|
||||||
|
]
|
||||||
|
payloads = [_event_data(event) for event in events]
|
||||||
|
|
||||||
|
assert payloads[1]["content_block"]["type"] == "server_tool_use"
|
||||||
|
assert payloads[3]["content_block"]["type"] == "web_fetch_tool_result"
|
||||||
|
assert payloads[3]["content_block"]["content"]["content"]["title"] == (
|
||||||
|
"Example Article"
|
||||||
|
)
|
||||||
|
assert payloads[-2]["usage"]["server_tool_use"] == {"web_fetch_requests": 1}
|
||||||
|
|
@ -118,6 +118,15 @@ def mock_platform():
|
||||||
platform.queue_edit_message = AsyncMock()
|
platform.queue_edit_message = AsyncMock()
|
||||||
platform.queue_delete_message = AsyncMock()
|
platform.queue_delete_message = AsyncMock()
|
||||||
|
|
||||||
|
async def _queue_delete_messages(
|
||||||
|
chat_id: str, message_ids: list[str], *, fire_and_forget: bool = True
|
||||||
|
) -> None:
|
||||||
|
qdm = platform.queue_delete_message
|
||||||
|
for mid in message_ids:
|
||||||
|
await qdm(chat_id, mid, fire_and_forget=fire_and_forget)
|
||||||
|
|
||||||
|
platform.queue_delete_messages = AsyncMock(side_effect=_queue_delete_messages)
|
||||||
|
|
||||||
def _fire_and_forget(task):
|
def _fire_and_forget(task):
|
||||||
if asyncio.iscoroutine(task):
|
if asyncio.iscoroutine(task):
|
||||||
# Create a task to avoid "coroutine was never awaited" warning
|
# Create a task to avoid "coroutine was never awaited" warning
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,20 @@
|
||||||
|
"""Package import contract tests (static AST; dynamic ``importlib`` loads are not scanned)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# `api` may only import this narrow ``providers`` surface (AGENTS/PLAN).
|
||||||
|
_API_ALLOWED_PROVIDER_MODULES = frozenset(
|
||||||
|
{
|
||||||
|
"providers",
|
||||||
|
"providers.base",
|
||||||
|
"providers.exceptions",
|
||||||
|
"providers.registry",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_api_and_messaging_do_not_import_provider_common() -> None:
|
def test_api_and_messaging_do_not_import_provider_common() -> None:
|
||||||
repo_root = Path(__file__).resolve().parents[2]
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
|
@ -25,6 +37,66 @@ def test_provider_adapters_do_not_import_runtime_layers() -> None:
|
||||||
assert offenders == []
|
assert offenders == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_does_not_import_product_packages() -> None:
|
||||||
|
"""Neutral ``core`` must stay independent of API, workers, and providers."""
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
offenders = _imports_matching(
|
||||||
|
[repo_root / "core"],
|
||||||
|
forbidden_prefixes=(
|
||||||
|
"api.",
|
||||||
|
"messaging.",
|
||||||
|
"cli.",
|
||||||
|
"smoke.",
|
||||||
|
"providers.",
|
||||||
|
"config.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert offenders == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_does_not_import_non_config_packages() -> None:
|
||||||
|
"""Settings and env handling must not depend on transport or protocol layers."""
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
offenders = _imports_matching(
|
||||||
|
[repo_root / "config"],
|
||||||
|
forbidden_prefixes=(
|
||||||
|
"api.",
|
||||||
|
"messaging.",
|
||||||
|
"cli.",
|
||||||
|
"smoke.",
|
||||||
|
"providers.",
|
||||||
|
"core.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert offenders == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_messaging_does_not_import_api_or_cli_or_providers() -> None:
|
||||||
|
"""Messaging is wired by ``api.runtime``; must not import server or provider adapters."""
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
offenders = _imports_matching(
|
||||||
|
[repo_root / "messaging"],
|
||||||
|
forbidden_prefixes=("api.", "cli.", "providers.", "smoke."),
|
||||||
|
)
|
||||||
|
assert offenders == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_may_only_import_narrow_provider_facade() -> None:
|
||||||
|
"""HTTP layer must not depend on per-adapter provider subpackages."""
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
offenders: list[str] = []
|
||||||
|
for path in (repo_root / "api").rglob("*.py"):
|
||||||
|
for imported in _imports_from(path, repo_root):
|
||||||
|
if imported is None or not imported.startswith("providers"):
|
||||||
|
continue
|
||||||
|
if imported in _API_ALLOWED_PROVIDER_MODULES:
|
||||||
|
continue
|
||||||
|
if imported.startswith("providers."):
|
||||||
|
rel = path.relative_to(repo_root)
|
||||||
|
offenders.append(f"{rel}: {imported}")
|
||||||
|
assert sorted(offenders) == []
|
||||||
|
|
||||||
|
|
||||||
def test_removed_openrouter_rollback_transport_stays_removed() -> None:
|
def test_removed_openrouter_rollback_transport_stays_removed() -> None:
|
||||||
repo_root = Path(__file__).resolve().parents[2]
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
|
@ -35,6 +107,11 @@ def test_removed_openrouter_rollback_transport_stays_removed() -> None:
|
||||||
|
|
||||||
def test_architecture_doc_names_enforced_boundaries() -> None:
|
def test_architecture_doc_names_enforced_boundaries() -> None:
|
||||||
repo_root = Path(__file__).resolve().parents[2]
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
contract_test = repo_root / "tests" / "contracts" / "test_import_boundaries.py"
|
||||||
|
assert contract_test.is_file()
|
||||||
|
stream_contracts = repo_root / "core" / "anthropic" / "stream_contracts.py"
|
||||||
|
assert stream_contracts.is_file()
|
||||||
|
|
||||||
text = (repo_root / "PLAN.md").read_text(encoding="utf-8")
|
text = (repo_root / "PLAN.md").read_text(encoding="utf-8")
|
||||||
|
|
||||||
assert "core/anthropic/" in text
|
assert "core/anthropic/" in text
|
||||||
|
|
@ -46,26 +123,89 @@ def _imports_matching(
|
||||||
roots: list[Path], *, forbidden_prefixes: tuple[str, ...]
|
roots: list[Path], *, forbidden_prefixes: tuple[str, ...]
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
offenders: list[str] = []
|
offenders: list[str] = []
|
||||||
|
repo_root = roots[0].parent
|
||||||
for root in roots:
|
for root in roots:
|
||||||
for path in root.rglob("*.py"):
|
for path in root.rglob("*.py"):
|
||||||
rel = path.relative_to(root.parent)
|
rel = path.relative_to(root.parent)
|
||||||
offenders.extend(
|
offenders.extend(
|
||||||
f"{rel}: {imported}"
|
f"{rel}: {imported}"
|
||||||
for imported in _imports_from(path)
|
for imported in _imports_from(path, repo_root)
|
||||||
if imported in forbidden_prefixes
|
if imported is not None and _is_forbidden(imported, forbidden_prefixes)
|
||||||
or imported.startswith(forbidden_prefixes)
|
|
||||||
)
|
)
|
||||||
return sorted(offenders)
|
return sorted(offenders)
|
||||||
|
|
||||||
|
|
||||||
def _imports_from(path: Path) -> list[str]:
|
def _is_forbidden(name: str, forbidden: tuple[str, ...]) -> bool:
|
||||||
|
"""Match root modules (``import api``) and submodules (``import api.x``)."""
|
||||||
|
for token in forbidden:
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
root = token.rstrip(".")
|
||||||
|
if name == root or name.startswith(f"{root}."):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _module_fqn_from_path(repo_root: Path, path: Path) -> str:
|
||||||
|
rel = path.relative_to(repo_root)
|
||||||
|
if rel.name == "__init__.py":
|
||||||
|
return ".".join(rel.parent.parts) if rel.parent != Path() else rel.parent.name
|
||||||
|
return ".".join(rel.with_suffix("").parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _importing_package_parts(repo_root: Path, path: Path) -> list[str]:
|
||||||
|
"""Package in which this file's module lives (for relative imports)."""
|
||||||
|
rel = path.relative_to(repo_root)
|
||||||
|
if rel.name == "__init__.py":
|
||||||
|
return list(rel.parent.parts)
|
||||||
|
fqn = _module_fqn_from_path(repo_root, path)
|
||||||
|
parts = fqn.split(".")
|
||||||
|
if len(parts) <= 1:
|
||||||
|
return []
|
||||||
|
return parts[:-1]
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_relative_import(
|
||||||
|
repo_root: Path, path: Path, node: ast.ImportFrom
|
||||||
|
) -> str | None:
|
||||||
|
"""Best-effort absolute name for ``from .x`` / ``from ..y`` (level >= 1)."""
|
||||||
|
if node.level == 0 and node.module:
|
||||||
|
return node.module
|
||||||
|
base = _importing_package_parts(repo_root, path)
|
||||||
|
for _ in range(node.level - 1):
|
||||||
|
if not base:
|
||||||
|
return None
|
||||||
|
base.pop()
|
||||||
|
if not node.module:
|
||||||
|
return ".".join(base) if base else None
|
||||||
|
return ".".join(base + node.module.split("."))
|
||||||
|
|
||||||
|
|
||||||
|
def _imports_from(path: Path, repo_root: Path) -> list[str]:
|
||||||
tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
|
||||||
imports: list[str] = []
|
imports: list[str] = []
|
||||||
for node in ast.walk(tree):
|
for node in ast.walk(tree):
|
||||||
if isinstance(node, ast.Import):
|
if isinstance(node, ast.Import):
|
||||||
imports.extend(alias.name for alias in node.names)
|
imports.extend(alias.name for alias in node.names)
|
||||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.level == 0:
|
||||||
|
if node.module:
|
||||||
imports.append(node.module)
|
imports.append(node.module)
|
||||||
|
continue
|
||||||
|
if node.module is not None:
|
||||||
|
resolved = _resolve_relative_import(repo_root, path, node)
|
||||||
|
if resolved:
|
||||||
|
imports.append(resolved)
|
||||||
|
else:
|
||||||
|
base = _importing_package_parts(repo_root, path).copy()
|
||||||
|
for _ in range(node.level - 1):
|
||||||
|
if base:
|
||||||
|
base.pop()
|
||||||
|
for alias in node.names:
|
||||||
|
if base:
|
||||||
|
imports.append(".".join([*base, alias.name]))
|
||||||
|
else:
|
||||||
|
imports.append(alias.name)
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
11
tests/contracts/test_smoke_sse_reexport.py
Normal file
11
tests/contracts/test_smoke_sse_reexport.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""Ensure smoke re-exports stay aligned with :mod:`core.anthropic.stream_contracts`."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import core.anthropic.stream_contracts as core_sc
|
||||||
|
import smoke.lib.sse as smoke_sse
|
||||||
|
|
||||||
|
|
||||||
|
def test_smoke_lib_sse_reexports_core_stream_contracts() -> None:
|
||||||
|
for name in smoke_sse.__all__:
|
||||||
|
assert getattr(smoke_sse, name) is getattr(core_sc, name)
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
|
"""Stream/SSE contract tests. Strict transcript *ordering* is covered here for
|
||||||
|
``SSEBuilder`` output; for transport-integrated ordering, add messaging or API
|
||||||
|
integration tests.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser
|
from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser
|
||||||
from messaging.event_parser import parse_cli_event
|
from core.anthropic.stream_contracts import (
|
||||||
from messaging.transcript import RenderCtx, TranscriptBuffer
|
|
||||||
from smoke.lib.sse import (
|
|
||||||
assert_anthropic_stream_contract,
|
assert_anthropic_stream_contract,
|
||||||
event_names,
|
event_names,
|
||||||
has_tool_use,
|
has_tool_use,
|
||||||
|
|
@ -13,6 +16,8 @@ from smoke.lib.sse import (
|
||||||
text_content,
|
text_content,
|
||||||
thinking_content,
|
thinking_content,
|
||||||
)
|
)
|
||||||
|
from messaging.event_parser import parse_cli_event
|
||||||
|
from messaging.transcript import RenderCtx, TranscriptBuffer
|
||||||
|
|
||||||
|
|
||||||
def test_interleaved_thinking_text_blocks_are_valid() -> None:
|
def test_interleaved_thinking_text_blocks_are_valid() -> None:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,10 @@
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from messaging.platforms.factory import create_messaging_platform
|
from messaging.platforms.factory import (
|
||||||
|
MessagingPlatformOptions,
|
||||||
|
create_messaging_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCreateMessagingPlatform:
|
class TestCreateMessagingPlatform:
|
||||||
|
|
@ -16,15 +19,29 @@ class TestCreateMessagingPlatform:
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.telegram.TelegramPlatform",
|
"messaging.platforms.telegram.TelegramPlatform",
|
||||||
return_value=mock_platform,
|
return_value=mock_platform,
|
||||||
),
|
) as platform_cls,
|
||||||
):
|
):
|
||||||
result = create_messaging_platform(
|
result = create_messaging_platform(
|
||||||
"telegram",
|
"telegram",
|
||||||
bot_token="test_token",
|
MessagingPlatformOptions(
|
||||||
allowed_user_id="12345",
|
telegram_bot_token="test_token",
|
||||||
|
allowed_telegram_user_id="12345",
|
||||||
|
voice_note_enabled=False,
|
||||||
|
whisper_model="large-v3",
|
||||||
|
whisper_device="cuda",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result is mock_platform
|
assert result is mock_platform
|
||||||
|
platform_cls.assert_called_once_with(
|
||||||
|
bot_token="test_token",
|
||||||
|
allowed_user_id="12345",
|
||||||
|
voice_note_enabled=False,
|
||||||
|
whisper_model="large-v3",
|
||||||
|
whisper_device="cuda",
|
||||||
|
hf_token="",
|
||||||
|
nvidia_nim_api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
def test_telegram_without_token(self):
|
def test_telegram_without_token(self):
|
||||||
"""Return None when no bot_token for Telegram."""
|
"""Return None when no bot_token for Telegram."""
|
||||||
|
|
@ -33,7 +50,9 @@ class TestCreateMessagingPlatform:
|
||||||
|
|
||||||
def test_telegram_empty_token(self):
|
def test_telegram_empty_token(self):
|
||||||
"""Return None when bot_token is empty string."""
|
"""Return None when bot_token is empty string."""
|
||||||
result = create_messaging_platform("telegram", bot_token="")
|
result = create_messaging_platform(
|
||||||
|
"telegram", MessagingPlatformOptions(telegram_bot_token="")
|
||||||
|
)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_discord_with_token(self):
|
def test_discord_with_token(self):
|
||||||
|
|
@ -44,15 +63,29 @@ class TestCreateMessagingPlatform:
|
||||||
patch(
|
patch(
|
||||||
"messaging.platforms.discord.DiscordPlatform",
|
"messaging.platforms.discord.DiscordPlatform",
|
||||||
return_value=mock_platform,
|
return_value=mock_platform,
|
||||||
),
|
) as platform_cls,
|
||||||
):
|
):
|
||||||
result = create_messaging_platform(
|
result = create_messaging_platform(
|
||||||
"discord",
|
"discord",
|
||||||
|
MessagingPlatformOptions(
|
||||||
discord_bot_token="test_token",
|
discord_bot_token="test_token",
|
||||||
allowed_discord_channels="123,456",
|
allowed_discord_channels="123,456",
|
||||||
|
voice_note_enabled=False,
|
||||||
|
whisper_model="small",
|
||||||
|
whisper_device="nvidia_nim",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result is mock_platform
|
assert result is mock_platform
|
||||||
|
platform_cls.assert_called_once_with(
|
||||||
|
bot_token="test_token",
|
||||||
|
allowed_channel_ids="123,456",
|
||||||
|
voice_note_enabled=False,
|
||||||
|
whisper_model="small",
|
||||||
|
whisper_device="nvidia_nim",
|
||||||
|
hf_token="",
|
||||||
|
nvidia_nim_api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
def test_discord_without_token(self):
|
def test_discord_without_token(self):
|
||||||
"""Return None when no discord_bot_token for Discord."""
|
"""Return None when no discord_bot_token for Discord."""
|
||||||
|
|
@ -62,7 +95,11 @@ class TestCreateMessagingPlatform:
|
||||||
def test_discord_empty_token(self):
|
def test_discord_empty_token(self):
|
||||||
"""Return None when discord_bot_token is empty string."""
|
"""Return None when discord_bot_token is empty string."""
|
||||||
result = create_messaging_platform(
|
result = create_messaging_platform(
|
||||||
"discord", discord_bot_token="", allowed_discord_channels="123"
|
"discord",
|
||||||
|
MessagingPlatformOptions(
|
||||||
|
discord_bot_token="",
|
||||||
|
allowed_discord_channels="123",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
@ -73,5 +110,7 @@ class TestCreateMessagingPlatform:
|
||||||
|
|
||||||
def test_unknown_platform_with_kwargs(self):
|
def test_unknown_platform_with_kwargs(self):
|
||||||
"""Return None for unknown platform even with kwargs."""
|
"""Return None for unknown platform even with kwargs."""
|
||||||
result = create_messaging_platform("slack", bot_token="token")
|
result = create_messaging_platform(
|
||||||
|
"slack", MessagingPlatformOptions(telegram_bot_token="token")
|
||||||
|
)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
|
||||||
|
|
@ -17,18 +17,20 @@ def telegram_platform():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_telegram_voice_disabled_sends_reply(telegram_platform):
|
async def test_telegram_voice_disabled_sends_reply():
|
||||||
"""When voice_note_enabled is False, reply with disabled message."""
|
"""When voice_note_enabled is False, reply with disabled message."""
|
||||||
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
|
telegram_platform = TelegramPlatform(
|
||||||
|
bot_token="test_token",
|
||||||
|
allowed_user_id="12345",
|
||||||
|
voice_note_enabled=False,
|
||||||
|
)
|
||||||
mock_update = MagicMock()
|
mock_update = MagicMock()
|
||||||
mock_update.message.voice = MagicMock(file_id="f1", mime_type="audio/ogg")
|
mock_update.message.voice = MagicMock(file_id="f1", mime_type="audio/ogg")
|
||||||
mock_update.effective_user.id = 12345
|
mock_update.effective_user.id = 12345
|
||||||
mock_update.effective_chat.id = 6789
|
mock_update.effective_chat.id = 6789
|
||||||
mock_update.message.reply_text = AsyncMock()
|
mock_update.message.reply_text = AsyncMock()
|
||||||
|
|
||||||
with patch(
|
|
||||||
"config.settings.get_settings",
|
|
||||||
return_value=MagicMock(voice_note_enabled=False),
|
|
||||||
):
|
|
||||||
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
|
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
|
||||||
|
|
||||||
mock_update.message.reply_text.assert_called_once_with("Voice notes are disabled.")
|
mock_update.message.reply_text.assert_called_once_with("Voice notes are disabled.")
|
||||||
|
|
@ -42,10 +44,6 @@ async def test_telegram_voice_unauthorized_ignored(telegram_platform):
|
||||||
mock_update.effective_user.id = 99999 # Not 12345
|
mock_update.effective_user.id = 99999 # Not 12345
|
||||||
mock_update.message.reply_text = AsyncMock()
|
mock_update.message.reply_text = AsyncMock()
|
||||||
|
|
||||||
with patch(
|
|
||||||
"config.settings.get_settings",
|
|
||||||
return_value=MagicMock(voice_note_enabled=True),
|
|
||||||
):
|
|
||||||
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
|
await telegram_platform._on_telegram_voice(mock_update, MagicMock())
|
||||||
|
|
||||||
mock_update.message.reply_text.assert_not_called()
|
mock_update.message.reply_text.assert_not_called()
|
||||||
|
|
@ -82,17 +80,8 @@ async def test_telegram_voice_success_invokes_handler(telegram_platform):
|
||||||
|
|
||||||
mock_file.download_to_drive = fake_download
|
mock_file.download_to_drive = fake_download
|
||||||
|
|
||||||
mock_settings = MagicMock(
|
|
||||||
voice_note_enabled=True,
|
|
||||||
whisper_model="base",
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_queue_send = AsyncMock(return_value="999")
|
mock_queue_send = AsyncMock(return_value="999")
|
||||||
with (
|
with (
|
||||||
patch(
|
|
||||||
"config.settings.get_settings",
|
|
||||||
return_value=mock_settings,
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"messaging.transcription.transcribe_audio",
|
"messaging.transcription.transcribe_audio",
|
||||||
return_value="Hello from voice",
|
return_value="Hello from voice",
|
||||||
|
|
@ -164,7 +153,11 @@ class TestDiscordGetAudioAttachment:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_discord_voice_disabled_sends_reply():
|
async def test_discord_voice_disabled_sends_reply():
|
||||||
"""When voice_note_enabled is False, reply with disabled message."""
|
"""When voice_note_enabled is False, reply with disabled message."""
|
||||||
platform = DiscordPlatform(bot_token="token", allowed_channel_ids="123")
|
platform = DiscordPlatform(
|
||||||
|
bot_token="token",
|
||||||
|
allowed_channel_ids="123",
|
||||||
|
voice_note_enabled=False,
|
||||||
|
)
|
||||||
platform._message_handler = None
|
platform._message_handler = None
|
||||||
|
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
|
|
@ -178,10 +171,6 @@ async def test_discord_voice_disabled_sends_reply():
|
||||||
mock_att.filename = "voice.ogg"
|
mock_att.filename = "voice.ogg"
|
||||||
mock_message.attachments = [mock_att]
|
mock_message.attachments = [mock_att]
|
||||||
|
|
||||||
with patch(
|
|
||||||
"config.settings.get_settings",
|
|
||||||
return_value=MagicMock(voice_note_enabled=False),
|
|
||||||
):
|
|
||||||
await platform._on_discord_message(mock_message)
|
await platform._on_discord_message(mock_message)
|
||||||
|
|
||||||
mock_message.reply.assert_called_once_with("Voice notes are disabled.")
|
mock_message.reply.assert_called_once_with("Voice notes are disabled.")
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ class MockBlock:
|
||||||
|
|
||||||
|
|
||||||
class MockTool:
|
class MockTool:
|
||||||
def __init__(self, name, description, input_schema):
|
def __init__(self, name, description, input_schema=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
self.input_schema = input_schema
|
self.input_schema = input_schema
|
||||||
|
|
@ -79,6 +79,23 @@ def test_convert_tools():
|
||||||
assert result[1]["function"]["description"] == "" # Check default empty string
|
assert result[1]["function"]["description"] == "" # Check default empty string
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_without_input_schema_uses_empty_object_schema():
|
||||||
|
tools = [MockTool("web_search", None)]
|
||||||
|
|
||||||
|
result = AnthropicToOpenAIConverter.convert_tools(tools)
|
||||||
|
|
||||||
|
assert result == [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "web_search",
|
||||||
|
"description": "",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tool_choice,expected",
|
"tool_choice,expected",
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -449,6 +449,40 @@ def test_heuristic_tool_parser_flush_no_tool():
|
||||||
assert tools == []
|
assert tools == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_heuristic_tool_parser_json_style_web_fetch_tool_call():
|
||||||
|
parser = HeuristicToolParser()
|
||||||
|
text = (
|
||||||
|
"Use WebFetch on the article.\n\n"
|
||||||
|
"{\n"
|
||||||
|
' "url": "https://example.com/article",\n'
|
||||||
|
' "prompt": "Summarize it."\n'
|
||||||
|
"}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered, tools = parser.feed(text)
|
||||||
|
tools.extend(parser.flush())
|
||||||
|
|
||||||
|
assert filtered == ""
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["name"] == "WebFetch"
|
||||||
|
assert tools[0]["input"] == {
|
||||||
|
"url": "https://example.com/article",
|
||||||
|
"prompt": "Summarize it.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_heuristic_tool_parser_json_style_web_search_tool_call():
|
||||||
|
parser = HeuristicToolParser()
|
||||||
|
|
||||||
|
filtered, tools = parser.feed('Use WebSearch {"query": "DeepSeek V4"}')
|
||||||
|
tools.extend(parser.flush())
|
||||||
|
|
||||||
|
assert filtered == ""
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["name"] == "WebSearch"
|
||||||
|
assert tools[0]["input"] == {"query": "DeepSeek V4"}
|
||||||
|
|
||||||
|
|
||||||
def test_heuristic_tool_parser_unicode_function_name():
|
def test_heuristic_tool_parser_unicode_function_name():
|
||||||
"""Unicode characters in function parameters."""
|
"""Unicode characters in function parameters."""
|
||||||
parser = HeuristicToolParser()
|
parser = HeuristicToolParser()
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,13 @@
|
||||||
from unittest.mock import MagicMock, patch
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from config.nim import NimSettings
|
from config.nim import NimSettings
|
||||||
|
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||||
from providers.deepseek import DeepSeekProvider
|
from providers.deepseek import DeepSeekProvider
|
||||||
|
from providers.exceptions import UnknownProviderTypeError
|
||||||
from providers.llamacpp import LlamaCppProvider
|
from providers.llamacpp import LlamaCppProvider
|
||||||
from providers.lmstudio import LMStudioProvider
|
from providers.lmstudio import LMStudioProvider
|
||||||
from providers.nvidia_nim import NvidiaNimProvider
|
from providers.nvidia_nim import NvidiaNimProvider
|
||||||
|
|
@ -41,14 +45,24 @@ def _make_settings(**overrides):
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_importing_registry_does_not_eager_load_other_adapters() -> None:
|
||||||
|
"""Registry metadata must not import every provider adapter up front."""
|
||||||
|
code = (
|
||||||
|
"import sys\n"
|
||||||
|
"import providers.registry\n"
|
||||||
|
"assert 'providers.open_router' not in sys.modules\n"
|
||||||
|
)
|
||||||
|
proc = subprocess.run(
|
||||||
|
[sys.executable, "-c", code],
|
||||||
|
check=False,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
assert proc.returncode == 0, proc.stderr or proc.stdout
|
||||||
|
|
||||||
|
|
||||||
def test_descriptors_cover_advertised_provider_ids():
|
def test_descriptors_cover_advertised_provider_ids():
|
||||||
assert set(PROVIDER_DESCRIPTORS) == {
|
assert set(PROVIDER_DESCRIPTORS) == set(SUPPORTED_PROVIDER_IDS)
|
||||||
"nvidia_nim",
|
|
||||||
"open_router",
|
|
||||||
"deepseek",
|
|
||||||
"lmstudio",
|
|
||||||
"llamacpp",
|
|
||||||
}
|
|
||||||
for descriptor in PROVIDER_DESCRIPTORS.values():
|
for descriptor in PROVIDER_DESCRIPTORS.values():
|
||||||
assert descriptor.provider_id
|
assert descriptor.provider_id
|
||||||
assert descriptor.transport_type in {"openai_chat", "anthropic_messages"}
|
assert descriptor.transport_type in {"openai_chat", "anthropic_messages"}
|
||||||
|
|
@ -90,6 +104,38 @@ def test_provider_registry_caches_by_provider_id():
|
||||||
assert first is second
|
assert first is second
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_provider_raises_value_error():
|
def test_unknown_provider_raises_unknown_provider_type_error():
|
||||||
with pytest.raises(ValueError, match="Unknown provider_type"):
|
with pytest.raises(UnknownProviderTypeError, match="Unknown provider_type"):
|
||||||
create_provider("unknown", _make_settings())
|
create_provider("unknown", _make_settings())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_registry_cleanup_runs_all_even_if_one_fails() -> None:
|
||||||
|
"""Every provider gets cleanup; cache is cleared even when one raises."""
|
||||||
|
reg = ProviderRegistry()
|
||||||
|
p1 = MagicMock()
|
||||||
|
p1.cleanup = AsyncMock(side_effect=RuntimeError("first"))
|
||||||
|
p2 = MagicMock()
|
||||||
|
p2.cleanup = AsyncMock()
|
||||||
|
reg._providers["a"] = p1
|
||||||
|
reg._providers["b"] = p2
|
||||||
|
with pytest.raises(RuntimeError, match="first"):
|
||||||
|
await reg.cleanup()
|
||||||
|
p1.cleanup.assert_awaited_once()
|
||||||
|
p2.cleanup.assert_awaited_once()
|
||||||
|
assert reg._providers == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_registry_cleanup_exceptiongroup_on_multiple_failures() -> None:
|
||||||
|
reg = ProviderRegistry()
|
||||||
|
p1 = MagicMock()
|
||||||
|
p1.cleanup = AsyncMock(side_effect=RuntimeError("a"))
|
||||||
|
p2 = MagicMock()
|
||||||
|
p2.cleanup = AsyncMock(side_effect=RuntimeError("b"))
|
||||||
|
reg._providers["x"] = p1
|
||||||
|
reg._providers["y"] = p2
|
||||||
|
with pytest.raises(ExceptionGroup) as exc_info:
|
||||||
|
await reg.cleanup()
|
||||||
|
assert len(exc_info.value.exceptions) == 2
|
||||||
|
assert reg._providers == {}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue