mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-23 12:44:31 +00:00
Deep-merge model preset slots with the active configuration so custom context windows, rate limits, and nested kwargs survive preset switches. Treat legacy utility preset defaults as implicit values, allow omitted utility and embedding slots to inherit configured models, and document the partial-preset behavior.
572 lines
19 KiB
Python
572 lines
19 KiB
Python
import os
|
|
from copy import deepcopy
|
|
|
|
import models
|
|
from helpers import plugins, files
|
|
from helpers import yaml as yaml_helper
|
|
from helpers.providers import get_provider_config, get_providers
|
|
|
|
PRESETS_FILE = "presets.yaml"
|
|
DEFAULT_PRESETS_FILE = "default_presets.yaml"
|
|
PROVIDER_METADATA_FILE = "provider_metadata.yaml"
|
|
PRESET_SCOPE_GLOBAL = "global"
|
|
PRESET_SCOPE_PROJECT = "project"
|
|
PRESET_SLOT_CONFIG_SECTIONS = {
|
|
"chat": "chat_model",
|
|
"utility": "utility_model",
|
|
"embedding": "embedding_model",
|
|
}
|
|
IMPLICIT_PRESET_SLOT_DEFAULTS = {
|
|
"utility": {
|
|
"ctx_length": 128000,
|
|
"ctx_input": 0.7,
|
|
"rl_requests": 0,
|
|
"rl_input": 0,
|
|
"rl_output": 0,
|
|
"kwargs": {},
|
|
},
|
|
"embedding": {
|
|
"rl_requests": 0,
|
|
"rl_input": 0,
|
|
"kwargs": {},
|
|
},
|
|
}
|
|
LOCAL_PROVIDERS = {"ollama", "lm_studio"}
|
|
LOCAL_EMBEDDING = {"huggingface"}
|
|
_PROVIDER_METADATA_CACHE: dict | None = None
|
|
|
|
|
|
def _get_provider_metadata_path() -> str:
|
|
plugin_dir = plugins.find_plugin_dir("_model_config")
|
|
return files.get_abs_path(plugin_dir, PROVIDER_METADATA_FILE) if plugin_dir else ""
|
|
|
|
|
|
def get_provider_metadata(model_type: str = "chat", provider: str = "") -> dict:
|
|
"""Get plugin-owned provider metadata that does not belong in conf/model_providers.yaml."""
|
|
global _PROVIDER_METADATA_CACHE
|
|
if _PROVIDER_METADATA_CACHE is None:
|
|
path = _get_provider_metadata_path()
|
|
if path and files.exists(path):
|
|
data = yaml_helper.loads(files.read_file(path))
|
|
_PROVIDER_METADATA_CACHE = data if isinstance(data, dict) else {}
|
|
else:
|
|
_PROVIDER_METADATA_CACHE = {}
|
|
|
|
section = _PROVIDER_METADATA_CACHE.get(model_type, {})
|
|
if not isinstance(section, dict):
|
|
return {}
|
|
meta = section.get(str(provider or "").strip().lower(), {})
|
|
return meta if isinstance(meta, dict) else {}
|
|
|
|
|
|
def _model_type_for_label(label: str) -> str:
|
|
return "embedding" if label == "Embedding Model" else "chat"
|
|
|
|
|
|
def provider_requires_api_key(provider: str, model_type: str = "chat") -> bool:
|
|
provider_id = str(provider or "").strip().lower()
|
|
if not provider_id:
|
|
return False
|
|
cfg = get_provider_config(model_type, provider_id) or get_provider_config("chat", provider_id) or {}
|
|
meta = get_provider_metadata(model_type, provider_id) or get_provider_metadata("chat", provider_id)
|
|
mode = str(meta.get("api_key_mode") or cfg.get("api_key_mode") or "required").strip().lower()
|
|
return mode not in {"none", "optional", "oauth"}
|
|
|
|
|
|
def _get_presets_path(project_name: str | None = None) -> str:
|
|
"""Return the user-editable presets path for the requested scope."""
|
|
if project_name:
|
|
return plugins.determine_plugin_asset_path(
|
|
"_model_config", project_name, "", PRESETS_FILE
|
|
)
|
|
return files.get_abs_path(files.USER_DIR, files.PLUGINS_DIR, "_model_config", PRESETS_FILE)
|
|
|
|
|
|
def _get_default_presets_path() -> str:
|
|
"""Return the path to the default presets file shipped with the plugin."""
|
|
plugin_dir = plugins.find_plugin_dir("_model_config")
|
|
return files.get_abs_path(plugin_dir, DEFAULT_PRESETS_FILE) if plugin_dir else ""
|
|
|
|
|
|
def get_config(agent=None, project_name=None, agent_profile=None):
|
|
"""Get the full model config dict for the given agent/scope."""
|
|
return plugins.get_plugin_config(
|
|
"_model_config",
|
|
agent=agent,
|
|
project_name=project_name,
|
|
agent_profile=agent_profile,
|
|
) or {}
|
|
|
|
|
|
def _load_presets_from_path(path: str) -> list | None:
|
|
if files.exists(path):
|
|
data = yaml_helper.loads(files.read_file(path))
|
|
if isinstance(data, list):
|
|
return data
|
|
return None
|
|
|
|
|
|
def _strip_ui_fields(value: dict, *, strip_api_key: bool) -> dict:
|
|
cleaned = deepcopy(value)
|
|
for key in list(cleaned.keys()):
|
|
if key.startswith("_"):
|
|
cleaned.pop(key, None)
|
|
if strip_api_key:
|
|
cleaned.pop("api_key", None)
|
|
return cleaned
|
|
|
|
|
|
def _preset_default_values_equal(value, default) -> bool:
|
|
if isinstance(default, float):
|
|
try:
|
|
return float(value) == default
|
|
except (TypeError, ValueError):
|
|
return False
|
|
return value == default
|
|
|
|
|
|
def _strip_implicit_preset_defaults(slot: str, slot_config: dict) -> dict:
|
|
cleaned = deepcopy(slot_config)
|
|
defaults = IMPLICIT_PRESET_SLOT_DEFAULTS.get(slot, {})
|
|
for key, default in defaults.items():
|
|
if key in cleaned and _preset_default_values_equal(cleaned[key], default):
|
|
cleaned.pop(key, None)
|
|
return cleaned
|
|
|
|
|
|
def _clean_preset_for_file(preset: dict) -> dict:
|
|
cleaned = {
|
|
"name": str(preset.get("name", "") or ""),
|
|
}
|
|
for slot in PRESET_SLOT_CONFIG_SECTIONS:
|
|
slot_config = preset.get(slot)
|
|
if isinstance(slot_config, dict):
|
|
slot_clean = _strip_ui_fields(slot_config, strip_api_key=False)
|
|
cleaned[slot] = _strip_implicit_preset_defaults(slot, slot_clean)
|
|
return cleaned
|
|
|
|
|
|
def clean_presets_for_file(presets: list) -> list:
|
|
"""Return presets without API/UI metadata, preserving the plain YAML schema."""
|
|
cleaned = []
|
|
for preset in presets:
|
|
if isinstance(preset, dict):
|
|
cleaned.append(_clean_preset_for_file(preset))
|
|
return cleaned
|
|
|
|
|
|
def normalize_config_for_save(config: dict) -> dict:
|
|
"""Remove UI-only fields and inline API keys before storing scoped config."""
|
|
cleaned = deepcopy(config or {})
|
|
for section_name in ("chat_model", "utility_model", "embedding_model"):
|
|
section = cleaned.get(section_name)
|
|
if isinstance(section, dict):
|
|
cleaned[section_name] = _strip_ui_fields(section, strip_api_key=True)
|
|
return cleaned
|
|
|
|
|
|
def get_presets(project_name: str | None = None) -> list:
|
|
"""Get global model presets list (not scoped to project/agent)."""
|
|
if project_name:
|
|
return get_project_presets(project_name)
|
|
|
|
path = _get_presets_path()
|
|
presets = _load_presets_from_path(path)
|
|
if presets is not None:
|
|
return presets
|
|
|
|
# Fall back to defaults bundled with the plugin
|
|
default_path = _get_default_presets_path()
|
|
default_presets = _load_presets_from_path(default_path) if default_path else None
|
|
return default_presets or []
|
|
|
|
|
|
def get_project_presets(project_name: str) -> list:
|
|
"""Get presets stored beside the project-scoped _model_config config."""
|
|
return _load_presets_from_path(_get_presets_path(project_name)) or []
|
|
|
|
|
|
def _with_preset_metadata(preset: dict, scope: str, project_name: str = "") -> dict:
|
|
item = deepcopy(preset)
|
|
item["scope"] = scope
|
|
item["project_name"] = project_name if scope == PRESET_SCOPE_PROJECT else ""
|
|
item["name"] = str(item.get("name", "") or "")
|
|
return item
|
|
|
|
|
|
def get_combined_presets(project_name: str | None = None) -> list:
|
|
"""Get presets with explicit API metadata for unambiguous UI/API selection."""
|
|
presets = [
|
|
_with_preset_metadata(preset, PRESET_SCOPE_GLOBAL)
|
|
for preset in get_presets()
|
|
if isinstance(preset, dict)
|
|
]
|
|
if project_name:
|
|
presets.extend(
|
|
_with_preset_metadata(preset, PRESET_SCOPE_PROJECT, project_name)
|
|
for preset in get_project_presets(project_name)
|
|
if isinstance(preset, dict)
|
|
)
|
|
return presets
|
|
|
|
|
|
def save_presets(presets: list, project_name: str | None = None) -> None:
|
|
"""Save the presets list for the requested scope."""
|
|
path = _get_presets_path(project_name)
|
|
files.write_file(path, yaml_helper.dumps(clean_presets_for_file(presets)))
|
|
|
|
|
|
def reset_presets(project_name: str | None = None) -> list:
|
|
"""Delete user presets for the scope. Global reset falls back to bundled defaults."""
|
|
path = _get_presets_path(project_name)
|
|
if os.path.exists(path):
|
|
os.remove(path)
|
|
return get_project_presets(project_name) if project_name else get_presets()
|
|
|
|
|
|
def resolve_preset(
|
|
name: str,
|
|
*,
|
|
scope: str = PRESET_SCOPE_GLOBAL,
|
|
project_name: str | None = None,
|
|
) -> dict | None:
|
|
"""Resolve a preset by explicit scope so same-name presets are unambiguous."""
|
|
if scope == PRESET_SCOPE_PROJECT:
|
|
if not project_name:
|
|
return None
|
|
presets = get_project_presets(project_name)
|
|
else:
|
|
presets = get_presets()
|
|
|
|
for p in presets:
|
|
if p.get("name") == name:
|
|
return p
|
|
return None
|
|
|
|
|
|
def resolve_preset_selection(selection: dict | str, project_name: str | None = None) -> dict | None:
|
|
"""Resolve a UI/API preset selection payload to a preset dict."""
|
|
if isinstance(selection, str):
|
|
return resolve_preset(selection)
|
|
if not isinstance(selection, dict):
|
|
return None
|
|
|
|
scope = str(selection.get("scope") or PRESET_SCOPE_GLOBAL)
|
|
if scope == "current":
|
|
return None
|
|
name = str(selection.get("name") or "")
|
|
selected_project = str(selection.get("project_name") or project_name or "")
|
|
return resolve_preset(name, scope=scope, project_name=selected_project or None)
|
|
|
|
|
|
def get_preset_by_name(
|
|
name: str,
|
|
*,
|
|
scope: str = PRESET_SCOPE_GLOBAL,
|
|
project_name: str | None = None,
|
|
) -> dict | None:
|
|
"""Find a preset by name. Defaults to global presets for legacy callers."""
|
|
return resolve_preset(name, scope=scope, project_name=project_name)
|
|
|
|
|
|
def _deep_merge_dict(base: dict, override: dict) -> dict:
|
|
"""Recursively overlay override onto base without mutating either input."""
|
|
result = deepcopy(base) if isinstance(base, dict) else {}
|
|
for key, value in override.items():
|
|
if (
|
|
isinstance(value, dict)
|
|
and isinstance(result.get(key), dict)
|
|
):
|
|
result[key] = _deep_merge_dict(result[key], value)
|
|
else:
|
|
result[key] = deepcopy(value)
|
|
return result
|
|
|
|
|
|
def _slot_has_identity(slot_config: dict) -> bool:
|
|
return bool(slot_config.get("provider") or slot_config.get("name"))
|
|
|
|
|
|
def _get_preset_slot_config(preset: dict, slot: str) -> dict | None:
|
|
"""Return the preset payload for a slot.
|
|
|
|
Legacy raw overrides store the main/chat model directly at the top level,
|
|
while named presets store it under the "chat" key.
|
|
"""
|
|
if not isinstance(preset, dict):
|
|
return None
|
|
|
|
slot_config = preset.get(slot)
|
|
if isinstance(slot_config, dict):
|
|
return slot_config
|
|
|
|
if slot == "chat" and not any(key in preset for key in PRESET_SLOT_CONFIG_SECTIONS):
|
|
if _slot_has_identity(preset):
|
|
return preset
|
|
|
|
return None
|
|
|
|
|
|
def _should_apply_preset_slot(slot: str, slot_config: dict | None) -> bool:
|
|
if not isinstance(slot_config, dict):
|
|
return False
|
|
|
|
cleaned = _strip_implicit_preset_defaults(
|
|
slot,
|
|
_strip_ui_fields(slot_config, strip_api_key=False),
|
|
)
|
|
meaningful = {
|
|
key: value
|
|
for key, value in cleaned.items()
|
|
if key != "api_key"
|
|
}
|
|
if not meaningful:
|
|
return False
|
|
|
|
# Slots inherit the configured model unless the preset declares a model
|
|
# identity for that slot. This keeps empty UI placeholders from accidentally
|
|
# overriding context/rate-limit settings.
|
|
return _slot_has_identity(cleaned)
|
|
|
|
|
|
def _merge_model_slot(
|
|
slot: str,
|
|
base_slot: dict,
|
|
preset_slot: dict,
|
|
*,
|
|
strip_api_key: bool,
|
|
) -> dict:
|
|
cleaned = _strip_implicit_preset_defaults(
|
|
slot,
|
|
_strip_ui_fields(preset_slot, strip_api_key=strip_api_key),
|
|
)
|
|
if not strip_api_key and not str(cleaned.get("api_key") or "").strip():
|
|
cleaned.pop("api_key", None)
|
|
return _deep_merge_dict(base_slot if isinstance(base_slot, dict) else {}, cleaned)
|
|
|
|
|
|
def build_config_from_preset(
|
|
preset: dict,
|
|
base_config: dict,
|
|
*,
|
|
strip_api_key: bool = True,
|
|
slots: tuple[str, ...] | None = None,
|
|
) -> dict:
|
|
"""Overlay preset settings onto a standalone model config.
|
|
|
|
Presets are intentionally partial: omitted fields inherit from the current
|
|
config, so selecting a preset does not reset tuned values such as context
|
|
windows, rate limits, or nested kwargs unless the preset explicitly defines
|
|
them.
|
|
"""
|
|
config = (
|
|
normalize_config_for_save(base_config)
|
|
if strip_api_key
|
|
else deepcopy(base_config or {})
|
|
)
|
|
|
|
for slot in slots or tuple(PRESET_SLOT_CONFIG_SECTIONS):
|
|
section = PRESET_SLOT_CONFIG_SECTIONS.get(slot)
|
|
if not section:
|
|
continue
|
|
slot_config = _get_preset_slot_config(preset, slot)
|
|
if not _should_apply_preset_slot(slot, slot_config):
|
|
continue
|
|
config[section] = _merge_model_slot(
|
|
slot,
|
|
config.get(section, {}),
|
|
slot_config,
|
|
strip_api_key=strip_api_key,
|
|
)
|
|
|
|
return config
|
|
|
|
|
|
def _resolve_override(agent) -> dict | None:
|
|
"""Resolve the active per-chat override config dict.
|
|
Supports both raw override dicts and preset-based overrides.
|
|
Returns None if no override is active or if override is not allowed."""
|
|
if not agent:
|
|
return None
|
|
if not is_chat_override_allowed(agent):
|
|
return None
|
|
override = agent.context.get_data("chat_model_override")
|
|
if not override:
|
|
return None
|
|
|
|
# If this is a preset reference, resolve it
|
|
if "preset_name" in override:
|
|
preset = get_preset_by_name(override["preset_name"])
|
|
if not preset:
|
|
return None
|
|
return preset
|
|
|
|
return override
|
|
|
|
|
|
def get_chat_model_config(agent=None) -> dict:
|
|
"""Get chat model config, with per-chat override if active."""
|
|
cfg = get_config(agent)
|
|
override = _resolve_override(agent)
|
|
if override:
|
|
config = build_config_from_preset(
|
|
override,
|
|
cfg,
|
|
strip_api_key=False,
|
|
slots=("chat",),
|
|
)
|
|
return config.get("chat_model", {})
|
|
return cfg.get("chat_model", {})
|
|
|
|
|
|
def get_utility_model_config(agent=None) -> dict:
|
|
"""Get utility model config, with per-chat override if active."""
|
|
cfg = get_config(agent)
|
|
override = _resolve_override(agent)
|
|
if override:
|
|
config = build_config_from_preset(
|
|
override,
|
|
cfg,
|
|
strip_api_key=False,
|
|
slots=("utility",),
|
|
)
|
|
return config.get("utility_model", {})
|
|
return cfg.get("utility_model", {})
|
|
|
|
|
|
def get_embedding_model_config(agent=None) -> dict:
|
|
"""Get embedding model config."""
|
|
cfg = get_config(agent)
|
|
return cfg.get("embedding_model", {})
|
|
|
|
def is_chat_override_allowed(agent=None) -> bool:
|
|
"""Check if per-chat model override is enabled."""
|
|
cfg = get_config(agent)
|
|
return bool(cfg.get("allow_chat_override", False))
|
|
|
|
|
|
def get_ctx_history(agent=None) -> float:
|
|
"""Get the chat model context history ratio."""
|
|
cfg = get_chat_model_config(agent)
|
|
return float(cfg.get("ctx_history", 0.7))
|
|
|
|
|
|
def get_ctx_input(agent=None) -> float:
|
|
"""Get the utility model context input ratio."""
|
|
cfg = get_utility_model_config(agent)
|
|
return float(cfg.get("ctx_input", 0.7))
|
|
|
|
|
|
def _normalize_kwargs(kwargs: dict) -> dict:
|
|
"""Convert string values that are valid numbers to numeric types."""
|
|
result = {}
|
|
for key, value in kwargs.items():
|
|
if isinstance(value, str):
|
|
try:
|
|
result[key] = int(value)
|
|
except ValueError:
|
|
try:
|
|
result[key] = float(value)
|
|
except ValueError:
|
|
result[key] = value
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
|
|
def build_model_config(cfg: dict, model_type: models.ModelType) -> models.ModelConfig:
|
|
"""Build a ModelConfig from a config dict section."""
|
|
return models.ModelConfig(
|
|
type=model_type,
|
|
provider=cfg.get("provider", ""),
|
|
name=cfg.get("name", ""),
|
|
api_key=cfg.get("api_key", ""),
|
|
api_base=cfg.get("api_base", ""),
|
|
ctx_length=int(cfg.get("ctx_length", 0)),
|
|
vision=bool(cfg.get("vision", False)),
|
|
limit_requests=int(cfg.get("rl_requests", 0)),
|
|
limit_input=int(cfg.get("rl_input", 0)),
|
|
limit_output=int(cfg.get("rl_output", 0)),
|
|
kwargs=_normalize_kwargs(cfg.get("kwargs", {})),
|
|
)
|
|
|
|
|
|
def build_chat_model(agent=None):
|
|
"""Build and return a LiteLLMChatWrapper from config."""
|
|
cfg = get_chat_model_config(agent)
|
|
mc = build_model_config(cfg, models.ModelType.CHAT)
|
|
return models.get_chat_model(
|
|
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
|
)
|
|
|
|
|
|
def build_utility_model(agent=None):
|
|
"""Build and return a LiteLLMChatWrapper for utility tasks."""
|
|
cfg = get_utility_model_config(agent)
|
|
mc = build_model_config(cfg, models.ModelType.CHAT)
|
|
return models.get_chat_model(
|
|
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
|
)
|
|
|
|
|
|
def build_embedding_model(agent=None):
|
|
"""Build and return an embedding model wrapper."""
|
|
cfg = get_embedding_model_config(agent)
|
|
mc = build_model_config(cfg, models.ModelType.EMBEDDING)
|
|
return models.get_embedding_model(
|
|
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
|
)
|
|
|
|
|
|
def get_embedding_model_config_object(agent=None) -> models.ModelConfig:
|
|
"""Get a ModelConfig object for embeddings (needed by memory plugin)."""
|
|
cfg = get_embedding_model_config(agent)
|
|
return build_model_config(cfg, models.ModelType.EMBEDDING)
|
|
|
|
|
|
def get_chat_providers():
|
|
"""Get list of chat providers for UI dropdowns."""
|
|
return get_providers("chat")
|
|
|
|
|
|
def get_embedding_providers():
|
|
"""Get list of embedding providers for UI dropdowns."""
|
|
return get_providers("embedding")
|
|
|
|
|
|
def has_provider_api_key(provider: str, configured_api_key: str = "", model_type: str = "chat") -> bool:
|
|
if not provider_requires_api_key(provider, model_type):
|
|
return True
|
|
configured_value = (configured_api_key or "").strip()
|
|
if configured_value and configured_value != "None":
|
|
return True
|
|
|
|
api_key = models.get_api_key(provider.lower())
|
|
return bool(api_key and api_key.strip() and api_key != "None")
|
|
|
|
|
|
def get_missing_api_key_providers(agent=None) -> list[dict]:
|
|
"""Check which configured providers are missing API keys."""
|
|
cfg = get_config(agent)
|
|
missing = []
|
|
|
|
checks = [
|
|
("Chat Model", cfg.get("chat_model", {})),
|
|
("Utility Model", cfg.get("utility_model", {})),
|
|
("Embedding Model", cfg.get("embedding_model", {})),
|
|
]
|
|
|
|
for label, model_cfg in checks:
|
|
provider = model_cfg.get("provider", "")
|
|
if not provider:
|
|
continue
|
|
provider_lower = provider.lower()
|
|
if provider_lower in LOCAL_PROVIDERS:
|
|
continue
|
|
if label == "Embedding Model" and provider_lower in LOCAL_EMBEDDING:
|
|
continue
|
|
|
|
if not has_provider_api_key(provider_lower, model_cfg.get("api_key", ""), _model_type_for_label(label)):
|
|
missing.append({"model_type": label, "provider": provider})
|
|
|
|
return missing
|