mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-04-28 03:30:23 +00:00
feat: add _model_config plugin with call-time model resolution
This commit is contained in:
parent
6d067a12d0
commit
d570c629c2
19 changed files with 793 additions and 207 deletions
64
agent.py
64
agent.py
|
|
@ -309,16 +309,9 @@ class AgentContext:
|
|||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
chat_model: models.ModelConfig
|
||||
utility_model: models.ModelConfig
|
||||
embeddings_model: models.ModelConfig
|
||||
browser_model: models.ModelConfig
|
||||
mcp_servers: str
|
||||
profile: str = ""
|
||||
knowledge_subdirs: list[str] = field(default_factory=lambda: ["default", "custom"])
|
||||
browser_http_headers: dict[str, str] = field(
|
||||
default_factory=dict
|
||||
) # Custom HTTP headers for browser requests
|
||||
additional: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
|
|
@ -713,39 +706,19 @@ class Agent:
|
|||
|
||||
@extension.extensible
|
||||
def get_chat_model(self):
|
||||
return models.get_chat_model(
|
||||
self.config.chat_model.provider,
|
||||
self.config.chat_model.name,
|
||||
model_config=self.config.chat_model,
|
||||
**self.config.chat_model.build_kwargs(),
|
||||
)
|
||||
return None
|
||||
|
||||
@extension.extensible
|
||||
def get_utility_model(self):
|
||||
return models.get_chat_model(
|
||||
self.config.utility_model.provider,
|
||||
self.config.utility_model.name,
|
||||
model_config=self.config.utility_model,
|
||||
**self.config.utility_model.build_kwargs(),
|
||||
)
|
||||
return None
|
||||
|
||||
@extension.extensible
|
||||
def get_browser_model(self):
|
||||
return models.get_browser_model(
|
||||
self.config.browser_model.provider,
|
||||
self.config.browser_model.name,
|
||||
model_config=self.config.browser_model,
|
||||
**self.config.browser_model.build_kwargs(),
|
||||
)
|
||||
return None
|
||||
|
||||
@extension.extensible
|
||||
def get_embedding_model(self):
|
||||
return models.get_embedding_model(
|
||||
self.config.embeddings_model.provider,
|
||||
self.config.embeddings_model.name,
|
||||
model_config=self.config.embeddings_model,
|
||||
**self.config.embeddings_model.build_kwargs(),
|
||||
)
|
||||
return None
|
||||
|
||||
@extension.extensible
|
||||
async def call_utility_model(
|
||||
|
|
@ -803,15 +776,32 @@ class Agent:
|
|||
# model class
|
||||
model = self.get_chat_model()
|
||||
|
||||
# call extensions before
|
||||
call_data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"response_callback": response_callback,
|
||||
"reasoning_callback": reasoning_callback,
|
||||
"background": background,
|
||||
"explicit_caching": explicit_caching,
|
||||
}
|
||||
await extension.call_extensions_async(
|
||||
"chat_model_call_before", self, call_data=call_data
|
||||
)
|
||||
|
||||
# call model
|
||||
response, reasoning = await model.unified_call(
|
||||
messages=messages,
|
||||
reasoning_callback=reasoning_callback,
|
||||
response_callback=response_callback,
|
||||
response, reasoning = await call_data["model"].unified_call(
|
||||
messages=call_data["messages"],
|
||||
reasoning_callback=call_data["reasoning_callback"],
|
||||
response_callback=call_data["response_callback"],
|
||||
rate_limiter_callback=(
|
||||
self.rate_limiter_callback if not background else None
|
||||
self.rate_limiter_callback if not call_data["background"] else None
|
||||
),
|
||||
explicit_caching=explicit_caching,
|
||||
explicit_caching=call_data["explicit_caching"],
|
||||
)
|
||||
|
||||
await extension.call_extensions_async(
|
||||
"chat_model_call_after", self, call_data=call_data, response=response, reasoning=reasoning
|
||||
)
|
||||
|
||||
return response, reasoning
|
||||
|
|
|
|||
|
|
@ -159,10 +159,13 @@ class Topic(Record):
|
|||
return self.summary
|
||||
|
||||
def compress_large_messages(self, message_ratio: float = CURRENT_TOPIC_RATIO * LARGE_MESSAGE_TO_CURRENT_TOPIC_RATIO) -> bool:
|
||||
set = settings.get_settings()
|
||||
from plugins._model_config.helpers.model_config import get_chat_model_config
|
||||
chat_cfg = get_chat_model_config()
|
||||
ctx_length = int(chat_cfg.get("ctx_length", 128000))
|
||||
ctx_history = float(chat_cfg.get("ctx_history", 0.7))
|
||||
msg_max_size = (
|
||||
set["chat_model_ctx_length"]
|
||||
* set["chat_model_ctx_history"]
|
||||
ctx_length
|
||||
* ctx_history
|
||||
* message_ratio
|
||||
)
|
||||
large_msgs = []
|
||||
|
|
@ -479,8 +482,11 @@ def deserialize_history(json_data: str, agent) -> History:
|
|||
|
||||
|
||||
def _get_ctx_size_for_history() -> int:
|
||||
set = settings.get_settings()
|
||||
return int(set["chat_model_ctx_length"] * set["chat_model_ctx_history"])
|
||||
from plugins._model_config.helpers.model_config import get_chat_model_config
|
||||
chat_cfg = get_chat_model_config()
|
||||
ctx_length = int(chat_cfg.get("ctx_length", 128000))
|
||||
ctx_history = float(chat_cfg.get("ctx_history", 0.7))
|
||||
return int(ctx_length * ctx_history)
|
||||
|
||||
|
||||
def _stringify_output(output: OutputMessage, ai_label="ai", human_label="human"):
|
||||
|
|
|
|||
|
|
@ -53,44 +53,6 @@ def get_default_value(name: str, value: T) -> T:
|
|||
class Settings(TypedDict):
|
||||
version: str
|
||||
|
||||
chat_model_provider: str
|
||||
chat_model_name: str
|
||||
chat_model_api_base: str
|
||||
chat_model_kwargs: dict[str, Any]
|
||||
chat_model_ctx_length: int
|
||||
chat_model_ctx_history: float
|
||||
chat_model_vision: bool
|
||||
chat_model_rl_requests: int
|
||||
chat_model_rl_input: int
|
||||
chat_model_rl_output: int
|
||||
|
||||
util_model_provider: str
|
||||
util_model_name: str
|
||||
util_model_api_base: str
|
||||
util_model_kwargs: dict[str, Any]
|
||||
util_model_ctx_length: int
|
||||
util_model_ctx_input: float
|
||||
util_model_rl_requests: int
|
||||
util_model_rl_input: int
|
||||
util_model_rl_output: int
|
||||
|
||||
embed_model_provider: str
|
||||
embed_model_name: str
|
||||
embed_model_api_base: str
|
||||
embed_model_kwargs: dict[str, Any]
|
||||
embed_model_rl_requests: int
|
||||
embed_model_rl_input: int
|
||||
|
||||
browser_model_provider: str
|
||||
browser_model_name: str
|
||||
browser_model_api_base: str
|
||||
browser_model_vision: bool
|
||||
browser_model_rl_requests: int
|
||||
browser_model_rl_input: int
|
||||
browser_model_rl_output: int
|
||||
browser_model_kwargs: dict[str, Any]
|
||||
browser_http_headers: dict[str, Any]
|
||||
|
||||
agent_profile: str
|
||||
agent_knowledge_subdir: str
|
||||
|
||||
|
|
@ -261,10 +223,6 @@ def convert_out(settings: Settings) -> SettingsOutput:
|
|||
),
|
||||
}
|
||||
|
||||
additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("chat_model_provider"))
|
||||
additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("util_model_provider"))
|
||||
additional["chat_providers"] = _ensure_option_present(additional.get("chat_providers"), current.get("browser_model_provider"))
|
||||
additional["embedding_providers"] = _ensure_option_present(additional.get("embedding_providers"), current.get("embed_model_provider"))
|
||||
additional["agent_subdirs"] = _ensure_option_present(additional.get("agent_subdirs"), current.get("agent_profile"))
|
||||
additional["knowledge_subdirs"] = _ensure_option_present(additional.get("knowledge_subdirs"), current.get("agent_knowledge_subdir"))
|
||||
additional["stt_models"] = _ensure_option_present(additional.get("stt_models"), current.get("stt_model_size"))
|
||||
|
|
@ -493,40 +451,6 @@ def get_default_settings() -> Settings:
|
|||
gitignore = files.read_file(files.get_abs_path("conf/workdir.gitignore"))
|
||||
return Settings(
|
||||
version=_get_version(),
|
||||
chat_model_provider=get_default_value("chat_model_provider", "openrouter"),
|
||||
chat_model_name=get_default_value("chat_model_name", "anthropic/claude-sonnet-4.6"),
|
||||
chat_model_api_base=get_default_value("chat_model_api_base", ""),
|
||||
chat_model_kwargs=get_default_value("chat_model_kwargs", {}),
|
||||
chat_model_ctx_length=get_default_value("chat_model_ctx_length", 100000),
|
||||
chat_model_ctx_history=get_default_value("chat_model_ctx_history", 0.7),
|
||||
chat_model_vision=get_default_value("chat_model_vision", True),
|
||||
chat_model_rl_requests=get_default_value("chat_model_rl_requests", 0),
|
||||
chat_model_rl_input=get_default_value("chat_model_rl_input", 0),
|
||||
chat_model_rl_output=get_default_value("chat_model_rl_output", 0),
|
||||
util_model_provider=get_default_value("util_model_provider", "openrouter"),
|
||||
util_model_name=get_default_value("util_model_name", "google/gemini-3-flash-preview"),
|
||||
util_model_api_base=get_default_value("util_model_api_base", ""),
|
||||
util_model_ctx_length=get_default_value("util_model_ctx_length", 100000),
|
||||
util_model_ctx_input=get_default_value("util_model_ctx_input", 0.7),
|
||||
util_model_kwargs=get_default_value("util_model_kwargs", {}),
|
||||
util_model_rl_requests=get_default_value("util_model_rl_requests", 0),
|
||||
util_model_rl_input=get_default_value("util_model_rl_input", 0),
|
||||
util_model_rl_output=get_default_value("util_model_rl_output", 0),
|
||||
embed_model_provider=get_default_value("embed_model_provider", "huggingface"),
|
||||
embed_model_name=get_default_value("embed_model_name", "sentence-transformers/all-MiniLM-L6-v2"),
|
||||
embed_model_api_base=get_default_value("embed_model_api_base", ""),
|
||||
embed_model_kwargs=get_default_value("embed_model_kwargs", {}),
|
||||
embed_model_rl_requests=get_default_value("embed_model_rl_requests", 0),
|
||||
embed_model_rl_input=get_default_value("embed_model_rl_input", 0),
|
||||
browser_model_provider=get_default_value("browser_model_provider", "openrouter"),
|
||||
browser_model_name=get_default_value("browser_model_name", "anthropic/claude-sonnet-4.6"),
|
||||
browser_model_api_base=get_default_value("browser_model_api_base", ""),
|
||||
browser_model_vision=get_default_value("browser_model_vision", True),
|
||||
browser_model_rl_requests=get_default_value("browser_model_rl_requests", 0),
|
||||
browser_model_rl_input=get_default_value("browser_model_rl_input", 0),
|
||||
browser_model_rl_output=get_default_value("browser_model_rl_output", 0),
|
||||
browser_model_kwargs=get_default_value("browser_model_kwargs", {}),
|
||||
browser_http_headers=get_default_value("browser_http_headers", {}),
|
||||
api_keys={},
|
||||
auth_login="",
|
||||
auth_password="",
|
||||
|
|
@ -587,18 +511,6 @@ def _apply_settings(previous: Settings | None):
|
|||
whisper.preload, _settings["stt_model_size"]
|
||||
) # TODO overkill, replace with background task
|
||||
|
||||
# notify plugins of embedding model change
|
||||
if not previous or (
|
||||
_settings["embed_model_name"] != previous["embed_model_name"]
|
||||
or _settings["embed_model_provider"] != previous["embed_model_provider"]
|
||||
or _settings["embed_model_kwargs"] != previous["embed_model_kwargs"]
|
||||
):
|
||||
from helpers.extension import call_extensions_async
|
||||
|
||||
defer.DeferredTask().start_task(
|
||||
call_extensions_async, "embedding_model_changed"
|
||||
)
|
||||
|
||||
# update mcp settings if necessary
|
||||
if not previous or _settings["mcp_servers"] != previous["mcp_servers"]:
|
||||
from helpers.mcp_handler import MCPConfig
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from agent import AgentConfig
|
||||
import models
|
||||
from helpers import runtime, settings, defer, extension
|
||||
from helpers.print_style import PrintStyle
|
||||
|
||||
|
|
@ -10,78 +9,11 @@ def initialize_agent(override_settings: dict | None = None):
|
|||
if override_settings:
|
||||
current_settings = settings.merge_settings(current_settings, override_settings)
|
||||
|
||||
def _normalize_model_kwargs(kwargs: dict) -> dict:
|
||||
# convert string values that represent valid Python numbers to numeric types
|
||||
result = {}
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, str):
|
||||
# try to convert string to number if it's a valid Python number
|
||||
try:
|
||||
# try int first, then float
|
||||
result[key] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
result[key] = float(value)
|
||||
except ValueError:
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
# chat model from user settings
|
||||
chat_llm = models.ModelConfig(
|
||||
type=models.ModelType.CHAT,
|
||||
provider=current_settings["chat_model_provider"],
|
||||
name=current_settings["chat_model_name"],
|
||||
api_base=current_settings["chat_model_api_base"],
|
||||
ctx_length=current_settings["chat_model_ctx_length"],
|
||||
vision=current_settings["chat_model_vision"],
|
||||
limit_requests=current_settings["chat_model_rl_requests"],
|
||||
limit_input=current_settings["chat_model_rl_input"],
|
||||
limit_output=current_settings["chat_model_rl_output"],
|
||||
kwargs=_normalize_model_kwargs(current_settings["chat_model_kwargs"]),
|
||||
)
|
||||
|
||||
# utility model from user settings
|
||||
utility_llm = models.ModelConfig(
|
||||
type=models.ModelType.CHAT,
|
||||
provider=current_settings["util_model_provider"],
|
||||
name=current_settings["util_model_name"],
|
||||
api_base=current_settings["util_model_api_base"],
|
||||
ctx_length=current_settings["util_model_ctx_length"],
|
||||
limit_requests=current_settings["util_model_rl_requests"],
|
||||
limit_input=current_settings["util_model_rl_input"],
|
||||
limit_output=current_settings["util_model_rl_output"],
|
||||
kwargs=_normalize_model_kwargs(current_settings["util_model_kwargs"]),
|
||||
)
|
||||
# embedding model from user settings
|
||||
embedding_llm = models.ModelConfig(
|
||||
type=models.ModelType.EMBEDDING,
|
||||
provider=current_settings["embed_model_provider"],
|
||||
name=current_settings["embed_model_name"],
|
||||
api_base=current_settings["embed_model_api_base"],
|
||||
limit_requests=current_settings["embed_model_rl_requests"],
|
||||
kwargs=_normalize_model_kwargs(current_settings["embed_model_kwargs"]),
|
||||
)
|
||||
# browser model from user settings
|
||||
browser_llm = models.ModelConfig(
|
||||
type=models.ModelType.CHAT,
|
||||
provider=current_settings["browser_model_provider"],
|
||||
name=current_settings["browser_model_name"],
|
||||
api_base=current_settings["browser_model_api_base"],
|
||||
vision=current_settings["browser_model_vision"],
|
||||
kwargs=_normalize_model_kwargs(current_settings["browser_model_kwargs"]),
|
||||
)
|
||||
# agent configuration
|
||||
# agent configuration - models are now resolved at call time via _model_config plugin
|
||||
config = AgentConfig(
|
||||
chat_model=chat_llm,
|
||||
utility_model=utility_llm,
|
||||
embeddings_model=embedding_llm,
|
||||
browser_model=browser_llm,
|
||||
profile=current_settings["agent_profile"],
|
||||
knowledge_subdirs=[current_settings["agent_knowledge_subdir"], "default"],
|
||||
mcp_servers=current_settings["mcp_servers"],
|
||||
browser_http_headers=current_settings["browser_http_headers"],
|
||||
)
|
||||
|
||||
# update config with runtime args
|
||||
|
|
|
|||
|
|
@ -60,6 +60,11 @@ class Memory:
|
|||
|
||||
index: dict[str, "MyFaiss"] = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding_config(agent=None):
|
||||
from plugins._model_config.helpers.model_config import get_embedding_model_config_object
|
||||
return get_embedding_model_config_object(agent)
|
||||
|
||||
@staticmethod
|
||||
async def get(agent: Agent):
|
||||
memory_subdir = get_agent_memory_subdir(agent)
|
||||
|
|
@ -70,7 +75,7 @@ class Memory:
|
|||
)
|
||||
db, created = Memory.initialize(
|
||||
log_item,
|
||||
agent.config.embeddings_model,
|
||||
Memory._get_embedding_config(agent),
|
||||
memory_subdir,
|
||||
False,
|
||||
)
|
||||
|
|
@ -98,7 +103,7 @@ class Memory:
|
|||
import initialize
|
||||
|
||||
agent_config = initialize.initialize_agent()
|
||||
model_config = agent_config.embeddings_model
|
||||
model_config = Memory._get_embedding_config()
|
||||
db, _created = Memory.initialize(
|
||||
log_item=log_item,
|
||||
model_config=model_config,
|
||||
|
|
|
|||
60
plugins/_model_config/api/api_keys.py
Normal file
60
plugins/_model_config/api/api_keys.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from helpers.api import ApiHandler, Request, Response
|
||||
from helpers import dotenv
|
||||
import models
|
||||
|
||||
API_KEY_PLACEHOLDER = "************"
|
||||
|
||||
|
||||
class ApiKeys(ApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
action = input.get("action", "get") # get | set | reveal
|
||||
|
||||
if action == "get":
|
||||
return self._get_keys()
|
||||
elif action == "set":
|
||||
return self._set_keys(input)
|
||||
elif action == "reveal":
|
||||
return self._reveal_key(input)
|
||||
|
||||
return Response(status=400, response=f"Unknown action: {action}")
|
||||
|
||||
def _get_keys(self) -> dict:
|
||||
from helpers.providers import get_providers
|
||||
|
||||
providers = get_providers("chat") + get_providers("embedding")
|
||||
seen = set()
|
||||
keys = {}
|
||||
|
||||
for p in providers:
|
||||
pid = p.get("value", "")
|
||||
if pid and pid not in seen:
|
||||
seen.add(pid)
|
||||
api_key = models.get_api_key(pid)
|
||||
has_key = bool(api_key and api_key.strip() and api_key != "None")
|
||||
keys[pid] = {
|
||||
"label": p.get("label", pid),
|
||||
"has_key": has_key,
|
||||
"masked": API_KEY_PLACEHOLDER if has_key else "",
|
||||
}
|
||||
|
||||
return {"keys": keys}
|
||||
|
||||
def _set_keys(self, input: dict) -> dict:
|
||||
updates = input.get("keys", {})
|
||||
if not isinstance(updates, dict):
|
||||
return {"ok": False, "error": "Invalid keys format"}
|
||||
|
||||
for provider, value in updates.items():
|
||||
if isinstance(value, str) and value != API_KEY_PLACEHOLDER:
|
||||
dotenv.save_dotenv_value(f"API_KEY_{provider.upper()}", value)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
def _reveal_key(self, input: dict) -> dict:
|
||||
provider = input.get("provider", "")
|
||||
if not provider:
|
||||
return {"ok": False, "error": "Missing provider"}
|
||||
api_key = models.get_api_key(provider)
|
||||
if api_key and api_key.strip() and api_key != "None":
|
||||
return {"ok": True, "value": api_key}
|
||||
return {"ok": True, "value": ""}
|
||||
41
plugins/_model_config/api/model_config_get.py
Normal file
41
plugins/_model_config/api/model_config_get.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from helpers.api import ApiHandler, Request, Response
|
||||
from helpers import plugins
|
||||
from plugins._model_config.helpers import model_config
|
||||
import models
|
||||
|
||||
|
||||
class ModelConfigGet(ApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
project_name = input.get("project_name", "")
|
||||
agent_profile = input.get("agent_profile", "")
|
||||
|
||||
config = model_config.get_config(
|
||||
project_name=project_name or None,
|
||||
agent_profile=agent_profile or None,
|
||||
)
|
||||
|
||||
# Provide default if no config found
|
||||
if not config:
|
||||
config = plugins.get_default_plugin_config("_model_config") or {}
|
||||
|
||||
# Add provider lists for UI dropdowns
|
||||
chat_providers = model_config.get_chat_providers()
|
||||
embedding_providers = model_config.get_embedding_providers()
|
||||
|
||||
# Mask API keys - show status only
|
||||
api_key_status = {}
|
||||
all_providers = chat_providers + embedding_providers
|
||||
seen = set()
|
||||
for p in all_providers:
|
||||
pid = p.get("value", "")
|
||||
if pid and pid not in seen:
|
||||
seen.add(pid)
|
||||
key = models.get_api_key(pid)
|
||||
api_key_status[pid] = bool(key and key.strip() and key != "None")
|
||||
|
||||
return {
|
||||
"config": config,
|
||||
"chat_providers": chat_providers,
|
||||
"embedding_providers": embedding_providers,
|
||||
"api_key_status": api_key_status,
|
||||
}
|
||||
35
plugins/_model_config/api/model_config_set.py
Normal file
35
plugins/_model_config/api/model_config_set.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from helpers.api import ApiHandler, Request, Response
|
||||
from helpers import plugins, defer
|
||||
from helpers.extension import call_extensions_async
|
||||
|
||||
|
||||
class ModelConfigSet(ApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
project_name = input.get("project_name", "")
|
||||
agent_profile = input.get("agent_profile", "")
|
||||
config = input.get("config")
|
||||
|
||||
if not config or not isinstance(config, dict):
|
||||
return Response(status=400, response="Missing or invalid config")
|
||||
|
||||
plugins.save_plugin_config(
|
||||
"_model_config",
|
||||
project_name=project_name,
|
||||
agent_profile=agent_profile,
|
||||
settings=config,
|
||||
)
|
||||
|
||||
# Check if embedding model changed and notify
|
||||
prev_config = plugins.get_plugin_config("_model_config") or {}
|
||||
prev_embed = prev_config.get("embedding_model", {})
|
||||
new_embed = config.get("embedding_model", {})
|
||||
if (
|
||||
prev_embed.get("provider") != new_embed.get("provider")
|
||||
or prev_embed.get("name") != new_embed.get("name")
|
||||
or prev_embed.get("kwargs") != new_embed.get("kwargs")
|
||||
):
|
||||
defer.DeferredTask().start_task(
|
||||
call_extensions_async, "embedding_model_changed"
|
||||
)
|
||||
|
||||
return {"ok": True}
|
||||
201
plugins/_model_config/api/model_search.py
Normal file
201
plugins/_model_config/api/model_search.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
import httpx
|
||||
from helpers.api import ApiHandler, Request, Response
|
||||
from helpers.providers import get_provider_config
|
||||
import models
|
||||
|
||||
_CLOUD_ENDPOINTS: dict[str, str] = {
|
||||
"openai": "https://api.openai.com/v1/models",
|
||||
"anthropic": "https://api.anthropic.com/v1/models",
|
||||
"groq": "https://api.groq.com/openai/v1/models",
|
||||
"deepseek": "https://api.deepseek.com/models",
|
||||
"mistral": "https://api.mistral.ai/v1/models",
|
||||
"openrouter": "https://openrouter.ai/api/v1/models",
|
||||
"xai": "https://api.x.ai/v1/models",
|
||||
"sambanova": "https://api.sambanova.ai/v1/models",
|
||||
"moonshot": "https://api.moonshot.cn/v1/models",
|
||||
"google": "https://generativelanguage.googleapis.com/v1beta/models",
|
||||
"a0_venice": "https://api.venice.ai/api/v1/models",
|
||||
"venice": "https://api.venice.ai/api/v1/models",
|
||||
}
|
||||
|
||||
# Local providers with default base URLs (no auth required).
|
||||
_LOCAL_DEFAULTS: dict[str, str] = {
|
||||
"ollama": "http://host.docker.internal:11434",
|
||||
"lm_studio": "http://host.docker.internal:1234",
|
||||
}
|
||||
|
||||
# Providers with hardcoded model lists (no listing API available).
|
||||
_STATIC_MODELS: dict[str, list[str]] = {
|
||||
"github_copilot": [
|
||||
"gpt-4.1", "gpt-4o", "gpt-5-mini", "oswe-vscode-prime",
|
||||
],
|
||||
"zai": [
|
||||
"glm-4-plus", "glm-4-air-250414", "glm-4-airx",
|
||||
"glm-4-long", "glm-4-flashx", "glm-4-flash-250414",
|
||||
"glm-4v-plus", "glm-4v", "glm-3-turbo",
|
||||
],
|
||||
"zai_coding": [
|
||||
"codegeex-4",
|
||||
"glm-4-plus", "glm-4-air-250414", "glm-4-airx",
|
||||
"glm-4-flashx", "glm-4-flash-250414",
|
||||
],
|
||||
}
|
||||
|
||||
# Model name substrings to exclude from litellm fallback results
|
||||
_LITELLM_EXCLUDE = frozenset({
|
||||
"dall-e", "gpt-image", "tts", "whisper", "audio",
|
||||
"realtime", "davinci", "babbage", "ada", "vision-preview",
|
||||
})
|
||||
|
||||
|
||||
class ModelSearch(ApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
provider = input.get("provider", "")
|
||||
query = input.get("query", "").lower()
|
||||
model_type = input.get("model_type", "chat")
|
||||
user_api_base = input.get("api_base", "")
|
||||
|
||||
if not provider:
|
||||
return {"models": []}
|
||||
|
||||
if provider in _STATIC_MODELS:
|
||||
all_models = list(_STATIC_MODELS[provider])
|
||||
else:
|
||||
provider_cfg = get_provider_config(model_type, provider)
|
||||
all_models = await self._fetch_models(provider, provider_cfg, user_api_base) or []
|
||||
|
||||
if not all_models:
|
||||
litellm_provider = (provider_cfg or {}).get("litellm_provider", provider)
|
||||
if litellm_provider == provider:
|
||||
all_models = self._litellm_fallback(provider, provider_cfg)
|
||||
|
||||
if query:
|
||||
all_models = [m for m in all_models if query in m.lower()]
|
||||
|
||||
return {"models": sorted(all_models)[:50], "provider": provider}
|
||||
|
||||
async def _fetch_models(self, provider: str, cfg: dict | None, user_api_base: str = "") -> list[str] | None:
|
||||
api_base = user_api_base or (cfg or {}).get("kwargs", {}).get("api_base", "")
|
||||
api_key = models.get_api_key(provider)
|
||||
|
||||
url, fmt = self._resolve_url(provider, api_base)
|
||||
if not url:
|
||||
return None
|
||||
|
||||
headers = self._build_headers(provider, api_key, cfg)
|
||||
params: dict[str, str] = {}
|
||||
if provider == "google":
|
||||
if api_key and api_key != "None":
|
||||
params["key"] = api_key
|
||||
params["pageSize"] = "1000"
|
||||
elif provider == "anthropic":
|
||||
params["limit"] = "1000"
|
||||
elif provider == "azure":
|
||||
params["api-version"] = "2024-10-21"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(url, headers=headers, params=params)
|
||||
if resp.status_code == 200:
|
||||
result = self._parse(resp.json(), fmt)
|
||||
if result:
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _resolve_url(self, provider: str, api_base: str) -> tuple[str | None, str]:
|
||||
if provider == "ollama":
|
||||
base = api_base or _LOCAL_DEFAULTS.get("ollama", "")
|
||||
return (base.rstrip("/") + "/api/tags" if base else None), "ollama"
|
||||
|
||||
if provider == "google":
|
||||
if api_base:
|
||||
return api_base.rstrip("/") + "/models", "google"
|
||||
return _CLOUD_ENDPOINTS["google"], "google"
|
||||
|
||||
if provider == "azure":
|
||||
if not api_base:
|
||||
return None, "openai"
|
||||
return api_base.rstrip("/") + "/openai/models", "openai"
|
||||
|
||||
if provider in _CLOUD_ENDPOINTS:
|
||||
return _CLOUD_ENDPOINTS[provider], "openai"
|
||||
|
||||
if api_base:
|
||||
return api_base.rstrip("/") + "/models", "openai"
|
||||
|
||||
if provider in _LOCAL_DEFAULTS:
|
||||
return _LOCAL_DEFAULTS[provider] + "/v1/models", "openai"
|
||||
|
||||
return None, "openai"
|
||||
|
||||
def _build_headers(self, provider: str, api_key: str, cfg: dict | None) -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
has_key = api_key and api_key != "None"
|
||||
|
||||
if provider == "anthropic":
|
||||
if has_key:
|
||||
headers["x-api-key"] = api_key
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
elif provider == "google":
|
||||
pass
|
||||
elif provider == "azure":
|
||||
if has_key:
|
||||
headers["api-key"] = api_key
|
||||
elif provider not in ("ollama", "lm_studio"):
|
||||
if has_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
extra = (cfg or {}).get("kwargs", {}).get("extra_headers", {})
|
||||
if isinstance(extra, dict):
|
||||
for k, v in extra.items():
|
||||
if isinstance(v, str):
|
||||
headers[k] = v
|
||||
|
||||
return headers
|
||||
|
||||
def _litellm_fallback(self, provider: str, cfg: dict | None) -> list[str]:
|
||||
try:
|
||||
import litellm
|
||||
registry = getattr(litellm, "models_by_provider", None)
|
||||
if not registry:
|
||||
return []
|
||||
|
||||
litellm_provider = (cfg or {}).get("litellm_provider", provider)
|
||||
raw_models: set = registry.get(litellm_provider, set())
|
||||
if not raw_models:
|
||||
return []
|
||||
|
||||
prefix = litellm_provider + "/"
|
||||
result: list[str] = []
|
||||
for name in raw_models:
|
||||
clean = name[len(prefix):] if name.startswith(prefix) else name
|
||||
low = clean.lower()
|
||||
if any(exc in low for exc in _LITELLM_EXCLUDE):
|
||||
continue
|
||||
if clean:
|
||||
result.append(clean)
|
||||
return result
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _parse(self, data: dict | list, fmt: str) -> list[str]:
|
||||
if fmt == "ollama":
|
||||
return [m.get("name", "") for m in data.get("models", []) if m.get("name")]
|
||||
|
||||
if fmt == "google":
|
||||
result = []
|
||||
for m in data.get("models", []):
|
||||
name = m.get("name", "")
|
||||
if name.startswith("models/"):
|
||||
name = name[7:]
|
||||
if name:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
if isinstance(data, dict) and "data" in data:
|
||||
return [m.get("id", "") for m in data["data"] if m.get("id")]
|
||||
|
||||
return []
|
||||
35
plugins/_model_config/default_config.yaml
Normal file
35
plugins/_model_config/default_config.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
chat_model:
|
||||
provider: "openrouter"
|
||||
name: "anthropic/claude-sonnet-4.6"
|
||||
api_base: ""
|
||||
ctx_length: 128000
|
||||
ctx_history: 0.7
|
||||
vision: true
|
||||
rl_requests: 0
|
||||
rl_input: 0
|
||||
rl_output: 0
|
||||
kwargs: {}
|
||||
allow_chat_override: false
|
||||
|
||||
utility_model:
|
||||
provider: "openrouter"
|
||||
name: "google/gemini-3-flash-preview"
|
||||
api_base: ""
|
||||
ctx_length: 128000
|
||||
ctx_input: 0.7
|
||||
rl_requests: 0
|
||||
rl_input: 0
|
||||
rl_output: 0
|
||||
kwargs: {}
|
||||
|
||||
embedding_model:
|
||||
provider: "huggingface"
|
||||
name: "sentence-transformers/all-MiniLM-L6-v2"
|
||||
api_base: ""
|
||||
rl_requests: 0
|
||||
rl_input: 0
|
||||
kwargs: {}
|
||||
|
||||
browser_http_headers: {}
|
||||
|
||||
model_presets: []
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from helpers.extension import Extension
|
||||
from plugins._model_config.helpers.model_config import build_browser_model
|
||||
|
||||
|
||||
class BrowserModelProvider(Extension):
|
||||
def execute(self, data: dict = {}, **kwargs):
|
||||
if self.agent:
|
||||
data["result"] = build_browser_model(self.agent)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from helpers.extension import Extension
|
||||
from plugins._model_config.helpers.model_config import build_chat_model
|
||||
|
||||
|
||||
class ChatModelProvider(Extension):
|
||||
def execute(self, data: dict = {}, **kwargs):
|
||||
if self.agent:
|
||||
data["result"] = build_chat_model(self.agent)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from helpers.extension import Extension
|
||||
from plugins._model_config.helpers.model_config import build_embedding_model
|
||||
|
||||
|
||||
class EmbeddingModelProvider(Extension):
|
||||
def execute(self, data: dict = {}, **kwargs):
|
||||
if self.agent:
|
||||
data["result"] = build_embedding_model(self.agent)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from helpers.extension import Extension
|
||||
from plugins._model_config.helpers.model_config import build_utility_model
|
||||
|
||||
|
||||
class UtilityModelProvider(Extension):
|
||||
def execute(self, data: dict = {}, **kwargs):
|
||||
if self.agent:
|
||||
data["result"] = build_utility_model(self.agent)
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
import json
|
||||
import os
|
||||
from helpers.extension import Extension
|
||||
from helpers import settings as settings_helper, files, plugins
|
||||
from helpers.print_style import PrintStyle
|
||||
|
||||
|
||||
class MigrateModelConfig(Extension):
|
||||
"""
|
||||
One-time migration: copy legacy model settings into _model_config plugin config.
|
||||
Runs during initialize_migration. Only migrates if no global plugin config exists yet
|
||||
and the settings file contains legacy model fields.
|
||||
"""
|
||||
|
||||
LEGACY_FIELDS = [
|
||||
"chat_model_provider", "chat_model_name", "chat_model_api_base",
|
||||
"chat_model_kwargs", "chat_model_ctx_length", "chat_model_vision",
|
||||
"chat_model_rl_requests", "chat_model_rl_input", "chat_model_rl_output",
|
||||
"chat_model_ctx_history",
|
||||
"util_model_provider", "util_model_name", "util_model_api_base",
|
||||
"util_model_kwargs", "util_model_ctx_length",
|
||||
"util_model_rl_requests", "util_model_rl_input", "util_model_rl_output",
|
||||
"util_model_ctx_input",
|
||||
"embed_model_provider", "embed_model_name", "embed_model_api_base",
|
||||
"embed_model_kwargs", "embed_model_rl_requests", "embed_model_rl_input",
|
||||
"browser_model_provider", "browser_model_name", "browser_model_api_base",
|
||||
"browser_model_vision", "browser_model_rl_requests", "browser_model_rl_input",
|
||||
"browser_model_rl_output", "browser_model_kwargs", "browser_http_headers",
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
# Check if global plugin config already exists
|
||||
global_config_path = files.get_abs_path("plugins/_model_config/config.json")
|
||||
if os.path.exists(global_config_path):
|
||||
return # already migrated or manually configured
|
||||
|
||||
# Read raw settings file to check for legacy model fields
|
||||
settings_file = files.get_abs_path("usr/settings.json")
|
||||
if not os.path.exists(settings_file):
|
||||
return
|
||||
|
||||
try:
|
||||
raw = json.loads(files.read_file(settings_file))
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Check if any legacy model field exists in the raw settings
|
||||
has_legacy = any(field in raw for field in self.LEGACY_FIELDS)
|
||||
if not has_legacy:
|
||||
return
|
||||
|
||||
# Build plugin config from legacy settings
|
||||
plugin_config = {
|
||||
"chat_model": {
|
||||
"provider": raw.get("chat_model_provider", "openrouter"),
|
||||
"name": raw.get("chat_model_name", ""),
|
||||
"api_base": raw.get("chat_model_api_base", ""),
|
||||
"ctx_length": raw.get("chat_model_ctx_length", 128000),
|
||||
"ctx_history": raw.get("chat_model_ctx_history", 0.7),
|
||||
"vision": raw.get("chat_model_vision", True),
|
||||
"rl_requests": raw.get("chat_model_rl_requests", 0),
|
||||
"rl_input": raw.get("chat_model_rl_input", 0),
|
||||
"rl_output": raw.get("chat_model_rl_output", 0),
|
||||
"kwargs": raw.get("chat_model_kwargs", {}),
|
||||
"allow_chat_override": False,
|
||||
},
|
||||
"utility_model": {
|
||||
"provider": raw.get("util_model_provider", "openrouter"),
|
||||
"name": raw.get("util_model_name", ""),
|
||||
"api_base": raw.get("util_model_api_base", ""),
|
||||
"ctx_length": raw.get("util_model_ctx_length", 128000),
|
||||
"ctx_input": raw.get("util_model_ctx_input", 0.7),
|
||||
"rl_requests": raw.get("util_model_rl_requests", 0),
|
||||
"rl_input": raw.get("util_model_rl_input", 0),
|
||||
"rl_output": raw.get("util_model_rl_output", 0),
|
||||
"kwargs": raw.get("util_model_kwargs", {}),
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": raw.get("embed_model_provider", "huggingface"),
|
||||
"name": raw.get("embed_model_name", "sentence-transformers/all-MiniLM-L6-v2"),
|
||||
"api_base": raw.get("embed_model_api_base", ""),
|
||||
"rl_requests": raw.get("embed_model_rl_requests", 0),
|
||||
"rl_input": raw.get("embed_model_rl_input", 0),
|
||||
"kwargs": raw.get("embed_model_kwargs", {}),
|
||||
},
|
||||
"browser_http_headers": raw.get("browser_http_headers", {}),
|
||||
}
|
||||
|
||||
# Ensure kwargs are dicts (might be strings from .env format)
|
||||
for section in ["chat_model", "utility_model", "embedding_model"]:
|
||||
kw = plugin_config[section].get("kwargs")
|
||||
if isinstance(kw, str):
|
||||
plugin_config[section]["kwargs"] = {}
|
||||
|
||||
if isinstance(plugin_config["browser_http_headers"], str):
|
||||
plugin_config["browser_http_headers"] = {}
|
||||
|
||||
# Save as global plugin config
|
||||
plugins.save_plugin_config("_model_config", "", "", plugin_config)
|
||||
PrintStyle(background_color="#6734C3", font_color="white", padding=True).print(
|
||||
"Migrated legacy model settings to _model_config plugin config."
|
||||
)
|
||||
217
plugins/_model_config/helpers/model_config.py
Normal file
217
plugins/_model_config/helpers/model_config.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
import models
|
||||
from helpers import plugins, settings, projects
|
||||
from helpers.providers import get_providers, get_raw_providers
|
||||
|
||||
|
||||
def get_config(agent=None, project_name=None, agent_profile=None):
|
||||
"""Get the full model config dict for the given agent/scope."""
|
||||
return plugins.get_plugin_config(
|
||||
"_model_config",
|
||||
agent=agent,
|
||||
project_name=project_name,
|
||||
agent_profile=agent_profile,
|
||||
) or {}
|
||||
|
||||
|
||||
def get_presets(agent=None, project_name=None, agent_profile=None) -> list:
|
||||
"""Get model presets list from config."""
|
||||
cfg = get_config(agent, project_name, agent_profile)
|
||||
return cfg.get("model_presets", [])
|
||||
|
||||
|
||||
def get_preset_by_name(name: str, agent=None) -> dict | None:
|
||||
"""Find a preset by name."""
|
||||
for p in get_presets(agent):
|
||||
if p.get("name") == name:
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_override(agent) -> dict | None:
|
||||
"""Resolve the active per-chat override config dict.
|
||||
Supports both raw override dicts and preset-based overrides.
|
||||
Returns None if no override is active."""
|
||||
if not agent:
|
||||
return None
|
||||
override = agent.context.get_data("chat_model_override")
|
||||
if not override:
|
||||
return None
|
||||
|
||||
# If this is a preset reference, resolve it
|
||||
if "preset_name" in override:
|
||||
preset = get_preset_by_name(override["preset_name"], agent)
|
||||
if not preset:
|
||||
return None
|
||||
return preset
|
||||
|
||||
return override
|
||||
|
||||
|
||||
def get_chat_model_config(agent=None) -> dict:
|
||||
"""Get chat model config, with per-chat override if active."""
|
||||
override = _resolve_override(agent)
|
||||
if override:
|
||||
# Preset has a nested 'chat' key; raw override is flat
|
||||
chat_cfg = override.get("chat", override)
|
||||
if chat_cfg.get("provider") or chat_cfg.get("name"):
|
||||
return chat_cfg
|
||||
cfg = get_config(agent)
|
||||
return cfg.get("chat_model", {})
|
||||
|
||||
|
||||
def get_utility_model_config(agent=None) -> dict:
|
||||
"""Get utility model config, with per-chat override if active."""
|
||||
override = _resolve_override(agent)
|
||||
if override:
|
||||
util_cfg = override.get("utility", {})
|
||||
if util_cfg.get("provider") or util_cfg.get("name"):
|
||||
return util_cfg
|
||||
cfg = get_config(agent)
|
||||
return cfg.get("utility_model", {})
|
||||
|
||||
|
||||
def get_embedding_model_config(agent=None) -> dict:
|
||||
"""Get embedding model config."""
|
||||
cfg = get_config(agent)
|
||||
return cfg.get("embedding_model", {})
|
||||
|
||||
|
||||
def get_browser_http_headers(agent=None) -> dict:
|
||||
"""Get browser HTTP headers from config."""
|
||||
cfg = get_config(agent)
|
||||
return cfg.get("browser_http_headers", {})
|
||||
|
||||
|
||||
def is_chat_override_allowed(agent=None) -> bool:
|
||||
"""Check if per-chat model override is enabled."""
|
||||
cfg = get_config(agent)
|
||||
chat_cfg = cfg.get("chat_model", {})
|
||||
return bool(chat_cfg.get("allow_chat_override", False))
|
||||
|
||||
|
||||
def get_ctx_history(agent=None) -> float:
|
||||
"""Get the chat model context history ratio."""
|
||||
cfg = get_chat_model_config(agent)
|
||||
return float(cfg.get("ctx_history", 0.7))
|
||||
|
||||
|
||||
def get_ctx_input(agent=None) -> float:
|
||||
"""Get the utility model context input ratio."""
|
||||
cfg = get_utility_model_config(agent)
|
||||
return float(cfg.get("ctx_input", 0.7))
|
||||
|
||||
|
||||
def _normalize_kwargs(kwargs: dict) -> dict:
|
||||
"""Convert string values that are valid numbers to numeric types."""
|
||||
result = {}
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
result[key] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
result[key] = float(value)
|
||||
except ValueError:
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def build_model_config(cfg: dict, model_type: models.ModelType) -> models.ModelConfig:
|
||||
"""Build a ModelConfig from a config dict section."""
|
||||
return models.ModelConfig(
|
||||
type=model_type,
|
||||
provider=cfg.get("provider", ""),
|
||||
name=cfg.get("name", ""),
|
||||
api_base=cfg.get("api_base", ""),
|
||||
ctx_length=int(cfg.get("ctx_length", 0)),
|
||||
vision=bool(cfg.get("vision", False)),
|
||||
limit_requests=int(cfg.get("rl_requests", 0)),
|
||||
limit_input=int(cfg.get("rl_input", 0)),
|
||||
limit_output=int(cfg.get("rl_output", 0)),
|
||||
kwargs=_normalize_kwargs(cfg.get("kwargs", {})),
|
||||
)
|
||||
|
||||
|
||||
def build_chat_model(agent=None):
|
||||
"""Build and return a LiteLLMChatWrapper from config."""
|
||||
cfg = get_chat_model_config(agent)
|
||||
mc = build_model_config(cfg, models.ModelType.CHAT)
|
||||
return models.get_chat_model(
|
||||
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
||||
)
|
||||
|
||||
|
||||
def build_utility_model(agent=None):
|
||||
"""Build and return a LiteLLMChatWrapper for utility tasks."""
|
||||
cfg = get_utility_model_config(agent)
|
||||
mc = build_model_config(cfg, models.ModelType.CHAT)
|
||||
return models.get_chat_model(
|
||||
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
||||
)
|
||||
|
||||
|
||||
def build_browser_model(agent=None):
|
||||
"""Build and return a BrowserCompatibleChatWrapper using chat model config."""
|
||||
cfg = get_chat_model_config(agent)
|
||||
mc = build_model_config(cfg, models.ModelType.CHAT)
|
||||
return models.get_browser_model(
|
||||
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
||||
)
|
||||
|
||||
|
||||
def build_embedding_model(agent=None):
|
||||
"""Build and return an embedding model wrapper."""
|
||||
cfg = get_embedding_model_config(agent)
|
||||
mc = build_model_config(cfg, models.ModelType.EMBEDDING)
|
||||
return models.get_embedding_model(
|
||||
mc.provider, mc.name, model_config=mc, **mc.build_kwargs()
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_config_object(agent=None) -> models.ModelConfig:
|
||||
"""Get a ModelConfig object for embeddings (needed by memory plugin)."""
|
||||
cfg = get_embedding_model_config(agent)
|
||||
return build_model_config(cfg, models.ModelType.EMBEDDING)
|
||||
|
||||
|
||||
def get_chat_providers():
|
||||
"""Get list of chat providers for UI dropdowns."""
|
||||
return get_providers("chat")
|
||||
|
||||
|
||||
def get_embedding_providers():
|
||||
"""Get list of embedding providers for UI dropdowns."""
|
||||
return get_providers("embedding")
|
||||
|
||||
|
||||
def get_missing_api_key_providers(agent=None) -> list[dict]:
|
||||
"""Check which configured providers are missing API keys."""
|
||||
cfg = get_config(agent)
|
||||
missing = []
|
||||
|
||||
LOCAL_PROVIDERS = {"ollama", "lm_studio"}
|
||||
LOCAL_EMBEDDING = {"huggingface"}
|
||||
|
||||
checks = [
|
||||
("Chat Model", cfg.get("chat_model", {})),
|
||||
("Utility Model", cfg.get("utility_model", {})),
|
||||
("Embedding Model", cfg.get("embedding_model", {})),
|
||||
]
|
||||
|
||||
for label, model_cfg in checks:
|
||||
provider = model_cfg.get("provider", "")
|
||||
if not provider:
|
||||
continue
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower in LOCAL_PROVIDERS:
|
||||
continue
|
||||
if label == "Embedding Model" and provider_lower in LOCAL_EMBEDDING:
|
||||
continue
|
||||
|
||||
api_key = models.get_api_key(provider_lower)
|
||||
if not (api_key and api_key.strip() and api_key != "None"):
|
||||
missing.append({"model_type": label, "provider": provider})
|
||||
|
||||
return missing
|
||||
8
plugins/_model_config/hooks.py
Normal file
8
plugins/_model_config/hooks.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
def save_plugin_config(result=None, settings=None, **kwargs):
|
||||
if settings and isinstance(settings, dict):
|
||||
# Remove transient UI-only fields before persisting
|
||||
settings.pop("_browser_headers_text", None)
|
||||
for section in ("chat_model", "utility_model", "embedding_model"):
|
||||
if section in settings and isinstance(settings[section], dict):
|
||||
settings[section].pop("_kwargs_text", None)
|
||||
return settings
|
||||
9
plugins/_model_config/plugin.yaml
Normal file
9
plugins/_model_config/plugin.yaml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
name: _model_config
|
||||
title: Model Configuration
|
||||
description: Manages LLM model selection and configuration for chat, utility, and embedding models. Supports per-project and per-agent overrides with optional per-chat model switching.
|
||||
version: 1.0.0
|
||||
always_enabled: true
|
||||
settings_sections:
|
||||
- agent
|
||||
per_project_config: true
|
||||
per_agent_config: true
|
||||
13
preload.py
13
preload.py
|
|
@ -18,16 +18,17 @@ async def preload():
|
|||
|
||||
# preload embedding model
|
||||
async def preload_embedding():
|
||||
if set["embed_model_provider"].lower() == "huggingface":
|
||||
try:
|
||||
# Use the new LiteLLM-based model system
|
||||
try:
|
||||
from plugins._model_config.helpers.model_config import get_embedding_model_config_object
|
||||
emb_cfg = get_embedding_model_config_object()
|
||||
if emb_cfg.provider.lower() == "huggingface":
|
||||
emb_mod = models.get_embedding_model(
|
||||
"huggingface", set["embed_model_name"]
|
||||
"huggingface", emb_cfg.name
|
||||
)
|
||||
emb_txt = await emb_mod.aembed_query("test")
|
||||
return emb_txt
|
||||
except Exception as e:
|
||||
PrintStyle().error(f"Error in preload_embedding: {e}")
|
||||
except Exception as e:
|
||||
PrintStyle().error(f"Error in preload_embedding: {e}")
|
||||
|
||||
# preload kokoro tts model if enabled
|
||||
async def preload_kokoro():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue