diff --git a/models.py b/models.py index 3606d3b02..02997e1c5 100644 --- a/models.py +++ b/models.py @@ -18,6 +18,7 @@ from litellm import completion, acompletion, embedding import litellm from python.helpers import dotenv +from python.helpers import settings from python.helpers.dotenv import load_dotenv from python.helpers.providers import get_provider_config from python.helpers.rate_limiter import RateLimiter @@ -434,7 +435,18 @@ class LocalSentenceTransformerWrapper(Embeddings): if model.startswith("sentence-transformers/"): model = model[len("sentence-transformers/") :] - self.model = SentenceTransformer(model, **kwargs) + # Filter kwargs for SentenceTransformer only (no LiteLLM params like 'stream_timeout') + st_allowed_keys = { + "device", + "cache_folder", + "use_auth_token", + "revision", + "trust_remote_code", + "model_kwargs", + } + st_kwargs = {k: v for k, v in (kwargs or {}).items() if k in st_allowed_keys} + + self.model = SentenceTransformer(model, **st_kwargs) self.model_name = model self.a0_model_conf = model_config @@ -542,6 +554,22 @@ def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict): def _merge_provider_defaults( provider_type: str, original_provider: str, kwargs: dict ) -> tuple[str, dict]: + # Normalize .env-style numeric strings (e.g., "timeout=30") into ints/floats for LiteLLM + def _normalize_values(values: dict) -> dict: + result: dict[str, Any] = {} + for k, v in values.items(): + if isinstance(v, str): + try: + result[k] = int(v) + except ValueError: + try: + result[k] = float(v) + except ValueError: + result[k] = v + else: + result[k] = v + return result + provider_name = original_provider # default: unchanged cfg = get_provider_config(provider_type, original_provider) if cfg: @@ -559,6 +587,15 @@ def _merge_provider_defaults( if key and key not in ("None", "NA"): kwargs["api_key"] = key + # Merge LiteLLM global kwargs (timeouts, stream_timeout, etc.) + try: + global_kwargs = settings.get_settings().get("litellm_global_kwargs", {}) # type: ignore[union-attr] + except Exception: + global_kwargs = {} + if isinstance(global_kwargs, dict): + for k, v in _normalize_values(global_kwargs).items(): + kwargs.setdefault(k, v) + return provider_name, kwargs diff --git a/python/helpers/settings.py b/python/helpers/settings.py index 4fb19ed6e..eb346ab2c 100644 --- a/python/helpers/settings.py +++ b/python/helpers/settings.py @@ -106,6 +106,9 @@ class Settings(TypedDict): variables: str secrets: str + # LiteLLM global kwargs applied to all model calls + litellm_global_kwargs: dict[str, str] + class PartialSettings(Settings, total=False): pass @@ -583,6 +586,28 @@ def convert_out(settings: Settings) -> SettingsOutput: "tab": "external", } + # LiteLLM global config section + litellm_fields: list[SettingsField] = [] + + litellm_fields.append( + { + "id": "litellm_global_kwargs", + "title": "LiteLLM global parameters", + "description": "Global LiteLLM params (e.g. timeout, stream_timeout) in .env format: one KEY=VALUE per line. Example: stream_timeout=30. Applied to all LiteLLM calls unless overridden. See LiteLLM and timeouts.", + "type": "textarea", + "value": _dict_to_env(settings["litellm_global_kwargs"]), + "style": "height: 12em", + } + ) + + litellm_section: SettingsSection = { + "id": "litellm", + "title": "LiteLLM Global Settings", + "description": "Configure global parameters passed to LiteLLM for all providers.", + "fields": litellm_fields, + "tab": "external", + } + # Agent config section agent_fields: list[SettingsField] = [] @@ -1215,6 +1240,7 @@ def convert_out(settings: Settings) -> SettingsOutput: memory_section, speech_section, api_keys_section, + litellm_section, secrets_section, auth_section, mcp_client_section, @@ -1452,6 +1478,7 @@ def get_default_settings() -> Settings: a2a_server_enabled=False, variables="", secrets="", + litellm_global_kwargs={}, ) diff --git a/webui/public/litellm.svg b/webui/public/litellm.svg new file mode 100644 index 000000000..14a43d01d --- /dev/null +++ b/webui/public/litellm.svg @@ -0,0 +1,11 @@ + + + + + + + + + + +