Skyvern/skyvern/forge/sdk/api/llm/models.py
2024-12-08 21:17:58 -08:00

86 lines
2.7 KiB
Python

from dataclasses import dataclass, field
from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
from litellm import AllowedFailsPolicy
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise
from skyvern.forge.sdk.settings_manager import SettingsManager
class LiteLLMParams(TypedDict):
api_key: str | None
api_version: str | None
api_base: str | None
model_info: dict[str, Any] | None
@dataclass(frozen=True)
class LLMConfigBase:
model_name: str
required_env_vars: list[str]
supports_vision: bool
add_assistant_prefix: bool
def get_missing_env_vars(self) -> list[str]:
missing_env_vars = []
for env_var in self.required_env_vars:
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
if not env_var_value:
missing_env_vars.append(env_var)
return missing_env_vars
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[LiteLLMParams] = field(default=None)
max_output_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS
@dataclass(frozen=True)
class LLMRouterModelConfig:
model_name: str
# https://litellm.vercel.app/docs/routing
litellm_params: dict[str, Any]
model_info: dict[str, Any] = field(default_factory=dict)
tpm: int | None = None
rpm: int | None = None
@dataclass(frozen=True)
class LLMRouterConfig(LLMConfigBase):
model_list: list[LLMRouterModelConfig]
# All three redis parameters are required. Even if there isn't a password, it should be an empty string.
main_model_group: str
redis_host: str | None = None
redis_port: int | None = None
redis_password: str | None = None
fallback_model_group: str | None = None
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
"usage-based-routing",
"usage-based-routing-v2",
"latency-based-routing",
] = "usage-based-routing"
num_retries: int = 1
retry_delay_seconds: int = 15
set_verbose: bool = False
disable_cooldowns: bool | None = None
allowed_fails: int | None = None
allowed_fails_policy: AllowedFailsPolicy | None = None
cooldown_time: float | None = None
max_output_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS
class LLMAPIHandler(Protocol):
def __call__(
self,
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverCruise | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]: ...