Implement LLMRouter (#127)

This commit is contained in:
Kerem Yilmaz 2024-03-27 14:44:25 -07:00 committed by GitHub
parent c58aaba4bb
commit 1c397a13af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 146 additions and 8 deletions

View file

@ -1,3 +1,4 @@
import dataclasses
import json import json
from typing import Any from typing import Any
@ -7,8 +8,12 @@ import structlog
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError from skyvern.forge.sdk.api.llm.exceptions import (
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler DuplicateCustomLLMProviderError,
InvalidLLMConfigError,
LLMProviderError,
)
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
@ -20,10 +25,112 @@ LOG = structlog.get_logger()
class LLMAPIHandlerFactory: class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {} _custom_handlers: dict[str, LLMAPIHandler] = {}
@staticmethod
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
if not isinstance(llm_config, LLMRouterConfig):
raise InvalidLLMConfigError(llm_key)
router = litellm.Router(
model_list=[dataclasses.asdict(model) for model in llm_config.model_list],
redis_host=llm_config.redis_host,
redis_port=llm_config.redis_port,
routing_strategy=llm_config.routing_strategy,
fallbacks=[{llm_config.main_model_group: llm_config.fallback_model_group}]
if llm_config.fallback_model_group
else [],
num_retries=llm_config.num_retries,
retry_after=llm_config.retry_delay_seconds,
set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose,
)
main_model_group = llm_config.main_model_group
async def llm_api_handler_with_router_and_fallback(
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
Args:
prompt: The prompt to generate completions for.
step: The step object associated with the prompt.
screenshots: The screenshots associated with the prompt.
parameters: Additional parameters to be passed to the LLM router.
Returns:
The response from the LLM router.
"""
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters()
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
messages = await llm_messages_builder(prompt, screenshots)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
)
parsed_response = parse_api_response(response)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
return llm_api_handler_with_router_and_fallback
@staticmethod @staticmethod
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler: def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key) llm_config = LLMConfigRegistry.get_config(llm_key)
if LLMConfigRegistry.is_router_config(llm_key):
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
async def llm_api_handler( async def llm_api_handler(
prompt: str, prompt: str,
step: Step | None = None, step: Step | None = None,

View file

@ -6,23 +6,27 @@ from skyvern.forge.sdk.api.llm.exceptions import (
MissingLLMProviderEnvVarsError, MissingLLMProviderEnvVarsError,
NoProviderEnabledError, NoProviderEnabledError,
) )
from skyvern.forge.sdk.api.llm.models import LLMConfig from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger() LOG = structlog.get_logger()
class LLMConfigRegistry: class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {} _configs: dict[str, LLMRouterConfig | LLMConfig] = {}
@staticmethod @staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None: def is_router_config(llm_key: str) -> bool:
return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig)
@staticmethod
def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars() missing_env_vars = config.get_missing_env_vars()
if missing_env_vars: if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars) raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
@classmethod @classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None: def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
if llm_key in cls._configs: if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key) raise DuplicateLLMConfigError(llm_key)
@ -32,7 +36,7 @@ class LLMConfigRegistry:
cls._configs[llm_key] = config cls._configs[llm_key] = config
@classmethod @classmethod
def get_config(cls, llm_key: str) -> LLMConfig: def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
if llm_key not in cls._configs: if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key) raise InvalidLLMConfigError(llm_key)

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Awaitable, Protocol from typing import Any, Awaitable, Literal, Protocol
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
@ -21,6 +21,33 @@ class LLMConfig:
return missing_env_vars return missing_env_vars
@dataclass(frozen=True)
class LLMRouterModelConfig:
model_name: str
# https://litellm.vercel.app/docs/routing
litellm_params: dict[str, Any]
tpm: int | None = None
rpm: int | None = None
@dataclass(frozen=True)
class LLMRouterConfig(LLMConfig):
model_list: list[LLMRouterModelConfig]
redis_host: str
redis_port: int
main_model_group: str
fallback_model_group: str | None = None
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
"usage-based-routing",
"latency-based-routing",
] = "usage-based-routing"
num_retries: int = 2
retry_delay_seconds: int = 15
set_verbose: bool = True
class LLMAPIHandler(Protocol): class LLMAPIHandler(Protocol):
def __call__( def __call__(
self, self,