feat: add _model_config plugin with call-time model resolution

This commit is contained in:
keyboardstaff 2026-03-14 09:41:19 -07:00
parent 6d067a12d0
commit d570c629c2
19 changed files with 793 additions and 207 deletions

View file

@ -309,16 +309,9 @@ class AgentContext:
@dataclass
class AgentConfig:
chat_model: models.ModelConfig
utility_model: models.ModelConfig
embeddings_model: models.ModelConfig
browser_model: models.ModelConfig
mcp_servers: str
profile: str = ""
knowledge_subdirs: list[str] = field(default_factory=lambda: ["default", "custom"])
browser_http_headers: dict[str, str] = field(
default_factory=dict
) # Custom HTTP headers for browser requests
additional: Dict[str, Any] = field(default_factory=dict)
@ -713,39 +706,19 @@ class Agent:
@extension.extensible
def get_chat_model(self):
return models.get_chat_model(
self.config.chat_model.provider,
self.config.chat_model.name,
model_config=self.config.chat_model,
**self.config.chat_model.build_kwargs(),
)
return None
@extension.extensible
def get_utility_model(self):
return models.get_chat_model(
self.config.utility_model.provider,
self.config.utility_model.name,
model_config=self.config.utility_model,
**self.config.utility_model.build_kwargs(),
)
return None
@extension.extensible
def get_browser_model(self):
return models.get_browser_model(
self.config.browser_model.provider,
self.config.browser_model.name,
model_config=self.config.browser_model,
**self.config.browser_model.build_kwargs(),
)
return None
@extension.extensible
def get_embedding_model(self):
return models.get_embedding_model(
self.config.embeddings_model.provider,
self.config.embeddings_model.name,
model_config=self.config.embeddings_model,
**self.config.embeddings_model.build_kwargs(),
)
return None
@extension.extensible
async def call_utility_model(
@ -803,15 +776,32 @@ class Agent:
# model class
model = self.get_chat_model()
# call extensions before
call_data = {
"model": model,
"messages": messages,
"response_callback": response_callback,
"reasoning_callback": reasoning_callback,
"background": background,
"explicit_caching": explicit_caching,
}
await extension.call_extensions_async(
"chat_model_call_before", self, call_data=call_data
)
# call model
response, reasoning = await model.unified_call(
messages=messages,
reasoning_callback=reasoning_callback,
response_callback=response_callback,
response, reasoning = await call_data["model"].unified_call(
messages=call_data["messages"],
reasoning_callback=call_data["reasoning_callback"],
response_callback=call_data["response_callback"],
rate_limiter_callback=(
self.rate_limiter_callback if not background else None
self.rate_limiter_callback if not call_data["background"] else None
),
explicit_caching=explicit_caching,
explicit_caching=call_data["explicit_caching"],
)
await extension.call_extensions_async(
"chat_model_call_after", self, call_data=call_data, response=response, reasoning=reasoning
)
return response, reasoning

View file

@ -159,10 +159,13 @@ class Topic(Record):
return self.summary
def compress_large_messages(self, message_ratio: float = CURRENT_TOPIC_RATIO * LARGE_MESSAGE_TO_CURRENT_TOPIC_RATIO) -> bool:
set = settings.get_settings()
from plugins._model_config.helpers.model_config import get_chat_model_config
chat_cfg = get_chat_model_config()
ctx_length = int(chat_cfg.get("ctx_length", 128000))
ctx_history = float(chat_cfg.get("ctx_history", 0.7))
msg_max_size = (
set["chat_model_ctx_length"]
* set["chat_model_ctx_history"]
ctx_length
* ctx_history
* message_ratio
)
large_msgs = []
@ -479,8 +482,11 @@ def deserialize_history(json_data: str, agent) -> History:
def _get_ctx_size_for_history() -> int:
set = settings.get_settings()
return int(set["chat_model_ctx_length"] * set["chat_model_ctx_history"])
from plugins._model_config.helpers.model_config import get_chat_model_config
chat_cfg = get_chat_model_config()
ctx_length = int(chat_cfg.get("ctx_length", 128000))
ctx_history = float(chat_cfg.get("ctx_history", 0.7))
return int(ctx_length * ctx_history)
def _stringify_output(output: OutputMessage, ai_label="ai", human_label="human"):

View file

@ -53,44 +53,6 @@ def get_default_value(name: str, value: T) -> T:
class Settings(TypedDict):
version: str
chat_model_provider: str
chat_model_name: str
chat_model_api_base: str
chat_model_kwargs: dict[str, Any]
chat_model_ctx_length: int
chat_model_ctx_history: float
chat_model_vision: bool
chat_model_rl_requests: int
chat_model_rl_input: int
chat_model_rl_output: int
util_model_provider: str
util_model_name: str
util_model_api_base: str
util_model_kwargs: dict[str, Any]
util_model_ctx_length: int
util_model_ctx_input: float
util_model_rl_requests: int
util_model_rl_input: int
util_model_rl_output: int
embed_model_provider: str
embed_model_name: str
embed_model_api_base: str
embed_model_kwargs: dict[str, Any]
embed_model_rl_requests: int
embed_model_rl_input: int
browser_model_provider: str
browser_model_name: str
browser_model_api_base: str
browser_model_vision: bool
browser_model_rl_requests: int
browser_model_rl_input: int
browser_model_rl_output: int
browser_model_kwargs: dict[str, Any]
browser_http_headers: dict[str, Any]
agent_profile: str
agent_knowledge_subdir: str
@ -261,10 +223,6 @@ def convert_out(settings: Settings) -> SettingsOutput:
),
}
additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("chat_model_provider"))
additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("util_model_provider"))
additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("browser_model_provider"))
additional["embedding_providers"] = _ensure_option_present(additional.get("embedding_providers"), current.get("embed_model_provider"))
additional["agent_subdirs"] = _ensure_option_present(additional.get("agent_subdirs"), current.get("agent_profile"))
additional["knowledge_subdirs"] = _ensure_option_present(additional.get("knowledge_subdirs"), current.get("agent_knowledge_subdir"))
additional["stt_models"] = _ensure_option_present(additional.get("stt_models"), current.get("stt_model_size"))
@ -493,40 +451,6 @@ def get_default_settings() -> Settings:
gitignore = files.read_file(files.get_abs_path("conf/workdir.gitignore"))
return Settings(
version=_get_version(),
chat_model_provider=get_default_value("chat_model_provider", "openrouter"),
chat_model_name=get_default_value("chat_model_name", "anthropic/claude-sonnet-4.6"),
chat_model_api_base=get_default_value("chat_model_api_base", ""),
chat_model_kwargs=get_default_value("chat_model_kwargs", {}),
chat_model_ctx_length=get_default_value("chat_model_ctx_length", 100000),
chat_model_ctx_history=get_default_value("chat_model_ctx_history", 0.7),
chat_model_vision=get_default_value("chat_model_vision", True),
chat_model_rl_requests=get_default_value("chat_model_rl_requests", 0),
chat_model_rl_input=get_default_value("chat_model_rl_input", 0),
chat_model_rl_output=get_default_value("chat_model_rl_output", 0),
util_model_provider=get_default_value("util_model_provider", "openrouter"),
util_model_name=get_default_value("util_model_name", "google/gemini-3-flash-preview"),
util_model_api_base=get_default_value("util_model_api_base", ""),
util_model_ctx_length=get_default_value("util_model_ctx_length", 100000),
util_model_ctx_input=get_default_value("util_model_ctx_input", 0.7),
util_model_kwargs=get_default_value("util_model_kwargs", {}),
util_model_rl_requests=get_default_value("util_model_rl_requests", 0),
util_model_rl_input=get_default_value("util_model_rl_input", 0),
util_model_rl_output=get_default_value("util_model_rl_output", 0),
embed_model_provider=get_default_value("embed_model_provider", "huggingface"),
embed_model_name=get_default_value("embed_model_name", "sentence-transformers/all-MiniLM-L6-v2"),
embed_model_api_base=get_default_value("embed_model_api_base", ""),
embed_model_kwargs=get_default_value("embed_model_kwargs", {}),
embed_model_rl_requests=get_default_value("embed_model_rl_requests", 0),
embed_model_rl_input=get_default_value("embed_model_rl_input", 0),
browser_model_provider=get_default_value("browser_model_provider", "openrouter"),
browser_model_name=get_default_value("browser_model_name", "anthropic/claude-sonnet-4.6"),
browser_model_api_base=get_default_value("browser_model_api_base", ""),
browser_model_vision=get_default_value("browser_model_vision", True),
browser_model_rl_requests=get_default_value("browser_model_rl_requests", 0),
browser_model_rl_input=get_default_value("browser_model_rl_input", 0),
browser_model_rl_output=get_default_value("browser_model_rl_output", 0),
browser_model_kwargs=get_default_value("browser_model_kwargs", {}),
browser_http_headers=get_default_value("browser_http_headers", {}),
api_keys={},
auth_login="",
auth_password="",
@ -587,18 +511,6 @@ def _apply_settings(previous: Settings | None):
whisper.preload, _settings["stt_model_size"]
) # TODO overkill, replace with background task
# notify plugins of embedding model change
if not previous or (
_settings["embed_model_name"] != previous["embed_model_name"]
or _settings["embed_model_provider"] != previous["embed_model_provider"]
or _settings["embed_model_kwargs"] != previous["embed_model_kwargs"]
):
from helpers.extension import call_extensions_async
defer.DeferredTask().start_task(
call_extensions_async, "embedding_model_changed"
)
# update mcp settings if necessary
if not previous or _settings["mcp_servers"] != previous["mcp_servers"]:
from helpers.mcp_handler import MCPConfig

View file

@ -1,5 +1,4 @@
from agent import AgentConfig
import models
from helpers import runtime, settings, defer, extension
from helpers.print_style import PrintStyle
@ -10,78 +9,11 @@ def initialize_agent(override_settings: dict | None = None):
if override_settings:
current_settings = settings.merge_settings(current_settings, override_settings)
def _normalize_model_kwargs(kwargs: dict) -> dict:
# convert string values that represent valid Python numbers to numeric types
result = {}
for key, value in kwargs.items():
if isinstance(value, str):
# try to convert string to number if it's a valid Python number
try:
# try int first, then float
result[key] = int(value)
except ValueError:
try:
result[key] = float(value)
except ValueError:
result[key] = value
else:
result[key] = value
return result
# chat model from user settings
chat_llm = models.ModelConfig(
type=models.ModelType.CHAT,
provider=current_settings["chat_model_provider"],
name=current_settings["chat_model_name"],
api_base=current_settings["chat_model_api_base"],
ctx_length=current_settings["chat_model_ctx_length"],
vision=current_settings["chat_model_vision"],
limit_requests=current_settings["chat_model_rl_requests"],
limit_input=current_settings["chat_model_rl_input"],
limit_output=current_settings["chat_model_rl_output"],
kwargs=_normalize_model_kwargs(current_settings["chat_model_kwargs"]),
)
# utility model from user settings
utility_llm = models.ModelConfig(
type=models.ModelType.CHAT,
provider=current_settings["util_model_provider"],
name=current_settings["util_model_name"],
api_base=current_settings["util_model_api_base"],
ctx_length=current_settings["util_model_ctx_length"],
limit_requests=current_settings["util_model_rl_requests"],
limit_input=current_settings["util_model_rl_input"],
limit_output=current_settings["util_model_rl_output"],
kwargs=_normalize_model_kwargs(current_settings["util_model_kwargs"]),
)
# embedding model from user settings
embedding_llm = models.ModelConfig(
type=models.ModelType.EMBEDDING,
provider=current_settings["embed_model_provider"],
name=current_settings["embed_model_name"],
api_base=current_settings["embed_model_api_base"],
limit_requests=current_settings["embed_model_rl_requests"],
kwargs=_normalize_model_kwargs(current_settings["embed_model_kwargs"]),
)
# browser model from user settings
browser_llm = models.ModelConfig(
type=models.ModelType.CHAT,
provider=current_settings["browser_model_provider"],
name=current_settings["browser_model_name"],
api_base=current_settings["browser_model_api_base"],
vision=current_settings["browser_model_vision"],
kwargs=_normalize_model_kwargs(current_settings["browser_model_kwargs"]),
)
# agent configuration
# agent configuration - models are now resolved at call time via _model_config plugin
config = AgentConfig(
chat_model=chat_llm,
utility_model=utility_llm,
embeddings_model=embedding_llm,
browser_model=browser_llm,
profile=current_settings["agent_profile"],
knowledge_subdirs=[current_settings["agent_knowledge_subdir"], "default"],
mcp_servers=current_settings["mcp_servers"],
browser_http_headers=current_settings["browser_http_headers"],
)
# update config with runtime args

View file

@ -60,6 +60,11 @@ class Memory:
index: dict[str, "MyFaiss"] = {}
@staticmethod
def _get_embedding_config(agent=None):
from plugins._model_config.helpers.model_config import get_embedding_model_config_object
return get_embedding_model_config_object(agent)
@staticmethod
async def get(agent: Agent):
memory_subdir = get_agent_memory_subdir(agent)
@ -70,7 +75,7 @@ class Memory:
)
db, created = Memory.initialize(
log_item,
agent.config.embeddings_model,
Memory._get_embedding_config(agent),
memory_subdir,
False,
)
@ -98,7 +103,7 @@ class Memory:
import initialize
agent_config = initialize.initialize_agent()
model_config = agent_config.embeddings_model
model_config = Memory._get_embedding_config()
db, _created = Memory.initialize(
log_item=log_item,
model_config=model_config,

View file

@ -0,0 +1,60 @@
from helpers.api import ApiHandler, Request, Response
from helpers import dotenv
import models
API_KEY_PLACEHOLDER = "************"
class ApiKeys(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
action = input.get("action", "get") # get | set | reveal
if action == "get":
return self._get_keys()
elif action == "set":
return self._set_keys(input)
elif action == "reveal":
return self._reveal_key(input)
return Response(status=400, response=f"Unknown action: {action}")
def _get_keys(self) -> dict:
from helpers.providers import get_providers
providers = get_providers("chat") + get_providers("embedding")
seen = set()
keys = {}
for p in providers:
pid = p.get("value", "")
if pid and pid not in seen:
seen.add(pid)
api_key = models.get_api_key(pid)
has_key = bool(api_key and api_key.strip() and api_key != "None")
keys[pid] = {
"label": p.get("label", pid),
"has_key": has_key,
"masked": API_KEY_PLACEHOLDER if has_key else "",
}
return {"keys": keys}
def _set_keys(self, input: dict) -> dict:
updates = input.get("keys", {})
if not isinstance(updates, dict):
return {"ok": False, "error": "Invalid keys format"}
for provider, value in updates.items():
if isinstance(value, str) and value != API_KEY_PLACEHOLDER:
dotenv.save_dotenv_value(f"API_KEY_{provider.upper()}", value)
return {"ok": True}
def _reveal_key(self, input: dict) -> dict:
provider = input.get("provider", "")
if not provider:
return {"ok": False, "error": "Missing provider"}
api_key = models.get_api_key(provider)
if api_key and api_key.strip() and api_key != "None":
return {"ok": True, "value": api_key}
return {"ok": True, "value": ""}

View file

@ -0,0 +1,41 @@
from helpers.api import ApiHandler, Request, Response
from helpers import plugins
from plugins._model_config.helpers import model_config
import models
class ModelConfigGet(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
project_name = input.get("project_name", "")
agent_profile = input.get("agent_profile", "")
config = model_config.get_config(
project_name=project_name or None,
agent_profile=agent_profile or None,
)
# Provide default if no config found
if not config:
config = plugins.get_default_plugin_config("_model_config") or {}
# Add provider lists for UI dropdowns
chat_providers = model_config.get_chat_providers()
embedding_providers = model_config.get_embedding_providers()
# Mask API keys - show status only
api_key_status = {}
all_providers = chat_providers + embedding_providers
seen = set()
for p in all_providers:
pid = p.get("value", "")
if pid and pid not in seen:
seen.add(pid)
key = models.get_api_key(pid)
api_key_status[pid] = bool(key and key.strip() and key != "None")
return {
"config": config,
"chat_providers": chat_providers,
"embedding_providers": embedding_providers,
"api_key_status": api_key_status,
}

View file

@ -0,0 +1,35 @@
from helpers.api import ApiHandler, Request, Response
from helpers import plugins, defer
from helpers.extension import call_extensions_async
class ModelConfigSet(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
project_name = input.get("project_name", "")
agent_profile = input.get("agent_profile", "")
config = input.get("config")
if not config or not isinstance(config, dict):
return Response(status=400, response="Missing or invalid config")
plugins.save_plugin_config(
"_model_config",
project_name=project_name,
agent_profile=agent_profile,
settings=config,
)
# Check if embedding model changed and notify
prev_config = plugins.get_plugin_config("_model_config") or {}
prev_embed = prev_config.get("embedding_model", {})
new_embed = config.get("embedding_model", {})
if (
prev_embed.get("provider") != new_embed.get("provider")
or prev_embed.get("name") != new_embed.get("name")
or prev_embed.get("kwargs") != new_embed.get("kwargs")
):
defer.DeferredTask().start_task(
call_extensions_async, "embedding_model_changed"
)
return {"ok": True}

View file

@ -0,0 +1,201 @@
import httpx
from helpers.api import ApiHandler, Request, Response
from helpers.providers import get_provider_config
import models
_CLOUD_ENDPOINTS: dict[str, str] = {
"openai": "https://api.openai.com/v1/models",
"anthropic": "https://api.anthropic.com/v1/models",
"groq": "https://api.groq.com/openai/v1/models",
"deepseek": "https://api.deepseek.com/models",
"mistral": "https://api.mistral.ai/v1/models",
"openrouter": "https://openrouter.ai/api/v1/models",
"xai": "https://api.x.ai/v1/models",
"sambanova": "https://api.sambanova.ai/v1/models",
"moonshot": "https://api.moonshot.cn/v1/models",
"google": "https://generativelanguage.googleapis.com/v1beta/models",
"a0_venice": "https://api.venice.ai/api/v1/models",
"venice": "https://api.venice.ai/api/v1/models",
}
# Local providers with default base URLs (no auth required).
_LOCAL_DEFAULTS: dict[str, str] = {
"ollama": "http://host.docker.internal:11434",
"lm_studio": "http://host.docker.internal:1234",
}
# Providers with hardcoded model lists (no listing API available).
_STATIC_MODELS: dict[str, list[str]] = {
"github_copilot": [
"gpt-4.1", "gpt-4o", "gpt-5-mini", "oswe-vscode-prime",
],
"zai": [
"glm-4-plus", "glm-4-air-250414", "glm-4-airx",
"glm-4-long", "glm-4-flashx", "glm-4-flash-250414",
"glm-4v-plus", "glm-4v", "glm-3-turbo",
],
"zai_coding": [
"codegeex-4",
"glm-4-plus", "glm-4-air-250414", "glm-4-airx",
"glm-4-flashx", "glm-4-flash-250414",
],
}
# Model name substrings to exclude from litellm fallback results
_LITELLM_EXCLUDE = frozenset({
"dall-e", "gpt-image", "tts", "whisper", "audio",
"realtime", "davinci", "babbage", "ada", "vision-preview",
})
class ModelSearch(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
provider = input.get("provider", "")
query = input.get("query", "").lower()
model_type = input.get("model_type", "chat")
user_api_base = input.get("api_base", "")
if not provider:
return {"models": []}
if provider in _STATIC_MODELS:
all_models = list(_STATIC_MODELS[provider])
else:
provider_cfg = get_provider_config(model_type, provider)
all_models = await self._fetch_models(provider, provider_cfg, user_api_base) or []
if not all_models:
litellm_provider = (provider_cfg or {}).get("litellm_provider", provider)
if litellm_provider == provider:
all_models = self._litellm_fallback(provider, provider_cfg)
if query:
all_models = [m for m in all_models if query in m.lower()]
return {"models": sorted(all_models)[:50], "provider": provider}
async def _fetch_models(self, provider: str, cfg: dict | None, user_api_base: str = "") -> list[str] | None:
api_base = user_api_base or (cfg or {}).get("kwargs", {}).get("api_base", "")
api_key = models.get_api_key(provider)
url, fmt = self._resolve_url(provider, api_base)
if not url:
return None
headers = self._build_headers(provider, api_key, cfg)
params: dict[str, str] = {}
if provider == "google":
if api_key and api_key != "None":
params["key"] = api_key
params["pageSize"] = "1000"
elif provider == "anthropic":
params["limit"] = "1000"
elif provider == "azure":
params["api-version"] = "2024-10-21"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(url, headers=headers, params=params)
if resp.status_code == 200:
result = self._parse(resp.json(), fmt)
if result:
return result
except Exception:
pass
return None
def _resolve_url(self, provider: str, api_base: str) -> tuple[str | None, str]:
if provider == "ollama":
base = api_base or _LOCAL_DEFAULTS.get("ollama", "")
return (base.rstrip("/") + "/api/tags" if base else None), "ollama"
if provider == "google":
if api_base:
return api_base.rstrip("/") + "/models", "google"
return _CLOUD_ENDPOINTS["google"], "google"
if provider == "azure":
if not api_base:
return None, "openai"
return api_base.rstrip("/") + "/openai/models", "openai"
if provider in _CLOUD_ENDPOINTS:
return _CLOUD_ENDPOINTS[provider], "openai"
if api_base:
return api_base.rstrip("/") + "/models", "openai"
if provider in _LOCAL_DEFAULTS:
return _LOCAL_DEFAULTS[provider] + "/v1/models", "openai"
return None, "openai"
def _build_headers(self, provider: str, api_key: str, cfg: dict | None) -> dict[str, str]:
headers: dict[str, str] = {}
has_key = api_key and api_key != "None"
if provider == "anthropic":
if has_key:
headers["x-api-key"] = api_key
headers["anthropic-version"] = "2023-06-01"
elif provider == "google":
pass
elif provider == "azure":
if has_key:
headers["api-key"] = api_key
elif provider not in ("ollama", "lm_studio"):
if has_key:
headers["Authorization"] = f"Bearer {api_key}"
extra = (cfg or {}).get("kwargs", {}).get("extra_headers", {})
if isinstance(extra, dict):
for k, v in extra.items():
if isinstance(v, str):
headers[k] = v
return headers
def _litellm_fallback(self, provider: str, cfg: dict | None) -> list[str]:
try:
import litellm
registry = getattr(litellm, "models_by_provider", None)
if not registry:
return []
litellm_provider = (cfg or {}).get("litellm_provider", provider)
raw_models: set = registry.get(litellm_provider, set())
if not raw_models:
return []
prefix = litellm_provider + "/"
result: list[str] = []
for name in raw_models:
clean = name[len(prefix):] if name.startswith(prefix) else name
low = clean.lower()
if any(exc in low for exc in _LITELLM_EXCLUDE):
continue
if clean:
result.append(clean)
return result
except Exception:
return []
def _parse(self, data: dict | list, fmt: str) -> list[str]:
if fmt == "ollama":
return [m.get("name", "") for m in data.get("models", []) if m.get("name")]
if fmt == "google":
result = []
for m in data.get("models", []):
name = m.get("name", "")
if name.startswith("models/"):
name = name[7:]
if name:
result.append(name)
return result
if isinstance(data, dict) and "data" in data:
return [m.get("id", "") for m in data["data"] if m.get("id")]
return []

View file

@ -0,0 +1,35 @@
chat_model:
provider: "openrouter"
name: "anthropic/claude-sonnet-4.6"
api_base: ""
ctx_length: 128000
ctx_history: 0.7
vision: true
rl_requests: 0
rl_input: 0
rl_output: 0
kwargs: {}
allow_chat_override: false
utility_model:
provider: "openrouter"
name: "google/gemini-3-flash-preview"
api_base: ""
ctx_length: 128000
ctx_input: 0.7
rl_requests: 0
rl_input: 0
rl_output: 0
kwargs: {}
embedding_model:
provider: "huggingface"
name: "sentence-transformers/all-MiniLM-L6-v2"
api_base: ""
rl_requests: 0
rl_input: 0
kwargs: {}
browser_http_headers: {}
model_presets: []

View file

@ -0,0 +1,8 @@
from helpers.extension import Extension
from plugins._model_config.helpers.model_config import build_browser_model
class BrowserModelProvider(Extension):
def execute(self, data: dict = {}, **kwargs):
if self.agent:
data["result"] = build_browser_model(self.agent)

View file

@ -0,0 +1,8 @@
from helpers.extension import Extension
from plugins._model_config.helpers.model_config import build_chat_model
class ChatModelProvider(Extension):
def execute(self, data: dict = {}, **kwargs):
if self.agent:
data["result"] = build_chat_model(self.agent)

View file

@ -0,0 +1,8 @@
from helpers.extension import Extension
from plugins._model_config.helpers.model_config import build_embedding_model
class EmbeddingModelProvider(Extension):
def execute(self, data: dict = {}, **kwargs):
if self.agent:
data["result"] = build_embedding_model(self.agent)

View file

@ -0,0 +1,8 @@
from helpers.extension import Extension
from plugins._model_config.helpers.model_config import build_utility_model
class UtilityModelProvider(Extension):
def execute(self, data: dict = {}, **kwargs):
if self.agent:
data["result"] = build_utility_model(self.agent)

View file

@ -0,0 +1,102 @@
import json
import os
from helpers.extension import Extension
from helpers import settings as settings_helper, files, plugins
from helpers.print_style import PrintStyle
class MigrateModelConfig(Extension):
"""
One-time migration: copy legacy model settings into _model_config plugin config.
Runs during initialize_migration. Only migrates if no global plugin config exists yet
and the settings file contains legacy model fields.
"""
LEGACY_FIELDS = [
"chat_model_provider", "chat_model_name", "chat_model_api_base",
"chat_model_kwargs", "chat_model_ctx_length", "chat_model_vision",
"chat_model_rl_requests", "chat_model_rl_input", "chat_model_rl_output",
"chat_model_ctx_history",
"util_model_provider", "util_model_name", "util_model_api_base",
"util_model_kwargs", "util_model_ctx_length",
"util_model_rl_requests", "util_model_rl_input", "util_model_rl_output",
"util_model_ctx_input",
"embed_model_provider", "embed_model_name", "embed_model_api_base",
"embed_model_kwargs", "embed_model_rl_requests", "embed_model_rl_input",
"browser_model_provider", "browser_model_name", "browser_model_api_base",
"browser_model_vision", "browser_model_rl_requests", "browser_model_rl_input",
"browser_model_rl_output", "browser_model_kwargs", "browser_http_headers",
]
async def execute(self, **kwargs):
# Check if global plugin config already exists
global_config_path = files.get_abs_path("plugins/_model_config/config.json")
if os.path.exists(global_config_path):
return # already migrated or manually configured
# Read raw settings file to check for legacy model fields
settings_file = files.get_abs_path("usr/settings.json")
if not os.path.exists(settings_file):
return
try:
raw = json.loads(files.read_file(settings_file))
except Exception:
return
# Check if any legacy model field exists in the raw settings
has_legacy = any(field in raw for field in self.LEGACY_FIELDS)
if not has_legacy:
return
# Build plugin config from legacy settings
plugin_config = {
"chat_model": {
"provider": raw.get("chat_model_provider", "openrouter"),
"name": raw.get("chat_model_name", ""),
"api_base": raw.get("chat_model_api_base", ""),
"ctx_length": raw.get("chat_model_ctx_length", 128000),
"ctx_history": raw.get("chat_model_ctx_history", 0.7),
"vision": raw.get("chat_model_vision", True),
"rl_requests": raw.get("chat_model_rl_requests", 0),
"rl_input": raw.get("chat_model_rl_input", 0),
"rl_output": raw.get("chat_model_rl_output", 0),
"kwargs": raw.get("chat_model_kwargs", {}),
"allow_chat_override": False,
},
"utility_model": {
"provider": raw.get("util_model_provider", "openrouter"),
"name": raw.get("util_model_name", ""),
"api_base": raw.get("util_model_api_base", ""),
"ctx_length": raw.get("util_model_ctx_length", 128000),
"ctx_input": raw.get("util_model_ctx_input", 0.7),
"rl_requests": raw.get("util_model_rl_requests", 0),
"rl_input": raw.get("util_model_rl_input", 0),
"rl_output": raw.get("util_model_rl_output", 0),
"kwargs": raw.get("util_model_kwargs", {}),
},
"embedding_model": {
"provider": raw.get("embed_model_provider", "huggingface"),
"name": raw.get("embed_model_name", "sentence-transformers/all-MiniLM-L6-v2"),
"api_base": raw.get("embed_model_api_base", ""),
"rl_requests": raw.get("embed_model_rl_requests", 0),
"rl_input": raw.get("embed_model_rl_input", 0),
"kwargs": raw.get("embed_model_kwargs", {}),
},
"browser_http_headers": raw.get("browser_http_headers", {}),
}
# Ensure kwargs are dicts (might be strings from .env format)
for section in ["chat_model", "utility_model", "embedding_model"]:
kw = plugin_config[section].get("kwargs")
if isinstance(kw, str):
plugin_config[section]["kwargs"] = {}
if isinstance(plugin_config["browser_http_headers"], str):
plugin_config["browser_http_headers"] = {}
# Save as global plugin config
plugins.save_plugin_config("_model_config", "", "", plugin_config)
PrintStyle(background_color="#6734C3", font_color="white", padding=True).print(
"Migrated legacy model settings to _model_config plugin config."
)

View file

@ -0,0 +1,217 @@
import models
from helpers import plugins, settings, projects
from helpers.providers import get_providers, get_raw_providers
def get_config(agent=None, project_name=None, agent_profile=None):
"""Get the full model config dict for the given agent/scope."""
return plugins.get_plugin_config(
"_model_config",
agent=agent,
project_name=project_name,
agent_profile=agent_profile,
) or {}
def get_presets(agent=None, project_name=None, agent_profile=None) -> list:
"""Get model presets list from config."""
cfg = get_config(agent, project_name, agent_profile)
return cfg.get("model_presets", [])
def get_preset_by_name(name: str, agent=None) -> dict | None:
"""Find a preset by name."""
for p in get_presets(agent):
if p.get("name") == name:
return p
return None
def _resolve_override(agent) -> dict | None:
"""Resolve the active per-chat override config dict.
Supports both raw override dicts and preset-based overrides.
Returns None if no override is active."""
if not agent:
return None
override = agent.context.get_data("chat_model_override")
if not override:
return None
# If this is a preset reference, resolve it
if "preset_name" in override:
preset = get_preset_by_name(override["preset_name"], agent)
if not preset:
return None
return preset
return override
def get_chat_model_config(agent=None) -> dict:
"""Get chat model config, with per-chat override if active."""
override = _resolve_override(agent)
if override:
# Preset has a nested 'chat' key; raw override is flat
chat_cfg = override.get("chat", override)
if chat_cfg.get("provider") or chat_cfg.get("name"):
return chat_cfg
cfg = get_config(agent)
return cfg.get("chat_model", {})
def get_utility_model_config(agent=None) -> dict:
"""Get utility model config, with per-chat override if active."""
override = _resolve_override(agent)
if override:
util_cfg = override.get("utility", {})
if util_cfg.get("provider") or util_cfg.get("name"):
return util_cfg
cfg = get_config(agent)
return cfg.get("utility_model", {})
def get_embedding_model_config(agent=None) -> dict:
"""Get embedding model config."""
cfg = get_config(agent)
return cfg.get("embedding_model", {})
def get_browser_http_headers(agent=None) -> dict:
"""Get browser HTTP headers from config."""
cfg = get_config(agent)
return cfg.get("browser_http_headers", {})
def is_chat_override_allowed(agent=None) -> bool:
"""Check if per-chat model override is enabled."""
cfg = get_config(agent)
chat_cfg = cfg.get("chat_model", {})
return bool(chat_cfg.get("allow_chat_override", False))
def get_ctx_history(agent=None) -> float:
"""Get the chat model context history ratio."""
cfg = get_chat_model_config(agent)
return float(cfg.get("ctx_history", 0.7))
def get_ctx_input(agent=None) -> float:
"""Get the utility model context input ratio."""
cfg = get_utility_model_config(agent)
return float(cfg.get("ctx_input", 0.7))
def _normalize_kwargs(kwargs: dict) -> dict:
"""Convert string values that are valid numbers to numeric types."""
result = {}
for key, value in kwargs.items():
if isinstance(value, str):
try:
result[key] = int(value)
except ValueError:
try:
result[key] = float(value)
except ValueError:
result[key] = value
else:
result[key] = value
return result
def build_model_config(cfg: dict, model_type: models.ModelType) -> models.ModelConfig:
"""Build a ModelConfig from a config dict section."""
return models.ModelConfig(
type=model_type,
provider=cfg.get("provider", ""),
name=cfg.get("name", ""),
api_base=cfg.get("api_base", ""),
ctx_length=int(cfg.get("ctx_length", 0)),
vision=bool(cfg.get("vision", False)),
limit_requests=int(cfg.get("rl_requests", 0)),
limit_input=int(cfg.get("rl_input", 0)),
limit_output=int(cfg.get("rl_output", 0)),
kwargs=_normalize_kwargs(cfg.get("kwargs", {})),
)
def build_chat_model(agent=None):
"""Build and return a LiteLLMChatWrapper from config."""
cfg = get_chat_model_config(agent)
mc = build_model_config(cfg, models.ModelType.CHAT)
return models.get_chat_model(
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
)
def build_utility_model(agent=None):
"""Build and return a LiteLLMChatWrapper for utility tasks."""
cfg = get_utility_model_config(agent)
mc = build_model_config(cfg, models.ModelType.CHAT)
return models.get_chat_model(
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
)
def build_browser_model(agent=None):
"""Build and return a BrowserCompatibleChatWrapper using chat model config."""
cfg = get_chat_model_config(agent)
mc = build_model_config(cfg, models.ModelType.CHAT)
return models.get_browser_model(
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
)
def build_embedding_model(agent=None):
"""Build and return an embedding model wrapper."""
cfg = get_embedding_model_config(agent)
mc = build_model_config(cfg, models.ModelType.EMBEDDING)
return models.get_embedding_model(
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
)
def get_embedding_model_config_object(agent=None) -> models.ModelConfig:
"""Get a ModelConfig object for embeddings (needed by memory plugin)."""
cfg = get_embedding_model_config(agent)
return build_model_config(cfg, models.ModelType.EMBEDDING)
def get_chat_providers():
"""Get list of chat providers for UI dropdowns."""
return get_providers("chat")
def get_embedding_providers():
"""Get list of embedding providers for UI dropdowns."""
return get_providers("embedding")
def get_missing_api_key_providers(agent=None) -> list[dict]:
"""Check which configured providers are missing API keys."""
cfg = get_config(agent)
missing = []
LOCAL_PROVIDERS = {"ollama", "lm_studio"}
LOCAL_EMBEDDING = {"huggingface"}
checks = [
("Chat Model", cfg.get("chat_model", {})),
("Utility Model", cfg.get("utility_model", {})),
("Embedding Model", cfg.get("embedding_model", {})),
]
for label, model_cfg in checks:
provider = model_cfg.get("provider", "")
if not provider:
continue
provider_lower = provider.lower()
if provider_lower in LOCAL_PROVIDERS:
continue
if label == "Embedding Model" and provider_lower in LOCAL_EMBEDDING:
continue
api_key = models.get_api_key(provider_lower)
if not (api_key and api_key.strip() and api_key != "None"):
missing.append({"model_type": label, "provider": provider})
return missing

View file

@ -0,0 +1,8 @@
def save_plugin_config(result=None, settings=None, **kwargs):
if settings and isinstance(settings, dict):
# Remove transient UI-only fields before persisting
settings.pop("_browser_headers_text", None)
for section in ("chat_model", "utility_model", "embedding_model"):
if section in settings and isinstance(settings[section], dict):
settings[section].pop("_kwargs_text", None)
return settings

View file

@ -0,0 +1,9 @@
name: _model_config
title: Model Configuration
description: Manages LLM model selection and configuration for chat, utility, and embedding models. Supports per-project and per-agent overrides with optional per-chat model switching.
version: 1.0.0
always_enabled: true
settings_sections:
- agent
per_project_config: true
per_agent_config: true

View file

@ -18,16 +18,17 @@ async def preload():
# preload embedding model
async def preload_embedding():
if set["embed_model_provider"].lower() == "huggingface":
try:
# Use the new LiteLLM-based model system
try:
from plugins._model_config.helpers.model_config import get_embedding_model_config_object
emb_cfg = get_embedding_model_config_object()
if emb_cfg.provider.lower() == "huggingface":
emb_mod = models.get_embedding_model(
"huggingface", set["embed_model_name"]
"huggingface", emb_cfg.name
)
emb_txt = await emb_mod.aembed_query("test")
return emb_txt
except Exception as e:
PrintStyle().error(f"Error in preload_embedding: {e}")
except Exception as e:
PrintStyle().error(f"Error in preload_embedding: {e}")
# preload kokoro tts model if enabled
async def preload_kokoro():