mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 02:30:07 +00:00
Implement LLMRouter (#127)
This commit is contained in:
parent
c58aaba4bb
commit
1c397a13af
3 changed files with 146 additions and 8 deletions
|
@ -1,3 +1,4 @@
|
|||
import dataclasses
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -7,8 +8,12 @@ import structlog
|
|||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
from skyvern.forge.sdk.api.llm.exceptions import (
|
||||
DuplicateCustomLLMProviderError,
|
||||
InvalidLLMConfigError,
|
||||
LLMProviderError,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
|
@ -20,10 +25,112 @@ LOG = structlog.get_logger()
|
|||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
if not isinstance(llm_config, LLMRouterConfig):
|
||||
raise InvalidLLMConfigError(llm_key)
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[dataclasses.asdict(model) for model in llm_config.model_list],
|
||||
redis_host=llm_config.redis_host,
|
||||
redis_port=llm_config.redis_port,
|
||||
routing_strategy=llm_config.routing_strategy,
|
||||
fallbacks=[{llm_config.main_model_group: llm_config.fallback_model_group}]
|
||||
if llm_config.fallback_model_group
|
||||
else [],
|
||||
num_retries=llm_config.num_retries,
|
||||
retry_after=llm_config.retry_delay_seconds,
|
||||
set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose,
|
||||
)
|
||||
main_model_group = llm_config.main_model_group
|
||||
|
||||
async def llm_api_handler_with_router_and_fallback(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to generate completions for.
|
||||
step: The step object associated with the prompt.
|
||||
screenshots: The screenshots associated with the prompt.
|
||||
parameters: Additional parameters to be passed to the LLM router.
|
||||
|
||||
Returns:
|
||||
The response from the LLM router.
|
||||
"""
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_key,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
)
|
||||
try:
|
||||
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
|
||||
except openai.OpenAIError as e:
|
||||
raise LLMProviderError(llm_key) from e
|
||||
except Exception as e:
|
||||
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||
raise LLMProviderError(llm_key) from e
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
return llm_api_handler_with_router_and_fallback
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
|
||||
if LLMConfigRegistry.is_router_config(llm_key):
|
||||
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
||||
|
||||
async def llm_api_handler(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
|
|
|
@ -6,23 +6,27 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
|||
MissingLLMProviderEnvVarsError,
|
||||
NoProviderEnabledError,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMConfigRegistry:
|
||||
_configs: dict[str, LLMConfig] = {}
|
||||
_configs: dict[str, LLMRouterConfig | LLMConfig] = {}
|
||||
|
||||
@staticmethod
|
||||
def validate_config(llm_key: str, config: LLMConfig) -> None:
|
||||
def is_router_config(llm_key: str) -> bool:
|
||||
return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig)
|
||||
|
||||
@staticmethod
|
||||
def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
|
||||
missing_env_vars = config.get_missing_env_vars()
|
||||
if missing_env_vars:
|
||||
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
|
||||
def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
|
||||
if llm_key in cls._configs:
|
||||
raise DuplicateLLMConfigError(llm_key)
|
||||
|
||||
|
@ -32,7 +36,7 @@ class LLMConfigRegistry:
|
|||
cls._configs[llm_key] = config
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, llm_key: str) -> LLMConfig:
|
||||
def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
|
||||
if llm_key not in cls._configs:
|
||||
raise InvalidLLMConfigError(llm_key)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Protocol
|
||||
from typing import Any, Awaitable, Literal, Protocol
|
||||
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
@ -21,6 +21,33 @@ class LLMConfig:
|
|||
return missing_env_vars
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMRouterModelConfig:
|
||||
model_name: str
|
||||
# https://litellm.vercel.app/docs/routing
|
||||
litellm_params: dict[str, Any]
|
||||
tpm: int | None = None
|
||||
rpm: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMRouterConfig(LLMConfig):
|
||||
model_list: list[LLMRouterModelConfig]
|
||||
redis_host: str
|
||||
redis_port: int
|
||||
main_model_group: str
|
||||
fallback_model_group: str | None = None
|
||||
routing_strategy: Literal[
|
||||
"simple-shuffle",
|
||||
"least-busy",
|
||||
"usage-based-routing",
|
||||
"latency-based-routing",
|
||||
] = "usage-based-routing"
|
||||
num_retries: int = 2
|
||||
retry_delay_seconds: int = 15
|
||||
set_verbose: bool = True
|
||||
|
||||
|
||||
class LLMAPIHandler(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
|
|
Loading…
Add table
Reference in a new issue