mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-28 11:30:00 +00:00
* feat: replace provider config with credential-based system (#477) Introduce a new credential management system replacing the old ProviderConfig singleton and standalone Models page. Each credential stores encrypted API keys and provider-specific configuration with full CRUD support via a unified settings UI. Backend: - Add Credential domain model with encrypted API key storage - Add credentials API router (CRUD, discovery, registration, testing) - Add encryption utilities for secure key storage - Add key_provider for DB-first env-var fallback provisioning - Add connection tester and model discovery services - Integrate ModelManager with credential-based config - Add provider name normalization for Esperanto compatibility - Add database migrations 11-12 for credential schema Frontend: - Rewrite settings/api-keys page with credential management UI - Add model discovery dialog with search and custom model support - Add compact default model assignments (primary/advanced layout) - Add inline model testing and credential connection testing - Add env-var migration banner - Update navigation to unified settings page - Remove standalone models page and old settings components i18n: - Update all 7 locale files with credential and model management keys Closes #477 Co-Authored-By: JFMD <git@jfmd.us> Co-Authored-By: OraCatQAQ <570768706@qq.com> * fix: address PR #540 review comments - Fix docs referencing removed Models page - Fix error-handler returning raw messages instead of i18n keys - Fix auth.py misleading docstring and missing no-password guard - Fix connection_tester using wrong env var for openai_compatible - Add provision_provider_keys before model discovery/sync - Update CLAUDE.md to reflect credential-based system - Fix missing closing brace in api-keys page useEffect * fix: add logging to credential migration and surface errors in UI - Add comprehensive logging to migrate-from-env and migrate-from-provider-config endpoints (start, per-provider progress, success/failure with stack traces, final summary) - Fix frontend migration hooks ignoring errors array from response - Show error toast when migration fails instead of "nothing to migrate" - Invalidate status/envStatus queries after migration so banner updates * docs: update CLAUDE.md files for credential system Replace stale ProviderConfig and /api-keys/ references across 8 CLAUDE.md files to reflect the new Credential-based system from PR #540. * docs: update user documentation for credential-based system Replace env var API key instructions with Settings UI credential workflow across all user-facing documentation. The new flow is: set OPEN_NOTEBOOK_ENCRYPTION_KEY → start services → add credential in Settings UI → test → discover models → register. - Rewrite ai-providers.md, api-configuration.md, environment-reference.md - Update all quick-start guides and installation docs - Update ollama.md, openai-compatible.md, local-tts/stt networking sections - Update reverse-proxy.md, development-setup.md, security.md - Fix broken links to non-existent docs/deployment/ paths - Add credentials endpoints to api-reference.md - Move all API key env vars to deprecated/legacy sections * chore: bump version to 1.7.0-rc1 Release candidate for credential-based provider management system. * fix: initialize provider before try block in test_credential Prevents UnboundLocalError when Credential.get() throws (e.g., invalid credential_id) before provider is assigned. * fix: reorder down migration to drop index before table Removes duplicate REMOVE FIELD statement and reorders so the index is dropped before the table, preventing rollback failures. * refactor: simplify encryption key to always derive via SHA-256 Remove the dual code path in _ensure_fernet_key() that detected native Fernet keys. Since the credential system is new, always deriving via SHA-256 removes unnecessary complexity. Also removes the generate_key() function and Fernet.generate_key() references from docs. * fix: correct mock patch targets in embedding tests and URL validation Fix embedding tests patching wrong module path for model_manager (was targeting open_notebook.utils.embedding.model_manager but it's imported locally from open_notebook.ai.models). Also fix URL validation to allow unresolvable hostnames since they may be valid in the deployment environment (e.g., Azure endpoints, internal DNS). * feat: add global setup banner for encryption and migration status Show a persistent banner in AppShell when encryption key is missing (red) or env var API keys can be migrated (amber), so users see these prompts on every page instead of only on Settings > API Keys. Includes a docs link for the encryption banner and i18n support across all 7 locales. * docs: several improvements to docker-compose e env examples * Update README.md Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> * docs: fix env var format in README and update model setup instructions Align the encryption key snippet in README Step 2 with the list format used in the compose file. Replace deprecated "Settings → Models" instructions with credential-based Discover Models flow. * fix: address credential system review issues - Fix SSRF bypass via IPv4-mapped IPv6 addresses (::ffff:169.254.x.x) - Fix TTS connection test missing config parameter - Add Azure-specific model discovery using api-key auth header - Add Vertex static model list for credential-based discovery - Fix PROVIDER_DISCOVERY_FUNCTIONS incorrect azure/vertex mapping - Extract business logic to api/credentials_service.py (service layer) - Move credential Pydantic schemas to api/models.py - Update tests to use new service imports and ValueError assertions * fix: sanitize error responses and migrate key_provider to Credential - Replace raw exception messages in all credential router 500 responses with generic error strings (internal details logged server-side only) - Refactor key_provider.py to use Credential.get_by_provider() instead of deprecated ProviderConfig.get_instance() - Remove unused functions (get_provider_configs, get_default_api_key, get_provider_config) that were dead code --------- Co-authored-by: JFMD <git@jfmd.us> Co-authored-by: OraCatQAQ <570768706@qq.com>
756 lines
23 KiB
Python
756 lines
23 KiB
Python
"""
|
|
Model Discovery - Automatic model fetching from AI providers.
|
|
|
|
This module provides functionality to discover available models from configured
|
|
AI providers and automatically register them in the database.
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import httpx
|
|
from loguru import logger
|
|
|
|
from open_notebook.ai.models import Model
|
|
from open_notebook.domain.credential import Credential
|
|
from open_notebook.database.repository import repo_query
|
|
|
|
|
|
@dataclass
|
|
class DiscoveredModel:
|
|
"""Represents a model discovered from a provider."""
|
|
|
|
name: str
|
|
provider: str
|
|
model_type: str # language, embedding, speech_to_text, text_to_speech
|
|
description: Optional[str] = None
|
|
|
|
|
|
# =============================================================================
|
|
# Provider-Specific Model Type Classification
|
|
# =============================================================================
|
|
# These mappings help classify models by their capabilities based on naming patterns
|
|
|
|
OPENAI_MODEL_TYPES = {
|
|
"language": [
|
|
"gpt-4",
|
|
"gpt-3.5",
|
|
"o1",
|
|
"o3",
|
|
"chatgpt",
|
|
"text-davinci",
|
|
"davinci",
|
|
"curie",
|
|
"babbage",
|
|
"ada",
|
|
],
|
|
"embedding": ["text-embedding", "embedding"],
|
|
"speech_to_text": ["whisper"],
|
|
"text_to_speech": ["tts"],
|
|
}
|
|
|
|
ANTHROPIC_MODELS = {
|
|
# Static list since Anthropic doesn't have a model listing API
|
|
"language": [
|
|
"claude-opus-4-20250514",
|
|
"claude-sonnet-4-20250514",
|
|
"claude-3-5-sonnet-20241022",
|
|
"claude-3-5-haiku-20241022",
|
|
"claude-3-opus-20240229",
|
|
"claude-3-sonnet-20240229",
|
|
"claude-3-haiku-20240307",
|
|
],
|
|
}
|
|
|
|
GOOGLE_MODEL_TYPES = {
|
|
"language": ["gemini", "palm", "bison", "chat"],
|
|
"embedding": ["embedding", "textembedding"],
|
|
}
|
|
|
|
OLLAMA_MODEL_TYPES = {
|
|
# Ollama models can do multiple things, classify by common names
|
|
"language": [
|
|
"llama",
|
|
"mistral",
|
|
"mixtral",
|
|
"codellama",
|
|
"phi",
|
|
"gemma",
|
|
"qwen",
|
|
"deepseek",
|
|
"vicuna",
|
|
"falcon",
|
|
"orca",
|
|
"neural",
|
|
"dolphin",
|
|
"openchat",
|
|
"starling",
|
|
"solar",
|
|
"yi",
|
|
"nous",
|
|
"wizard",
|
|
"zephyr",
|
|
"tinyllama",
|
|
],
|
|
"embedding": ["nomic-embed", "mxbai-embed", "all-minilm", "bge-", "e5-"],
|
|
}
|
|
|
|
MISTRAL_MODEL_TYPES = {
|
|
"language": [
|
|
"mistral",
|
|
"mixtral",
|
|
"codestral",
|
|
"ministral",
|
|
"pixtral",
|
|
"open-mistral",
|
|
"open-mixtral",
|
|
],
|
|
"embedding": ["mistral-embed"],
|
|
}
|
|
|
|
GROQ_MODEL_TYPES = {
|
|
"language": ["llama", "mixtral", "gemma", "whisper"],
|
|
"speech_to_text": ["whisper"],
|
|
}
|
|
|
|
DEEPSEEK_MODEL_TYPES = {
|
|
"language": ["deepseek-chat", "deepseek-reasoner", "deepseek-coder"],
|
|
}
|
|
|
|
XAI_MODEL_TYPES = {
|
|
"language": ["grok"],
|
|
}
|
|
|
|
VOYAGE_MODEL_TYPES = {
|
|
"embedding": ["voyage"],
|
|
}
|
|
|
|
ELEVENLABS_MODEL_TYPES = {
|
|
"text_to_speech": ["eleven"],
|
|
}
|
|
|
|
|
|
def classify_model_type(model_name: str, provider: str) -> str:
|
|
"""
|
|
Classify a model into a type based on its name and provider.
|
|
|
|
Returns one of: language, embedding, speech_to_text, text_to_speech
|
|
"""
|
|
name_lower = model_name.lower()
|
|
|
|
type_mappings = {
|
|
"openai": OPENAI_MODEL_TYPES,
|
|
"google": GOOGLE_MODEL_TYPES,
|
|
"ollama": OLLAMA_MODEL_TYPES,
|
|
"mistral": MISTRAL_MODEL_TYPES,
|
|
"groq": GROQ_MODEL_TYPES,
|
|
"deepseek": DEEPSEEK_MODEL_TYPES,
|
|
"xai": XAI_MODEL_TYPES,
|
|
"voyage": VOYAGE_MODEL_TYPES,
|
|
"elevenlabs": ELEVENLABS_MODEL_TYPES,
|
|
}
|
|
|
|
mapping = type_mappings.get(provider, {})
|
|
|
|
# Check each type in order of specificity
|
|
for model_type in ["speech_to_text", "text_to_speech", "embedding", "language"]:
|
|
patterns = mapping.get(model_type, [])
|
|
for pattern in patterns:
|
|
if pattern in name_lower:
|
|
return model_type
|
|
|
|
# Default to language for unknown models
|
|
return "language"
|
|
|
|
|
|
# =============================================================================
|
|
# Provider-Specific Model Discovery Functions
|
|
# =============================================================================
|
|
|
|
|
|
async def discover_openai_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from OpenAI API."""
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.openai.com/v1/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
model_type = classify_model_type(model_id, "openai")
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="openai",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover OpenAI models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_anthropic_models() -> List[DiscoveredModel]:
|
|
"""Return static list of Anthropic models (no discovery API available)."""
|
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
# Anthropic doesn't have a model listing API, so we use a static list
|
|
models = []
|
|
for model_name in ANTHROPIC_MODELS.get("language", []):
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_name,
|
|
provider="anthropic",
|
|
model_type="language",
|
|
)
|
|
)
|
|
return models
|
|
|
|
|
|
async def discover_google_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from Google Gemini API."""
|
|
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
# Build URL without logging the key to avoid exposure
|
|
url = "https://generativelanguage.googleapis.com/v1/models"
|
|
headers = {"X-Goog-Api-Key": api_key}
|
|
response = await client.get(url, headers=headers, timeout=30.0)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("models", []):
|
|
# Google returns full path like "models/gemini-1.5-flash"
|
|
model_name = model.get("name", "").replace("models/", "")
|
|
if model_name:
|
|
model_type = classify_model_type(model_name, "google")
|
|
# Check supported generation methods for better classification
|
|
methods = model.get("supportedGenerationMethods", [])
|
|
if "embedContent" in methods:
|
|
model_type = "embedding"
|
|
elif "generateContent" in methods:
|
|
model_type = "language"
|
|
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_name,
|
|
provider="google",
|
|
model_type=model_type,
|
|
description=model.get("displayName"),
|
|
)
|
|
)
|
|
except Exception as e:
|
|
# Log without exposing the API key in the message
|
|
logger.warning(f"Failed to discover Google models: {type(e).__name__}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_ollama_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from local Ollama instance."""
|
|
base_url = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
|
|
if not base_url:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{base_url}/api/tags",
|
|
timeout=10.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("models", []):
|
|
model_name = model.get("name", "")
|
|
if model_name:
|
|
model_type = classify_model_type(model_name, "ollama")
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_name,
|
|
provider="ollama",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover Ollama models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_groq_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from Groq API."""
|
|
api_key = os.environ.get("GROQ_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.groq.com/openai/v1/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
model_type = classify_model_type(model_id, "groq")
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="groq",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover Groq models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_mistral_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from Mistral API."""
|
|
api_key = os.environ.get("MISTRAL_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.mistral.ai/v1/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
model_type = classify_model_type(model_id, "mistral")
|
|
# Check capabilities if available
|
|
capabilities = model.get("capabilities", {})
|
|
if capabilities.get("completion_chat"):
|
|
model_type = "language"
|
|
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="mistral",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover Mistral models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_deepseek_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from DeepSeek API."""
|
|
api_key = os.environ.get("DEEPSEEK_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.deepseek.com/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
model_type = classify_model_type(model_id, "deepseek")
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="deepseek",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover DeepSeek models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_xai_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from xAI API."""
|
|
api_key = os.environ.get("XAI_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://api.x.ai/v1/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
model_type = classify_model_type(model_id, "xai")
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="xai",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover xAI models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_openrouter_models() -> List[DiscoveredModel]:
|
|
"""Fetch available models from OpenRouter API."""
|
|
api_key = os.environ.get("OPENROUTER_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
"https://openrouter.ai/api/v1/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
# OpenRouter models are typically language models
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="openrouter",
|
|
model_type="language",
|
|
description=model.get("name"),
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover OpenRouter models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
async def discover_voyage_models() -> List[DiscoveredModel]:
|
|
"""Return static list of Voyage AI models (embedding only)."""
|
|
api_key = os.environ.get("VOYAGE_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
# Voyage AI specializes in embeddings
|
|
voyage_models = [
|
|
"voyage-3",
|
|
"voyage-3-lite",
|
|
"voyage-code-3",
|
|
"voyage-finance-2",
|
|
"voyage-law-2",
|
|
"voyage-multilingual-2",
|
|
]
|
|
|
|
return [
|
|
DiscoveredModel(name=m, provider="voyage", model_type="embedding")
|
|
for m in voyage_models
|
|
]
|
|
|
|
|
|
async def discover_elevenlabs_models() -> List[DiscoveredModel]:
|
|
"""Return static list of ElevenLabs TTS models."""
|
|
api_key = os.environ.get("ELEVENLABS_API_KEY")
|
|
if not api_key:
|
|
return []
|
|
|
|
# ElevenLabs specializes in TTS
|
|
elevenlabs_models = [
|
|
"eleven_multilingual_v2",
|
|
"eleven_turbo_v2_5",
|
|
"eleven_turbo_v2",
|
|
"eleven_monolingual_v1",
|
|
"eleven_multilingual_v1",
|
|
]
|
|
|
|
return [
|
|
DiscoveredModel(name=m, provider="elevenlabs", model_type="text_to_speech")
|
|
for m in elevenlabs_models
|
|
]
|
|
|
|
|
|
async def discover_openai_compatible_models() -> List[DiscoveredModel]:
|
|
"""
|
|
Fetch available models from an OpenAI-compatible API endpoint.
|
|
Uses the configured base_url from the database or environment variable.
|
|
"""
|
|
api_key = None
|
|
base_url = None
|
|
|
|
# Try to get config from Credential database first
|
|
try:
|
|
credentials = await Credential.get_by_provider("openai_compatible")
|
|
if credentials:
|
|
cred = credentials[0]
|
|
config = cred.to_esperanto_config()
|
|
api_key = config.get("api_key")
|
|
base_url = config.get("base_url", "").rstrip("/")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to read openai_compatible config from Credential: {e}")
|
|
|
|
# Fall back to environment variables
|
|
if not api_key:
|
|
api_key = os.environ.get("OPENAI_COMPATIBLE_API_KEY")
|
|
if not base_url:
|
|
base_url = os.environ.get("OPENAI_COMPATIBLE_BASE_URL", "").rstrip("/")
|
|
|
|
if not base_url:
|
|
logger.warning("No base_url configured for openai_compatible provider")
|
|
return []
|
|
|
|
models = []
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
headers = {}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
response = await client.get(
|
|
f"{base_url}/models",
|
|
headers=headers,
|
|
timeout=30.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
if model_id:
|
|
# Classify based on model name patterns
|
|
model_type = classify_model_type(model_id, "openai")
|
|
models.append(
|
|
DiscoveredModel(
|
|
name=model_id,
|
|
provider="openai_compatible",
|
|
model_type=model_type,
|
|
)
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
logger.warning(f"Failed to discover openai_compatible models: HTTP {e.response.status_code}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to discover openai_compatible models: {e}")
|
|
|
|
return models
|
|
|
|
|
|
# =============================================================================
|
|
# Main Discovery Functions
|
|
# =============================================================================
|
|
|
|
# Map provider names to their discovery functions
|
|
PROVIDER_DISCOVERY_FUNCTIONS = {
|
|
"openai": discover_openai_models,
|
|
"anthropic": discover_anthropic_models,
|
|
"google": discover_google_models,
|
|
"ollama": discover_ollama_models,
|
|
"groq": discover_groq_models,
|
|
"mistral": discover_mistral_models,
|
|
"deepseek": discover_deepseek_models,
|
|
"xai": discover_xai_models,
|
|
"openrouter": discover_openrouter_models,
|
|
"voyage": discover_voyage_models,
|
|
"elevenlabs": discover_elevenlabs_models,
|
|
"openai_compatible": discover_openai_compatible_models,
|
|
"azure": None, # Azure requires credential-based discovery (different auth)
|
|
"vertex": None, # Vertex requires credential-based discovery (service account)
|
|
}
|
|
|
|
|
|
async def discover_provider_models(provider: str) -> List[DiscoveredModel]:
|
|
"""
|
|
Discover available models for a specific provider.
|
|
|
|
Args:
|
|
provider: Provider name (openai, anthropic, etc.)
|
|
|
|
Returns:
|
|
List of discovered models
|
|
"""
|
|
discover_func = PROVIDER_DISCOVERY_FUNCTIONS.get(provider)
|
|
if discover_func is None:
|
|
if provider in PROVIDER_DISCOVERY_FUNCTIONS:
|
|
logger.info(
|
|
f"Provider '{provider}' requires credential-based discovery. "
|
|
f"Use the /credentials/{{id}}/discover endpoint instead."
|
|
)
|
|
else:
|
|
logger.warning(f"No discovery function for provider: {provider}")
|
|
return []
|
|
|
|
return await discover_func()
|
|
|
|
|
|
async def sync_provider_models(
|
|
provider: str, auto_register: bool = True
|
|
) -> Tuple[int, int, int]:
|
|
"""
|
|
Sync models for a provider: discover and optionally register in database.
|
|
|
|
Args:
|
|
provider: Provider name
|
|
auto_register: If True, automatically create Model records in database
|
|
|
|
Returns:
|
|
Tuple of (discovered_count, new_count, existing_count)
|
|
"""
|
|
discovered = await discover_provider_models(provider)
|
|
discovered_count = len(discovered)
|
|
new_count = 0
|
|
existing_count = 0
|
|
|
|
if not auto_register:
|
|
return discovered_count, 0, 0
|
|
|
|
if not discovered:
|
|
return 0, 0, 0
|
|
|
|
# Batch fetch existing models to avoid N+1 query pattern
|
|
try:
|
|
existing_models = await repo_query(
|
|
"SELECT string::lowercase(name) as name, string::lowercase(type) as type FROM model "
|
|
"WHERE string::lowercase(provider) = $provider",
|
|
{"provider": provider.lower()},
|
|
)
|
|
# Create a set of (name, type) tuples for O(1) lookup
|
|
existing_keys = set()
|
|
for m in existing_models:
|
|
existing_keys.add((m.get("name", ""), m.get("type", "")))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch existing models for {provider}: {e}")
|
|
existing_keys = set()
|
|
|
|
for model in discovered:
|
|
model_key = (model.name.lower(), model.model_type.lower())
|
|
|
|
# Check if model already exists using pre-fetched data
|
|
if model_key in existing_keys:
|
|
existing_count += 1
|
|
continue
|
|
|
|
# Create new model
|
|
try:
|
|
new_model = Model(
|
|
name=model.name,
|
|
provider=model.provider,
|
|
type=model.model_type,
|
|
)
|
|
await new_model.save()
|
|
new_count += 1
|
|
logger.info(f"Registered new model: {model.provider}/{model.name} ({model.model_type})")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to register model {model.name}: {e}")
|
|
|
|
logger.info(
|
|
f"Synced {provider}: {discovered_count} discovered, "
|
|
f"{new_count} new, {existing_count} existing"
|
|
)
|
|
return discovered_count, new_count, existing_count
|
|
|
|
|
|
async def sync_all_providers() -> Dict[str, Tuple[int, int, int]]:
|
|
"""
|
|
Sync models for all configured providers.
|
|
|
|
Returns:
|
|
Dict mapping provider names to (discovered, new, existing) tuples
|
|
"""
|
|
results = {}
|
|
|
|
# Run discovery for all providers in parallel
|
|
tasks = []
|
|
providers = list(PROVIDER_DISCOVERY_FUNCTIONS.keys())
|
|
|
|
for provider in providers:
|
|
tasks.append(sync_provider_models(provider, auto_register=True))
|
|
|
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
for provider, result in zip(providers, task_results):
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Error syncing {provider}: {result}")
|
|
results[provider] = (0, 0, 0)
|
|
else:
|
|
results[provider] = result
|
|
|
|
return results
|
|
|
|
|
|
async def get_provider_model_count(provider: str) -> Dict[str, int]:
|
|
"""
|
|
Get count of registered models for a provider, grouped by type.
|
|
|
|
Args:
|
|
provider: Provider name (case-insensitive)
|
|
|
|
Returns:
|
|
Dict mapping model type to count
|
|
"""
|
|
# Use case-insensitive comparison by lowercasing the provider
|
|
result = await repo_query(
|
|
"SELECT type, count() as count FROM model WHERE string::lowercase(provider) = string::lowercase($provider) GROUP BY type",
|
|
{"provider": provider},
|
|
)
|
|
|
|
counts = {
|
|
"language": 0,
|
|
"embedding": 0,
|
|
"speech_to_text": 0,
|
|
"text_to_speech": 0,
|
|
}
|
|
|
|
for row in result:
|
|
model_type = row.get("type")
|
|
count = row.get("count", 0)
|
|
if model_type in counts:
|
|
counts[model_type] = count
|
|
|
|
return counts
|