agent-zero/plugins/_model_config/api/model_search.py
Alessandro f6bc52201d Redesign first-run onboarding
Introduce a guided Cloud versus Local first-run modal with provider selection, account connection, model picking, and a ready state.\n\nAdd the reusable discovery auto-modal trigger, chat-created startup checks, onboarding-owned provider presentation metadata and assets, OAuth affordances, local provider guidance, and model-search hardening.\n\nKeep runtime provider data centralized while preserving onboarding-specific copy, logos, and docs links in the onboarding plugin.

Update onboarding.html

Update onboarding.html
2026-05-09 07:46:36 +02:00

287 lines
9.9 KiB
Python

from __future__ import annotations
from typing import Any
import httpx
from helpers.api import ApiHandler, Request, Response
from helpers.providers import get_provider_config
import models
# Model name substrings to exclude from chat dropdowns and LiteLLM fallback results.
_NON_CHAT_EXCLUDE = frozenset({
"dall-e",
"gpt-image",
"image",
"tts",
"text-to-speech",
"whisper",
"audio",
"transcribe",
"transcription",
"speech",
"realtime",
"embedding",
"embed",
"moderation",
"omni-moderation",
"vision-preview",
})
class ModelSearch(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
provider = str(input.get("provider", "") or "").strip().lower()
model_type = str(input.get("model_type", "chat") or "chat").strip().lower()
query = str(input.get("query", "") or "").strip().lower()
user_api_base = str(input.get("api_base", "") or "").strip()
if not provider:
return {"models": [], "provider": "", "source": "none", "error": ""}
cfg = self._get_provider_cfg(model_type, provider)
ml = self._get_models_list(cfg)
models_list, source, error = await self._fetch_models(provider, cfg, ml, user_api_base)
if not models_list:
fallback = self._litellm_fallback(provider, cfg)
if fallback:
models_list = fallback
source = "litellm_registry"
elif not source:
source = "none"
models_list = self._filter_models(models_list, model_type)
if query:
models_list = [name for name in models_list if query in name.lower()]
return {
"models": sorted(set(models_list), key=str.lower),
"provider": provider,
"source": source,
"error": error,
}
@staticmethod
def _get_provider_cfg(model_type: str, provider: str) -> dict:
"""Get provider config, falling back to chat config for models_list."""
cfg = get_provider_config(model_type, provider) or {}
if model_type != "chat" and not cfg.get("models_list"):
chat_cfg = get_provider_config("chat", provider) or {}
if chat_cfg.get("models_list"):
merged = dict(cfg)
merged["models_list"] = chat_cfg["models_list"]
return merged
return cfg
@staticmethod
def _get_models_list(cfg: dict) -> dict:
"""Extract models_list sub-config."""
return cfg.get("models_list") or {}
async def _fetch_models(
self,
provider: str,
cfg: dict,
ml: dict,
user_api_base: str = "",
) -> tuple[list[str], str, str]:
api_key = models.get_api_key(provider)
kwargs = (cfg or {}).get("kwargs", {}) or {}
api_base = user_api_base or kwargs.get("api_base", "") or ml.get("default_base", "")
effective_ml = dict(ml or {})
# Ollama's native endpoint is /api/tags, but user-supplied /v1 bases usually
# mean the OpenAI-compatible /v1/models endpoint.
if provider == "ollama" and user_api_base.rstrip("/").endswith("/v1"):
effective_ml["endpoint_url"] = "/models"
effective_ml["format"] = "openai"
url, fmt = self._resolve_url(effective_ml, api_base)
if not url:
return [], "none", ""
headers = self._build_headers(provider, api_key, cfg)
params = dict(effective_ml.get("params", {}) or {})
# Google uses query-param auth for the public models list endpoint.
if provider == "google" and api_key and api_key != "None":
params.setdefault("key", api_key)
urls: list[tuple[str, str]] = [(url, fmt)]
if provider == "ollama" and fmt == "ollama":
ps_url = self._ollama_ps_url(url)
if ps_url and ps_url != url:
urls.append((ps_url, "ollama"))
combined: list[str] = []
errors: list[str] = []
try:
async with httpx.AsyncClient(timeout=10.0) as client:
for candidate_url, candidate_fmt in urls:
resp = await client.get(candidate_url, headers=headers, params=params)
if resp.status_code == 200:
combined.extend(self._parse(resp.json(), candidate_fmt))
else:
errors.append(f"{candidate_url}: HTTP {resp.status_code}")
except Exception as exc:
errors.append(str(exc))
if combined:
return combined, "provider_endpoint", ""
return [], "provider_endpoint", "; ".join(errors)
@staticmethod
def _resolve_url(ml: dict, api_base: str) -> tuple[str | None, str]:
fmt = ml.get("format", "openai")
endpoint = str(ml.get("endpoint_url", "") or "")
default_base = str(ml.get("default_base", "") or "")
if endpoint.startswith("http://") or endpoint.startswith("https://"):
return endpoint, fmt
base = str(api_base or default_base or "").strip()
if not base:
return None, fmt
endpoint = endpoint or "/models"
base = base.rstrip("/")
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
# Avoid doubled /v1/v1 when users enter a base ending in /v1 and metadata
# also contains a versioned endpoint.
if base.endswith("/v1") and endpoint.startswith("/v1/"):
endpoint = endpoint[3:]
return base + endpoint, fmt
@staticmethod
def _ollama_ps_url(resolved_url: str) -> str:
"""Return the Ollama running-model endpoint for a resolved native URL."""
marker = "/api/"
if marker not in resolved_url:
return ""
return resolved_url.split(marker, 1)[0].rstrip("/") + "/api/ps"
def _build_headers(self, provider: str, api_key: str, cfg: dict | None) -> dict[str, str]:
headers: dict[str, str] = {}
has_key = bool(api_key and api_key.strip() and api_key != "None")
if provider == "anthropic":
if has_key:
headers["x-api-key"] = api_key
headers["anthropic-version"] = "2023-06-01"
elif provider == "google":
pass
elif provider == "azure":
if has_key:
headers["api-key"] = api_key
elif provider not in ("ollama", "lm_studio"):
if has_key:
headers["Authorization"] = f"Bearer {api_key}"
extra = (cfg or {}).get("kwargs", {}).get("extra_headers", {})
if isinstance(extra, dict):
for key, value in extra.items():
if isinstance(value, str):
headers[key] = value
return headers
def _litellm_fallback(self, provider: str, cfg: dict | None) -> list[str]:
try:
import litellm
registry = getattr(litellm, "models_by_provider", None)
if not registry:
return []
litellm_provider = (cfg or {}).get("litellm_provider", provider)
raw_models = registry.get(litellm_provider, set()) or set()
if not raw_models:
return []
prefix = litellm_provider + "/"
result: list[str] = []
for name in raw_models:
clean = str(name or "")
clean = clean[len(prefix):] if clean.startswith(prefix) else clean
if clean and not self._is_non_chat_model(clean):
result.append(clean)
return result
except Exception:
return []
def _parse(self, data: dict | list, fmt: str) -> list[str]:
if isinstance(data, list):
return self._parse_list(data)
if not isinstance(data, dict):
return []
if fmt == "ollama":
return self._parse_models_array(data.get("models", []), "name")
if fmt == "google":
result = []
for item in data.get("models", []) or []:
if not isinstance(item, dict):
continue
name = str(item.get("name", "") or "")
if name.startswith("models/"):
name = name[7:]
if name:
result.append(name)
return result
if "data" in data:
return self._parse_models_array(data.get("data", []), "id")
if "models" in data:
return self._parse_models_array(data.get("models", []), "id")
return []
@staticmethod
def _parse_models_array(items: Any, primary_key: str) -> list[str]:
if not isinstance(items, list):
return []
result = []
for item in items:
if isinstance(item, str):
result.append(item)
elif isinstance(item, dict):
value = item.get(primary_key) or item.get("id") or item.get("name")
if value:
result.append(str(value))
return result
def _parse_list(self, data: list) -> list[str]:
result = []
for item in data:
if isinstance(item, str):
result.append(item)
elif isinstance(item, dict):
value = item.get("id") or item.get("name")
if value:
result.append(str(value))
return result
def _filter_models(self, model_names: list[str], model_type: str) -> list[str]:
cleaned = []
for name in model_names or []:
value = str(name or "").strip()
if not value:
continue
if model_type == "chat" and self._is_non_chat_model(value):
continue
cleaned.append(value)
return cleaned
@staticmethod
def _is_non_chat_model(name: str) -> bool:
low = name.lower()
return any(token in low for token in _NON_CHAT_EXCLUDE)