Per claude model mapping (#66)

This commit is contained in:
Ali Khokhar 2026-03-01 21:32:23 -08:00 committed by GitHub
parent 763c8b62b7
commit 0b324e0421
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 454 additions and 81 deletions

View file

@ -1,7 +1,7 @@
"""API layer for Claude Code Proxy."""
from .app import app, create_app
from .dependencies import get_provider
from .dependencies import get_provider, get_provider_for_type
from .models import (
MessagesRequest,
MessagesResponse,
@ -17,4 +17,5 @@ __all__ = [
"app",
"create_app",
"get_provider",
"get_provider_for_type",
]

View file

@ -12,8 +12,8 @@ from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
# Global provider instance (singleton)
_provider: BaseProvider | None = None
# Provider registry: keyed by provider type string, lazily populated
_providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
@ -21,9 +21,9 @@ def get_settings() -> Settings:
return _get_settings()
def _create_provider(settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance from settings."""
if settings.provider_type == "nvidia_nim":
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance for the given provider type."""
if provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise AuthenticationError(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
@ -39,8 +39,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
elif settings.provider_type == "open_router":
return NvidiaNimProvider(config, nim_settings=settings.nim)
if provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise AuthenticationError(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
@ -56,8 +56,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = OpenRouterProvider(config)
elif settings.provider_type == "lmstudio":
return OpenRouterProvider(config)
if provider_type == "lmstudio":
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
@ -68,37 +68,47 @@ def _create_provider(settings: Settings) -> BaseProvider:
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = LMStudioProvider(config)
else:
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
logger.info("Provider initialized: {}", settings.provider_type)
return provider
return LMStudioProvider(config)
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
provider_type,
)
raise ValueError(
f"Unknown provider_type: '{provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
def get_provider() -> BaseProvider:
"""Get or create the provider instance based on settings.provider_type."""
global _provider
if _provider is None:
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider for the given provider type.
Providers are cached in the registry and reused across requests.
"""
if provider_type not in _providers:
try:
_provider = _create_provider(get_settings())
_providers[provider_type] = _create_provider_for_type(
provider_type, get_settings()
)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
return _provider
logger.info("Provider initialized: {}", provider_type)
return _providers[provider_type]
def get_provider() -> BaseProvider:
"""Get or create the default provider (based on MODEL env var).
Backward-compatible convenience for health/root endpoints and tests.
"""
return get_provider_for_type(get_settings().provider_type)
async def cleanup_provider():
"""Cleanup provider resources."""
global _provider
if _provider:
await _provider.cleanup()
_provider = None
"""Cleanup all provider resources."""
global _providers
for provider in _providers.values():
await provider.cleanup()
_providers = {}
logger.debug("Provider cleanup completed")

View file

@ -6,13 +6,12 @@ from typing import Any, Literal
from loguru import logger
from pydantic import BaseModel, field_validator, model_validator
from config.settings import get_settings
from config.settings import Settings, get_settings
# =============================================================================
# Content Block Types
# =============================================================================
class Role(StrEnum):
user = "user"
assistant = "assistant"
@ -55,8 +54,6 @@ class SystemContent(BaseModel):
# =============================================================================
# Message Types
# =============================================================================
class Message(BaseModel):
role: Literal["user", "assistant"]
content: (
@ -85,8 +82,6 @@ class ThinkingConfig(BaseModel):
# =============================================================================
# Request Models
# =============================================================================
class MessagesRequest(BaseModel):
model: str
max_tokens: int | None = None
@ -103,15 +98,18 @@ class MessagesRequest(BaseModel):
thinking: ThinkingConfig | None = None
extra_body: dict[str, Any] | None = None
original_model: str | None = None
resolved_provider_model: str | None = None
@model_validator(mode="after")
def map_model(self) -> MessagesRequest:
"""Map any Claude model name to the configured model."""
"""Map any Claude model name to the configured model (tier-aware)."""
settings = get_settings()
if self.original_model is None:
self.original_model = self.model
self.model = settings.model_name
resolved_full = settings.resolve_model(self.original_model)
self.resolved_provider_model = resolved_full
self.model = Settings.parse_model_name(resolved_full)
if self.model != self.original_model:
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
@ -129,7 +127,8 @@ class TokenCountRequest(BaseModel):
@field_validator("model")
@classmethod
def validate_model_field(cls, v, info):
"""Map any Claude model name to the configured model."""
def validate_model_field(cls, v: str, info) -> str:
"""Map any Claude model name to the configured model (tier-aware)."""
settings = get_settings()
return settings.model_name
resolved_full = settings.resolve_model(v)
return Settings.parse_model_name(resolved_full)

View file

@ -9,12 +9,11 @@ from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings
from providers.base import BaseProvider
from providers.common import get_user_facing_error_message
from providers.exceptions import InvalidRequestError, ProviderError
from providers.logging_utils import build_request_summary, log_request_compact
from .dependencies import get_provider, get_settings
from .dependencies import get_provider_for_type, get_settings
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
@ -26,13 +25,10 @@ router = APIRouter()
# =============================================================================
# Routes
# =============================================================================
@router.post("/v1/messages")
async def create_message(
request_data: MessagesRequest,
raw_request: Request,
provider: BaseProvider = Depends(get_provider),
settings: Settings = Depends(get_settings),
):
"""Create a message (always streaming)."""
@ -45,6 +41,12 @@ async def create_message(
if optimized is not None:
return optimized
# Resolve provider from the tier-aware model mapping
provider_type = Settings.parse_provider_type(
request_data.resolved_provider_model or settings.model
)
provider = get_provider_for_type(provider_type)
request_id = f"req_{uuid.uuid4().hex[:12]}"
log_request_compact(logger, request_id, request_data)