diff --git a/.env.example b/.env.example index 27716f6..4ac2911 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,6 @@ # NVIDIA NIM Config NVIDIA_NIM_API_KEY="" +NVIDIA_NIM_BASE_URL="https://integrate.api.nvidia.com/v1" # OpenRouter Config @@ -26,22 +27,28 @@ OPENCODE_API_KEY="" ZAI_API_KEY="" ZAI_BASE_URL="https://api.z.ai/api/coding/paas/v4" +# OpenAI Config (official Chat Completions API) +OPENAI_API_KEY="" +OPENAI_BASE_URL="https://api.openai.com/v1" + # LM Studio Config (local provider, no API key required) LM_STUDIO_BASE_URL="http://localhost:1234/v1" -# Llama.cpp Config (local provider, no API key required) +# Llama.cpp Config (local provider, API key optional) LLAMACPP_BASE_URL="http://localhost:8080/v1" +LLAMACPP_API_KEY="" -# Ollama Config (local provider, no API key required) +# Ollama Config (local provider, API key optional) OLLAMA_BASE_URL="http://localhost:11434" +OLLAMA_API_KEY="" # All Claude model requests are mapped to these models, plain model is fallback # Format: provider_type/model/name -# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" | "ollama" | "kimi" | "wafer" | "opencode" | "zai" +# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" | "ollama" | "kimi" | "wafer" | "opencode" | "zai" | "openai" MODEL_OPUS= MODEL_SONNET= MODEL_HAIKU= @@ -60,6 +67,7 @@ FCC_SMOKE_MODEL_KIMI= FCC_SMOKE_MODEL_WAFER= FCC_SMOKE_MODEL_OPENCODE= FCC_SMOKE_MODEL_ZAI= +FCC_SMOKE_MODEL_OPENAI= FCC_SMOKE_NIM_MODELS= FCC_SMOKE_NIM_EXTRA_MODELS= FCC_SMOKE_OPENROUTER_FREE_MODELS= @@ -85,6 +93,8 @@ KIMI_PROXY="" WAFER_PROXY="" OPENCODE_PROXY="" ZAI_PROXY="" +OPENAI_PROXY="" +OLLAMA_PROXY="" PROVIDER_RATE_LIMIT=1 PROVIDER_RATE_WINDOW=3 @@ -100,6 +110,12 @@ HTTP_CONNECT_TIMEOUT=60 # Optional server API key (Anthropic-style) ANTHROPIC_AUTH_TOKEN="freecc" +# API Key Pass-through +# When true, the client's auth token is used as the upstream provider bearer +# token, eliminating the need for separate per-provider API keys. +# Requires ANTHROPIC_AUTH_TOKEN to be set (clients must authenticate). +ENABLE_API_KEY_PASSTHROUGH=false + # Messaging Platform: "telegram" | "discord" | "none" MESSAGING_PLATFORM="discord" diff --git a/api/admin_config.py b/api/admin_config.py index cb4f7ab..246d0c4 100644 --- a/api/admin_config.py +++ b/api/admin_config.py @@ -184,6 +184,51 @@ FIELDS: tuple[ConfigFieldSpec, ...] = ( default="https://api.z.ai/api/coding/paas/v4", description="Z.ai OpenAI-compatible Coding Plan endpoint.", ), + ConfigFieldSpec( + "OPENAI_API_KEY", + "OpenAI API Key", + "providers", + "secret", + settings_attr="openai_api_key", + secret=True, + description="OpenAI API key for direct Chat Completions access.", + ), + ConfigFieldSpec( + "OPENAI_BASE_URL", + "OpenAI Base URL", + "providers", + settings_attr="openai_base_url", + default="https://api.openai.com/v1", + description="OpenAI-compatible Chat Completions endpoint.", + ), + ConfigFieldSpec( + "NVIDIA_NIM_BASE_URL", + "NVIDIA NIM Base URL", + "providers", + settings_attr="nvidia_nim_base_url", + default="https://integrate.api.nvidia.com/v1", + description="NVIDIA NIM API endpoint override.", + ), + ConfigFieldSpec( + "LLAMACPP_API_KEY", + "llama.cpp API Key", + "providers", + "secret", + settings_attr="llamacpp_api_key", + secret=True, + advanced=True, + description="Optional API key for llama.cpp server authentication.", + ), + ConfigFieldSpec( + "OLLAMA_API_KEY", + "Ollama API Key", + "providers", + "secret", + settings_attr="ollama_api_key", + secret=True, + advanced=True, + description="Optional API key for Ollama server authentication.", + ), ConfigFieldSpec( "LM_STUDIO_BASE_URL", "LM Studio Base URL", @@ -277,6 +322,24 @@ FIELDS: tuple[ConfigFieldSpec, ...] = ( secret=True, advanced=True, ), + ConfigFieldSpec( + "OPENAI_PROXY", + "OpenAI Proxy", + "providers", + "secret", + settings_attr="openai_proxy", + secret=True, + advanced=True, + ), + ConfigFieldSpec( + "OLLAMA_PROXY", + "Ollama Proxy", + "providers", + "secret", + settings_attr="ollama_proxy", + secret=True, + advanced=True, + ), ConfigFieldSpec( "MODEL", "Default Model", diff --git a/api/dependencies.py b/api/dependencies.py index fbc37b0..ded5aae 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -33,6 +33,7 @@ def resolve_provider( *, app: Starlette | None, settings: Settings, + passthrough_api_key: str = "", ) -> BaseProvider: """Resolve a provider using the app-scoped registry when ``app`` is set. @@ -43,6 +44,10 @@ def resolve_provider( When ``app`` is ``None`` (no HTTP context), uses the process-level :data:`_providers` cache only. + + When ``passthrough_api_key`` is non-empty and + ``settings.enable_api_key_passthrough`` is True, the client's auth token + replaces per-provider credentials in the upstream request. """ if app is not None: reg = getattr(app.state, "provider_registry", None) @@ -51,16 +56,27 @@ def resolve_provider( "Provider registry is not configured. Ensure AppRuntime startup ran " "or assign app.state.provider_registry for test apps." ) - return _resolve_with_registry(reg, provider_type, settings) + return _resolve_with_registry( + reg, provider_type, settings, passthrough_api_key=passthrough_api_key + ) return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings) def _resolve_with_registry( - registry: ProviderRegistry, provider_type: str, settings: Settings + registry: ProviderRegistry, + provider_type: str, + settings: Settings, + *, + passthrough_api_key: str = "", ) -> BaseProvider: - should_log_init = not registry.is_cached(provider_type) + cache_key = provider_type + if passthrough_api_key and settings.enable_api_key_passthrough: + cache_key = f"{provider_type}:{passthrough_api_key}" + should_log_init = not registry.is_cached(cache_key) try: - provider = registry.get(provider_type, settings) + provider = registry.get( + provider_type, settings, passthrough_api_key=passthrough_api_key + ) except AuthenticationError as e: # Provider :class:`~providers.exceptions.AuthenticationError` messages are # curated configuration hints (env var names, docs links), not upstream noise. @@ -88,18 +104,46 @@ def get_provider_for_type(provider_type: str) -> BaseProvider: return resolve_provider(provider_type, app=None, settings=get_settings()) +def _parse_bearer_token(header: str) -> str: + """Extract and clean a token from an Authorization or x-api-key header value.""" + token = header + if header.lower().startswith("bearer "): + token = header.split(" ", 1)[1] + if token and ":" in token: + token = token.split(":", 1)[0] + return token + + +def _extract_client_token(request: Request) -> str: + """Extract a client auth token from the request for passthrough mode.""" + header = ( + request.headers.get("x-api-key") + or request.headers.get("authorization") + or request.headers.get("anthropic-auth-token") + ) + if not header: + raise HTTPException(status_code=401, detail="Missing API key") + return _parse_bearer_token(header) + + def require_api_key( request: Request, settings: Settings = Depends(get_settings) -) -> None: +) -> str: """Require a server API key (Anthropic-style). Checks `x-api-key` header or `Authorization: Bearer ...` against `Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op. + + Returns the extracted client token for pass-through when + ``ENABLE_API_KEY_PASSTHROUGH`` is active. """ anthropic_auth_token = settings.anthropic_auth_token if not anthropic_auth_token: - # No API key configured -> allow - return + # No server auth configured. In passthrough mode, still extract the + # client token from the request header so it can be forwarded upstream. + if settings.enable_api_key_passthrough: + return _extract_client_token(request) + return "" header = ( request.headers.get("x-api-key") @@ -109,14 +153,7 @@ def require_api_key( if not header: raise HTTPException(status_code=401, detail="Missing API key") - # Support both raw key in X-API-Key and Bearer token in Authorization - token = header - if header.lower().startswith("bearer "): - token = header.split(" ", 1)[1] - - # Strip anything after the first colon to handle tokens with appended model names - if token and ":" in token: - token = token.split(":", 1)[0] + token = _parse_bearer_token(header) # Constant-time comparison to avoid leaking the configured token via # response-time differences on a per-byte mismatch (CWE-208). @@ -125,6 +162,8 @@ def require_api_key( ): raise HTTPException(status_code=401, detail="Invalid API key") + return token + def get_provider() -> BaseProvider: """Get or create the default provider (``MODEL`` / ``provider_type``). diff --git a/api/routes.py b/api/routes.py index 6049494..654212a 100644 --- a/api/routes.py +++ b/api/routes.py @@ -62,12 +62,19 @@ SUPPORTED_CLAUDE_MODELS = [ def get_proxy_service( request: Request, settings: Settings = Depends(get_settings), + passthrough_api_key: str = Depends(require_api_key), ) -> ClaudeProxyService: """Build the request service for route handlers.""" + effective_passthrough = ( + passthrough_api_key if settings.enable_api_key_passthrough else "" + ) return ClaudeProxyService( settings, provider_getter=lambda provider_type: dependencies.resolve_provider( - provider_type, app=request.app, settings=settings + provider_type, + app=request.app, + settings=settings, + passthrough_api_key=effective_passthrough, ), token_counter=get_token_count, ) @@ -167,7 +174,6 @@ def _build_models_list_response( async def create_message( request_data: MessagesRequest, service: ClaudeProxyService = Depends(get_proxy_service), - _auth=Depends(require_api_key), ): """Create a message (always streaming).""" return service.create_message(request_data) @@ -183,7 +189,6 @@ async def probe_messages(_auth=Depends(require_api_key)): async def count_tokens( request_data: TokenCountRequest, service: ClaudeProxyService = Depends(get_proxy_service), - _auth=Depends(require_api_key), ): """Count tokens for a request.""" return service.count_tokens(request_data) diff --git a/api/services.py b/api/services.py index 6566248..1f648bc 100644 --- a/api/services.py +++ b/api/services.py @@ -34,7 +34,7 @@ TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], i ProviderGetter = Callable[[str], BaseProvider] # Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). -_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "zai"}) +_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "zai", "openai"}) def anthropic_sse_streaming_response( diff --git a/config/provider_catalog.py b/config/provider_catalog.py index e2a53c8..45d8e53 100644 --- a/config/provider_catalog.py +++ b/config/provider_catalog.py @@ -25,6 +25,7 @@ LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1" OLLAMA_DEFAULT_BASE = "http://localhost:11434" OPENCODE_DEFAULT_BASE = "https://opencode.ai/zen/v1" ZAI_DEFAULT_BASE = "https://api.z.ai/api/coding/paas/v4" +OPENAI_DEFAULT_BASE = "https://api.openai.com/v1" @dataclass(frozen=True, slots=True) @@ -51,6 +52,7 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { credential_url="https://build.nvidia.com/settings/api-keys", credential_attr="nvidia_nim_api_key", default_base_url=NVIDIA_NIM_DEFAULT_BASE, + base_url_attr="nvidia_nim_base_url", proxy_attr="nvidia_nim_proxy", capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), ), @@ -85,7 +87,8 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { "llamacpp": ProviderDescriptor( provider_id="llamacpp", transport_type="anthropic_messages", - static_credential="llamacpp", + credential_env=None, + credential_attr="llamacpp_api_key", default_base_url=LLAMACPP_DEFAULT_BASE, base_url_attr="llamacpp_base_url", proxy_attr="llamacpp_proxy", @@ -94,9 +97,11 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { "ollama": ProviderDescriptor( provider_id="ollama", transport_type="anthropic_messages", - static_credential="ollama", + credential_env=None, + credential_attr="ollama_api_key", default_base_url=OLLAMA_DEFAULT_BASE, base_url_attr="ollama_base_url", + proxy_attr="ollama_proxy", capabilities=( "chat", "streaming", @@ -146,6 +151,17 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = { proxy_attr="zai_proxy", capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), ), + "openai": ProviderDescriptor( + provider_id="openai", + transport_type="openai_chat", + credential_env="OPENAI_API_KEY", + credential_url="https://platform.openai.com/api-keys", + credential_attr="openai_api_key", + default_base_url=OPENAI_DEFAULT_BASE, + base_url_attr="openai_base_url", + proxy_attr="openai_proxy", + capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"), + ), } # Order matches docs / historical error text; must match PROVIDER_CATALOG keys. diff --git a/config/settings.py b/config/settings.py index e451761..7f799f6 100644 --- a/config/settings.py +++ b/config/settings.py @@ -129,6 +129,14 @@ class Settings(BaseSettings): validation_alias="ZAI_BASE_URL", ) + # ==================== OpenAI Config ==================== + openai_api_key: str = Field(default="", validation_alias="OPENAI_API_KEY") + openai_base_url: str = Field( + default="https://api.openai.com/v1", + validation_alias="OPENAI_BASE_URL", + ) + openai_proxy: str = Field(default="", validation_alias="OPENAI_PROXY") + # ==================== Messaging Platform Selection ==================== # Valid: "telegram" | "discord" | "none" messaging_platform: str = Field( @@ -143,6 +151,10 @@ class Settings(BaseSettings): # ==================== NVIDIA NIM Config ==================== nvidia_nim_api_key: str = "" + nvidia_nim_base_url: str = Field( + default="https://integrate.api.nvidia.com/v1", + validation_alias="NVIDIA_NIM_BASE_URL", + ) # ==================== LM Studio Config ==================== lm_studio_base_url: str = Field( @@ -155,12 +167,15 @@ class Settings(BaseSettings): default="http://localhost:8080/v1", validation_alias="LLAMACPP_BASE_URL", ) + llamacpp_api_key: str = Field(default="", validation_alias="LLAMACPP_API_KEY") # ==================== Ollama Config ==================== ollama_base_url: str = Field( default="http://localhost:11434", validation_alias="OLLAMA_BASE_URL", ) + ollama_api_key: str = Field(default="", validation_alias="OLLAMA_API_KEY") + ollama_proxy: str = Field(default="", validation_alias="OLLAMA_PROXY") # ==================== Model ==================== # All Claude model requests are mapped to this single model (fallback) @@ -315,6 +330,12 @@ class Settings(BaseSettings): anthropic_auth_token: str = Field( default="", validation_alias="ANTHROPIC_AUTH_TOKEN" ) + # When true, the client's auth token is passed through as the upstream + # provider bearer token, eliminating the need for separate per-provider + # API keys. Requires ANTHROPIC_AUTH_TOKEN to be set. + enable_api_key_passthrough: bool = Field( + default=False, validation_alias="ENABLE_API_KEY_PASSTHROUGH" + ) @model_validator(mode="before") @classmethod @@ -432,6 +453,8 @@ class Settings(BaseSettings): @model_validator(mode="after") def check_nvidia_nim_api_key(self) -> Settings: + if self.enable_api_key_passthrough: + return self if ( self.voice_note_enabled and self.whisper_device == "nvidia_nim" diff --git a/providers/defaults.py b/providers/defaults.py index a0ec814..16155a6 100644 --- a/providers/defaults.py +++ b/providers/defaults.py @@ -8,6 +8,7 @@ from config.provider_catalog import ( LMSTUDIO_DEFAULT_BASE, NVIDIA_NIM_DEFAULT_BASE, OLLAMA_DEFAULT_BASE, + OPENAI_DEFAULT_BASE, OPENCODE_DEFAULT_BASE, OPENROUTER_DEFAULT_BASE, WAFER_DEFAULT_BASE, @@ -22,6 +23,7 @@ __all__ = ( "LMSTUDIO_DEFAULT_BASE", "NVIDIA_NIM_DEFAULT_BASE", "OLLAMA_DEFAULT_BASE", + "OPENAI_DEFAULT_BASE", "OPENCODE_DEFAULT_BASE", "OPENROUTER_DEFAULT_BASE", "WAFER_DEFAULT_BASE", diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py index 891022d..ea427f7 100644 --- a/providers/llamacpp/client.py +++ b/providers/llamacpp/client.py @@ -14,3 +14,15 @@ class LlamaCppProvider(AnthropicMessagesTransport): provider_name="LLAMACPP", default_base_url=LLAMACPP_DEFAULT_BASE, ) + self._api_key = config.api_key or "llamacpp" + + def _request_headers(self) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self._api_key and self._api_key != "llamacpp": + headers["Authorization"] = f"Bearer {self._api_key}" + return headers + + def _model_list_headers(self) -> dict[str, str]: + if self._api_key and self._api_key != "llamacpp": + return {"Authorization": f"Bearer {self._api_key}"} + return {} diff --git a/providers/ollama/client.py b/providers/ollama/client.py index 86fa542..e5edf14 100644 --- a/providers/ollama/client.py +++ b/providers/ollama/client.py @@ -19,6 +19,17 @@ class OllamaProvider(AnthropicMessagesTransport): ) self._api_key = config.api_key or "ollama" + def _request_headers(self) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self._api_key and self._api_key != "ollama": + headers["Authorization"] = f"Bearer {self._api_key}" + return headers + + def _model_list_headers(self) -> dict[str, str]: + if self._api_key and self._api_key != "ollama": + return {"Authorization": f"Bearer {self._api_key}"} + return {} + async def _send_stream_request(self, body: dict) -> httpx.Response: """Create a streaming native Anthropic messages response.""" request = self._client.build_request( diff --git a/providers/openai/__init__.py b/providers/openai/__init__.py new file mode 100644 index 0000000..8b30d46 --- /dev/null +++ b/providers/openai/__init__.py @@ -0,0 +1,10 @@ +"""OpenAI provider exports.""" + +from providers.defaults import OPENAI_DEFAULT_BASE + +from .client import OpenAIProvider + +__all__ = [ + "OPENAI_DEFAULT_BASE", + "OpenAIProvider", +] diff --git a/providers/openai/client.py b/providers/openai/client.py new file mode 100644 index 0000000..4dc5284 --- /dev/null +++ b/providers/openai/client.py @@ -0,0 +1,31 @@ +"""OpenAI provider implementation.""" + +from __future__ import annotations + +from typing import Any + +from providers.base import ProviderConfig +from providers.defaults import OPENAI_DEFAULT_BASE +from providers.openai_compat import OpenAIChatTransport + +from .request import build_request_body + + +class OpenAIProvider(OpenAIChatTransport): + """OpenAI provider using the official Chat Completions API.""" + + def __init__(self, config: ProviderConfig): + super().__init__( + config, + provider_name="OPENAI", + base_url=config.base_url or OPENAI_DEFAULT_BASE, + api_key=config.api_key, + ) + + def _build_request_body( + self, request: Any, thinking_enabled: bool | None = None + ) -> dict: + return build_request_body( + request, + thinking_enabled=self._is_thinking_enabled(request, thinking_enabled), + ) diff --git a/providers/openai/request.py b/providers/openai/request.py new file mode 100644 index 0000000..cdadb17 --- /dev/null +++ b/providers/openai/request.py @@ -0,0 +1,35 @@ +"""Request builder for OpenAI provider.""" + +from typing import Any + +from loguru import logger + +from core.anthropic import ReasoningReplayMode, build_base_request_body +from core.anthropic.conversion import OpenAIConversionError +from providers.exceptions import InvalidRequestError + + +def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: + """Build OpenAI-format request body from Anthropic request for OpenAI.""" + logger.debug( + "OPENAI_REQUEST: conversion start model={} msgs={}", + getattr(request_data, "model", "?"), + len(getattr(request_data, "messages", [])), + ) + try: + body = build_base_request_body( + request_data, + reasoning_replay=ReasoningReplayMode.REASONING_CONTENT + if thinking_enabled + else ReasoningReplayMode.DISABLED, + ) + except OpenAIConversionError as exc: + raise InvalidRequestError(str(exc)) from exc + + logger.debug( + "OPENAI_REQUEST: conversion done model={} msgs={} tools={}", + body.get("model"), + len(body.get("messages", [])), + len(body.get("tools", [])), + ) + return body diff --git a/providers/registry.py b/providers/registry.py index 4145702..f73718c 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -92,6 +92,12 @@ def _create_zai(config: ProviderConfig, _settings: Settings) -> BaseProvider: return ZaiProvider(config) +def _create_openai(config: ProviderConfig, _settings: Settings) -> BaseProvider: + from providers.openai import OpenAIProvider + + return OpenAIProvider(config) + + PROVIDER_FACTORIES: dict[str, ProviderFactory] = { "nvidia_nim": _create_nvidia_nim, "open_router": _create_open_router, @@ -103,6 +109,7 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = { "wafer": _create_wafer, "opencode": _create_opencode, "zai": _create_zai, + "openai": _create_openai, } if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set( @@ -130,11 +137,22 @@ def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str: return "" -def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None: +def _require_credential( + descriptor: ProviderDescriptor, + credential: str, + *, + passthrough_enabled: bool = False, +) -> None: if descriptor.credential_env is None: return if credential and credential.strip(): return + if passthrough_enabled: + raise AuthenticationError( + "ENABLE_API_KEY_PASSTHROUGH is active but no client auth token was " + "provided. Ensure the request includes a valid Authorization or " + "x-api-key header." + ) message = f"{descriptor.credential_env} is not set. Add it to your .env file." if descriptor.credential_url: message = f"{message} Get a key at {descriptor.credential_url}" @@ -142,10 +160,19 @@ def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None def build_provider_config( - descriptor: ProviderDescriptor, settings: Settings + descriptor: ProviderDescriptor, + settings: Settings, + *, + passthrough_api_key: str = "", ) -> ProviderConfig: - credential = _credential_for(descriptor, settings) - _require_credential(descriptor, credential) + passthrough_enabled = settings.enable_api_key_passthrough + if passthrough_api_key and passthrough_enabled: + credential = passthrough_api_key + elif passthrough_enabled: + credential = "" + else: + credential = _credential_for(descriptor, settings) + _require_credential(descriptor, credential, passthrough_enabled=passthrough_enabled) base_url = _string_attr( settings, descriptor.base_url_attr, descriptor.default_base_url or "" ) @@ -166,7 +193,12 @@ def build_provider_config( ) -def create_provider(provider_id: str, settings: Settings) -> BaseProvider: +def create_provider( + provider_id: str, + settings: Settings, + *, + passthrough_api_key: str = "", +) -> BaseProvider: descriptor = PROVIDER_DESCRIPTORS.get(provider_id) if descriptor is None: supported = "', '".join(PROVIDER_DESCRIPTORS) @@ -174,7 +206,9 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider: f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'" ) - config = build_provider_config(descriptor, settings) + config = build_provider_config( + descriptor, settings, passthrough_api_key=passthrough_api_key + ) factory = PROVIDER_FACTORIES.get(provider_id) if factory is None: raise AssertionError(f"Unhandled provider descriptor: {provider_id}") @@ -260,10 +294,23 @@ class ProviderRegistry: """Return whether a provider for this id is already in the cache.""" return provider_id in self._providers - def get(self, provider_id: str, settings: Settings) -> BaseProvider: - if provider_id not in self._providers: - self._providers[provider_id] = create_provider(provider_id, settings) - return self._providers[provider_id] + def get( + self, + provider_id: str, + settings: Settings, + *, + passthrough_api_key: str = "", + ) -> BaseProvider: + # When pass-through is active, include the API key in the cache key + # so different client tokens get distinct provider instances. + cache_key = provider_id + if passthrough_api_key and settings.enable_api_key_passthrough: + cache_key = f"{provider_id}:{passthrough_api_key}" + if cache_key not in self._providers: + self._providers[cache_key] = create_provider( + provider_id, settings, passthrough_api_key=passthrough_api_key + ) + return self._providers[cache_key] def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None: """Store a provider model-list result for later instant API responses.""" diff --git a/smoke/lib/config.py b/smoke/lib/config.py index f69267c..7309d84 100644 --- a/smoke/lib/config.py +++ b/smoke/lib/config.py @@ -51,6 +51,7 @@ PROVIDER_SMOKE_DEFAULT_MODELS: dict[str, str] = { "wafer": "wafer/DeepSeek-V4-Pro", "opencode": "opencode/gpt-5.3-codex", "zai": "zai/glm-5.1", + "openai": "openai/gpt-4o", } NVIDIA_NIM_CLI_DEFAULT_MODELS: tuple[str, ...] = ( @@ -237,6 +238,8 @@ class SmokeConfig: return bool(self.settings.opencode_api_key.strip()) if provider == "zai": return bool(self.settings.zai_api_key.strip()) + if provider == "openai": + return bool(self.settings.openai_api_key.strip()) return False diff --git a/tests/api/test_api_key_passthrough.py b/tests/api/test_api_key_passthrough.py new file mode 100644 index 0000000..3474978 --- /dev/null +++ b/tests/api/test_api_key_passthrough.py @@ -0,0 +1,284 @@ +"""Tests for API key pass-through (ENABLE_API_KEY_PASSTHROUGH).""" + +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from starlette.applications import Starlette +from starlette.datastructures import State + +from api.dependencies import ( + require_api_key, + resolve_provider, +) +from config.provider_catalog import PROVIDER_CATALOG +from providers.nvidia_nim import NvidiaNimProvider +from providers.registry import ProviderRegistry, build_provider_config, create_provider + + +def _make_settings(**overrides): + """Create a mock settings object with all required fields.""" + mock = MagicMock() + mock.model = "nvidia_nim/meta/llama3" + mock.provider_type = "nvidia_nim" + mock.nvidia_nim_api_key = "test_key" + mock.open_router_api_key = "test_openrouter_key" + mock.deepseek_api_key = "test_deepseek_key" + mock.wafer_api_key = "test_wafer_key" + mock.opencode_api_key = "test_opencode_key" + mock.zai_api_key = "test_zai_key" + mock.openai_api_key = "test_openai_key" + mock.kimi_api_key = "test_kimi_key" + mock.lm_studio_base_url = "http://localhost:1234/v1" + mock.llamacpp_base_url = "http://localhost:8080/v1" + mock.ollama_base_url = "http://localhost:11434" + mock.llamacpp_api_key = "" + mock.ollama_api_key = "" + mock.nvidia_nim_proxy = "" + mock.open_router_proxy = "" + mock.lmstudio_proxy = "" + mock.llamacpp_proxy = "" + mock.kimi_proxy = "" + mock.wafer_proxy = "" + mock.opencode_proxy = "" + mock.zai_proxy = "" + mock.openai_proxy = "" + mock.provider_rate_limit = 40 + mock.provider_rate_window = 60 + mock.provider_max_concurrency = 5 + mock.http_read_timeout = 300.0 + mock.http_write_timeout = 10.0 + mock.http_connect_timeout = 10.0 + mock.enable_model_thinking = True + mock.log_raw_sse_events = False + mock.log_api_error_tracebacks = False + mock.enable_api_key_passthrough = False + mock.anthropic_auth_token = "" + for key, value in overrides.items(): + setattr(mock, key, value) + return mock + + +class TestBuildProviderConfigPassthrough: + """Tests for build_provider_config with passthrough_api_key.""" + + def test_passthrough_disabled_uses_configured_credential(self): + """When passthrough is disabled, the configured provider key is used.""" + settings = _make_settings(enable_api_key_passthrough=False) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config(descriptor, settings) + assert config.api_key == "test_key" + + def test_passthrough_enabled_uses_passthrough_key(self): + """When passthrough is enabled and key provided, it overrides credential.""" + settings = _make_settings(enable_api_key_passthrough=True) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-passthrough-123" + ) + assert config.api_key == "sk-passthrough-123" + + def test_passthrough_enabled_empty_key_raises(self): + """When passthrough is enabled but key is empty, raises AuthenticationError.""" + from providers.exceptions import AuthenticationError + + settings = _make_settings(enable_api_key_passthrough=True) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + with pytest.raises(AuthenticationError, match="ENABLE_API_KEY_PASSTHROUGH"): + build_provider_config(descriptor, settings, passthrough_api_key="") + + def test_passthrough_enabled_key_present_but_passthrough_disabled(self): + """Passthrough key is ignored when enable_api_key_passthrough is False.""" + settings = _make_settings(enable_api_key_passthrough=False) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-ignored" + ) + assert config.api_key == "test_key" + + def test_passthrough_skips_credential_validation(self): + """With passthrough active, missing per-provider key does not raise.""" + settings = _make_settings( + enable_api_key_passthrough=True, nvidia_nim_api_key="" + ) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-valid" + ) + assert config.api_key == "sk-valid" + + def test_passthrough_disabled_missing_key_raises(self): + """Without passthrough, missing required credential still raises.""" + from providers.exceptions import AuthenticationError + + settings = _make_settings( + enable_api_key_passthrough=False, nvidia_nim_api_key="" + ) + descriptor = PROVIDER_CATALOG["nvidia_nim"] + with pytest.raises(AuthenticationError, match="NVIDIA_NIM_API_KEY"): + build_provider_config(descriptor, settings) + + def test_passthrough_with_static_credential_provider(self): + """Passthrough overrides even static credentials (e.g. lmstudio).""" + settings = _make_settings(enable_api_key_passthrough=True) + descriptor = PROVIDER_CATALOG["lmstudio"] + config = build_provider_config( + descriptor, settings, passthrough_api_key="sk-lmstudio-passthrough" + ) + assert config.api_key == "sk-lmstudio-passthrough" + + +class TestCreateProviderPassthrough: + """Tests for create_provider with passthrough_api_key.""" + + def test_create_provider_with_passthrough(self): + """create_provider passes passthrough key through to provider config.""" + settings = _make_settings(enable_api_key_passthrough=True) + with patch("providers.openai_compat.AsyncOpenAI"): + provider = create_provider( + "nvidia_nim", settings, passthrough_api_key="sk-nim-passthrough" + ) + assert isinstance(provider, NvidiaNimProvider) + assert provider._api_key == "sk-nim-passthrough" + + def test_create_provider_without_passthrough(self): + """create_provider uses configured key when passthrough is empty.""" + settings = _make_settings(enable_api_key_passthrough=False) + with patch("providers.openai_compat.AsyncOpenAI"): + provider = create_provider("nvidia_nim", settings) + assert isinstance(provider, NvidiaNimProvider) + assert provider._api_key == "test_key" + + +class TestProviderRegistryPassthrough: + """Tests for ProviderRegistry.get with passthrough_api_key.""" + + def test_different_passthrough_keys_get_distinct_instances(self): + """Different passthrough keys produce separate cached provider instances.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with patch("providers.openai_compat.AsyncOpenAI"): + p1 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-key-1") + p2 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-key-2") + + assert p1 is not p2 + assert isinstance(p1, NvidiaNimProvider) + assert isinstance(p2, NvidiaNimProvider) + assert p1._api_key == "sk-key-1" + assert p2._api_key == "sk-key-2" + + def test_same_passthrough_key_returns_cached_instance(self): + """Same passthrough key returns the cached provider instance.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with patch("providers.openai_compat.AsyncOpenAI"): + p1 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-same") + p2 = registry.get("nvidia_nim", settings, passthrough_api_key="sk-same") + + assert p1 is p2 + + def test_no_passthrough_caches_by_provider_id(self): + """Without passthrough, caching works by provider_id as before.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=False) + + with patch("providers.openai_compat.AsyncOpenAI"): + p1 = registry.get("nvidia_nim", settings) + p2 = registry.get("nvidia_nim", settings) + + assert p1 is p2 + + def test_passthrough_empty_key_raises_in_registry(self): + """A passthrough-enabled registry raises when no client token given.""" + from providers.exceptions import AuthenticationError + + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with ( + pytest.raises(AuthenticationError, match="ENABLE_API_KEY_PASSTHROUGH"), + patch("providers.openai_compat.AsyncOpenAI"), + ): + registry.get("nvidia_nim", settings) + + def test_different_passthrough_keys_are_distinct(self): + """Providers with different passthrough keys are separate instances.""" + registry = ProviderRegistry() + settings = _make_settings(enable_api_key_passthrough=True) + + with patch("providers.openai_compat.AsyncOpenAI"): + p_a = registry.get("nvidia_nim", settings, passthrough_api_key="sk-aaa") + p_b = registry.get("nvidia_nim", settings, passthrough_api_key="sk-bbb") + + assert p_a is not p_b + + +class TestRequireApiKeyReturnsToken: + """Tests that require_api_key returns the extracted client token.""" + + def test_returns_token_on_valid_x_api_key(self): + """Valid x-api-key header returns the extracted token.""" + request = MagicMock() + request.headers = {"x-api-key": "my-secret-key"} + settings = _make_settings(anthropic_auth_token="my-secret-key") + token = require_api_key(request, settings) + assert token == "my-secret-key" + + def test_returns_token_on_valid_bearer(self): + """Valid Bearer authorization returns the token (without Bearer prefix).""" + request = MagicMock() + request.headers = {"authorization": "Bearer my-bearer-key"} + settings = _make_settings(anthropic_auth_token="my-bearer-key") + token = require_api_key(request, settings) + assert token == "my-bearer-key" + + def test_returns_empty_string_when_no_auth_configured(self): + """When ANTHROPIC_AUTH_TOKEN is empty, returns empty string.""" + request = MagicMock() + request.headers = {} + settings = _make_settings(anthropic_auth_token="") + token = require_api_key(request, settings) + assert token == "" + + def test_raises_401_on_invalid_key(self): + """Invalid API key raises 401.""" + request = MagicMock() + request.headers = {"x-api-key": "wrong-key"} + settings = _make_settings(anthropic_auth_token="correct-key") + with pytest.raises(HTTPException) as exc_info: + require_api_key(request, settings) + assert exc_info.value.status_code == 401 + + def test_strips_colon_suffix(self): + """Token with colon suffix (model name append) gets stripped.""" + request = MagicMock() + request.headers = {"x-api-key": "my-key:claude-3-opus"} + settings = _make_settings(anthropic_auth_token="my-key") + token = require_api_key(request, settings) + assert token == "my-key" + + +class TestResolveProviderPassthrough: + """Tests for resolve_provider with passthrough_api_key.""" + + def test_resolve_provider_passes_passthrough_key(self): + """resolve_provider forwards passthrough_api_key to the registry.""" + settings = _make_settings(enable_api_key_passthrough=True) + app = SimpleNamespace(state=State()) + registry = ProviderRegistry() + app.state.provider_registry = registry + + with patch("providers.openai_compat.AsyncOpenAI"): + provider = resolve_provider( + "nvidia_nim", + app=cast(Starlette, app), + settings=settings, + passthrough_api_key="sk-resolve-test", + ) + + assert isinstance(provider, NvidiaNimProvider) + assert provider._api_key == "sk-resolve-test" diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py index 0264b8f..06786cf 100644 --- a/tests/api/test_dependencies.py +++ b/tests/api/test_dependencies.py @@ -52,6 +52,8 @@ def _make_mock_settings(**overrides): mock.http_write_timeout = 10.0 mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True + mock.enable_api_key_passthrough = False + mock.anthropic_auth_token = "" for key, value in overrides.items(): setattr(mock, key, value) return mock diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 2d24b25..2f1ecf1 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -51,6 +51,8 @@ def _make_settings(**overrides): mock.http_write_timeout = 10.0 mock.http_connect_timeout = 10.0 mock.enable_model_thinking = True + mock.enable_api_key_passthrough = False + mock.anthropic_auth_token = "" mock.nim = NimSettings() for key, value in overrides.items(): setattr(mock, key, value)