mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Per claude model mapping (#66)
This commit is contained in:
parent
763c8b62b7
commit
0b324e0421
15 changed files with 454 additions and 81 deletions
|
|
@ -1,7 +1,7 @@
|
|||
"""API layer for Claude Code Proxy."""
|
||||
|
||||
from .app import app, create_app
|
||||
from .dependencies import get_provider
|
||||
from .dependencies import get_provider, get_provider_for_type
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
|
|
@ -17,4 +17,5 @@ __all__ = [
|
|||
"app",
|
||||
"create_app",
|
||||
"get_provider",
|
||||
"get_provider_for_type",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from providers.lmstudio import LMStudioProvider
|
|||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||
|
||||
# Global provider instance (singleton)
|
||||
_provider: BaseProvider | None = None
|
||||
# Provider registry: keyed by provider type string, lazily populated
|
||||
_providers: dict[str, BaseProvider] = {}
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
|
|
@ -21,9 +21,9 @@ def get_settings() -> Settings:
|
|||
return _get_settings()
|
||||
|
||||
|
||||
def _create_provider(settings: Settings) -> BaseProvider:
|
||||
"""Construct and return a new provider instance from settings."""
|
||||
if settings.provider_type == "nvidia_nim":
|
||||
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
|
||||
"""Construct and return a new provider instance for the given provider type."""
|
||||
if provider_type == "nvidia_nim":
|
||||
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
|
|
@ -39,8 +39,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
elif settings.provider_type == "open_router":
|
||||
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
if provider_type == "open_router":
|
||||
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
|
|
@ -56,8 +56,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = OpenRouterProvider(config)
|
||||
elif settings.provider_type == "lmstudio":
|
||||
return OpenRouterProvider(config)
|
||||
if provider_type == "lmstudio":
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url=settings.lm_studio_base_url,
|
||||
|
|
@ -68,37 +68,47 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = LMStudioProvider(config)
|
||||
else:
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
settings.provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{settings.provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
logger.info("Provider initialized: {}", settings.provider_type)
|
||||
return provider
|
||||
return LMStudioProvider(config)
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the provider instance based on settings.provider_type."""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider for the given provider type.
|
||||
|
||||
Providers are cached in the registry and reused across requests.
|
||||
"""
|
||||
if provider_type not in _providers:
|
||||
try:
|
||||
_provider = _create_provider(get_settings())
|
||||
_providers[provider_type] = _create_provider_for_type(
|
||||
provider_type, get_settings()
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
return _provider
|
||||
logger.info("Provider initialized: {}", provider_type)
|
||||
return _providers[provider_type]
|
||||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the default provider (based on MODEL env var).
|
||||
|
||||
Backward-compatible convenience for health/root endpoints and tests.
|
||||
"""
|
||||
return get_provider_for_type(get_settings().provider_type)
|
||||
|
||||
|
||||
async def cleanup_provider():
|
||||
"""Cleanup provider resources."""
|
||||
global _provider
|
||||
if _provider:
|
||||
await _provider.cleanup()
|
||||
_provider = None
|
||||
"""Cleanup all provider resources."""
|
||||
global _providers
|
||||
for provider in _providers.values():
|
||||
await provider.cleanup()
|
||||
_providers = {}
|
||||
logger.debug("Provider cleanup completed")
|
||||
|
|
|
|||
|
|
@ -6,13 +6,12 @@ from typing import Any, Literal
|
|||
from loguru import logger
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from config.settings import get_settings
|
||||
from config.settings import Settings, get_settings
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content Block Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
|
|
@ -55,8 +54,6 @@ class SystemContent(BaseModel):
|
|||
# =============================================================================
|
||||
# Message Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal["user", "assistant"]
|
||||
content: (
|
||||
|
|
@ -85,8 +82,6 @@ class ThinkingConfig(BaseModel):
|
|||
# =============================================================================
|
||||
# Request Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
model: str
|
||||
max_tokens: int | None = None
|
||||
|
|
@ -103,15 +98,18 @@ class MessagesRequest(BaseModel):
|
|||
thinking: ThinkingConfig | None = None
|
||||
extra_body: dict[str, Any] | None = None
|
||||
original_model: str | None = None
|
||||
resolved_provider_model: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def map_model(self) -> MessagesRequest:
|
||||
"""Map any Claude model name to the configured model."""
|
||||
"""Map any Claude model name to the configured model (tier-aware)."""
|
||||
settings = get_settings()
|
||||
if self.original_model is None:
|
||||
self.original_model = self.model
|
||||
|
||||
self.model = settings.model_name
|
||||
resolved_full = settings.resolve_model(self.original_model)
|
||||
self.resolved_provider_model = resolved_full
|
||||
self.model = Settings.parse_model_name(resolved_full)
|
||||
|
||||
if self.model != self.original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||
|
|
@ -129,7 +127,8 @@ class TokenCountRequest(BaseModel):
|
|||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model_field(cls, v, info):
|
||||
"""Map any Claude model name to the configured model."""
|
||||
def validate_model_field(cls, v: str, info) -> str:
|
||||
"""Map any Claude model name to the configured model (tier-aware)."""
|
||||
settings = get_settings()
|
||||
return settings.model_name
|
||||
resolved_full = settings.resolve_model(v)
|
||||
return Settings.parse_model_name(resolved_full)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,11 @@ from fastapi.responses import StreamingResponse
|
|||
from loguru import logger
|
||||
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
from providers.logging_utils import build_request_summary, log_request_compact
|
||||
|
||||
from .dependencies import get_provider, get_settings
|
||||
from .dependencies import get_provider_for_type, get_settings
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
|
|
@ -26,13 +25,10 @@ router = APIRouter()
|
|||
# =============================================================================
|
||||
# Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(
|
||||
request_data: MessagesRequest,
|
||||
raw_request: Request,
|
||||
provider: BaseProvider = Depends(get_provider),
|
||||
settings: Settings = Depends(get_settings),
|
||||
):
|
||||
"""Create a message (always streaming)."""
|
||||
|
|
@ -45,6 +41,12 @@ async def create_message(
|
|||
if optimized is not None:
|
||||
return optimized
|
||||
|
||||
# Resolve provider from the tier-aware model mapping
|
||||
provider_type = Settings.parse_provider_type(
|
||||
request_data.resolved_provider_model or settings.model
|
||||
)
|
||||
provider = get_provider_for_type(provider_type)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
log_request_compact(logger, request_id, request_data)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue