diff --git a/agent.py b/agent.py index 4d0052848..15f40bf40 100644 --- a/agent.py +++ b/agent.py @@ -309,16 +309,9 @@ class AgentContext: @dataclass class AgentConfig: - chat_model: models.ModelConfig - utility_model: models.ModelConfig - embeddings_model: models.ModelConfig - browser_model: models.ModelConfig mcp_servers: str profile: str = "" knowledge_subdirs: list[str] = field(default_factory=lambda: ["default", "custom"]) - browser_http_headers: dict[str, str] = field( - default_factory=dict - ) # Custom HTTP headers for browser requests additional: Dict[str, Any] = field(default_factory=dict) @@ -713,39 +706,19 @@ class Agent: @extension.extensible def get_chat_model(self): - return models.get_chat_model( - self.config.chat_model.provider, - self.config.chat_model.name, - model_config=self.config.chat_model, - **self.config.chat_model.build_kwargs(), - ) + return None @extension.extensible def get_utility_model(self): - return models.get_chat_model( - self.config.utility_model.provider, - self.config.utility_model.name, - model_config=self.config.utility_model, - **self.config.utility_model.build_kwargs(), - ) + return None @extension.extensible def get_browser_model(self): - return models.get_browser_model( - self.config.browser_model.provider, - self.config.browser_model.name, - model_config=self.config.browser_model, - **self.config.browser_model.build_kwargs(), - ) + return None @extension.extensible def get_embedding_model(self): - return models.get_embedding_model( - self.config.embeddings_model.provider, - self.config.embeddings_model.name, - model_config=self.config.embeddings_model, - **self.config.embeddings_model.build_kwargs(), - ) + return None @extension.extensible async def call_utility_model( @@ -803,15 +776,32 @@ class Agent: # model class model = self.get_chat_model() + # call extensions before + call_data = { + "model": model, + "messages": messages, + "response_callback": response_callback, + "reasoning_callback": reasoning_callback, + "background": background, + "explicit_caching": explicit_caching, + } + await extension.call_extensions_async( + "chat_model_call_before", self, call_data=call_data + ) + # call model - response, reasoning = await model.unified_call( - messages=messages, - reasoning_callback=reasoning_callback, - response_callback=response_callback, + response, reasoning = await call_data["model"].unified_call( + messages=call_data["messages"], + reasoning_callback=call_data["reasoning_callback"], + response_callback=call_data["response_callback"], rate_limiter_callback=( - self.rate_limiter_callback if not background else None + self.rate_limiter_callback if not call_data["background"] else None ), - explicit_caching=explicit_caching, + explicit_caching=call_data["explicit_caching"], + ) + + await extension.call_extensions_async( + "chat_model_call_after", self, call_data=call_data, response=response, reasoning=reasoning ) return response, reasoning diff --git a/helpers/history.py b/helpers/history.py index be51f19f3..16b57f92a 100644 --- a/helpers/history.py +++ b/helpers/history.py @@ -159,10 +159,13 @@ class Topic(Record): return self.summary def compress_large_messages(self, message_ratio: float = CURRENT_TOPIC_RATIO * LARGE_MESSAGE_TO_CURRENT_TOPIC_RATIO) -> bool: - set = settings.get_settings() + from plugins._model_config.helpers.model_config import get_chat_model_config + chat_cfg = get_chat_model_config() + ctx_length = int(chat_cfg.get("ctx_length", 128000)) + ctx_history = float(chat_cfg.get("ctx_history", 0.7)) msg_max_size = ( - set["chat_model_ctx_length"] - * set["chat_model_ctx_history"] + ctx_length + * ctx_history * message_ratio ) large_msgs = [] @@ -479,8 +482,11 @@ def deserialize_history(json_data: str, agent) -> History: def _get_ctx_size_for_history() -> int: - set = settings.get_settings() - return int(set["chat_model_ctx_length"] * set["chat_model_ctx_history"]) + from plugins._model_config.helpers.model_config import get_chat_model_config + chat_cfg = get_chat_model_config() + ctx_length = int(chat_cfg.get("ctx_length", 128000)) + ctx_history = float(chat_cfg.get("ctx_history", 0.7)) + return int(ctx_length * ctx_history) def _stringify_output(output: OutputMessage, ai_label="ai", human_label="human"): diff --git a/helpers/settings.py b/helpers/settings.py index 64f34a552..e85a741c2 100644 --- a/helpers/settings.py +++ b/helpers/settings.py @@ -53,44 +53,6 @@ def get_default_value(name: str, value: T) -> T: class Settings(TypedDict): version: str - chat_model_provider: str - chat_model_name: str - chat_model_api_base: str - chat_model_kwargs: dict[str, Any] - chat_model_ctx_length: int - chat_model_ctx_history: float - chat_model_vision: bool - chat_model_rl_requests: int - chat_model_rl_input: int - chat_model_rl_output: int - - util_model_provider: str - util_model_name: str - util_model_api_base: str - util_model_kwargs: dict[str, Any] - util_model_ctx_length: int - util_model_ctx_input: float - util_model_rl_requests: int - util_model_rl_input: int - util_model_rl_output: int - - embed_model_provider: str - embed_model_name: str - embed_model_api_base: str - embed_model_kwargs: dict[str, Any] - embed_model_rl_requests: int - embed_model_rl_input: int - - browser_model_provider: str - browser_model_name: str - browser_model_api_base: str - browser_model_vision: bool - browser_model_rl_requests: int - browser_model_rl_input: int - browser_model_rl_output: int - browser_model_kwargs: dict[str, Any] - browser_http_headers: dict[str, Any] - agent_profile: str agent_knowledge_subdir: str @@ -261,10 +223,6 @@ def convert_out(settings: Settings) -> SettingsOutput: ), } - additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("chat_model_provider")) - additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("util_model_provider")) - additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("browser_model_provider")) - additional["embedding_providers"] = _ensure_option_present(additional.get("embedding_providers"), current.get("embed_model_provider")) additional["agent_subdirs"] = _ensure_option_present(additional.get("agent_subdirs"), current.get("agent_profile")) additional["knowledge_subdirs"] = _ensure_option_present(additional.get("knowledge_subdirs"), current.get("agent_knowledge_subdir")) additional["stt_models"] = _ensure_option_present(additional.get("stt_models"), current.get("stt_model_size")) @@ -493,40 +451,6 @@ def get_default_settings() -> Settings: gitignore = files.read_file(files.get_abs_path("conf/workdir.gitignore")) return Settings( version=_get_version(), - chat_model_provider=get_default_value("chat_model_provider", "openrouter"), - chat_model_name=get_default_value("chat_model_name", "anthropic/claude-sonnet-4.6"), - chat_model_api_base=get_default_value("chat_model_api_base", ""), - chat_model_kwargs=get_default_value("chat_model_kwargs", {}), - chat_model_ctx_length=get_default_value("chat_model_ctx_length", 100000), - chat_model_ctx_history=get_default_value("chat_model_ctx_history", 0.7), - chat_model_vision=get_default_value("chat_model_vision", True), - chat_model_rl_requests=get_default_value("chat_model_rl_requests", 0), - chat_model_rl_input=get_default_value("chat_model_rl_input", 0), - chat_model_rl_output=get_default_value("chat_model_rl_output", 0), - util_model_provider=get_default_value("util_model_provider", "openrouter"), - util_model_name=get_default_value("util_model_name", "google/gemini-3-flash-preview"), - util_model_api_base=get_default_value("util_model_api_base", ""), - util_model_ctx_length=get_default_value("util_model_ctx_length", 100000), - util_model_ctx_input=get_default_value("util_model_ctx_input", 0.7), - util_model_kwargs=get_default_value("util_model_kwargs", {}), - util_model_rl_requests=get_default_value("util_model_rl_requests", 0), - util_model_rl_input=get_default_value("util_model_rl_input", 0), - util_model_rl_output=get_default_value("util_model_rl_output", 0), - embed_model_provider=get_default_value("embed_model_provider", "huggingface"), - embed_model_name=get_default_value("embed_model_name", "sentence-transformers/all-MiniLM-L6-v2"), - embed_model_api_base=get_default_value("embed_model_api_base", ""), - embed_model_kwargs=get_default_value("embed_model_kwargs", {}), - embed_model_rl_requests=get_default_value("embed_model_rl_requests", 0), - embed_model_rl_input=get_default_value("embed_model_rl_input", 0), - browser_model_provider=get_default_value("browser_model_provider", "openrouter"), - browser_model_name=get_default_value("browser_model_name", "anthropic/claude-sonnet-4.6"), - browser_model_api_base=get_default_value("browser_model_api_base", ""), - browser_model_vision=get_default_value("browser_model_vision", True), - browser_model_rl_requests=get_default_value("browser_model_rl_requests", 0), - browser_model_rl_input=get_default_value("browser_model_rl_input", 0), - browser_model_rl_output=get_default_value("browser_model_rl_output", 0), - browser_model_kwargs=get_default_value("browser_model_kwargs", {}), - browser_http_headers=get_default_value("browser_http_headers", {}), api_keys={}, auth_login="", auth_password="", @@ -587,18 +511,6 @@ def _apply_settings(previous: Settings | None): whisper.preload, _settings["stt_model_size"] ) # TODO overkill, replace with background task - # notify plugins of embedding model change - if not previous or ( - _settings["embed_model_name"] != previous["embed_model_name"] - or _settings["embed_model_provider"] != previous["embed_model_provider"] - or _settings["embed_model_kwargs"] != previous["embed_model_kwargs"] - ): - from helpers.extension import call_extensions_async - - defer.DeferredTask().start_task( - call_extensions_async, "embedding_model_changed" - ) - # update mcp settings if necessary if not previous or _settings["mcp_servers"] != previous["mcp_servers"]: from helpers.mcp_handler import MCPConfig diff --git a/initialize.py b/initialize.py index 8cc2d1790..dbb773f85 100644 --- a/initialize.py +++ b/initialize.py @@ -1,5 +1,4 @@ from agent import AgentConfig -import models from helpers import runtime, settings, defer, extension from helpers.print_style import PrintStyle @@ -10,78 +9,11 @@ def initialize_agent(override_settings: dict | None = None): if override_settings: current_settings = settings.merge_settings(current_settings, override_settings) - def _normalize_model_kwargs(kwargs: dict) -> dict: - # convert string values that represent valid Python numbers to numeric types - result = {} - for key, value in kwargs.items(): - if isinstance(value, str): - # try to convert string to number if it's a valid Python number - try: - # try int first, then float - result[key] = int(value) - except ValueError: - try: - result[key] = float(value) - except ValueError: - result[key] = value - else: - result[key] = value - return result - - # chat model from user settings - chat_llm = models.ModelConfig( - type=models.ModelType.CHAT, - provider=current_settings["chat_model_provider"], - name=current_settings["chat_model_name"], - api_base=current_settings["chat_model_api_base"], - ctx_length=current_settings["chat_model_ctx_length"], - vision=current_settings["chat_model_vision"], - limit_requests=current_settings["chat_model_rl_requests"], - limit_input=current_settings["chat_model_rl_input"], - limit_output=current_settings["chat_model_rl_output"], - kwargs=_normalize_model_kwargs(current_settings["chat_model_kwargs"]), - ) - - # utility model from user settings - utility_llm = models.ModelConfig( - type=models.ModelType.CHAT, - provider=current_settings["util_model_provider"], - name=current_settings["util_model_name"], - api_base=current_settings["util_model_api_base"], - ctx_length=current_settings["util_model_ctx_length"], - limit_requests=current_settings["util_model_rl_requests"], - limit_input=current_settings["util_model_rl_input"], - limit_output=current_settings["util_model_rl_output"], - kwargs=_normalize_model_kwargs(current_settings["util_model_kwargs"]), - ) - # embedding model from user settings - embedding_llm = models.ModelConfig( - type=models.ModelType.EMBEDDING, - provider=current_settings["embed_model_provider"], - name=current_settings["embed_model_name"], - api_base=current_settings["embed_model_api_base"], - limit_requests=current_settings["embed_model_rl_requests"], - kwargs=_normalize_model_kwargs(current_settings["embed_model_kwargs"]), - ) - # browser model from user settings - browser_llm = models.ModelConfig( - type=models.ModelType.CHAT, - provider=current_settings["browser_model_provider"], - name=current_settings["browser_model_name"], - api_base=current_settings["browser_model_api_base"], - vision=current_settings["browser_model_vision"], - kwargs=_normalize_model_kwargs(current_settings["browser_model_kwargs"]), - ) - # agent configuration + # agent configuration - models are now resolved at call time via _model_config plugin config = AgentConfig( - chat_model=chat_llm, - utility_model=utility_llm, - embeddings_model=embedding_llm, - browser_model=browser_llm, profile=current_settings["agent_profile"], knowledge_subdirs=[current_settings["agent_knowledge_subdir"], "default"], mcp_servers=current_settings["mcp_servers"], - browser_http_headers=current_settings["browser_http_headers"], ) # update config with runtime args diff --git a/plugins/_memory/helpers/memory.py b/plugins/_memory/helpers/memory.py index b0e2a4e70..43518338e 100644 --- a/plugins/_memory/helpers/memory.py +++ b/plugins/_memory/helpers/memory.py @@ -60,6 +60,11 @@ class Memory: index: dict[str, "MyFaiss"] = {} + @staticmethod + def _get_embedding_config(agent=None): + from plugins._model_config.helpers.model_config import get_embedding_model_config_object + return get_embedding_model_config_object(agent) + @staticmethod async def get(agent: Agent): memory_subdir = get_agent_memory_subdir(agent) @@ -70,7 +75,7 @@ class Memory: ) db, created = Memory.initialize( log_item, - agent.config.embeddings_model, + Memory._get_embedding_config(agent), memory_subdir, False, ) @@ -98,7 +103,7 @@ class Memory: import initialize agent_config = initialize.initialize_agent() - model_config = agent_config.embeddings_model + model_config = Memory._get_embedding_config() db, _created = Memory.initialize( log_item=log_item, model_config=model_config, diff --git a/plugins/_model_config/api/api_keys.py b/plugins/_model_config/api/api_keys.py new file mode 100644 index 000000000..0fe7410bd --- /dev/null +++ b/plugins/_model_config/api/api_keys.py @@ -0,0 +1,60 @@ +from helpers.api import ApiHandler, Request, Response +from helpers import dotenv +import models + +API_KEY_PLACEHOLDER = "************" + + +class ApiKeys(ApiHandler): + async def process(self, input: dict, request: Request) -> dict | Response: + action = input.get("action", "get") # get | set | reveal + + if action == "get": + return self._get_keys() + elif action == "set": + return self._set_keys(input) + elif action == "reveal": + return self._reveal_key(input) + + return Response(status=400, response=f"Unknown action: {action}") + + def _get_keys(self) -> dict: + from helpers.providers import get_providers + + providers = get_providers("chat") + get_providers("embedding") + seen = set() + keys = {} + + for p in providers: + pid = p.get("value", "") + if pid and pid not in seen: + seen.add(pid) + api_key = models.get_api_key(pid) + has_key = bool(api_key and api_key.strip() and api_key != "None") + keys[pid] = { + "label": p.get("label", pid), + "has_key": has_key, + "masked": API_KEY_PLACEHOLDER if has_key else "", + } + + return {"keys": keys} + + def _set_keys(self, input: dict) -> dict: + updates = input.get("keys", {}) + if not isinstance(updates, dict): + return {"ok": False, "error": "Invalid keys format"} + + for provider, value in updates.items(): + if isinstance(value, str) and value != API_KEY_PLACEHOLDER: + dotenv.save_dotenv_value(f"API_KEY_{provider.upper()}", value) + + return {"ok": True} + + def _reveal_key(self, input: dict) -> dict: + provider = input.get("provider", "") + if not provider: + return {"ok": False, "error": "Missing provider"} + api_key = models.get_api_key(provider) + if api_key and api_key.strip() and api_key != "None": + return {"ok": True, "value": api_key} + return {"ok": True, "value": ""} diff --git a/plugins/_model_config/api/model_config_get.py b/plugins/_model_config/api/model_config_get.py new file mode 100644 index 000000000..3350d0d0f --- /dev/null +++ b/plugins/_model_config/api/model_config_get.py @@ -0,0 +1,41 @@ +from helpers.api import ApiHandler, Request, Response +from helpers import plugins +from plugins._model_config.helpers import model_config +import models + + +class ModelConfigGet(ApiHandler): + async def process(self, input: dict, request: Request) -> dict | Response: + project_name = input.get("project_name", "") + agent_profile = input.get("agent_profile", "") + + config = model_config.get_config( + project_name=project_name or None, + agent_profile=agent_profile or None, + ) + + # Provide default if no config found + if not config: + config = plugins.get_default_plugin_config("_model_config") or {} + + # Add provider lists for UI dropdowns + chat_providers = model_config.get_chat_providers() + embedding_providers = model_config.get_embedding_providers() + + # Mask API keys - show status only + api_key_status = {} + all_providers = chat_providers + embedding_providers + seen = set() + for p in all_providers: + pid = p.get("value", "") + if pid and pid not in seen: + seen.add(pid) + key = models.get_api_key(pid) + api_key_status[pid] = bool(key and key.strip() and key != "None") + + return { + "config": config, + "chat_providers": chat_providers, + "embedding_providers": embedding_providers, + "api_key_status": api_key_status, + } diff --git a/plugins/_model_config/api/model_config_set.py b/plugins/_model_config/api/model_config_set.py new file mode 100644 index 000000000..f8d55e150 --- /dev/null +++ b/plugins/_model_config/api/model_config_set.py @@ -0,0 +1,35 @@ +from helpers.api import ApiHandler, Request, Response +from helpers import plugins, defer +from helpers.extension import call_extensions_async + + +class ModelConfigSet(ApiHandler): + async def process(self, input: dict, request: Request) -> dict | Response: + project_name = input.get("project_name", "") + agent_profile = input.get("agent_profile", "") + config = input.get("config") + + if not config or not isinstance(config, dict): + return Response(status=400, response="Missing or invalid config") + + plugins.save_plugin_config( + "_model_config", + project_name=project_name, + agent_profile=agent_profile, + settings=config, + ) + + # Check if embedding model changed and notify + prev_config = plugins.get_plugin_config("_model_config") or {} + prev_embed = prev_config.get("embedding_model", {}) + new_embed = config.get("embedding_model", {}) + if ( + prev_embed.get("provider") != new_embed.get("provider") + or prev_embed.get("name") != new_embed.get("name") + or prev_embed.get("kwargs") != new_embed.get("kwargs") + ): + defer.DeferredTask().start_task( + call_extensions_async, "embedding_model_changed" + ) + + return {"ok": True} diff --git a/plugins/_model_config/api/model_search.py b/plugins/_model_config/api/model_search.py new file mode 100644 index 000000000..9c9a3d4e0 --- /dev/null +++ b/plugins/_model_config/api/model_search.py @@ -0,0 +1,201 @@ +import httpx +from helpers.api import ApiHandler, Request, Response +from helpers.providers import get_provider_config +import models + +_CLOUD_ENDPOINTS: dict[str, str] = { + "openai": "https://api.openai.com/v1/models", + "anthropic": "https://api.anthropic.com/v1/models", + "groq": "https://api.groq.com/openai/v1/models", + "deepseek": "https://api.deepseek.com/models", + "mistral": "https://api.mistral.ai/v1/models", + "openrouter": "https://openrouter.ai/api/v1/models", + "xai": "https://api.x.ai/v1/models", + "sambanova": "https://api.sambanova.ai/v1/models", + "moonshot": "https://api.moonshot.cn/v1/models", + "google": "https://generativelanguage.googleapis.com/v1beta/models", + "a0_venice": "https://api.venice.ai/api/v1/models", + "venice": "https://api.venice.ai/api/v1/models", +} + +# Local providers with default base URLs (no auth required). +_LOCAL_DEFAULTS: dict[str, str] = { + "ollama": "http://host.docker.internal:11434", + "lm_studio": "http://host.docker.internal:1234", +} + +# Providers with hardcoded model lists (no listing API available). +_STATIC_MODELS: dict[str, list[str]] = { + "github_copilot": [ + "gpt-4.1", "gpt-4o", "gpt-5-mini", "oswe-vscode-prime", + ], + "zai": [ + "glm-4-plus", "glm-4-air-250414", "glm-4-airx", + "glm-4-long", "glm-4-flashx", "glm-4-flash-250414", + "glm-4v-plus", "glm-4v", "glm-3-turbo", + ], + "zai_coding": [ + "codegeex-4", + "glm-4-plus", "glm-4-air-250414", "glm-4-airx", + "glm-4-flashx", "glm-4-flash-250414", + ], +} + +# Model name substrings to exclude from litellm fallback results +_LITELLM_EXCLUDE = frozenset({ + "dall-e", "gpt-image", "tts", "whisper", "audio", + "realtime", "davinci", "babbage", "ada", "vision-preview", +}) + + +class ModelSearch(ApiHandler): + async def process(self, input: dict, request: Request) -> dict | Response: + provider = input.get("provider", "") + query = input.get("query", "").lower() + model_type = input.get("model_type", "chat") + user_api_base = input.get("api_base", "") + + if not provider: + return {"models": []} + + if provider in _STATIC_MODELS: + all_models = list(_STATIC_MODELS[provider]) + else: + provider_cfg = get_provider_config(model_type, provider) + all_models = await self._fetch_models(provider, provider_cfg, user_api_base) or [] + + if not all_models: + litellm_provider = (provider_cfg or {}).get("litellm_provider", provider) + if litellm_provider == provider: + all_models = self._litellm_fallback(provider, provider_cfg) + + if query: + all_models = [m for m in all_models if query in m.lower()] + + return {"models": sorted(all_models)[:50], "provider": provider} + + async def _fetch_models(self, provider: str, cfg: dict | None, user_api_base: str = "") -> list[str] | None: + api_base = user_api_base or (cfg or {}).get("kwargs", {}).get("api_base", "") + api_key = models.get_api_key(provider) + + url, fmt = self._resolve_url(provider, api_base) + if not url: + return None + + headers = self._build_headers(provider, api_key, cfg) + params: dict[str, str] = {} + if provider == "google": + if api_key and api_key != "None": + params["key"] = api_key + params["pageSize"] = "1000" + elif provider == "anthropic": + params["limit"] = "1000" + elif provider == "azure": + params["api-version"] = "2024-10-21" + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(url, headers=headers, params=params) + if resp.status_code == 200: + result = self._parse(resp.json(), fmt) + if result: + return result + except Exception: + pass + + return None + + def _resolve_url(self, provider: str, api_base: str) -> tuple[str | None, str]: + if provider == "ollama": + base = api_base or _LOCAL_DEFAULTS.get("ollama", "") + return (base.rstrip("/") + "/api/tags" if base else None), "ollama" + + if provider == "google": + if api_base: + return api_base.rstrip("/") + "/models", "google" + return _CLOUD_ENDPOINTS["google"], "google" + + if provider == "azure": + if not api_base: + return None, "openai" + return api_base.rstrip("/") + "/openai/models", "openai" + + if provider in _CLOUD_ENDPOINTS: + return _CLOUD_ENDPOINTS[provider], "openai" + + if api_base: + return api_base.rstrip("/") + "/models", "openai" + + if provider in _LOCAL_DEFAULTS: + return _LOCAL_DEFAULTS[provider] + "/v1/models", "openai" + + return None, "openai" + + def _build_headers(self, provider: str, api_key: str, cfg: dict | None) -> dict[str, str]: + headers: dict[str, str] = {} + has_key = api_key and api_key != "None" + + if provider == "anthropic": + if has_key: + headers["x-api-key"] = api_key + headers["anthropic-version"] = "2023-06-01" + elif provider == "google": + pass + elif provider == "azure": + if has_key: + headers["api-key"] = api_key + elif provider not in ("ollama", "lm_studio"): + if has_key: + headers["Authorization"] = f"Bearer {api_key}" + + extra = (cfg or {}).get("kwargs", {}).get("extra_headers", {}) + if isinstance(extra, dict): + for k, v in extra.items(): + if isinstance(v, str): + headers[k] = v + + return headers + + def _litellm_fallback(self, provider: str, cfg: dict | None) -> list[str]: + try: + import litellm + registry = getattr(litellm, "models_by_provider", None) + if not registry: + return [] + + litellm_provider = (cfg or {}).get("litellm_provider", provider) + raw_models: set = registry.get(litellm_provider, set()) + if not raw_models: + return [] + + prefix = litellm_provider + "/" + result: list[str] = [] + for name in raw_models: + clean = name[len(prefix):] if name.startswith(prefix) else name + low = clean.lower() + if any(exc in low for exc in _LITELLM_EXCLUDE): + continue + if clean: + result.append(clean) + return result + except Exception: + return [] + + def _parse(self, data: dict | list, fmt: str) -> list[str]: + if fmt == "ollama": + return [m.get("name", "") for m in data.get("models", []) if m.get("name")] + + if fmt == "google": + result = [] + for m in data.get("models", []): + name = m.get("name", "") + if name.startswith("models/"): + name = name[7:] + if name: + result.append(name) + return result + + if isinstance(data, dict) and "data" in data: + return [m.get("id", "") for m in data["data"] if m.get("id")] + + return [] diff --git a/plugins/_model_config/default_config.yaml b/plugins/_model_config/default_config.yaml new file mode 100644 index 000000000..5aae03841 --- /dev/null +++ b/plugins/_model_config/default_config.yaml @@ -0,0 +1,35 @@ +chat_model: + provider: "openrouter" + name: "anthropic/claude-sonnet-4.6" + api_base: "" + ctx_length: 128000 + ctx_history: 0.7 + vision: true + rl_requests: 0 + rl_input: 0 + rl_output: 0 + kwargs: {} + allow_chat_override: false + +utility_model: + provider: "openrouter" + name: "google/gemini-3-flash-preview" + api_base: "" + ctx_length: 128000 + ctx_input: 0.7 + rl_requests: 0 + rl_input: 0 + rl_output: 0 + kwargs: {} + +embedding_model: + provider: "huggingface" + name: "sentence-transformers/all-MiniLM-L6-v2" + api_base: "" + rl_requests: 0 + rl_input: 0 + kwargs: {} + +browser_http_headers: {} + +model_presets: [] diff --git a/plugins/_model_config/extensions/python/agent_Agent_get_browser_model_start/_10_model_config.py b/plugins/_model_config/extensions/python/agent_Agent_get_browser_model_start/_10_model_config.py new file mode 100644 index 000000000..f4797110b --- /dev/null +++ b/plugins/_model_config/extensions/python/agent_Agent_get_browser_model_start/_10_model_config.py @@ -0,0 +1,8 @@ +from helpers.extension import Extension +from plugins._model_config.helpers.model_config import build_browser_model + + +class BrowserModelProvider(Extension): + def execute(self, data: dict = {}, **kwargs): + if self.agent: + data["result"] = build_browser_model(self.agent) diff --git a/plugins/_model_config/extensions/python/agent_Agent_get_chat_model_start/_10_model_config.py b/plugins/_model_config/extensions/python/agent_Agent_get_chat_model_start/_10_model_config.py new file mode 100644 index 000000000..867b64448 --- /dev/null +++ b/plugins/_model_config/extensions/python/agent_Agent_get_chat_model_start/_10_model_config.py @@ -0,0 +1,8 @@ +from helpers.extension import Extension +from plugins._model_config.helpers.model_config import build_chat_model + + +class ChatModelProvider(Extension): + def execute(self, data: dict = {}, **kwargs): + if self.agent: + data["result"] = build_chat_model(self.agent) diff --git a/plugins/_model_config/extensions/python/agent_Agent_get_embedding_model_start/_10_model_config.py b/plugins/_model_config/extensions/python/agent_Agent_get_embedding_model_start/_10_model_config.py new file mode 100644 index 000000000..19e9a99c1 --- /dev/null +++ b/plugins/_model_config/extensions/python/agent_Agent_get_embedding_model_start/_10_model_config.py @@ -0,0 +1,8 @@ +from helpers.extension import Extension +from plugins._model_config.helpers.model_config import build_embedding_model + + +class EmbeddingModelProvider(Extension): + def execute(self, data: dict = {}, **kwargs): + if self.agent: + data["result"] = build_embedding_model(self.agent) diff --git a/plugins/_model_config/extensions/python/agent_Agent_get_utility_model_start/_10_model_config.py b/plugins/_model_config/extensions/python/agent_Agent_get_utility_model_start/_10_model_config.py new file mode 100644 index 000000000..b979ba7f8 --- /dev/null +++ b/plugins/_model_config/extensions/python/agent_Agent_get_utility_model_start/_10_model_config.py @@ -0,0 +1,8 @@ +from helpers.extension import Extension +from plugins._model_config.helpers.model_config import build_utility_model + + +class UtilityModelProvider(Extension): + def execute(self, data: dict = {}, **kwargs): + if self.agent: + data["result"] = build_utility_model(self.agent) diff --git a/plugins/_model_config/extensions/python/initialize_migration_start/_10_migrate_model_config.py b/plugins/_model_config/extensions/python/initialize_migration_start/_10_migrate_model_config.py new file mode 100644 index 000000000..2e530c8ca --- /dev/null +++ b/plugins/_model_config/extensions/python/initialize_migration_start/_10_migrate_model_config.py @@ -0,0 +1,102 @@ +import json +import os +from helpers.extension import Extension +from helpers import settings as settings_helper, files, plugins +from helpers.print_style import PrintStyle + + +class MigrateModelConfig(Extension): + """ + One-time migration: copy legacy model settings into _model_config plugin config. + Runs during initialize_migration. Only migrates if no global plugin config exists yet + and the settings file contains legacy model fields. + """ + + LEGACY_FIELDS = [ + "chat_model_provider", "chat_model_name", "chat_model_api_base", + "chat_model_kwargs", "chat_model_ctx_length", "chat_model_vision", + "chat_model_rl_requests", "chat_model_rl_input", "chat_model_rl_output", + "chat_model_ctx_history", + "util_model_provider", "util_model_name", "util_model_api_base", + "util_model_kwargs", "util_model_ctx_length", + "util_model_rl_requests", "util_model_rl_input", "util_model_rl_output", + "util_model_ctx_input", + "embed_model_provider", "embed_model_name", "embed_model_api_base", + "embed_model_kwargs", "embed_model_rl_requests", "embed_model_rl_input", + "browser_model_provider", "browser_model_name", "browser_model_api_base", + "browser_model_vision", "browser_model_rl_requests", "browser_model_rl_input", + "browser_model_rl_output", "browser_model_kwargs", "browser_http_headers", + ] + + async def execute(self, **kwargs): + # Check if global plugin config already exists + global_config_path = files.get_abs_path("plugins/_model_config/config.json") + if os.path.exists(global_config_path): + return # already migrated or manually configured + + # Read raw settings file to check for legacy model fields + settings_file = files.get_abs_path("usr/settings.json") + if not os.path.exists(settings_file): + return + + try: + raw = json.loads(files.read_file(settings_file)) + except Exception: + return + + # Check if any legacy model field exists in the raw settings + has_legacy = any(field in raw for field in self.LEGACY_FIELDS) + if not has_legacy: + return + + # Build plugin config from legacy settings + plugin_config = { + "chat_model": { + "provider": raw.get("chat_model_provider", "openrouter"), + "name": raw.get("chat_model_name", ""), + "api_base": raw.get("chat_model_api_base", ""), + "ctx_length": raw.get("chat_model_ctx_length", 128000), + "ctx_history": raw.get("chat_model_ctx_history", 0.7), + "vision": raw.get("chat_model_vision", True), + "rl_requests": raw.get("chat_model_rl_requests", 0), + "rl_input": raw.get("chat_model_rl_input", 0), + "rl_output": raw.get("chat_model_rl_output", 0), + "kwargs": raw.get("chat_model_kwargs", {}), + "allow_chat_override": False, + }, + "utility_model": { + "provider": raw.get("util_model_provider", "openrouter"), + "name": raw.get("util_model_name", ""), + "api_base": raw.get("util_model_api_base", ""), + "ctx_length": raw.get("util_model_ctx_length", 128000), + "ctx_input": raw.get("util_model_ctx_input", 0.7), + "rl_requests": raw.get("util_model_rl_requests", 0), + "rl_input": raw.get("util_model_rl_input", 0), + "rl_output": raw.get("util_model_rl_output", 0), + "kwargs": raw.get("util_model_kwargs", {}), + }, + "embedding_model": { + "provider": raw.get("embed_model_provider", "huggingface"), + "name": raw.get("embed_model_name", "sentence-transformers/all-MiniLM-L6-v2"), + "api_base": raw.get("embed_model_api_base", ""), + "rl_requests": raw.get("embed_model_rl_requests", 0), + "rl_input": raw.get("embed_model_rl_input", 0), + "kwargs": raw.get("embed_model_kwargs", {}), + }, + "browser_http_headers": raw.get("browser_http_headers", {}), + } + + # Ensure kwargs are dicts (might be strings from .env format) + for section in ["chat_model", "utility_model", "embedding_model"]: + kw = plugin_config[section].get("kwargs") + if isinstance(kw, str): + plugin_config[section]["kwargs"] = {} + + if isinstance(plugin_config["browser_http_headers"], str): + plugin_config["browser_http_headers"] = {} + + # Save as global plugin config + plugins.save_plugin_config("_model_config", "", "", plugin_config) + PrintStyle(background_color="#6734C3", font_color="white", padding=True).print( + "Migrated legacy model settings to _model_config plugin config." + ) diff --git a/plugins/_model_config/helpers/model_config.py b/plugins/_model_config/helpers/model_config.py new file mode 100644 index 000000000..88cce94da --- /dev/null +++ b/plugins/_model_config/helpers/model_config.py @@ -0,0 +1,217 @@ +import models +from helpers import plugins, settings, projects +from helpers.providers import get_providers, get_raw_providers + + +def get_config(agent=None, project_name=None, agent_profile=None): + """Get the full model config dict for the given agent/scope.""" + return plugins.get_plugin_config( + "_model_config", + agent=agent, + project_name=project_name, + agent_profile=agent_profile, + ) or {} + + +def get_presets(agent=None, project_name=None, agent_profile=None) -> list: + """Get model presets list from config.""" + cfg = get_config(agent, project_name, agent_profile) + return cfg.get("model_presets", []) + + +def get_preset_by_name(name: str, agent=None) -> dict | None: + """Find a preset by name.""" + for p in get_presets(agent): + if p.get("name") == name: + return p + return None + + +def _resolve_override(agent) -> dict | None: + """Resolve the active per-chat override config dict. + Supports both raw override dicts and preset-based overrides. + Returns None if no override is active.""" + if not agent: + return None + override = agent.context.get_data("chat_model_override") + if not override: + return None + + # If this is a preset reference, resolve it + if "preset_name" in override: + preset = get_preset_by_name(override["preset_name"], agent) + if not preset: + return None + return preset + + return override + + +def get_chat_model_config(agent=None) -> dict: + """Get chat model config, with per-chat override if active.""" + override = _resolve_override(agent) + if override: + # Preset has a nested 'chat' key; raw override is flat + chat_cfg = override.get("chat", override) + if chat_cfg.get("provider") or chat_cfg.get("name"): + return chat_cfg + cfg = get_config(agent) + return cfg.get("chat_model", {}) + + +def get_utility_model_config(agent=None) -> dict: + """Get utility model config, with per-chat override if active.""" + override = _resolve_override(agent) + if override: + util_cfg = override.get("utility", {}) + if util_cfg.get("provider") or util_cfg.get("name"): + return util_cfg + cfg = get_config(agent) + return cfg.get("utility_model", {}) + + +def get_embedding_model_config(agent=None) -> dict: + """Get embedding model config.""" + cfg = get_config(agent) + return cfg.get("embedding_model", {}) + + +def get_browser_http_headers(agent=None) -> dict: + """Get browser HTTP headers from config.""" + cfg = get_config(agent) + return cfg.get("browser_http_headers", {}) + + +def is_chat_override_allowed(agent=None) -> bool: + """Check if per-chat model override is enabled.""" + cfg = get_config(agent) + chat_cfg = cfg.get("chat_model", {}) + return bool(chat_cfg.get("allow_chat_override", False)) + + +def get_ctx_history(agent=None) -> float: + """Get the chat model context history ratio.""" + cfg = get_chat_model_config(agent) + return float(cfg.get("ctx_history", 0.7)) + + +def get_ctx_input(agent=None) -> float: + """Get the utility model context input ratio.""" + cfg = get_utility_model_config(agent) + return float(cfg.get("ctx_input", 0.7)) + + +def _normalize_kwargs(kwargs: dict) -> dict: + """Convert string values that are valid numbers to numeric types.""" + result = {} + for key, value in kwargs.items(): + if isinstance(value, str): + try: + result[key] = int(value) + except ValueError: + try: + result[key] = float(value) + except ValueError: + result[key] = value + else: + result[key] = value + return result + + +def build_model_config(cfg: dict, model_type: models.ModelType) -> models.ModelConfig: + """Build a ModelConfig from a config dict section.""" + return models.ModelConfig( + type=model_type, + provider=cfg.get("provider", ""), + name=cfg.get("name", ""), + api_base=cfg.get("api_base", ""), + ctx_length=int(cfg.get("ctx_length", 0)), + vision=bool(cfg.get("vision", False)), + limit_requests=int(cfg.get("rl_requests", 0)), + limit_input=int(cfg.get("rl_input", 0)), + limit_output=int(cfg.get("rl_output", 0)), + kwargs=_normalize_kwargs(cfg.get("kwargs", {})), + ) + + +def build_chat_model(agent=None): + """Build and return a LiteLLMChatWrapper from config.""" + cfg = get_chat_model_config(agent) + mc = build_model_config(cfg, models.ModelType.CHAT) + return models.get_chat_model( + mc.provider, mc.name, model_config=mc, **mc.build_kwargs() + ) + + +def build_utility_model(agent=None): + """Build and return a LiteLLMChatWrapper for utility tasks.""" + cfg = get_utility_model_config(agent) + mc = build_model_config(cfg, models.ModelType.CHAT) + return models.get_chat_model( + mc.provider, mc.name, model_config=mc, **mc.build_kwargs() + ) + + +def build_browser_model(agent=None): + """Build and return a BrowserCompatibleChatWrapper using chat model config.""" + cfg = get_chat_model_config(agent) + mc = build_model_config(cfg, models.ModelType.CHAT) + return models.get_browser_model( + mc.provider, mc.name, model_config=mc, **mc.build_kwargs() + ) + + +def build_embedding_model(agent=None): + """Build and return an embedding model wrapper.""" + cfg = get_embedding_model_config(agent) + mc = build_model_config(cfg, models.ModelType.EMBEDDING) + return models.get_embedding_model( + mc.provider, mc.name, model_config=mc, **mc.build_kwargs() + ) + + +def get_embedding_model_config_object(agent=None) -> models.ModelConfig: + """Get a ModelConfig object for embeddings (needed by memory plugin).""" + cfg = get_embedding_model_config(agent) + return build_model_config(cfg, models.ModelType.EMBEDDING) + + +def get_chat_providers(): + """Get list of chat providers for UI dropdowns.""" + return get_providers("chat") + + +def get_embedding_providers(): + """Get list of embedding providers for UI dropdowns.""" + return get_providers("embedding") + + +def get_missing_api_key_providers(agent=None) -> list[dict]: + """Check which configured providers are missing API keys.""" + cfg = get_config(agent) + missing = [] + + LOCAL_PROVIDERS = {"ollama", "lm_studio"} + LOCAL_EMBEDDING = {"huggingface"} + + checks = [ + ("Chat Model", cfg.get("chat_model", {})), + ("Utility Model", cfg.get("utility_model", {})), + ("Embedding Model", cfg.get("embedding_model", {})), + ] + + for label, model_cfg in checks: + provider = model_cfg.get("provider", "") + if not provider: + continue + provider_lower = provider.lower() + if provider_lower in LOCAL_PROVIDERS: + continue + if label == "Embedding Model" and provider_lower in LOCAL_EMBEDDING: + continue + + api_key = models.get_api_key(provider_lower) + if not (api_key and api_key.strip() and api_key != "None"): + missing.append({"model_type": label, "provider": provider}) + + return missing diff --git a/plugins/_model_config/hooks.py b/plugins/_model_config/hooks.py new file mode 100644 index 000000000..48e9d53b1 --- /dev/null +++ b/plugins/_model_config/hooks.py @@ -0,0 +1,8 @@ +def save_plugin_config(result=None, settings=None, **kwargs): + if settings and isinstance(settings, dict): + # Remove transient UI-only fields before persisting + settings.pop("_browser_headers_text", None) + for section in ("chat_model", "utility_model", "embedding_model"): + if section in settings and isinstance(settings[section], dict): + settings[section].pop("_kwargs_text", None) + return settings diff --git a/plugins/_model_config/plugin.yaml b/plugins/_model_config/plugin.yaml new file mode 100644 index 000000000..b16891d93 --- /dev/null +++ b/plugins/_model_config/plugin.yaml @@ -0,0 +1,9 @@ +name: _model_config +title: Model Configuration +description: Manages LLM model selection and configuration for chat, utility, and embedding models. Supports per-project and per-agent overrides with optional per-chat model switching. +version: 1.0.0 +always_enabled: true +settings_sections: + - agent +per_project_config: true +per_agent_config: true diff --git a/preload.py b/preload.py index a4cda4208..84639b43d 100644 --- a/preload.py +++ b/preload.py @@ -18,16 +18,17 @@ async def preload(): # preload embedding model async def preload_embedding(): - if set["embed_model_provider"].lower() == "huggingface": - try: - # Use the new LiteLLM-based model system + try: + from plugins._model_config.helpers.model_config import get_embedding_model_config_object + emb_cfg = get_embedding_model_config_object() + if emb_cfg.provider.lower() == "huggingface": emb_mod = models.get_embedding_model( - "huggingface", set["embed_model_name"] + "huggingface", emb_cfg.name ) emb_txt = await emb_mod.aembed_query("test") return emb_txt - except Exception as e: - PrintStyle().error(f"Error in preload_embedding: {e}") + except Exception as e: + PrintStyle().error(f"Error in preload_embedding: {e}") # preload kokoro tts model if enabled async def preload_kokoro():