From 05d50f269c64027c6b2c3f03dd93e782c2e299eb Mon Sep 17 00:00:00 2001 From: Chinfeng Date: Sun, 17 May 2026 16:03:54 +0800 Subject: [PATCH 1/4] Add NVIDIA_NIM_BASE_URL, OLLAMA_API_KEY, LLAMACPP_API_KEY, and OpenAI provider support - Add NVIDIA_NIM_BASE_URL env var for custom NIM endpoint - Add optional OLLAMA_API_KEY and LLAMACPP_API_KEY for remote auth - Add openai provider with OPENAI_API_KEY and OPENAI_BASE_URL - Send Bearer auth in Ollama/llamacpp when non-default API key is set --- .env.example | 16 +++++++-- api/admin_config.py | 63 ++++++++++++++++++++++++++++++++++++ api/services.py | 2 +- config/provider_catalog.py | 20 ++++++++++-- config/settings.py | 15 +++++++++ providers/defaults.py | 2 ++ providers/llamacpp/client.py | 12 +++++++ providers/ollama/client.py | 11 +++++++ providers/openai/__init__.py | 10 ++++++ providers/openai/client.py | 31 ++++++++++++++++++ providers/openai/request.py | 35 ++++++++++++++++++++ providers/registry.py | 7 ++++ smoke/lib/config.py | 3 ++ 13 files changed, 221 insertions(+), 6 deletions(-) create mode 100644 providers/openai/__init__.py create mode 100644 providers/openai/client.py create mode 100644 providers/openai/request.py diff --git a/.env.example b/.env.example index 27716f6..39bdf2b 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,6 @@ # NVIDIA NIM Config NVIDIA_NIM_API_KEY="" +NVIDIA_NIM_BASE_URL="https://integrate.api.nvidia.com/v1" # OpenRouter Config @@ -26,22 +27,28 @@ OPENCODE_API_KEY="" ZAI_API_KEY="" ZAI_BASE_URL="https://api.z.ai/api/coding/paas/v4" +# OpenAI Config (official Chat Completions API) +OPENAI_API_KEY="" +OPENAI_BASE_URL="https://api.openai.com/v1" + # LM Studio Config (local provider, no API key required) LM_STUDIO_BASE_URL="http://localhost:1234/v1" -# Llama.cpp Config (local provider, no API key required) +# Llama.cpp Config (local provider, API key optional) LLAMACPP_BASE_URL="http://localhost:8080/v1" +LLAMACPP_API_KEY="" -# Ollama Config (local provider, no API key required) +# Ollama Config (local provider, API key optional) OLLAMA_BASE_URL="http://localhost:11434" +OLLAMA_API_KEY="" # All Claude model requests are mapped to these models, plain model is fallback # Format: provider_type/model/name -# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" | "ollama" | "kimi" | "wafer" | "opencode" | "zai" +# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" | "ollama" | "kimi" | "wafer" | "opencode" | "zai" | "openai" MODEL_OPUS= MODEL_SONNET= MODEL_HAIKU= @@ -60,6 +67,7 @@ FCC_SMOKE_MODEL_KIMI= FCC_SMOKE_MODEL_WAFER= FCC_SMOKE_MODEL_OPENCODE= FCC_SMOKE_MODEL_ZAI= +FCC_SMOKE_MODEL_OPENAI= FCC_SMOKE_NIM_MODELS= FCC_SMOKE_NIM_EXTRA_MODELS= FCC_SMOKE_OPENROUTER_FREE_MODELS= @@ -85,6 +93,8 @@ KIMI_PROXY="" WAFER_PROXY="" OPENCODE_PROXY="" ZAI_PROXY="" +OPENAI_PROXY="" +OLLAMA_PROXY="" PROVIDER_RATE_LIMIT=1 PROVIDER_RATE_WINDOW=3 diff --git a/api/admin_config.py b/api/admin_config.py index cb4f7ab..246d0c4 100644 --- a/api/admin_config.py +++ b/api/admin_config.py @@ -184,6 +184,51 @@ FIELDS: tuple[ConfigFieldSpec, ...] = ( default="https://api.z.ai/api/coding/paas/v4", description="Z.ai OpenAI-compatible Coding Plan endpoint.", ), + ConfigFieldSpec( + "OPENAI_API_KEY", + "OpenAI API Key", + "providers", + "secret", + settings_attr="openai_api_key", + secret=True, + description="OpenAI API key for direct Chat Completions access.", + ), + ConfigFieldSpec( + "OPENAI_BASE_URL", + "OpenAI Base URL", + "providers", + settings_attr="openai_base_url", + default="https://api.openai.com/v1", + description="OpenAI-compatible Chat Completions endpoint.", + ), + ConfigFieldSpec( + "NVIDIA_NIM_BASE_URL", + "NVIDIA NIM Base URL", + "providers", + settings_attr="nvidia_nim_base_url", + default="https://integrate.api.nvidia.com/v1", + description="NVIDIA NIM API endpoint override.", + ), + ConfigFieldSpec( + "LLAMACPP_API_KEY", + "llama.cpp API Key", + "providers", + "secret", + settings_attr="llamacpp_api_key", + secret=True, + advanced=True, + description="Optional API key for llama.cpp server authentication.", + ), + ConfigFieldSpec( + "OLLAMA_API_KEY", + "Ollama API Key", + "providers", + "secret", + settings_attr="ollama_api_key", + secret=True, + advanced=True, + description="Optional API key for Ollama server authentication.", + ), ConfigFieldSpec( "LM_STUDIO_BASE_URL", "LM Studio Base URL", @@ -277,6 +322,24 @@ FIELDS: tuple[ConfigFieldSpec, ...] = ( secret=True, advanced=True, ), + ConfigFieldSpec( + "OPENAI_PROXY", + "OpenAI Proxy", + "providers", + "secret", + settings_attr="openai_proxy", + secret=True, + advanced=True, + ), + ConfigFieldSpec( + "OLLAMA_PROXY", + "Ollama Proxy", + "providers", + "secret", + settings_attr="ollama_proxy", + secret=True, + advanced=True, + ), ConfigFieldSpec( "MODEL", "Default Model", diff --git a/api/services.py b/api/services.py index 6566248..1f648bc 100644 --- a/api/services.py +++ b/api/services.py @@ -34,7 +34,7 @@ TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], i ProviderGetter = Callable[[str], BaseProvider] # Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). -_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "zai"}) +_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "zai", "openai"}) def anthropic_sse_streaming_response( diff --git a/config/provider_catalog.py b/config/provider_catalog.py index e2a53c8..45d8e53 100644 --- a/config/provider_catalog.py +++ b/config/provider_catalog.py @@ -25,6 +25,7 @@ LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" OLLAMA_DEFAULT_BASE = "http://localhost:11434" OPENCODE_DEFAULT_BASE = "https://opencode.ai/zen/v1" ZAI_DEFAULT_BASE = "https://api.z.ai/api/coding/paas/v4" +OPENAI_DEFAULT_BASE = "https://api.openai.com/v1" @dataclass(frozen=True, slots=True) @@ -51,6 +52,7 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { credential_url="https://build.nvidia.com/settings/api-keys", credential_attr="nvidia_nim_api_key", default_base_url=NVIDIA_NIM_DEFAULT_BASE, + base_url_attr="nvidia_nim_base_url", proxy_attr="nvidia_nim_proxy", capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), ), @@ -85,7 +87,8 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { "llamacpp": ProviderDescriptor( provider_id="llamacpp", transport_type="anthropic_messages", - static_credential="llamacpp", + credential_env=None, + credential_attr="llamacpp_api_key", default_base_url=LLAMACPP_DEFAULT_BASE, base_url_attr="llamacpp_base_url", proxy_attr="llamacpp_proxy", @@ -94,9 +97,11 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { "ollama": ProviderDescriptor( provider_id="ollama", transport_type="anthropic_messages", - static_credential="ollama", + credential_env=None, + credential_attr="ollama_api_key", default_base_url=OLLAMA_DEFAULT_BASE, base_url_attr="ollama_base_url", + proxy_attr="ollama_proxy", capabilities=( "chat", "streaming", @@ -146,6 +151,17 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { proxy_attr="zai_proxy", capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), ), + "openai": ProviderDescriptor( + provider_id="openai", + transport_type="openai_chat", + credential_env="OPENAI_API_KEY", + credential_url="https://platform.openai.com/api-keys", + credential_attr="openai_api_key", + default_base_url=OPENAI_DEFAULT_BASE, + base_url_attr="openai_base_url", + proxy_attr="openai_proxy", + capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), + ), } # Order matches docs / historical error text; must match PROVIDER_CATALOG keys. diff --git a/config/settings.py b/config/settings.py index e451761..7504def 100644 --- a/config/settings.py +++ b/config/settings.py @@ -129,6 +129,14 @@ class Settings(BaseSettings): validation_alias="ZAI_BASE_URL", ) + # ==================== OpenAI Config ==================== + openai_api_key: str = Field(default="", validation_alias="OPENAI_API_KEY") + openai_base_url: str = Field( + default="https://api.openai.com/v1", + validation_alias="OPENAI_BASE_URL", + ) + openai_proxy: str = Field(default="", validation_alias="OPENAI_PROXY") + # ==================== Messaging Platform Selection ==================== # Valid: "telegram" | "discord" | "none" messaging_platform: str = Field( @@ -143,6 +151,10 @@ class Settings(BaseSettings): # ==================== NVIDIA NIM Config ==================== nvidia_nim_api_key: str = "" + nvidia_nim_base_url: str = Field( + default="https://integrate.api.nvidia.com/v1", + validation_alias="NVIDIA_NIM_BASE_URL", + ) # ==================== LM Studio Config ==================== lm_studio_base_url: str = Field( @@ -155,12 +167,15 @@ class Settings(BaseSettings): default="http://localhost:8080/v1", validation_alias="LLAMACPP_BASE_URL", ) + llamacpp_api_key: str = Field(default="", validation_alias="LLAMACPP_API_KEY") # ==================== Ollama Config ==================== ollama_base_url: str = Field( default="http://localhost:11434", validation_alias="OLLAMA_BASE_URL", ) + ollama_api_key: str = Field(default="", validation_alias="OLLAMA_API_KEY") + ollama_proxy: str = Field(default="", validation_alias="OLLAMA_PROXY") # ==================== Model ==================== # All Claude model requests are mapped to this single model (fallback) diff --git a/providers/defaults.py b/providers/defaults.py index a0ec814..16155a6 100644 --- a/providers/defaults.py +++ b/providers/defaults.py @@ -8,6 +8,7 @@ from config.provider_catalog import ( LMSTUDIO_DEFAULT_BASE, NVIDIA_NIM_DEFAULT_BASE, OLLAMA_DEFAULT_BASE, + OPENAI_DEFAULT_BASE, OPENCODE_DEFAULT_BASE, OPENROUTER_DEFAULT_BASE, WAFER_DEFAULT_BASE, @@ -22,6 +23,7 @@ __all__ = ( "LMSTUDIO_DEFAULT_BASE", "NVIDIA_NIM_DEFAULT_BASE", "OLLAMA_DEFAULT_BASE", + "OPENAI_DEFAULT_BASE", "OPENCODE_DEFAULT_BASE", "OPENROUTER_DEFAULT_BASE", "WAFER_DEFAULT_BASE", diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py index 891022d..ea427f7 100644 --- a/providers/llamacpp/client.py +++ b/providers/llamacpp/client.py @@ -14,3 +14,15 @@ class LlamaCppProvider(AnthropicMessagesTransport): provider_name="LLAMACPP", default_base_url=LLAMACPP_DEFAULT_BASE, ) + self._api_key = config.api_key or "llamacpp" + + def _request_headers(self) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self._api_key and self._api_key != "llamacpp": + headers["Authorization"] = f"Bearer {self._api_key}" + return headers + + def _model_list_headers(self) -> dict[str, str]: + if self._api_key and self._api_key != "llamacpp": + return {"Authorization": f"Bearer {self._api_key}"} + return {} diff --git a/providers/ollama/client.py b/providers/ollama/client.py index 86fa542..e5edf14 100644 --- a/providers/ollama/client.py +++ b/providers/ollama/client.py @@ -19,6 +19,17 @@ class OllamaProvider(AnthropicMessagesTransport): ) self._api_key = config.api_key or "ollama" + def _request_headers(self) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self._api_key and self._api_key != "ollama": + headers["Authorization"] = f"Bearer {self._api_key}" + return headers + + def _model_list_headers(self) -> dict[str, str]: + if self._api_key and self._api_key != "ollama": + return {"Authorization": f"Bearer {self._api_key}"} + return {} + async def _send_stream_request(self, body: dict) -> httpx.Response: """Create a streaming native Anthropic messages response.""" request = self._client.build_request( diff --git a/providers/openai/__init__.py b/providers/openai/__init__.py new file mode 100644 index 0000000..8b30d46 --- /dev/null +++ b/providers/openai/__init__.py @@ -0,0 +1,10 @@ +"""OpenAI provider exports.""" + +from providers.defaults import OPENAI_DEFAULT_BASE + +from .client import OpenAIProvider + +__all__ = [ + "OPENAI_DEFAULT_BASE", + "OpenAIProvider", +] diff --git a/providers/openai/client.py b/providers/openai/client.py new file mode 100644 index 0000000..4dc5284 --- /dev/null +++ b/providers/openai/client.py @@ -0,0 +1,31 @@ +"""OpenAI provider implementation.""" + +from __future__ import annotations + +from typing import Any + +from providers.base import ProviderConfig +from providers.defaults import OPENAI_DEFAULT_BASE +from providers.openai_compat import OpenAIChatTransport + +from .request import build_request_body + + +class OpenAIProvider(OpenAIChatTransport): + """OpenAI provider using the official Chat Completions API.""" + + def __init__(self, config: ProviderConfig): + super().__init__( + config, + provider_name="OPENAI", + base_url=config.base_url or OPENAI_DEFAULT_BASE, + api_key=config.api_key, + ) + + def _build_request_body( + self, request: Any, thinking_enabled: bool | None = None + ) -> dict: + return build_request_body( + request, + thinking_enabled=self._is_thinking_enabled(request, thinking_enabled), + ) diff --git a/providers/openai/request.py b/providers/openai/request.py new file mode 100644 index 0000000..cdadb17 --- /dev/null +++ b/providers/openai/request.py @@ -0,0 +1,35 @@ +"""Request builder for OpenAI provider.""" + +from typing import Any + +from loguru import logger + +from core.anthropic import ReasoningReplayMode, build_base_request_body +from core.anthropic.conversion import OpenAIConversionError +from providers.exceptions import InvalidRequestError + + +def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: + """Build OpenAI-format request body from Anthropic request for OpenAI.""" + logger.debug( + "OPENAI_REQUEST: conversion start model={} msgs={}", + getattr(request_data, "model", "?"), + len(getattr(request_data, "messages", [])), + ) + try: + body = build_base_request_body( + request_data, + reasoning_replay=ReasoningReplayMode.REASONING_CONTENT + if thinking_enabled + else ReasoningReplayMode.DISABLED, + ) + except OpenAIConversionError as exc: + raise InvalidRequestError(str(exc)) from exc + + logger.debug( + "OPENAI_REQUEST: conversion done model={} msgs={} tools={}", + body.get("model"), + len(body.get("messages", [])), + len(body.get("tools", [])), + ) + return body diff --git a/providers/registry.py b/providers/registry.py index 4145702..9f04760 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -92,6 +92,12 @@ def _create_zai(config: ProviderConfig, _settings: Settings) -> BaseProvider: return ZaiProvider(config) +def _create_openai(config: ProviderConfig, _settings: Settings) -> BaseProvider: + from providers.openai import OpenAIProvider + + return OpenAIProvider(config) + + PROVIDER_FACTORIES: dict[str, ProviderFactory] = { "nvidia_nim": _create_nvidia_nim, "open_router": _create_open_router, @@ -103,6 +109,7 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = { "wafer": _create_wafer, "opencode": _create_opencode, "zai": _create_zai, + "openai": _create_openai, } if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set( diff --git a/smoke/lib/config.py b/smoke/lib/config.py index f69267c..7309d84 100644 --- a/smoke/lib/config.py +++ b/smoke/lib/config.py @@ -51,6 +51,7 @@ PROVIDER_SMOKE_DEFAULT_MODELS: dict[str, str] = { "wafer": "wafer/DeepSeek-V4-Pro", "opencode": "opencode/gpt-5.3-codex", "zai": "zai/glm-5.1", + "openai": "openai/gpt-4o", } NVIDIA_NIM_CLI_DEFAULT_MODELS: tuple[str, ...] = ( @@ -237,6 +238,8 @@ class SmokeConfig: return bool(self.settings.opencode_api_key.strip()) if provider == "zai": return bool(self.settings.zai_api_key.strip()) + if provider == "openai": + return bool(self.settings.openai_api_key.strip()) return False From 2287ace66184f76cadc2fac8cc67ac7624c709b5 Mon Sep 17 00:00:00 2001 From: chinfeng Date: Sun, 17 May 2026 20:54:16 +0800 Subject: [PATCH 2/4] Add API key pass-through feature (ENABLE_API_KEY_PASSTHROUGH) When enabled, the client's auth token is used as the upstream provider bearer token, eliminating the need for separate per-provider API keys. Provider registry uses composite cache keys (provider_id:api_key) to isolate distinct client tokens. --- .env.example | 6 + api/dependencies.py | 33 +++- api/routes.py | 11 +- config/settings.py | 6 + providers/registry.py | 44 ++++- tests/api/test_api_key_passthrough.py | 271 ++++++++++++++++++++++++++ 6 files changed, 353 insertions(+), 18 deletions(-) create mode 100644 tests/api/test_api_key_passthrough.py diff --git a/.env.example b/.env.example index 39bdf2b..4ac2911 100644 --- a/.env.example +++ b/.env.example @@ -110,6 +110,12 @@ HTTP_CONNECT_TIMEOUT=60 # Optional server API key (Anthropic-style) ANTHROPIC_AUTH_TOKEN="freecc" +# API Key Pass-through +# When true, the client's auth token is used as the upstream provider bearer +# token, eliminating the need for separate per-provider API keys. +# Requires ANTHROPIC_AUTH_TOKEN to be set (clients must authenticate). +ENABLE_API_KEY_PASSTHROUGH=false + # Messaging Platform: "telegram" | "discord" | "none" MESSAGING_PLATFORM="discord" diff --git a/api/dependencies.py b/api/dependencies.py index fbc37b0..1700a57 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -33,6 +33,7 @@ def resolve_provider( *, app: Starlette | None, settings: Settings, + passthrough_api_key: str = "", ) -> BaseProvider: """Resolve a provider using the app-scoped registry when ``app`` is set. @@ -43,6 +44,10 @@ def resolve_provider( When ``app`` is ``None`` (no HTTP context), uses the process-level :data:`_providers` cache only. + + When ``passthrough_api_key`` is non-empty and + ``settings.enable_api_key_passthrough`` is True, the client's auth token + replaces per-provider credentials in the upstream request. """ if app is not None: reg = getattr(app.state, "provider_registry", None) @@ -51,16 +56,27 @@ def resolve_provider( "Provider registry is not configured. Ensure AppRuntime startup ran " "or assign app.state.provider_registry for test apps." ) - return _resolve_with_registry(reg, provider_type, settings) + return _resolve_with_registry( + reg, provider_type, settings, passthrough_api_key=passthrough_api_key + ) return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings) def _resolve_with_registry( - registry: ProviderRegistry, provider_type: str, settings: Settings + registry: ProviderRegistry, + provider_type: str, + settings: Settings, + *, + passthrough_api_key: str = "", ) -> BaseProvider: - should_log_init = not registry.is_cached(provider_type) + cache_key = provider_type + if passthrough_api_key and settings.enable_api_key_passthrough: + cache_key = f"{provider_type}:{passthrough_api_key}" + should_log_init = not registry.is_cached(cache_key) try: - provider = registry.get(provider_type, settings) + provider = registry.get( + provider_type, settings, passthrough_api_key=passthrough_api_key + ) except AuthenticationError as e: # Provider :class:`~providers.exceptions.AuthenticationError` messages are # curated configuration hints (env var names, docs links), not upstream noise. @@ -90,16 +106,19 @@ def get_provider_for_type(provider_type: str) -> BaseProvider: def require_api_key( request: Request, settings: Settings = Depends(get_settings) -) -> None: +) -> str: """Require a server API key (Anthropic-style). Checks `x-api-key` header or `Authorization: Bearer ...` against `Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op. + + Returns the extracted client token for pass-through when + ``ENABLE_API_KEY_PASSTHROUGH`` is active. """ anthropic_auth_token = settings.anthropic_auth_token if not anthropic_auth_token: # No API key configured -> allow - return + return "" header = ( request.headers.get("x-api-key") @@ -125,6 +144,8 @@ def require_api_key( ): raise HTTPException(status_code=401, detail="Invalid API key") + return token + def get_provider() -> BaseProvider: """Get or create the default provider (``MODEL`` / ``provider_type``). diff --git a/api/routes.py b/api/routes.py index 6049494..654212a 100644 --- a/api/routes.py +++ b/api/routes.py @@ -62,12 +62,19 @@ SUPPORTED_CLAUDE_MODELS = [ def get_proxy_service( request: Request, settings: Settings = Depends(get_settings), + passthrough_api_key: str = Depends(require_api_key), ) -> ClaudeProxyService: """Build the request service for route handlers.""" + effective_passthrough = ( + passthrough_api_key if settings.enable_api_key_passthrough else "" + ) return ClaudeProxyService( settings, provider_getter=lambda provider_type: dependencies.resolve_provider( - provider_type, app=request.app, settings=settings + provider_type, + app=request.app, + settings=settings, + passthrough_api_key=effective_passthrough, ), token_counter=get_token_count, ) @@ -167,7 +174,6 @@ def _build_models_list_response( async def create_message( request_data: MessagesRequest, service: ClaudeProxyService = Depends(get_proxy_service), - _auth=Depends(require_api_key), ): """Create a message (always streaming).""" return service.create_message(request_data) @@ -183,7 +189,6 @@ async def probe_messages(_auth=Depends(require_api_key)): async def count_tokens( request_data: TokenCountRequest, service: ClaudeProxyService = Depends(get_proxy_service), - _auth=Depends(require_api_key), ): """Count tokens for a request.""" return service.count_tokens(request_data) diff --git a/config/settings.py b/config/settings.py index 7504def..9c2bf1d 100644 --- a/config/settings.py +++ b/config/settings.py @@ -330,6 +330,12 @@ class Settings(BaseSettings): anthropic_auth_token: str = Field( default="", validation_alias="ANTHROPIC_AUTH_TOKEN" ) + # When true, the client's auth token is passed through as the upstream + # provider bearer token, eliminating the need for separate per-provider + # API keys. Requires ANTHROPIC_AUTH_TOKEN to be set. + enable_api_key_passthrough: bool = Field( + default=False, validation_alias="ENABLE_API_KEY_PASSTHROUGH" + ) @model_validator(mode="before") @classmethod diff --git a/providers/registry.py b/providers/registry.py index 9f04760..38de129 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -149,10 +149,16 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None def build_provider_config( - descriptor: ProviderDescriptor, settings: Settings + descriptor: ProviderDescriptor, + settings: Settings, + *, + passthrough_api_key: str = "", ) -> ProviderConfig: - credential = _credential_for(descriptor, settings) - _require_credential(descriptor, credential) + if passthrough_api_key and settings.enable_api_key_passthrough: + credential = passthrough_api_key + else: + credential = _credential_for(descriptor, settings) + _require_credential(descriptor, credential) base_url = _string_attr( settings, descriptor.base_url_attr, descriptor.default_base_url or "" ) @@ -173,7 +179,12 @@ def build_provider_config( ) -def create_provider(provider_id: str, settings: Settings) -> BaseProvider: +def create_provider( + provider_id: str, + settings: Settings, + *, + passthrough_api_key: str = "", +) -> BaseProvider: descriptor = PROVIDER_DESCRIPTORS.get(provider_id) if descriptor is None: supported = "', '".join(PROVIDER_DESCRIPTORS) @@ -181,7 +192,9 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider: f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'" ) - config = build_provider_config(descriptor, settings) + config = build_provider_config( + descriptor, settings, passthrough_api_key=passthrough_api_key + ) factory = PROVIDER_FACTORIES.get(provider_id) if factory is None: raise AssertionError(f"Unhandled provider descriptor: {provider_id}") @@ -267,10 +280,23 @@ class ProviderRegistry: """Return whether a provider for this id is already in the cache.""" return provider_id in self._providers - def get(self, provider_id: str, settings: Settings) -> BaseProvider: - if provider_id not in self._providers: - self._providers[provider_id] = create_provider(provider_id, settings) - return self._providers[provider_id] + def get( + self, + provider_id: str, + settings: Settings, + *, + passthrough_api_key: str = "", + ) -> BaseProvider: + # When pass-through is active, include the API key in the cache key + # so different client tokens get distinct provider instances. + cache_key = provider_id + if passthrough_api_key and settings.enable_api_key_passthrough: + cache_key = f"{provider_id}:{passthrough_api_key}" + if cache_key not in self._providers: + self._providers[cache_key] = create_provider( + provider_id, settings, passthrough_api_key=passthrough_api_key + ) + return self._providers[cache_key] def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None: """Store a provider model-list result for later instant API responses.""" diff --git a/tests/api/test_api_key_passthrough.py b/tests/api/test_api_key_passthrough.py new file mode 100644 index 0000000..e5839ee --- /dev/null +++ b/tests/api/test_api_key_passthrough.py @@ -0,0 +1,271 @@ +"""Tests for API key pass-through (ENABLE_API_KEY_PASSTHROUGH).""" + +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from starlette.applications import Starlette +from starlette.datastructures import State + +from api.dependencies import ( + require_api_key, + resolve_provider, +) +from config.provider_catalog import PROVIDER_CATALOG +from providers.nvidia_nim import NvidiaNimProvider +from providers.registry import ProviderRegistry, build_provider_config, create_provider + + +def _make_settings(**overrides): + """Create a mock settings object with all required fields.""" + mock = MagicMock() + mock.model = "nvidia_nim/meta/llama3" + mock.provider_type = "nvidia_nim" + mock.nvidia_nim_api_key = "test_key" + mock.open_router_api_key = "test_openrouter_key" + mock.deepseek_api_key = "test_deepseek_key" + mock.wafer_api_key = "test_wafer_key" + mock.opencode_api_key = "test_opencode_key" + mock.zai_api_key = "test_zai_key" + mock.openai_api_key = "test_openai_key" + mock.kimi_api_key = "test_kimi_key" + mock.lm_studio_base_url = "http://localhost:1234/v1" + mock.llamacpp_base_url = "http://localhost:8080/v1" + mock.ollama_base_url = "http://localhost:11434" + mock.llamacpp_api_key = "" + mock.ollama_api_key = "" + mock.nvidia_nim_proxy = "" + mock.open_router_proxy = "" + mock.lmstudio_proxy = "" + mock.llamacpp_proxy = "" + mock.kimi_proxy = "" + mock.wafer_proxy = "" + mock.opencode_proxy = "" + mock.zai_proxy = "" + mock.openai_proxy = "" + mock.provider_rate_limit = 40 + mock.provider_rate_window = 60 + mock.provider_max_concurrency = 5 + mock.http_read_timeout = 300.0 + mock.http_write_timeout = 10.0 + mock.http_connect_timeout = 10.0 + mock.enable_model_thinking = True + mock.log_raw_sse_events = False + mock.log_api_error_tracebacks = False + mock.enable_api_key_passthrough = False + mock.anthropic_auth_token = "" + for key, value in overrides.items(): + setattr(mock, key, value) + return mock + + +class TestBuildProviderConfigPassthrough: + """Tests for build_provider_config with passthrough_api_key.""" + + def test_passthrough_disabled_uses_configured_credential(self): + """When passthrough is disabled, the configured provider key is used.""" + settings = _make_settings(enable_api_key_passthrough=False) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config(descriptor, settings) + assert config.api_key == "test_key" + + def test_passthrough_enabled_uses_passthrough_key(self): + """When passthrough is enabled and key provided, it overrides credential.""" + settings = _make_settings(enable_api_key_passthrough=True) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-passthrough-123" + ) + assert config.api_key == "sk-passthrough-123" + + def test_passthrough_enabled_empty_key_uses_configured_credential(self): + """When passthrough is enabled but key is empty, falls back to credential.""" + settings = _make_settings(enable_api_key_passthrough=True) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config(descriptor, settings, passthrough_api_key="") + assert config.api_key == "test_key" + + def test_passthrough_enabled_key_present_but_passthrough_disabled(self): + """Passthrough key is ignored when enable_api_key_passthrough is False.""" + settings = _make_settings(enable_api_key_passthrough=False) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-ignored" + ) + assert config.api_key == "test_key" + + def test_passthrough_skips_credential_validation(self): + """With passthrough active, missing per-provider key does not raise.""" + settings = _make_settings( + enable_api_key_passthrough=True, nvidia_nim_api_key="" + ) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-valid" + ) + assert config.api_key == "sk-valid" + + def test_passthrough_disabled_missing_key_raises(self): + """Without passthrough, missing required credential still raises.""" + from providers.exceptions import AuthenticationError + + settings = _make_settings( + enable_api_key_passthrough=False, nvidia_nim_api_key="" + ) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + with pytest.raises(AuthenticationError, match="NVIDIA_NIM_API_KEY"): + build_provider_config(descriptor, settings) + + def test_passthrough_with_static_credential_provider(self): + """Passthrough overrides even static credentials (e.g. lmstudio).""" + settings = _make_settings(enable_api_key_passthrough=True) + descriptor = PROVIDER_CATALOG["lmstudio"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-lmstudio-passthrough" + ) + assert config.api_key == "sk-lmstudio-passthrough" + + +class TestCreateProviderPassthrough: + """Tests for create_provider with passthrough_api_key.""" + + def test_create_provider_with_passthrough(self): + """create_provider passes passthrough key through to provider config.""" + settings = _make_settings(enable_api_key_passthrough=True) + with patch("providers.openai_compat.AsyncOpenAI"): + provider = create_provider( + "nvidia_nim", settings, passthrough_api_key="sk-nim-passthrough" + ) + assert isinstance(provider, NvidiaNimProvider) + assert provider._api_key == "sk-nim-passthrough" + + def test_create_provider_without_passthrough(self): + """create_provider uses configured key when passthrough is empty.""" + settings = _make_settings(enable_api_key_passthrough=False) + with patch("providers.openai_compat.AsyncOpenAI"): + provider = create_provider("nvidia_nim", settings) + assert isinstance(provider, NvidiaNimProvider) + assert provider._api_key == "test_key" + + +class TestProviderRegistryPassthrough: + """Tests for ProviderRegistry.get with passthrough_api_key.""" + + def test_different_passthrough_keys_get_distinct_instances(self): + """Different passthrough keys produce separate cached provider instances.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with patch("providers.openai_compat.AsyncOpenAI"): + p1 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-key-1") + p2 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-key-2") + + assert p1 is not p2 + assert isinstance(p1, NvidiaNimProvider) + assert isinstance(p2, NvidiaNimProvider) + assert p1._api_key == "sk-key-1" + assert p2._api_key == "sk-key-2" + + def test_same_passthrough_key_returns_cached_instance(self): + """Same passthrough key returns the cached provider instance.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with patch("providers.openai_compat.AsyncOpenAI"): + p1 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-same") + p2 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-same") + + assert p1 is p2 + + def test_no_passthrough_caches_by_provider_id(self): + """Without passthrough, caching works by provider_id as before.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=False) + + with patch("providers.openai_compat.AsyncOpenAI"): + p1 = registry.get("nvidia_nim", settings) + p2 = registry.get("nvidia_nim", settings) + + assert p1 is p2 + + def test_passthrough_and_non_passthrough_are_distinct(self): + """A passthrough provider and non-passthrough provider are separate.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with patch("providers.openai_compat.AsyncOpenAI"): + p_no_key = registry.get("nvidia_nim", settings) + p_with_key = registry.get( + "nvidia_nim", settings, passthrough_api_key="sk-passthrough" + ) + + assert p_no_key is not p_with_key + + +class TestRequireApiKeyReturnsToken: + """Tests that require_api_key returns the extracted client token.""" + + def test_returns_token_on_valid_x_api_key(self): + """Valid x-api-key header returns the extracted token.""" + request = MagicMock() + request.headers = {"x-api-key": "my-secret-key"} + settings = _make_settings(anthropic_auth_token="my-secret-key") + token = require_api_key(request, settings) + assert token == "my-secret-key" + + def test_returns_token_on_valid_bearer(self): + """Valid Bearer authorization returns the token (without Bearer prefix).""" + request = MagicMock() + request.headers = {"authorization": "Bearer my-bearer-key"} + settings = _make_settings(anthropic_auth_token="my-bearer-key") + token = require_api_key(request, settings) + assert token == "my-bearer-key" + + def test_returns_empty_string_when_no_auth_configured(self): + """When ANTHROPIC_AUTH_TOKEN is empty, returns empty string.""" + request = MagicMock() + request.headers = {} + settings = _make_settings(anthropic_auth_token="") + token = require_api_key(request, settings) + assert token == "" + + def test_raises_401_on_invalid_key(self): + """Invalid API key raises 401.""" + request = MagicMock() + request.headers = {"x-api-key": "wrong-key"} + settings = _make_settings(anthropic_auth_token="correct-key") + with pytest.raises(HTTPException) as exc_info: + require_api_key(request, settings) + assert exc_info.value.status_code == 401 + + def test_strips_colon_suffix(self): + """Token with colon suffix (model name append) gets stripped.""" + request = MagicMock() + request.headers = {"x-api-key": "my-key:claude-3-opus"} + settings = _make_settings(anthropic_auth_token="my-key") + token = require_api_key(request, settings) + assert token == "my-key" + + +class TestResolveProviderPassthrough: + """Tests for resolve_provider with passthrough_api_key.""" + + def test_resolve_provider_passes_passthrough_key(self): + """resolve_provider forwards passthrough_api_key to the registry.""" + settings = _make_settings(enable_api_key_passthrough=True) + app = SimpleNamespace(state=State()) + registry = ProviderRegistry() + app.state.provider_registry = registry + + with patch("providers.openai_compat.AsyncOpenAI"): + provider = resolve_provider( + "nvidia_nim", + app=cast(Starlette, app), + settings=settings, + passthrough_api_key="sk-resolve-test", + ) + + assert isinstance(provider, NvidiaNimProvider) + assert provider._api_key == "sk-resolve-test" From e62cbf04eb0370d1f8873416b7d9023713907bbf Mon Sep 17 00:00:00 2001 From: chinfeng Date: Mon, 18 May 2026 00:00:40 +0800 Subject: [PATCH 3/4] fix: respect ENABLE_API_KEY_PASSTHROUGH in credential validation and startup checks When passthrough is enabled, skip the provider API key validation at startup and in build_provider_config, instead requiring the client auth token from the incoming request. Previously the credential check ignored the passthrough flag, causing 503 errors even when passthrough was correctly configured. --- config/settings.py | 2 ++ providers/registry.py | 20 ++++++++++++--- tests/api/test_api_key_passthrough.py | 35 ++++++++++++++++++--------- tests/api/test_dependencies.py | 2 ++ tests/providers/test_registry.py | 2 ++ 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/config/settings.py b/config/settings.py index 9c2bf1d..7f799f6 100644 --- a/config/settings.py +++ b/config/settings.py @@ -453,6 +453,8 @@ class Settings(BaseSettings): @model_validator(mode="after") def check_nvidia_nim_api_key(self) -> Settings: + if self.enable_api_key_passthrough: + return self if ( self.voice_note_enabled and self.whisper_device == "nvidia_nim" diff --git a/providers/registry.py b/providers/registry.py index 38de129..f73718c 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -137,11 +137,22 @@ def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str: return "" -def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None: +def _require_credential( + descriptor: ProviderDescriptor, + credential: str, + *, + passthrough_enabled: bool = False, +) -> None: if descriptor.credential_env is None: return if credential and credential.strip(): return + if passthrough_enabled: + raise AuthenticationError( + "ENABLE_API_KEY_PASSTHROUGH is active but no client auth token was " + "provided. Ensure the request includes a valid Authorization or " + "x-api-key header." + ) message = f"{descriptor.credential_env} is not set. Add it to your .env file." if descriptor.credential_url: message = f"{message} Get a key at {descriptor.credential_url}" @@ -154,11 +165,14 @@ def build_provider_config( *, passthrough_api_key: str = "", ) -> ProviderConfig: - if passthrough_api_key and settings.enable_api_key_passthrough: + passthrough_enabled = settings.enable_api_key_passthrough + if passthrough_api_key and passthrough_enabled: credential = passthrough_api_key + elif passthrough_enabled: + credential = "" else: credential = _credential_for(descriptor, settings) - _require_credential(descriptor, credential) + _require_credential(descriptor, credential, passthrough_enabled=passthrough_enabled) base_url = _string_attr( settings, descriptor.base_url_attr, descriptor.default_base_url or "" ) diff --git a/tests/api/test_api_key_passthrough.py b/tests/api/test_api_key_passthrough.py index e5839ee..3474978 100644 --- a/tests/api/test_api_key_passthrough.py +++ b/tests/api/test_api_key_passthrough.py @@ -80,12 +80,14 @@ class TestBuildProviderConfigPassthrough: ) assert config.api_key == "sk-passthrough-123" - def test_passthrough_enabled_empty_key_uses_configured_credential(self): - """When passthrough is enabled but key is empty, falls back to credential.""" + def test_passthrough_enabled_empty_key_raises(self): + """When passthrough is enabled but key is empty, raises AuthenticationError.""" + from providers.exceptions import AuthenticationError + settings = _make_settings(enable_api_key_passthrough=True) descriptor = PROVIDER_CATALOG["nvidia_nim"] - config = build_provider_config(descriptor, settings, passthrough_api_key="") - assert config.api_key == "test_key" + with pytest.raises(AuthenticationError, match="ENABLE_API_KEY_PASSTHROUGH"): + build_provider_config(descriptor, settings, passthrough_api_key="") def test_passthrough_enabled_key_present_but_passthrough_disabled(self): """Passthrough key is ignored when enable_api_key_passthrough is False.""" @@ -190,18 +192,29 @@ class TestProviderRegistryPassthrough: assert p1 is p2 - def test_passthrough_and_non_passthrough_are_distinct(self): - """A passthrough provider and non-passthrough provider are separate.""" + def test_passthrough_empty_key_raises_in_registry(self): + """A passthrough-enabled registry raises when no client token given.""" + from providers.exceptions import AuthenticationError + + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with ( + pytest.raises(AuthenticationError, match="ENABLE_API_KEY_PASSTHROUGH"), + patch("providers.openai_compat.AsyncOpenAI"), + ): + registry.get("nvidia_nim", settings) + + def test_different_passthrough_keys_are_distinct(self): + """Providers with different passthrough keys are separate instances.""" registry = ProviderRegistry() settings = _make_settings(enable_api_key_passthrough=True) with patch("providers.openai_compat.AsyncOpenAI"): - p_no_key = registry.get("nvidia_nim", settings) - p_with_key = registry.get( - "nvidia_nim", settings, passthrough_api_key="sk-passthrough" - ) + p_a = registry.get("nvidia_nim", settings, passthrough_api_key="sk-aaa") + p_b = registry.get("nvidia_nim", settings, passthrough_api_key="sk-bbb") - assert p_no_key is not p_with_key + assert p_a is not p_b class TestRequireApiKeyReturnsToken: diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py index 0264b8f..06786cf 100644 --- a/tests/api/test_dependencies.py +++ b/tests/api/test_dependencies.py @@ -52,6 +52,8 @@ def _make_mock_settings(**overrides): mock.http_write_timeout = 10.0 mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True + mock.enable_api_key_passthrough = False + mock.anthropic_auth_token = "" for key, value in overrides.items(): setattr(mock, key, value) return mock diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 2d24b25..2f1ecf1 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -51,6 +51,8 @@ def _make_settings(**overrides): mock.http_write_timeout = 10.0 mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True + mock.enable_api_key_passthrough = False + mock.anthropic_auth_token = "" mock.nim = NimSettings() for key, value in overrides.items(): setattr(mock, key, value) From b07892fbec7f403285bc3a86fd5d6cf4fb7a5982 Mon Sep 17 00:00:00 2001 From: chinfeng Date: Mon, 18 May 2026 00:51:01 +0800 Subject: [PATCH 4/4] fix: extract client token from request when passthrough is enabled without server auth When ENABLE_API_KEY_PASSTHROUGH is active but ANTHROPIC_AUTH_TOKEN is not set, require_api_key must still extract the client token from x-api-key or Authorization headers so it can be forwarded to the upstream provider. Previously it returned an empty string, causing "no client auth token was provided" errors even when the client sent a valid header. --- api/dependencies.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/api/dependencies.py b/api/dependencies.py index 1700a57..ded5aae 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -104,6 +104,28 @@ def get_provider_for_type(provider_type: str) -> BaseProvider: return resolve_provider(provider_type, app=None, settings=get_settings()) +def _parse_bearer_token(header: str) -> str: + """Extract and clean a token from an Authorization or x-api-key header value.""" + token = header + if header.lower().startswith("bearer "): + token = header.split(" ", 1)[1] + if token and ":" in token: + token = token.split(":", 1)[0] + return token + + +def _extract_client_token(request: Request) -> str: + """Extract a client auth token from the request for passthrough mode.""" + header = ( + request.headers.get("x-api-key") + or request.headers.get("authorization") + or request.headers.get("anthropic-auth-token") + ) + if not header: + raise HTTPException(status_code=401, detail="Missing API key") + return _parse_bearer_token(header) + + def require_api_key( request: Request, settings: Settings = Depends(get_settings) ) -> str: @@ -117,7 +139,10 @@ def require_api_key( """ anthropic_auth_token = settings.anthropic_auth_token if not anthropic_auth_token: - # No API key configured -> allow + # No server auth configured. In passthrough mode, still extract the + # client token from the request header so it can be forwarded upstream. + if settings.enable_api_key_passthrough: + return _extract_client_token(request) return "" header = ( @@ -128,14 +153,7 @@ def require_api_key( if not header: raise HTTPException(status_code=401, detail="Missing API key") - # Support both raw key in X-API-Key and Bearer token in Authorization - token = header - if header.lower().startswith("bearer "): - token = header.split(" ", 1)[1] - - # Strip anything after the first colon to handle tokens with appended model names - if token and ":" in token: - token = token.split(":", 1)[0] + token = _parse_bearer_token(header) # Constant-time comparison to avoid leaking the configured token via # response-time differences on a per-byte mismatch (CWE-208).