free-claude-code/providers/registry.py
Wang Ji b525217633
[feat] ollama method support (#129)
Support use ollama method like LM stuio

---------

Co-authored-by: Alishahryar1 <alishahryar2@gmail.com>
Co-authored-by: u011436427 <u011436427@noreply.gitcode.com>
2026-04-25 22:06:36 -07:00

264 lines
9.3 KiB
Python

"""Provider descriptors, factory, and runtime registry."""
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from dataclasses import dataclass
from typing import Literal
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from config.settings import Settings
from providers.base import BaseProvider, ProviderConfig
from providers.defaults import (
DEEPSEEK_DEFAULT_BASE,
LLAMACPP_DEFAULT_BASE,
LMSTUDIO_DEFAULT_BASE,
NVIDIA_NIM_DEFAULT_BASE,
OLLAMA_DEFAULT_BASE,
OPENROUTER_DEFAULT_BASE,
)
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
TransportType = Literal["openai_chat", "anthropic_messages"]
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
@dataclass(frozen=True, slots=True)
class ProviderDescriptor:
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
provider_id: str
transport_type: TransportType
capabilities: tuple[str, ...]
credential_env: str | None = None
credential_url: str | None = None
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
credential_attr: str | None = None
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
static_credential: str | None = None
default_base_url: str | None = None
base_url_attr: str | None = None
proxy_attr: str | None = None
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
"nvidia_nim": ProviderDescriptor(
provider_id="nvidia_nim",
transport_type="openai_chat",
credential_env="NVIDIA_NIM_API_KEY",
credential_url="https://build.nvidia.com/settings/api-keys",
credential_attr="nvidia_nim_api_key",
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
proxy_attr="nvidia_nim_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
),
"open_router": ProviderDescriptor(
provider_id="open_router",
transport_type="anthropic_messages",
credential_env="OPENROUTER_API_KEY",
credential_url="https://openrouter.ai/keys",
credential_attr="open_router_api_key",
default_base_url=OPENROUTER_DEFAULT_BASE,
proxy_attr="open_router_proxy",
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
),
"deepseek": ProviderDescriptor(
provider_id="deepseek",
transport_type="openai_chat",
credential_env="DEEPSEEK_API_KEY",
credential_url="https://platform.deepseek.com/api_keys",
credential_attr="deepseek_api_key",
default_base_url=DEEPSEEK_DEFAULT_BASE,
capabilities=("chat", "streaming", "thinking"),
),
"lmstudio": ProviderDescriptor(
provider_id="lmstudio",
transport_type="anthropic_messages",
static_credential="lm-studio",
default_base_url=LMSTUDIO_DEFAULT_BASE,
base_url_attr="lm_studio_base_url",
proxy_attr="lmstudio_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
),
"llamacpp": ProviderDescriptor(
provider_id="llamacpp",
transport_type="anthropic_messages",
static_credential="llamacpp",
default_base_url=LLAMACPP_DEFAULT_BASE,
base_url_attr="llamacpp_base_url",
proxy_attr="llamacpp_proxy",
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
),
"ollama": ProviderDescriptor(
provider_id="ollama",
transport_type="anthropic_messages",
static_credential="ollama",
default_base_url=OLLAMA_DEFAULT_BASE,
base_url_attr="ollama_base_url",
capabilities=(
"chat",
"streaming",
"tools",
"thinking",
"native_anthropic",
"local",
),
),
}
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.nvidia_nim import NvidiaNimProvider
return NvidiaNimProvider(config, nim_settings=settings.nim)
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.open_router import OpenRouterProvider
return OpenRouterProvider(config)
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.deepseek import DeepSeekProvider
return DeepSeekProvider(config)
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.lmstudio import LMStudioProvider
return LMStudioProvider(config)
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.llamacpp import LlamaCppProvider
return LlamaCppProvider(config)
def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.ollama import OllamaProvider
return OllamaProvider(config)
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"nvidia_nim": _create_nvidia_nim,
"open_router": _create_open_router,
"deepseek": _create_deepseek,
"lmstudio": _create_lmstudio,
"llamacpp": _create_llamacpp,
"ollama": _create_ollama,
}
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
PROVIDER_FACTORIES
) != set(SUPPORTED_PROVIDER_IDS):
raise AssertionError(
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
)
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
if attr_name is None:
return default
value = getattr(settings, attr_name, default)
return value if isinstance(value, str) else default
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
if descriptor.static_credential is not None:
return descriptor.static_credential
if descriptor.credential_attr:
return _string_attr(settings, descriptor.credential_attr)
return ""
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
if descriptor.credential_env is None:
return
if credential and credential.strip():
return
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
if descriptor.credential_url:
message = f"{message} Get a key at {descriptor.credential_url}"
raise AuthenticationError(message)
def build_provider_config(
descriptor: ProviderDescriptor, settings: Settings
) -> ProviderConfig:
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential)
base_url = _string_attr(
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
)
proxy = _string_attr(settings, descriptor.proxy_attr)
return ProviderConfig(
api_key=credential,
base_url=base_url or descriptor.default_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_model_thinking,
proxy=proxy,
)
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
if descriptor is None:
supported = "', '".join(PROVIDER_DESCRIPTORS)
raise UnknownProviderTypeError(
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
)
config = build_provider_config(descriptor, settings)
factory = PROVIDER_FACTORIES.get(provider_id)
if factory is None:
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
return factory(config, settings)
class ProviderRegistry:
"""Cache and clean up provider instances by provider id."""
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
self._providers = providers if providers is not None else {}
def is_cached(self, provider_id: str) -> bool:
"""Return whether a provider for this id is already in the cache."""
return provider_id in self._providers
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
if provider_id not in self._providers:
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
async def cleanup(self) -> None:
"""Call ``cleanup`` on every cached provider, then clear the cache.
Attempts all providers even if one fails. A single failure is re-raised
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
"""
items = list(self._providers.items())
errors: list[Exception] = []
try:
for _pid, provider in items:
try:
await provider.cleanup()
except Exception as e:
errors.append(e)
finally:
self._providers.clear()
if len(errors) == 1:
raise errors[0]
if len(errors) > 1:
msg = "One or more provider cleanups failed"
raise ExceptionGroup(msg, errors)