Implement LLMRouter (#127)

This commit is contained in:
Kerem Yilmaz 2024-03-27 14:44:25 -07:00 committed by GitHub
parent c58aaba4bb
commit 1c397a13af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 146 additions and 8 deletions

View file

@ -1,3 +1,4 @@
import dataclasses
import json
from typing import Any
@ -7,8 +8,12 @@ import structlog
from skyvern.forge import app
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
from skyvern.forge.sdk.api.llm.exceptions import (
DuplicateCustomLLMProviderError,
InvalidLLMConfigError,
LLMProviderError,
)
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
@ -20,10 +25,112 @@ LOG = structlog.get_logger()
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
@staticmethod
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
if not isinstance(llm_config, LLMRouterConfig):
raise InvalidLLMConfigError(llm_key)
router = litellm.Router(
model_list=[dataclasses.asdict(model) for model in llm_config.model_list],
redis_host=llm_config.redis_host,
redis_port=llm_config.redis_port,
routing_strategy=llm_config.routing_strategy,
fallbacks=[{llm_config.main_model_group: llm_config.fallback_model_group}]
if llm_config.fallback_model_group
else [],
num_retries=llm_config.num_retries,
retry_after=llm_config.retry_delay_seconds,
set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose,
)
main_model_group = llm_config.main_model_group
async def llm_api_handler_with_router_and_fallback(
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
Args:
prompt: The prompt to generate completions for.
step: The step object associated with the prompt.
screenshots: The screenshots associated with the prompt.
parameters: Additional parameters to be passed to the LLM router.
Returns:
The response from the LLM router.
"""
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters()
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
messages = await llm_messages_builder(prompt, screenshots)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
)
parsed_response = parse_api_response(response)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
return llm_api_handler_with_router_and_fallback
@staticmethod
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
if LLMConfigRegistry.is_router_config(llm_key):
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
async def llm_api_handler(
prompt: str,
step: Step | None = None,

View file

@ -6,23 +6,27 @@ from skyvern.forge.sdk.api.llm.exceptions import (
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {}
_configs: dict[str, LLMRouterConfig | LLMConfig] = {}
@staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None:
def is_router_config(llm_key: str) -> bool:
return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig)
@staticmethod
def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars()
if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
@classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key)
@ -32,7 +36,7 @@ class LLMConfigRegistry:
cls._configs[llm_key] = config
@classmethod
def get_config(cls, llm_key: str) -> LLMConfig:
def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key)

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Awaitable, Protocol
from typing import Any, Awaitable, Literal, Protocol
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
@ -21,6 +21,33 @@ class LLMConfig:
return missing_env_vars
@dataclass(frozen=True)
class LLMRouterModelConfig:
model_name: str
# https://litellm.vercel.app/docs/routing
litellm_params: dict[str, Any]
tpm: int | None = None
rpm: int | None = None
@dataclass(frozen=True)
class LLMRouterConfig(LLMConfig):
model_list: list[LLMRouterModelConfig]
redis_host: str
redis_port: int
main_model_group: str
fallback_model_group: str | None = None
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
"usage-based-routing",
"latency-based-routing",
] = "usage-based-routing"
num_retries: int = 2
retry_delay_seconds: int = 15
set_verbose: bool = True
class LLMAPIHandler(Protocol):
def __call__(
self,