This commit is contained in:
Chinfeng Chung 2026-05-17 17:02:43 +00:00 committed by GitHub
commit a9597ab359
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 635 additions and 34 deletions

View file

@ -1,5 +1,6 @@
# NVIDIA NIM Config
NVIDIA_NIM_API_KEY=""
NVIDIA_NIM_BASE_URL="https://integrate.api.nvidia.com/v1"
# OpenRouter Config
@ -26,22 +27,28 @@ OPENCODE_API_KEY=""
ZAI_API_KEY=""
ZAI_BASE_URL="https://api.z.ai/api/coding/paas/v4"
# OpenAI Config (official Chat Completions API)
OPENAI_API_KEY=""
OPENAI_BASE_URL="https://api.openai.com/v1"
# LM Studio Config (local provider, no API key required)
LM_STUDIO_BASE_URL="http://localhost:1234/v1"
# Llama.cpp Config (local provider, no API key required)
# Llama.cpp Config (local provider, API key optional)
LLAMACPP_BASE_URL="http://localhost:8080/v1"
LLAMACPP_API_KEY=""
# Ollama Config (local provider, no API key required)
# Ollama Config (local provider, API key optional)
OLLAMA_BASE_URL="http://localhost:11434"
OLLAMA_API_KEY=""
# All Claude model requests are mapped to these models, plain model is fallback
# Format: provider_type/model/name
# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" | "ollama" | "kimi" | "wafer" | "opencode" | "zai"
# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" | "ollama" | "kimi" | "wafer" | "opencode" | "zai" | "openai"
MODEL_OPUS=
MODEL_SONNET=
MODEL_HAIKU=
@ -60,6 +67,7 @@ FCC_SMOKE_MODEL_KIMI=
FCC_SMOKE_MODEL_WAFER=
FCC_SMOKE_MODEL_OPENCODE=
FCC_SMOKE_MODEL_ZAI=
FCC_SMOKE_MODEL_OPENAI=
FCC_SMOKE_NIM_MODELS=
FCC_SMOKE_NIM_EXTRA_MODELS=
FCC_SMOKE_OPENROUTER_FREE_MODELS=
@ -85,6 +93,8 @@ KIMI_PROXY=""
WAFER_PROXY=""
OPENCODE_PROXY=""
ZAI_PROXY=""
OPENAI_PROXY=""
OLLAMA_PROXY=""
PROVIDER_RATE_LIMIT=1
PROVIDER_RATE_WINDOW=3
@ -100,6 +110,12 @@ HTTP_CONNECT_TIMEOUT=60
# Optional server API key (Anthropic-style)
ANTHROPIC_AUTH_TOKEN="freecc"
# API Key Pass-through
# When true, the client's auth token is used as the upstream provider bearer
# token, eliminating the need for separate per-provider API keys.
# Requires ANTHROPIC_AUTH_TOKEN to be set (clients must authenticate).
ENABLE_API_KEY_PASSTHROUGH=false
# Messaging Platform: "telegram" | "discord" | "none"
MESSAGING_PLATFORM="discord"

View file

@ -184,6 +184,51 @@ FIELDS: tuple[ConfigFieldSpec, ...] = (
default="https://api.z.ai/api/coding/paas/v4",
description="Z.ai OpenAI-compatible Coding Plan endpoint.",
),
ConfigFieldSpec(
"OPENAI_API_KEY",
"OpenAI API Key",
"providers",
"secret",
settings_attr="openai_api_key",
secret=True,
description="OpenAI API key for direct Chat Completions access.",
),
ConfigFieldSpec(
"OPENAI_BASE_URL",
"OpenAI Base URL",
"providers",
settings_attr="openai_base_url",
default="https://api.openai.com/v1",
description="OpenAI-compatible Chat Completions endpoint.",
),
ConfigFieldSpec(
"NVIDIA_NIM_BASE_URL",
"NVIDIA NIM Base URL",
"providers",
settings_attr="nvidia_nim_base_url",
default="https://integrate.api.nvidia.com/v1",
description="NVIDIA NIM API endpoint override.",
),
ConfigFieldSpec(
"LLAMACPP_API_KEY",
"llama.cpp API Key",
"providers",
"secret",
settings_attr="llamacpp_api_key",
secret=True,
advanced=True,
description="Optional API key for llama.cpp server authentication.",
),
ConfigFieldSpec(
"OLLAMA_API_KEY",
"Ollama API Key",
"providers",
"secret",
settings_attr="ollama_api_key",
secret=True,
advanced=True,
description="Optional API key for Ollama server authentication.",
),
ConfigFieldSpec(
"LM_STUDIO_BASE_URL",
"LM Studio Base URL",
@ -277,6 +322,24 @@ FIELDS: tuple[ConfigFieldSpec, ...] = (
secret=True,
advanced=True,
),
ConfigFieldSpec(
"OPENAI_PROXY",
"OpenAI Proxy",
"providers",
"secret",
settings_attr="openai_proxy",
secret=True,
advanced=True,
),
ConfigFieldSpec(
"OLLAMA_PROXY",
"Ollama Proxy",
"providers",
"secret",
settings_attr="ollama_proxy",
secret=True,
advanced=True,
),
ConfigFieldSpec(
"MODEL",
"Default Model",

View file

@ -33,6 +33,7 @@ def resolve_provider(
*,
app: Starlette | None,
settings: Settings,
passthrough_api_key: str = "",
) -> BaseProvider:
"""Resolve a provider using the app-scoped registry when ``app`` is set.
@ -43,6 +44,10 @@ def resolve_provider(
When ``app`` is ``None`` (no HTTP context), uses the process-level
:data:`_providers` cache only.
When ``passthrough_api_key`` is non-empty and
``settings.enable_api_key_passthrough`` is True, the client's auth token
replaces per-provider credentials in the upstream request.
"""
if app is not None:
reg = getattr(app.state, "provider_registry", None)
@ -51,16 +56,27 @@ def resolve_provider(
"Provider registry is not configured. Ensure AppRuntime startup ran "
"or assign app.state.provider_registry for test apps."
)
return _resolve_with_registry(reg, provider_type, settings)
return _resolve_with_registry(
reg, provider_type, settings, passthrough_api_key=passthrough_api_key
)
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
def _resolve_with_registry(
registry: ProviderRegistry, provider_type: str, settings: Settings
registry: ProviderRegistry,
provider_type: str,
settings: Settings,
*,
passthrough_api_key: str = "",
) -> BaseProvider:
should_log_init = not registry.is_cached(provider_type)
cache_key = provider_type
if passthrough_api_key and settings.enable_api_key_passthrough:
cache_key = f"{provider_type}:{passthrough_api_key}"
should_log_init = not registry.is_cached(cache_key)
try:
provider = registry.get(provider_type, settings)
provider = registry.get(
provider_type, settings, passthrough_api_key=passthrough_api_key
)
except AuthenticationError as e:
# Provider :class:`~providers.exceptions.AuthenticationError` messages are
# curated configuration hints (env var names, docs links), not upstream noise.
@ -88,18 +104,46 @@ def get_provider_for_type(provider_type: str) -> BaseProvider:
return resolve_provider(provider_type, app=None, settings=get_settings())
def _parse_bearer_token(header: str) -> str:
"""Extract and clean a token from an Authorization or x-api-key header value."""
token = header
if header.lower().startswith("bearer "):
token = header.split(" ", 1)[1]
if token and ":" in token:
token = token.split(":", 1)[0]
return token
def _extract_client_token(request: Request) -> str:
"""Extract a client auth token from the request for passthrough mode."""
header = (
request.headers.get("x-api-key")
or request.headers.get("authorization")
or request.headers.get("anthropic-auth-token")
)
if not header:
raise HTTPException(status_code=401, detail="Missing API key")
return _parse_bearer_token(header)
def require_api_key(
request: Request, settings: Settings = Depends(get_settings)
) -> None:
) -> str:
"""Require a server API key (Anthropic-style).
Checks `x-api-key` header or `Authorization: Bearer ...` against
`Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op.
Returns the extracted client token for pass-through when
``ENABLE_API_KEY_PASSTHROUGH`` is active.
"""
anthropic_auth_token = settings.anthropic_auth_token
if not anthropic_auth_token:
# No API key configured -> allow
return
# No server auth configured. In passthrough mode, still extract the
# client token from the request header so it can be forwarded upstream.
if settings.enable_api_key_passthrough:
return _extract_client_token(request)
return ""
header = (
request.headers.get("x-api-key")
@ -109,14 +153,7 @@ def require_api_key(
if not header:
raise HTTPException(status_code=401, detail="Missing API key")
# Support both raw key in X-API-Key and Bearer token in Authorization
token = header
if header.lower().startswith("bearer "):
token = header.split(" ", 1)[1]
# Strip anything after the first colon to handle tokens with appended model names
if token and ":" in token:
token = token.split(":", 1)[0]
token = _parse_bearer_token(header)
# Constant-time comparison to avoid leaking the configured token via
# response-time differences on a per-byte mismatch (CWE-208).
@ -125,6 +162,8 @@ def require_api_key(
):
raise HTTPException(status_code=401, detail="Invalid API key")
return token
def get_provider() -> BaseProvider:
"""Get or create the default provider (``MODEL`` / ``provider_type``).

View file

@ -62,12 +62,19 @@ SUPPORTED_CLAUDE_MODELS = [
def get_proxy_service(
request: Request,
settings: Settings = Depends(get_settings),
passthrough_api_key: str = Depends(require_api_key),
) -> ClaudeProxyService:
"""Build the request service for route handlers."""
effective_passthrough = (
passthrough_api_key if settings.enable_api_key_passthrough else ""
)
return ClaudeProxyService(
settings,
provider_getter=lambda provider_type: dependencies.resolve_provider(
provider_type, app=request.app, settings=settings
provider_type,
app=request.app,
settings=settings,
passthrough_api_key=effective_passthrough,
),
token_counter=get_token_count,
)
@ -167,7 +174,6 @@ def _build_models_list_response(
async def create_message(
request_data: MessagesRequest,
service: ClaudeProxyService = Depends(get_proxy_service),
_auth=Depends(require_api_key),
):
"""Create a message (always streaming)."""
return service.create_message(request_data)
@ -183,7 +189,6 @@ async def probe_messages(_auth=Depends(require_api_key)):
async def count_tokens(
request_data: TokenCountRequest,
service: ClaudeProxyService = Depends(get_proxy_service),
_auth=Depends(require_api_key),
):
"""Count tokens for a request."""
return service.count_tokens(request_data)

View file

@ -34,7 +34,7 @@ TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], i
ProviderGetter = Callable[[str], BaseProvider]
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "zai"})
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "zai", "openai"})
def anthropic_sse_streaming_response(

View file

@ -25,6 +25,7 @@ LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
OLLAMA_DEFAULT_BASE = "http://localhost:11434"
OPENCODE_DEFAULT_BASE = "https://opencode.ai/zen/v1"
ZAI_DEFAULT_BASE = "https://api.z.ai/api/coding/paas/v4"
OPENAI_DEFAULT_BASE = "https://api.openai.com/v1"
@dataclass(frozen=True, slots=True)
@ -51,6 +52,7 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
credential_url="https://build.nvidia.com/settings/api-keys",
credential_attr="nvidia_nim_api_key",
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
base_url_attr="nvidia_nim_base_url",
proxy_attr="nvidia_nim_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
@ -85,7 +87,8 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
"llamacpp": ProviderDescriptor(
provider_id="llamacpp",
transport_type="anthropic_messages",
static_credential="llamacpp",
credential_env=None,
credential_attr="llamacpp_api_key",
default_base_url=LLAMACPP_DEFAULT_BASE,
base_url_attr="llamacpp_base_url",
proxy_attr="llamacpp_proxy",
@ -94,9 +97,11 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
"ollama": ProviderDescriptor(
provider_id="ollama",
transport_type="anthropic_messages",
static_credential="ollama",
credential_env=None,
credential_attr="ollama_api_key",
default_base_url=OLLAMA_DEFAULT_BASE,
base_url_attr="ollama_base_url",
proxy_attr="ollama_proxy",
capabilities=(
"chat",
"streaming",
@ -146,6 +151,17 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
proxy_attr="zai_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
"openai": ProviderDescriptor(
provider_id="openai",
transport_type="openai_chat",
credential_env="OPENAI_API_KEY",
credential_url="https://platform.openai.com/api-keys",
credential_attr="openai_api_key",
default_base_url=OPENAI_DEFAULT_BASE,
base_url_attr="openai_base_url",
proxy_attr="openai_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
}
# Order matches docs / historical error text; must match PROVIDER_CATALOG keys.

View file

@ -129,6 +129,14 @@ class Settings(BaseSettings):
validation_alias="ZAI_BASE_URL",
)
# ==================== OpenAI Config ====================
openai_api_key: str = Field(default="", validation_alias="OPENAI_API_KEY")
openai_base_url: str = Field(
default="https://api.openai.com/v1",
validation_alias="OPENAI_BASE_URL",
)
openai_proxy: str = Field(default="", validation_alias="OPENAI_PROXY")
# ==================== Messaging Platform Selection ====================
# Valid: "telegram" | "discord" | "none"
messaging_platform: str = Field(
@ -143,6 +151,10 @@ class Settings(BaseSettings):
# ==================== NVIDIA NIM Config ====================
nvidia_nim_api_key: str = ""
nvidia_nim_base_url: str = Field(
default="https://integrate.api.nvidia.com/v1",
validation_alias="NVIDIA_NIM_BASE_URL",
)
# ==================== LM Studio Config ====================
lm_studio_base_url: str = Field(
@ -155,12 +167,15 @@ class Settings(BaseSettings):
default="http://localhost:8080/v1",
validation_alias="LLAMACPP_BASE_URL",
)
llamacpp_api_key: str = Field(default="", validation_alias="LLAMACPP_API_KEY")
# ==================== Ollama Config ====================
ollama_base_url: str = Field(
default="http://localhost:11434",
validation_alias="OLLAMA_BASE_URL",
)
ollama_api_key: str = Field(default="", validation_alias="OLLAMA_API_KEY")
ollama_proxy: str = Field(default="", validation_alias="OLLAMA_PROXY")
# ==================== Model ====================
# All Claude model requests are mapped to this single model (fallback)
@ -315,6 +330,12 @@ class Settings(BaseSettings):
anthropic_auth_token: str = Field(
default="", validation_alias="ANTHROPIC_AUTH_TOKEN"
)
# When true, the client's auth token is passed through as the upstream
# provider bearer token, eliminating the need for separate per-provider
# API keys. Requires ANTHROPIC_AUTH_TOKEN to be set.
enable_api_key_passthrough: bool = Field(
default=False, validation_alias="ENABLE_API_KEY_PASSTHROUGH"
)
@model_validator(mode="before")
@classmethod
@ -432,6 +453,8 @@ class Settings(BaseSettings):
@model_validator(mode="after")
def check_nvidia_nim_api_key(self) -> Settings:
if self.enable_api_key_passthrough:
return self
if (
self.voice_note_enabled
and self.whisper_device == "nvidia_nim"

View file

@ -8,6 +8,7 @@ from config.provider_catalog import (
LMSTUDIO_DEFAULT_BASE,
NVIDIA_NIM_DEFAULT_BASE,
OLLAMA_DEFAULT_BASE,
OPENAI_DEFAULT_BASE,
OPENCODE_DEFAULT_BASE,
OPENROUTER_DEFAULT_BASE,
WAFER_DEFAULT_BASE,
@ -22,6 +23,7 @@ __all__ = (
"LMSTUDIO_DEFAULT_BASE",
"NVIDIA_NIM_DEFAULT_BASE",
"OLLAMA_DEFAULT_BASE",
"OPENAI_DEFAULT_BASE",
"OPENCODE_DEFAULT_BASE",
"OPENROUTER_DEFAULT_BASE",
"WAFER_DEFAULT_BASE",

View file

@ -14,3 +14,15 @@ class LlamaCppProvider(AnthropicMessagesTransport):
provider_name="LLAMACPP",
default_base_url=LLAMACPP_DEFAULT_BASE,
)
self._api_key = config.api_key or "llamacpp"
def _request_headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self._api_key and self._api_key != "llamacpp":
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def _model_list_headers(self) -> dict[str, str]:
if self._api_key and self._api_key != "llamacpp":
return {"Authorization": f"Bearer {self._api_key}"}
return {}

View file

@ -19,6 +19,17 @@ class OllamaProvider(AnthropicMessagesTransport):
)
self._api_key = config.api_key or "ollama"
def _request_headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self._api_key and self._api_key != "ollama":
headers["Authorization"] = f"Bearer {self._api_key}"
return headers
def _model_list_headers(self) -> dict[str, str]:
if self._api_key and self._api_key != "ollama":
return {"Authorization": f"Bearer {self._api_key}"}
return {}
async def _send_stream_request(self, body: dict) -> httpx.Response:
"""Create a streaming native Anthropic messages response."""
request = self._client.build_request(

View file

@ -0,0 +1,10 @@
"""OpenAI provider exports."""
from providers.defaults import OPENAI_DEFAULT_BASE
from .client import OpenAIProvider
__all__ = [
"OPENAI_DEFAULT_BASE",
"OpenAIProvider",
]

View file

@ -0,0 +1,31 @@
"""OpenAI provider implementation."""
from __future__ import annotations
from typing import Any
from providers.base import ProviderConfig
from providers.defaults import OPENAI_DEFAULT_BASE
from providers.openai_compat import OpenAIChatTransport
from .request import build_request_body
class OpenAIProvider(OpenAIChatTransport):
"""OpenAI provider using the official Chat Completions API."""
def __init__(self, config: ProviderConfig):
super().__init__(
config,
provider_name="OPENAI",
base_url=config.base_url or OPENAI_DEFAULT_BASE,
api_key=config.api_key,
)
def _build_request_body(
self, request: Any, thinking_enabled: bool | None = None
) -> dict:
return build_request_body(
request,
thinking_enabled=self._is_thinking_enabled(request, thinking_enabled),
)

View file

@ -0,0 +1,35 @@
"""Request builder for OpenAI provider."""
from typing import Any
from loguru import logger
from core.anthropic import ReasoningReplayMode, build_base_request_body
from core.anthropic.conversion import OpenAIConversionError
from providers.exceptions import InvalidRequestError
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
"""Build OpenAI-format request body from Anthropic request for OpenAI."""
logger.debug(
"OPENAI_REQUEST: conversion start model={} msgs={}",
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
try:
body = build_base_request_body(
request_data,
reasoning_replay=ReasoningReplayMode.REASONING_CONTENT
if thinking_enabled
else ReasoningReplayMode.DISABLED,
)
except OpenAIConversionError as exc:
raise InvalidRequestError(str(exc)) from exc
logger.debug(
"OPENAI_REQUEST: conversion done model={} msgs={} tools={}",
body.get("model"),
len(body.get("messages", [])),
len(body.get("tools", [])),
)
return body

View file

@ -92,6 +92,12 @@ def _create_zai(config: ProviderConfig, _settings: Settings) -> BaseProvider:
return ZaiProvider(config)
def _create_openai(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.openai import OpenAIProvider
return OpenAIProvider(config)
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"nvidia_nim": _create_nvidia_nim,
"open_router": _create_open_router,
@ -103,6 +109,7 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"wafer": _create_wafer,
"opencode": _create_opencode,
"zai": _create_zai,
"openai": _create_openai,
}
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
@ -130,11 +137,22 @@ def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
return ""
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
def _require_credential(
descriptor: ProviderDescriptor,
credential: str,
*,
passthrough_enabled: bool = False,
) -> None:
if descriptor.credential_env is None:
return
if credential and credential.strip():
return
if passthrough_enabled:
raise AuthenticationError(
"ENABLE_API_KEY_PASSTHROUGH is active but no client auth token was "
"provided. Ensure the request includes a valid Authorization or "
"x-api-key header."
)
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
if descriptor.credential_url:
message = f"{message} Get a key at {descriptor.credential_url}"
@ -142,10 +160,19 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None
def build_provider_config(
descriptor: ProviderDescriptor, settings: Settings
descriptor: ProviderDescriptor,
settings: Settings,
*,
passthrough_api_key: str = "",
) -> ProviderConfig:
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential)
passthrough_enabled = settings.enable_api_key_passthrough
if passthrough_api_key and passthrough_enabled:
credential = passthrough_api_key
elif passthrough_enabled:
credential = ""
else:
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential, passthrough_enabled=passthrough_enabled)
base_url = _string_attr(
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
)
@ -166,7 +193,12 @@ def build_provider_config(
)
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
def create_provider(
provider_id: str,
settings: Settings,
*,
passthrough_api_key: str = "",
) -> BaseProvider:
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
if descriptor is None:
supported = "', '".join(PROVIDER_DESCRIPTORS)
@ -174,7 +206,9 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
)
config = build_provider_config(descriptor, settings)
config = build_provider_config(
descriptor, settings, passthrough_api_key=passthrough_api_key
)
factory = PROVIDER_FACTORIES.get(provider_id)
if factory is None:
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
@ -260,10 +294,23 @@ class ProviderRegistry:
"""Return whether a provider for this id is already in the cache."""
return provider_id in self._providers
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
if provider_id not in self._providers:
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
def get(
self,
provider_id: str,
settings: Settings,
*,
passthrough_api_key: str = "",
) -> BaseProvider:
# When pass-through is active, include the API key in the cache key
# so different client tokens get distinct provider instances.
cache_key = provider_id
if passthrough_api_key and settings.enable_api_key_passthrough:
cache_key = f"{provider_id}:{passthrough_api_key}"
if cache_key not in self._providers:
self._providers[cache_key] = create_provider(
provider_id, settings, passthrough_api_key=passthrough_api_key
)
return self._providers[cache_key]
def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None:
"""Store a provider model-list result for later instant API responses."""

View file

@ -51,6 +51,7 @@ PROVIDER_SMOKE_DEFAULT_MODELS: dict[str, str] = {
"wafer": "wafer/DeepSeek-V4-Pro",
"opencode": "opencode/gpt-5.3-codex",
"zai": "zai/glm-5.1",
"openai": "openai/gpt-4o",
}
NVIDIA_NIM_CLI_DEFAULT_MODELS: tuple[str, ...] = (
@ -237,6 +238,8 @@ class SmokeConfig:
return bool(self.settings.opencode_api_key.strip())
if provider == "zai":
return bool(self.settings.zai_api_key.strip())
if provider == "openai":
return bool(self.settings.openai_api_key.strip())
return False

View file

@ -0,0 +1,284 @@
"""Tests for API key pass-through (ENABLE_API_KEY_PASSTHROUGH)."""
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException
from starlette.applications import Starlette
from starlette.datastructures import State
from api.dependencies import (
require_api_key,
resolve_provider,
)
from config.provider_catalog import PROVIDER_CATALOG
from providers.nvidia_nim import NvidiaNimProvider
from providers.registry import ProviderRegistry, build_provider_config, create_provider
def _make_settings(**overrides):
"""Create a mock settings object with all required fields."""
mock = MagicMock()
mock.model = "nvidia_nim/meta/llama3"
mock.provider_type = "nvidia_nim"
mock.nvidia_nim_api_key = "test_key"
mock.open_router_api_key = "test_openrouter_key"
mock.deepseek_api_key = "test_deepseek_key"
mock.wafer_api_key = "test_wafer_key"
mock.opencode_api_key = "test_opencode_key"
mock.zai_api_key = "test_zai_key"
mock.openai_api_key = "test_openai_key"
mock.kimi_api_key = "test_kimi_key"
mock.lm_studio_base_url = "http://localhost:1234/v1"
mock.llamacpp_base_url = "http://localhost:8080/v1"
mock.ollama_base_url = "http://localhost:11434"
mock.llamacpp_api_key = ""
mock.ollama_api_key = ""
mock.nvidia_nim_proxy = ""
mock.open_router_proxy = ""
mock.lmstudio_proxy = ""
mock.llamacpp_proxy = ""
mock.kimi_proxy = ""
mock.wafer_proxy = ""
mock.opencode_proxy = ""
mock.zai_proxy = ""
mock.openai_proxy = ""
mock.provider_rate_limit = 40
mock.provider_rate_window = 60
mock.provider_max_concurrency = 5
mock.http_read_timeout = 300.0
mock.http_write_timeout = 10.0
mock.http_connect_timeout = 10.0
mock.enable_model_thinking = True
mock.log_raw_sse_events = False
mock.log_api_error_tracebacks = False
mock.enable_api_key_passthrough = False
mock.anthropic_auth_token = ""
for key, value in overrides.items():
setattr(mock, key, value)
return mock
class TestBuildProviderConfigPassthrough:
"""Tests for build_provider_config with passthrough_api_key."""
def test_passthrough_disabled_uses_configured_credential(self):
"""When passthrough is disabled, the configured provider key is used."""
settings = _make_settings(enable_api_key_passthrough=False)
descriptor = PROVIDER_CATALOG["nvidia_nim"]
config = build_provider_config(descriptor, settings)
assert config.api_key == "test_key"
def test_passthrough_enabled_uses_passthrough_key(self):
"""When passthrough is enabled and key provided, it overrides credential."""
settings = _make_settings(enable_api_key_passthrough=True)
descriptor = PROVIDER_CATALOG["nvidia_nim"]
config = build_provider_config(
descriptor, settings, passthrough_api_key="sk-passthrough-123"
)
assert config.api_key == "sk-passthrough-123"
def test_passthrough_enabled_empty_key_raises(self):
"""When passthrough is enabled but key is empty, raises AuthenticationError."""
from providers.exceptions import AuthenticationError
settings = _make_settings(enable_api_key_passthrough=True)
descriptor = PROVIDER_CATALOG["nvidia_nim"]
with pytest.raises(AuthenticationError, match="ENABLE_API_KEY_PASSTHROUGH"):
build_provider_config(descriptor, settings, passthrough_api_key="")
def test_passthrough_enabled_key_present_but_passthrough_disabled(self):
"""Passthrough key is ignored when enable_api_key_passthrough is False."""
settings = _make_settings(enable_api_key_passthrough=False)
descriptor = PROVIDER_CATALOG["nvidia_nim"]
config = build_provider_config(
descriptor, settings, passthrough_api_key="sk-ignored"
)
assert config.api_key == "test_key"
def test_passthrough_skips_credential_validation(self):
"""With passthrough active, missing per-provider key does not raise."""
settings = _make_settings(
enable_api_key_passthrough=True, nvidia_nim_api_key=""
)
descriptor = PROVIDER_CATALOG["nvidia_nim"]
config = build_provider_config(
descriptor, settings, passthrough_api_key="sk-valid"
)
assert config.api_key == "sk-valid"
def test_passthrough_disabled_missing_key_raises(self):
"""Without passthrough, missing required credential still raises."""
from providers.exceptions import AuthenticationError
settings = _make_settings(
enable_api_key_passthrough=False, nvidia_nim_api_key=""
)
descriptor = PROVIDER_CATALOG["nvidia_nim"]
with pytest.raises(AuthenticationError, match="NVIDIA_NIM_API_KEY"):
build_provider_config(descriptor, settings)
def test_passthrough_with_static_credential_provider(self):
"""Passthrough overrides even static credentials (e.g. lmstudio)."""
settings = _make_settings(enable_api_key_passthrough=True)
descriptor = PROVIDER_CATALOG["lmstudio"]
config = build_provider_config(
descriptor, settings, passthrough_api_key="sk-lmstudio-passthrough"
)
assert config.api_key == "sk-lmstudio-passthrough"
class TestCreateProviderPassthrough:
"""Tests for create_provider with passthrough_api_key."""
def test_create_provider_with_passthrough(self):
"""create_provider passes passthrough key through to provider config."""
settings = _make_settings(enable_api_key_passthrough=True)
with patch("providers.openai_compat.AsyncOpenAI"):
provider = create_provider(
"nvidia_nim", settings, passthrough_api_key="sk-nim-passthrough"
)
assert isinstance(provider, NvidiaNimProvider)
assert provider._api_key == "sk-nim-passthrough"
def test_create_provider_without_passthrough(self):
"""create_provider uses configured key when passthrough is empty."""
settings = _make_settings(enable_api_key_passthrough=False)
with patch("providers.openai_compat.AsyncOpenAI"):
provider = create_provider("nvidia_nim", settings)
assert isinstance(provider, NvidiaNimProvider)
assert provider._api_key == "test_key"
class TestProviderRegistryPassthrough:
"""Tests for ProviderRegistry.get with passthrough_api_key."""
def test_different_passthrough_keys_get_distinct_instances(self):
"""Different passthrough keys produce separate cached provider instances."""
registry = ProviderRegistry()
settings = _make_settings(enable_api_key_passthrough=True)
with patch("providers.openai_compat.AsyncOpenAI"):
p1 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-key-1")
p2 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-key-2")
assert p1 is not p2
assert isinstance(p1, NvidiaNimProvider)
assert isinstance(p2, NvidiaNimProvider)
assert p1._api_key == "sk-key-1"
assert p2._api_key == "sk-key-2"
def test_same_passthrough_key_returns_cached_instance(self):
"""Same passthrough key returns the cached provider instance."""
registry = ProviderRegistry()
settings = _make_settings(enable_api_key_passthrough=True)
with patch("providers.openai_compat.AsyncOpenAI"):
p1 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-same")
p2 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-same")
assert p1 is p2
def test_no_passthrough_caches_by_provider_id(self):
"""Without passthrough, caching works by provider_id as before."""
registry = ProviderRegistry()
settings = _make_settings(enable_api_key_passthrough=False)
with patch("providers.openai_compat.AsyncOpenAI"):
p1 = registry.get("nvidia_nim", settings)
p2 = registry.get("nvidia_nim", settings)
assert p1 is p2
def test_passthrough_empty_key_raises_in_registry(self):
"""A passthrough-enabled registry raises when no client token given."""
from providers.exceptions import AuthenticationError
registry = ProviderRegistry()
settings = _make_settings(enable_api_key_passthrough=True)
with (
pytest.raises(AuthenticationError, match="ENABLE_API_KEY_PASSTHROUGH"),
patch("providers.openai_compat.AsyncOpenAI"),
):
registry.get("nvidia_nim", settings)
def test_different_passthrough_keys_are_distinct(self):
"""Providers with different passthrough keys are separate instances."""
registry = ProviderRegistry()
settings = _make_settings(enable_api_key_passthrough=True)
with patch("providers.openai_compat.AsyncOpenAI"):
p_a = registry.get("nvidia_nim", settings, passthrough_api_key="sk-aaa")
p_b = registry.get("nvidia_nim", settings, passthrough_api_key="sk-bbb")
assert p_a is not p_b
class TestRequireApiKeyReturnsToken:
"""Tests that require_api_key returns the extracted client token."""
def test_returns_token_on_valid_x_api_key(self):
"""Valid x-api-key header returns the extracted token."""
request = MagicMock()
request.headers = {"x-api-key": "my-secret-key"}
settings = _make_settings(anthropic_auth_token="my-secret-key")
token = require_api_key(request, settings)
assert token == "my-secret-key"
def test_returns_token_on_valid_bearer(self):
"""Valid Bearer authorization returns the token (without Bearer prefix)."""
request = MagicMock()
request.headers = {"authorization": "Bearer my-bearer-key"}
settings = _make_settings(anthropic_auth_token="my-bearer-key")
token = require_api_key(request, settings)
assert token == "my-bearer-key"
def test_returns_empty_string_when_no_auth_configured(self):
"""When ANTHROPIC_AUTH_TOKEN is empty, returns empty string."""
request = MagicMock()
request.headers = {}
settings = _make_settings(anthropic_auth_token="")
token = require_api_key(request, settings)
assert token == ""
def test_raises_401_on_invalid_key(self):
"""Invalid API key raises 401."""
request = MagicMock()
request.headers = {"x-api-key": "wrong-key"}
settings = _make_settings(anthropic_auth_token="correct-key")
with pytest.raises(HTTPException) as exc_info:
require_api_key(request, settings)
assert exc_info.value.status_code == 401
def test_strips_colon_suffix(self):
"""Token with colon suffix (model name append) gets stripped."""
request = MagicMock()
request.headers = {"x-api-key": "my-key:claude-3-opus"}
settings = _make_settings(anthropic_auth_token="my-key")
token = require_api_key(request, settings)
assert token == "my-key"
class TestResolveProviderPassthrough:
"""Tests for resolve_provider with passthrough_api_key."""
def test_resolve_provider_passes_passthrough_key(self):
"""resolve_provider forwards passthrough_api_key to the registry."""
settings = _make_settings(enable_api_key_passthrough=True)
app = SimpleNamespace(state=State())
registry = ProviderRegistry()
app.state.provider_registry = registry
with patch("providers.openai_compat.AsyncOpenAI"):
provider = resolve_provider(
"nvidia_nim",
app=cast(Starlette, app),
settings=settings,
passthrough_api_key="sk-resolve-test",
)
assert isinstance(provider, NvidiaNimProvider)
assert provider._api_key == "sk-resolve-test"

View file

@ -52,6 +52,8 @@ def _make_mock_settings(**overrides):
mock.http_write_timeout = 10.0
mock.http_connect_timeout = 10.0
mock.enable_model_thinking = True
mock.enable_api_key_passthrough = False
mock.anthropic_auth_token = ""
for key, value in overrides.items():
setattr(mock, key, value)
return mock

View file

@ -51,6 +51,8 @@ def _make_settings(**overrides):
mock.http_write_timeout = 10.0
mock.http_connect_timeout = 10.0
mock.enable_model_thinking = True
mock.enable_api_key_passthrough = False
mock.anthropic_auth_token = ""
mock.nim = NimSettings()
for key, value in overrides.items():
setattr(mock, key, value)