Implement LLMRouter (#127)

This commit is contained in:
Kerem Yilmaz 2024-03-27 14:44:25 -07:00 committed by GitHub
parent c58aaba4bb
commit 1c397a13af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 146 additions and 8 deletions

View file

@ -6,23 +6,27 @@ from skyvern.forge.sdk.api.llm.exceptions import (
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {}
_configs: dict[str, LLMRouterConfig | LLMConfig] = {}
@staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None:
def is_router_config(llm_key: str) -> bool:
return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig)
@staticmethod
def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars()
if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
@classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key)
@ -32,7 +36,7 @@ class LLMConfigRegistry:
cls._configs[llm_key] = config
@classmethod
def get_config(cls, llm_key: str) -> LLMConfig:
def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key)