mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 10:41:04 +00:00
Implement LLMRouter (#127)
This commit is contained in:
parent
c58aaba4bb
commit
1c397a13af
3 changed files with 146 additions and 8 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -7,8 +8,12 @@ import structlog
|
||||||
|
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||||
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
|
from skyvern.forge.sdk.api.llm.exceptions import (
|
||||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
DuplicateCustomLLMProviderError,
|
||||||
|
InvalidLLMConfigError,
|
||||||
|
LLMProviderError,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
|
||||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.models import Step
|
from skyvern.forge.sdk.models import Step
|
||||||
|
@ -20,10 +25,112 @@ LOG = structlog.get_logger()
|
||||||
class LLMAPIHandlerFactory:
|
class LLMAPIHandlerFactory:
|
||||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
|
||||||
|
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
|
if not isinstance(llm_config, LLMRouterConfig):
|
||||||
|
raise InvalidLLMConfigError(llm_key)
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[dataclasses.asdict(model) for model in llm_config.model_list],
|
||||||
|
redis_host=llm_config.redis_host,
|
||||||
|
redis_port=llm_config.redis_port,
|
||||||
|
routing_strategy=llm_config.routing_strategy,
|
||||||
|
fallbacks=[{llm_config.main_model_group: llm_config.fallback_model_group}]
|
||||||
|
if llm_config.fallback_model_group
|
||||||
|
else [],
|
||||||
|
num_retries=llm_config.num_retries,
|
||||||
|
retry_after=llm_config.retry_delay_seconds,
|
||||||
|
set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose,
|
||||||
|
)
|
||||||
|
main_model_group = llm_config.main_model_group
|
||||||
|
|
||||||
|
async def llm_api_handler_with_router_and_fallback(
|
||||||
|
prompt: str,
|
||||||
|
step: Step | None = None,
|
||||||
|
screenshots: list[bytes] | None = None,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to generate completions for.
|
||||||
|
step: The step object associated with the prompt.
|
||||||
|
screenshots: The screenshots associated with the prompt.
|
||||||
|
parameters: Additional parameters to be passed to the LLM router.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the LLM router.
|
||||||
|
"""
|
||||||
|
if parameters is None:
|
||||||
|
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||||
|
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_PROMPT,
|
||||||
|
data=prompt.encode("utf-8"),
|
||||||
|
)
|
||||||
|
for screenshot in screenshots or []:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||||
|
data=screenshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = await llm_messages_builder(prompt, screenshots)
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_REQUEST,
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": llm_key,
|
||||||
|
"messages": messages,
|
||||||
|
**parameters,
|
||||||
|
}
|
||||||
|
).encode("utf-8"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
raise LLMProviderError(llm_key) from e
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||||
|
raise LLMProviderError(llm_key) from e
|
||||||
|
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||||
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
|
)
|
||||||
|
llm_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
await app.DATABASE.update_step(
|
||||||
|
task_id=step.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
organization_id=step.organization_id,
|
||||||
|
incremental_cost=llm_cost,
|
||||||
|
)
|
||||||
|
parsed_response = parse_api_response(response)
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||||
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
|
)
|
||||||
|
return parsed_response
|
||||||
|
|
||||||
|
return llm_api_handler_with_router_and_fallback
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
|
|
||||||
|
if LLMConfigRegistry.is_router_config(llm_key):
|
||||||
|
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
||||||
|
|
||||||
async def llm_api_handler(
|
async def llm_api_handler(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
step: Step | None = None,
|
step: Step | None = None,
|
||||||
|
|
|
@ -6,23 +6,27 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
||||||
MissingLLMProviderEnvVarsError,
|
MissingLLMProviderEnvVarsError,
|
||||||
NoProviderEnabledError,
|
NoProviderEnabledError,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigRegistry:
|
class LLMConfigRegistry:
|
||||||
_configs: dict[str, LLMConfig] = {}
|
_configs: dict[str, LLMRouterConfig | LLMConfig] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_config(llm_key: str, config: LLMConfig) -> None:
|
def is_router_config(llm_key: str) -> bool:
|
||||||
|
return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
|
||||||
missing_env_vars = config.get_missing_env_vars()
|
missing_env_vars = config.get_missing_env_vars()
|
||||||
if missing_env_vars:
|
if missing_env_vars:
|
||||||
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
|
def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
|
||||||
if llm_key in cls._configs:
|
if llm_key in cls._configs:
|
||||||
raise DuplicateLLMConfigError(llm_key)
|
raise DuplicateLLMConfigError(llm_key)
|
||||||
|
|
||||||
|
@ -32,7 +36,7 @@ class LLMConfigRegistry:
|
||||||
cls._configs[llm_key] = config
|
cls._configs[llm_key] = config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, llm_key: str) -> LLMConfig:
|
def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
|
||||||
if llm_key not in cls._configs:
|
if llm_key not in cls._configs:
|
||||||
raise InvalidLLMConfigError(llm_key)
|
raise InvalidLLMConfigError(llm_key)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Protocol
|
from typing import Any, Awaitable, Literal, Protocol
|
||||||
|
|
||||||
from skyvern.forge.sdk.models import Step
|
from skyvern.forge.sdk.models import Step
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
@ -21,6 +21,33 @@ class LLMConfig:
|
||||||
return missing_env_vars
|
return missing_env_vars
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMRouterModelConfig:
|
||||||
|
model_name: str
|
||||||
|
# https://litellm.vercel.app/docs/routing
|
||||||
|
litellm_params: dict[str, Any]
|
||||||
|
tpm: int | None = None
|
||||||
|
rpm: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMRouterConfig(LLMConfig):
|
||||||
|
model_list: list[LLMRouterModelConfig]
|
||||||
|
redis_host: str
|
||||||
|
redis_port: int
|
||||||
|
main_model_group: str
|
||||||
|
fallback_model_group: str | None = None
|
||||||
|
routing_strategy: Literal[
|
||||||
|
"simple-shuffle",
|
||||||
|
"least-busy",
|
||||||
|
"usage-based-routing",
|
||||||
|
"latency-based-routing",
|
||||||
|
] = "usage-based-routing"
|
||||||
|
num_retries: int = 2
|
||||||
|
retry_delay_seconds: int = 15
|
||||||
|
set_verbose: bool = True
|
||||||
|
|
||||||
|
|
||||||
class LLMAPIHandler(Protocol):
|
class LLMAPIHandler(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Reference in a new issue