open-notebook/open_notebook/ai/model_discovery.py
Luis Novo adc03e56bb feat: add DashScope (Qwen) and MiniMax provider support
- Bump esperanto dependency to >=2.20.0 for new provider profiles
- Register both providers in credentials, key provider, connection tester, model discovery, and models router
- Add frontend provider entries (display names, modalities, docs links)
- Add documentation sections for both providers in ai-providers.md, environment-reference.md, and provider comparison
2026-04-06 10:54:37 -03:00

836 lines
26 KiB
Python

"""
Model Discovery - Automatic model fetching from AI providers.
This module provides functionality to discover available models from configured
AI providers and automatically register them in the database.
"""
import asyncio
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import httpx
from loguru import logger
from open_notebook.ai.models import Model
from open_notebook.domain.credential import Credential
from open_notebook.database.repository import repo_query
@dataclass
class DiscoveredModel:
"""Represents a model discovered from a provider."""
name: str
provider: str
model_type: str # language, embedding, speech_to_text, text_to_speech
description: Optional[str] = None
# =============================================================================
# Provider-Specific Model Type Classification
# =============================================================================
# These mappings help classify models by their capabilities based on naming patterns
OPENAI_MODEL_TYPES = {
"language": [
"gpt-4",
"gpt-3.5",
"o1",
"o3",
"chatgpt",
"text-davinci",
"davinci",
"curie",
"babbage",
"ada",
],
"embedding": ["text-embedding", "embedding"],
"speech_to_text": ["whisper"],
"text_to_speech": ["tts"],
}
ANTHROPIC_MODELS = {
# Static list since Anthropic doesn't have a model listing API
"language": [
"claude-opus-4-20250514",
"claude-sonnet-4-20250514",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
}
GOOGLE_MODEL_TYPES = {
"language": ["gemini", "palm", "bison", "chat"],
"embedding": ["embedding", "textembedding"],
}
OLLAMA_MODEL_TYPES = {
# Ollama models can do multiple things, classify by common names
"language": [
"llama",
"mistral",
"mixtral",
"codellama",
"phi",
"gemma",
"qwen",
"deepseek",
"vicuna",
"falcon",
"orca",
"neural",
"dolphin",
"openchat",
"starling",
"solar",
"yi",
"nous",
"wizard",
"zephyr",
"tinyllama",
],
"embedding": ["nomic-embed", "mxbai-embed", "all-minilm", "bge-", "e5-"],
}
MISTRAL_MODEL_TYPES = {
"language": [
"mistral",
"mixtral",
"codestral",
"ministral",
"pixtral",
"open-mistral",
"open-mixtral",
],
"embedding": ["mistral-embed"],
}
GROQ_MODEL_TYPES = {
"language": ["llama", "mixtral", "gemma", "whisper"],
"speech_to_text": ["whisper"],
}
DEEPSEEK_MODEL_TYPES = {
"language": ["deepseek-chat", "deepseek-reasoner", "deepseek-coder"],
}
XAI_MODEL_TYPES = {
"language": ["grok"],
}
VOYAGE_MODEL_TYPES = {
"embedding": ["voyage"],
}
ELEVENLABS_MODEL_TYPES = {
"text_to_speech": ["eleven"],
}
DASHSCOPE_MODEL_TYPES = {
"language": ["qwen"],
}
MINIMAX_MODEL_TYPES = {
"language": ["minimax", "abab"],
}
def classify_model_type(model_name: str, provider: str) -> str:
"""
Classify a model into a type based on its name and provider.
Returns one of: language, embedding, speech_to_text, text_to_speech
"""
name_lower = model_name.lower()
type_mappings = {
"openai": OPENAI_MODEL_TYPES,
"google": GOOGLE_MODEL_TYPES,
"ollama": OLLAMA_MODEL_TYPES,
"mistral": MISTRAL_MODEL_TYPES,
"groq": GROQ_MODEL_TYPES,
"deepseek": DEEPSEEK_MODEL_TYPES,
"xai": XAI_MODEL_TYPES,
"voyage": VOYAGE_MODEL_TYPES,
"elevenlabs": ELEVENLABS_MODEL_TYPES,
"dashscope": DASHSCOPE_MODEL_TYPES,
"minimax": MINIMAX_MODEL_TYPES,
}
mapping = type_mappings.get(provider, {})
# Check each type in order of specificity
for model_type in ["speech_to_text", "text_to_speech", "embedding", "language"]:
patterns = mapping.get(model_type, [])
for pattern in patterns:
if pattern in name_lower:
return model_type
# Default to language for unknown models
return "language"
# =============================================================================
# Provider-Specific Model Discovery Functions
# =============================================================================
async def discover_openai_models() -> List[DiscoveredModel]:
"""Fetch available models from OpenAI API."""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "openai")
models.append(
DiscoveredModel(
name=model_id,
provider="openai",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover OpenAI models: {e}")
return models
async def discover_anthropic_models() -> List[DiscoveredModel]:
"""Return static list of Anthropic models (no discovery API available)."""
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
return []
# Anthropic doesn't have a model listing API, so we use a static list
models = []
for model_name in ANTHROPIC_MODELS.get("language", []):
models.append(
DiscoveredModel(
name=model_name,
provider="anthropic",
model_type="language",
)
)
return models
async def discover_google_models() -> List[DiscoveredModel]:
"""Fetch available models from Google Gemini API."""
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
# Build URL without logging the key to avoid exposure
url = "https://generativelanguage.googleapis.com/v1/models"
headers = {"X-Goog-Api-Key": api_key}
response = await client.get(url, headers=headers, timeout=30.0)
response.raise_for_status()
data = response.json()
for model in data.get("models", []):
# Google returns full path like "models/gemini-1.5-flash"
model_name = model.get("name", "").replace("models/", "")
if model_name:
model_type = classify_model_type(model_name, "google")
# Check supported generation methods for better classification
methods = model.get("supportedGenerationMethods", [])
if "embedContent" in methods:
model_type = "embedding"
elif "generateContent" in methods:
model_type = "language"
models.append(
DiscoveredModel(
name=model_name,
provider="google",
model_type=model_type,
description=model.get("displayName"),
)
)
except Exception as e:
# Log without exposing the API key in the message
logger.warning(f"Failed to discover Google models: {type(e).__name__}")
return models
async def discover_ollama_models() -> List[DiscoveredModel]:
"""Fetch available models from local Ollama instance."""
base_url = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
if not base_url:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{base_url}/api/tags",
timeout=10.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("models", []):
model_name = model.get("name", "")
if model_name:
model_type = classify_model_type(model_name, "ollama")
models.append(
DiscoveredModel(
name=model_name,
provider="ollama",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover Ollama models: {e}")
return models
async def discover_groq_models() -> List[DiscoveredModel]:
"""Fetch available models from Groq API."""
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.groq.com/openai/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "groq")
models.append(
DiscoveredModel(
name=model_id,
provider="groq",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover Groq models: {e}")
return models
async def discover_mistral_models() -> List[DiscoveredModel]:
"""Fetch available models from Mistral API."""
api_key = os.environ.get("MISTRAL_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.mistral.ai/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "mistral")
# Check capabilities if available
capabilities = model.get("capabilities", {})
if capabilities.get("completion_chat"):
model_type = "language"
models.append(
DiscoveredModel(
name=model_id,
provider="mistral",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover Mistral models: {e}")
return models
async def discover_deepseek_models() -> List[DiscoveredModel]:
"""Fetch available models from DeepSeek API."""
api_key = os.environ.get("DEEPSEEK_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.deepseek.com/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "deepseek")
models.append(
DiscoveredModel(
name=model_id,
provider="deepseek",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover DeepSeek models: {e}")
return models
async def discover_xai_models() -> List[DiscoveredModel]:
"""Fetch available models from xAI API."""
api_key = os.environ.get("XAI_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.x.ai/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "xai")
models.append(
DiscoveredModel(
name=model_id,
provider="xai",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover xAI models: {e}")
return models
async def discover_openrouter_models() -> List[DiscoveredModel]:
"""Fetch available models from OpenRouter API."""
api_key = os.environ.get("OPENROUTER_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://openrouter.ai/api/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
# OpenRouter models are typically language models
models.append(
DiscoveredModel(
name=model_id,
provider="openrouter",
model_type="language",
description=model.get("name"),
)
)
except Exception as e:
logger.warning(f"Failed to discover OpenRouter models: {e}")
return models
async def discover_voyage_models() -> List[DiscoveredModel]:
"""Return static list of Voyage AI models (embedding only)."""
api_key = os.environ.get("VOYAGE_API_KEY")
if not api_key:
return []
# Voyage AI specializes in embeddings
voyage_models = [
"voyage-3",
"voyage-3-lite",
"voyage-code-3",
"voyage-finance-2",
"voyage-law-2",
"voyage-multilingual-2",
]
return [
DiscoveredModel(name=m, provider="voyage", model_type="embedding")
for m in voyage_models
]
async def discover_elevenlabs_models() -> List[DiscoveredModel]:
"""Return static list of ElevenLabs TTS models."""
api_key = os.environ.get("ELEVENLABS_API_KEY")
if not api_key:
return []
# ElevenLabs specializes in TTS
elevenlabs_models = [
"eleven_multilingual_v2",
"eleven_turbo_v2_5",
"eleven_turbo_v2",
"eleven_monolingual_v1",
"eleven_multilingual_v1",
]
return [
DiscoveredModel(name=m, provider="elevenlabs", model_type="text_to_speech")
for m in elevenlabs_models
]
async def discover_dashscope_models() -> List[DiscoveredModel]:
"""Fetch available models from DashScope (Qwen) API."""
api_key = os.environ.get("DASHSCOPE_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://dashscope.aliyuncs.com/compatible-mode/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "dashscope")
models.append(
DiscoveredModel(
name=model_id,
provider="dashscope",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover DashScope models: {e}")
return models
async def discover_minimax_models() -> List[DiscoveredModel]:
"""Fetch available models from MiniMax API."""
api_key = os.environ.get("MINIMAX_API_KEY")
if not api_key:
return []
models = []
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.minimax.io/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
model_type = classify_model_type(model_id, "minimax")
models.append(
DiscoveredModel(
name=model_id,
provider="minimax",
model_type=model_type,
)
)
except Exception as e:
logger.warning(f"Failed to discover MiniMax models: {e}")
return models
async def discover_openai_compatible_models() -> List[DiscoveredModel]:
"""
Fetch available models from an OpenAI-compatible API endpoint.
Uses the configured base_url from the database or environment variable.
"""
api_key = None
base_url = None
# Try to get config from Credential database first
try:
credentials = await Credential.get_by_provider("openai_compatible")
if credentials:
cred = credentials[0]
config = cred.to_esperanto_config()
api_key = config.get("api_key")
base_url = config.get("base_url", "").rstrip("/")
except Exception as e:
logger.warning(f"Failed to read openai_compatible config from Credential: {e}")
# Fall back to environment variables
if not api_key:
api_key = os.environ.get("OPENAI_COMPATIBLE_API_KEY")
if not base_url:
base_url = os.environ.get("OPENAI_COMPATIBLE_BASE_URL", "").rstrip("/")
if not base_url:
logger.warning("No base_url configured for openai_compatible provider")
return []
models = []
try:
async with httpx.AsyncClient() as client:
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
response = await client.get(
f"{base_url}/models",
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.json()
for model in data.get("data", []):
model_id = model.get("id", "")
if model_id:
# Classify based on model name patterns
model_type = classify_model_type(model_id, "openai")
models.append(
DiscoveredModel(
name=model_id,
provider="openai_compatible",
model_type=model_type,
)
)
except httpx.HTTPStatusError as e:
logger.warning(f"Failed to discover openai_compatible models: HTTP {e.response.status_code}")
except Exception as e:
logger.warning(f"Failed to discover openai_compatible models: {e}")
return models
# =============================================================================
# Main Discovery Functions
# =============================================================================
# Map provider names to their discovery functions
PROVIDER_DISCOVERY_FUNCTIONS = {
"openai": discover_openai_models,
"anthropic": discover_anthropic_models,
"google": discover_google_models,
"ollama": discover_ollama_models,
"groq": discover_groq_models,
"mistral": discover_mistral_models,
"deepseek": discover_deepseek_models,
"xai": discover_xai_models,
"openrouter": discover_openrouter_models,
"voyage": discover_voyage_models,
"elevenlabs": discover_elevenlabs_models,
"openai_compatible": discover_openai_compatible_models,
"dashscope": discover_dashscope_models,
"minimax": discover_minimax_models,
"azure": None, # Azure requires credential-based discovery (different auth)
"vertex": None, # Vertex requires credential-based discovery (service account)
}
async def discover_provider_models(provider: str) -> List[DiscoveredModel]:
"""
Discover available models for a specific provider.
Args:
provider: Provider name (openai, anthropic, etc.)
Returns:
List of discovered models
"""
discover_func = PROVIDER_DISCOVERY_FUNCTIONS.get(provider)
if discover_func is None:
if provider in PROVIDER_DISCOVERY_FUNCTIONS:
logger.info(
f"Provider '{provider}' requires credential-based discovery. "
f"Use the /credentials/{{id}}/discover endpoint instead."
)
else:
logger.warning(f"No discovery function for provider: {provider}")
return []
return await discover_func()
async def sync_provider_models(
provider: str, auto_register: bool = True
) -> Tuple[int, int, int]:
"""
Sync models for a provider: discover and optionally register in database.
Args:
provider: Provider name
auto_register: If True, automatically create Model records in database
Returns:
Tuple of (discovered_count, new_count, existing_count)
"""
discovered = await discover_provider_models(provider)
discovered_count = len(discovered)
new_count = 0
existing_count = 0
if not auto_register:
return discovered_count, 0, 0
if not discovered:
return 0, 0, 0
# Batch fetch existing models to avoid N+1 query pattern
try:
existing_models = await repo_query(
"SELECT string::lowercase(name) as name, string::lowercase(type) as type FROM model "
"WHERE string::lowercase(provider) = $provider",
{"provider": provider.lower()},
)
# Create a set of (name, type) tuples for O(1) lookup
existing_keys = set()
for m in existing_models:
existing_keys.add((m.get("name", ""), m.get("type", "")))
except Exception as e:
logger.warning(f"Failed to fetch existing models for {provider}: {e}")
existing_keys = set()
for model in discovered:
model_key = (model.name.lower(), model.model_type.lower())
# Check if model already exists using pre-fetched data
if model_key in existing_keys:
existing_count += 1
continue
# Create new model
try:
new_model = Model(
name=model.name,
provider=model.provider,
type=model.model_type,
)
await new_model.save()
new_count += 1
logger.info(f"Registered new model: {model.provider}/{model.name} ({model.model_type})")
except Exception as e:
logger.warning(f"Failed to register model {model.name}: {e}")
logger.info(
f"Synced {provider}: {discovered_count} discovered, "
f"{new_count} new, {existing_count} existing"
)
return discovered_count, new_count, existing_count
async def sync_all_providers() -> Dict[str, Tuple[int, int, int]]:
"""
Sync models for all configured providers.
Returns:
Dict mapping provider names to (discovered, new, existing) tuples
"""
results = {}
# Run discovery for all providers in parallel
tasks = []
providers = list(PROVIDER_DISCOVERY_FUNCTIONS.keys())
for provider in providers:
tasks.append(sync_provider_models(provider, auto_register=True))
task_results = await asyncio.gather(*tasks, return_exceptions=True)
for provider, result in zip(providers, task_results):
if isinstance(result, Exception):
logger.error(f"Error syncing {provider}: {result}")
results[provider] = (0, 0, 0)
else:
results[provider] = result
return results
async def get_provider_model_count(provider: str) -> Dict[str, int]:
"""
Get count of registered models for a provider, grouped by type.
Args:
provider: Provider name (case-insensitive)
Returns:
Dict mapping model type to count
"""
# Use case-insensitive comparison by lowercasing the provider
result = await repo_query(
"SELECT type, count() as count FROM model WHERE string::lowercase(provider) = string::lowercase($provider) GROUP BY type",
{"provider": provider},
)
counts = {
"language": 0,
"embedding": 0,
"speech_to_text": 0,
"text_to_speech": 0,
}
for row in result:
model_type = row.get("type")
count = row.get("count", 0)
if model_type in counts:
counts[model_type] = count
return counts