mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Support use ollama method like LM stuio --------- Co-authored-by: Alishahryar1 <alishahryar2@gmail.com> Co-authored-by: u011436427 <u011436427@noreply.gitcode.com>
264 lines
9.3 KiB
Python
264 lines
9.3 KiB
Python
"""Provider descriptors, factory, and runtime registry."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, MutableMapping
|
|
from dataclasses import dataclass
|
|
from typing import Literal
|
|
|
|
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
|
from config.settings import Settings
|
|
from providers.base import BaseProvider, ProviderConfig
|
|
from providers.defaults import (
|
|
DEEPSEEK_DEFAULT_BASE,
|
|
LLAMACPP_DEFAULT_BASE,
|
|
LMSTUDIO_DEFAULT_BASE,
|
|
NVIDIA_NIM_DEFAULT_BASE,
|
|
OLLAMA_DEFAULT_BASE,
|
|
OPENROUTER_DEFAULT_BASE,
|
|
)
|
|
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
|
|
|
TransportType = Literal["openai_chat", "anthropic_messages"]
|
|
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ProviderDescriptor:
|
|
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
|
|
|
|
provider_id: str
|
|
transport_type: TransportType
|
|
capabilities: tuple[str, ...]
|
|
credential_env: str | None = None
|
|
credential_url: str | None = None
|
|
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
|
|
credential_attr: str | None = None
|
|
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
|
|
static_credential: str | None = None
|
|
default_base_url: str | None = None
|
|
base_url_attr: str | None = None
|
|
proxy_attr: str | None = None
|
|
|
|
|
|
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
|
"nvidia_nim": ProviderDescriptor(
|
|
provider_id="nvidia_nim",
|
|
transport_type="openai_chat",
|
|
credential_env="NVIDIA_NIM_API_KEY",
|
|
credential_url="https://build.nvidia.com/settings/api-keys",
|
|
credential_attr="nvidia_nim_api_key",
|
|
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
|
|
proxy_attr="nvidia_nim_proxy",
|
|
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
|
),
|
|
"open_router": ProviderDescriptor(
|
|
provider_id="open_router",
|
|
transport_type="anthropic_messages",
|
|
credential_env="OPENROUTER_API_KEY",
|
|
credential_url="https://openrouter.ai/keys",
|
|
credential_attr="open_router_api_key",
|
|
default_base_url=OPENROUTER_DEFAULT_BASE,
|
|
proxy_attr="open_router_proxy",
|
|
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
|
|
),
|
|
"deepseek": ProviderDescriptor(
|
|
provider_id="deepseek",
|
|
transport_type="openai_chat",
|
|
credential_env="DEEPSEEK_API_KEY",
|
|
credential_url="https://platform.deepseek.com/api_keys",
|
|
credential_attr="deepseek_api_key",
|
|
default_base_url=DEEPSEEK_DEFAULT_BASE,
|
|
capabilities=("chat", "streaming", "thinking"),
|
|
),
|
|
"lmstudio": ProviderDescriptor(
|
|
provider_id="lmstudio",
|
|
transport_type="anthropic_messages",
|
|
static_credential="lm-studio",
|
|
default_base_url=LMSTUDIO_DEFAULT_BASE,
|
|
base_url_attr="lm_studio_base_url",
|
|
proxy_attr="lmstudio_proxy",
|
|
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
|
),
|
|
"llamacpp": ProviderDescriptor(
|
|
provider_id="llamacpp",
|
|
transport_type="anthropic_messages",
|
|
static_credential="llamacpp",
|
|
default_base_url=LLAMACPP_DEFAULT_BASE,
|
|
base_url_attr="llamacpp_base_url",
|
|
proxy_attr="llamacpp_proxy",
|
|
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
|
),
|
|
"ollama": ProviderDescriptor(
|
|
provider_id="ollama",
|
|
transport_type="anthropic_messages",
|
|
static_credential="ollama",
|
|
default_base_url=OLLAMA_DEFAULT_BASE,
|
|
base_url_attr="ollama_base_url",
|
|
capabilities=(
|
|
"chat",
|
|
"streaming",
|
|
"tools",
|
|
"thinking",
|
|
"native_anthropic",
|
|
"local",
|
|
),
|
|
),
|
|
}
|
|
|
|
|
|
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
|
from providers.nvidia_nim import NvidiaNimProvider
|
|
|
|
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
|
|
|
|
|
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
|
from providers.open_router import OpenRouterProvider
|
|
|
|
return OpenRouterProvider(config)
|
|
|
|
|
|
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
|
from providers.deepseek import DeepSeekProvider
|
|
|
|
return DeepSeekProvider(config)
|
|
|
|
|
|
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
|
from providers.lmstudio import LMStudioProvider
|
|
|
|
return LMStudioProvider(config)
|
|
|
|
|
|
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
|
from providers.llamacpp import LlamaCppProvider
|
|
|
|
return LlamaCppProvider(config)
|
|
|
|
|
|
def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
|
from providers.ollama import OllamaProvider
|
|
|
|
return OllamaProvider(config)
|
|
|
|
|
|
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
|
|
"nvidia_nim": _create_nvidia_nim,
|
|
"open_router": _create_open_router,
|
|
"deepseek": _create_deepseek,
|
|
"lmstudio": _create_lmstudio,
|
|
"llamacpp": _create_llamacpp,
|
|
"ollama": _create_ollama,
|
|
}
|
|
|
|
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
|
|
PROVIDER_FACTORIES
|
|
) != set(SUPPORTED_PROVIDER_IDS):
|
|
raise AssertionError(
|
|
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
|
|
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
|
|
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
|
|
)
|
|
|
|
|
|
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
|
|
if attr_name is None:
|
|
return default
|
|
value = getattr(settings, attr_name, default)
|
|
return value if isinstance(value, str) else default
|
|
|
|
|
|
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
|
|
if descriptor.static_credential is not None:
|
|
return descriptor.static_credential
|
|
if descriptor.credential_attr:
|
|
return _string_attr(settings, descriptor.credential_attr)
|
|
return ""
|
|
|
|
|
|
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
|
|
if descriptor.credential_env is None:
|
|
return
|
|
if credential and credential.strip():
|
|
return
|
|
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
|
|
if descriptor.credential_url:
|
|
message = f"{message} Get a key at {descriptor.credential_url}"
|
|
raise AuthenticationError(message)
|
|
|
|
|
|
def build_provider_config(
|
|
descriptor: ProviderDescriptor, settings: Settings
|
|
) -> ProviderConfig:
|
|
credential = _credential_for(descriptor, settings)
|
|
_require_credential(descriptor, credential)
|
|
base_url = _string_attr(
|
|
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
|
|
)
|
|
proxy = _string_attr(settings, descriptor.proxy_attr)
|
|
return ProviderConfig(
|
|
api_key=credential,
|
|
base_url=base_url or descriptor.default_base_url,
|
|
rate_limit=settings.provider_rate_limit,
|
|
rate_window=settings.provider_rate_window,
|
|
max_concurrency=settings.provider_max_concurrency,
|
|
http_read_timeout=settings.http_read_timeout,
|
|
http_write_timeout=settings.http_write_timeout,
|
|
http_connect_timeout=settings.http_connect_timeout,
|
|
enable_thinking=settings.enable_model_thinking,
|
|
proxy=proxy,
|
|
)
|
|
|
|
|
|
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
|
|
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
|
|
if descriptor is None:
|
|
supported = "', '".join(PROVIDER_DESCRIPTORS)
|
|
raise UnknownProviderTypeError(
|
|
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
|
|
)
|
|
|
|
config = build_provider_config(descriptor, settings)
|
|
factory = PROVIDER_FACTORIES.get(provider_id)
|
|
if factory is None:
|
|
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
|
|
return factory(config, settings)
|
|
|
|
|
|
class ProviderRegistry:
|
|
"""Cache and clean up provider instances by provider id."""
|
|
|
|
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
|
|
self._providers = providers if providers is not None else {}
|
|
|
|
def is_cached(self, provider_id: str) -> bool:
|
|
"""Return whether a provider for this id is already in the cache."""
|
|
return provider_id in self._providers
|
|
|
|
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
|
|
if provider_id not in self._providers:
|
|
self._providers[provider_id] = create_provider(provider_id, settings)
|
|
return self._providers[provider_id]
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Call ``cleanup`` on every cached provider, then clear the cache.
|
|
|
|
Attempts all providers even if one fails. A single failure is re-raised
|
|
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
|
|
"""
|
|
items = list(self._providers.items())
|
|
errors: list[Exception] = []
|
|
try:
|
|
for _pid, provider in items:
|
|
try:
|
|
await provider.cleanup()
|
|
except Exception as e:
|
|
errors.append(e)
|
|
finally:
|
|
self._providers.clear()
|
|
if len(errors) == 1:
|
|
raise errors[0]
|
|
if len(errors) > 1:
|
|
msg = "One or more provider cleanups failed"
|
|
raise ExceptionGroup(msg, errors)
|