diff --git a/skyvern/config.py b/skyvern/config.py index 288418ac..1f9346a4 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -264,6 +264,28 @@ class Settings(BaseSettings): SKYVERN_BASE_URL: str = "https://api.skyvern.com" SKYVERN_API_KEY: str = "PLACEHOLDER" + def get_model_name_to_llm_key(self) -> dict[str, str]: + """ + Keys are model names available to blocks in the frontend. These map to key names + in LLMConfigRegistry._configs. + """ + + if self.is_cloud_environment(): + return { + "Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW", + "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", + "GPT 4.1": "OPENAI_GPT4_1", + "GPT o3-mini": "OPENAI_O3_MINI", + } + else: + # TODO: apparently the list for OSS is to be much larger + return { + "Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW", + "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", + "GPT 4.1": "OPENAI_GPT4_1", + "GPT o3-mini": "OPENAI_O3_MINI", + } + def is_cloud_environment(self) -> bool: """ :return: True if env is not local, else False diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 2b2e68e2..e519085c 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -381,13 +381,15 @@ class ForgeAgent: if page := await browser_state.get_working_page(): await self.register_async_operations(organization, task, page) - llm_caller = LLMCallerManager.get_llm_caller(task.task_id) - if engine == RunEngine.anthropic_cua and not llm_caller: - # llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE") + if not llm_caller: llm_caller = LLMCallerManager.get_llm_caller(task.task_id) - if not llm_caller: - llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True) - LLMCallerManager.set_llm_caller(task.task_id, llm_caller) + if engine == RunEngine.anthropic_cua and not llm_caller: + # llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE") + llm_caller = LLMCallerManager.get_llm_caller(task.task_id) + if not llm_caller: + llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True) + LLMCallerManager.set_llm_caller(task.task_id, llm_caller) + step, detailed_output = await self.agent_step( task, step, diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index be60447e..0e37102c 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -50,6 +50,10 @@ class LLMConfigRegistry: return cls._configs[llm_key] + @classmethod + def get_model_names(cls) -> list[str]: + return list(cls._configs.keys()) + if settings.ENABLE_OPENAI: LLMConfigRegistry.register_config( diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 3030e5e7..abc749fa 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1181,6 +1181,7 @@ class AgentDB: totp_verification_url: str | None = None, totp_identifier: str | None = None, persist_browser_session: bool = False, + model: dict[str, Any] | None = None, workflow_permanent_id: str | None = None, version: int | None = None, is_saved_task: bool = False, @@ -1197,6 +1198,7 @@ class AgentDB: totp_verification_url=totp_verification_url, totp_identifier=totp_identifier, persist_browser_session=persist_browser_session, + model=model, is_saved_task=is_saved_task, status=status, ) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 99902b61..9e8a2747 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -216,6 +216,7 @@ class WorkflowModel(Base): totp_verification_url = Column(String) totp_identifier = Column(String) persist_browser_session = Column(Boolean, default=False, nullable=False) + model = Column(JSON, nullable=True) status = Column(String, nullable=False, default="published") created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index c421b28a..605cb428 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -185,6 +185,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal totp_verification_url=workflow_model.totp_verification_url, totp_identifier=workflow_model.totp_identifier, persist_browser_session=workflow_model.persist_browser_session, + model=workflow_model.model, proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None), version=workflow_model.version, is_saved_task=workflow_model.is_saved_task, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index b9e843e9..caa7a62f 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -48,6 +48,7 @@ from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, Task from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, + ModelsResponse, OrderBy, SortDirection, Task, @@ -794,6 +795,24 @@ async def heartbeat() -> Response: return Response(content="Server is running.", status_code=200, headers={"X-Skyvern-API-Version": __version__}) +@legacy_base_router.get( + "/models", + tags=["agent"], + openapi_extra={ + "x-fern-sdk-group-name": "agent", + }, +) +@legacy_base_router.get("/models/", include_in_schema=False) +async def models() -> ModelsResponse: + """ + Get a list of available models. + """ + mapping = settings.get_model_name_to_llm_key() + models = list(mapping.keys()) + + return ModelsResponse(models=models) + + @legacy_base_router.post( "/tasks", tags=["agent"], diff --git a/skyvern/forge/sdk/routes/code_samples.py b/skyvern/forge/sdk/routes/code_samples.py index 11c7e484..f7524e3a 100644 --- a/skyvern/forge/sdk/routes/code_samples.py +++ b/skyvern/forge/sdk/routes/code_samples.py @@ -36,6 +36,8 @@ proxy_location: RESIDENTIAL webhook_callback_url: https://example.com/webhook totp_verification_url: https://example.com/totp persist_browser_session: false +model: + model: gpt-3.5-turbo workflow_definition: parameters: - key: website_url @@ -119,6 +121,7 @@ workflow_definition = { "webhook_callback_url": "https://example.com/webhook", "totp_verification_url": "https://example.com/totp", "totp_identifier": "4155555555", + "model": {"model": "gpt-3.5-turbo"}, "workflow_definition": { "parameters": [ { @@ -201,6 +204,7 @@ proxy_location: RESIDENTIAL webhook_callback_url: https://example.com/webhook totp_verification_url: https://example.com/totp persist_browser_session: false +model: {model: gpt-3.5-turbo} workflow_definition: parameters: - key: website_url @@ -283,6 +287,7 @@ updated_workflow_definition = { "webhook_callback_url": "https://example.com/webhook", "totp_verification_url": "https://example.com/totp", "totp_identifier": "4155555555", + "model": {"model": "gpt-3.5-turbo"}, "workflow_definition": { "parameters": [ { diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index c0254f43..4cdefc23 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -349,3 +349,7 @@ class OrderBy(StrEnum): class SortDirection(StrEnum): asc = "asc" desc = "desc" + + +class ModelsResponse(BaseModel): + models: list[str] diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index c3f0e8e3..fb7817aa 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -46,7 +46,7 @@ from skyvern.forge.sdk.api.files import ( download_from_s3, get_path_for_workflow_download_directory, ) -from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.db.enums import TaskType @@ -126,6 +126,7 @@ class Block(BaseModel, abc.ABC): block_type: BlockType output_parameter: OutputParameter continue_on_failure: bool = False + model: dict[str, Any] | None = None async def record_output_parameter_value( self, @@ -618,6 +619,9 @@ class BaseTaskBlock(Block): try: current_context = skyvern_context.ensure_context() current_context.task_id = task.task_id + llm_key = workflow.determine_llm_key(block=self) + llm_caller = None if not llm_key else LLMCaller(llm_key=llm_key) + await app.agent.execute_step( organization=organization, task=task, @@ -627,6 +631,7 @@ class BaseTaskBlock(Block): close_browser_on_completion=browser_session_id is None, complete_verification=self.complete_verification, engine=self.engine, + llm_caller=llm_caller, ) except Exception as e: # Make sure the task is marked as failed in the database before raising the exception diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index f8d4a466..61d55b42 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -5,10 +5,11 @@ from typing import Any, List from pydantic import BaseModel, field_validator from typing_extensions import deprecated +from skyvern.config import settings from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.task_v2 import TaskV2 from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels -from skyvern.forge.sdk.workflow.models.block import BlockTypeVar +from skyvern.forge.sdk.workflow.models.block import Block, BlockTypeVar from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE from skyvern.schemas.runs import ProxyLocation from skyvern.utils.url_validators import validate_url @@ -74,12 +75,41 @@ class Workflow(BaseModel): totp_verification_url: str | None = None totp_identifier: str | None = None persist_browser_session: bool = False + model: dict[str, Any] | None = None status: WorkflowStatus = WorkflowStatus.published created_at: datetime modified_at: datetime deleted_at: datetime | None = None + def determine_llm_key(self, *, block: Block | None = None) -> str | None: + """ + Determine the LLM key override to use for a block, if it has one. + + It has one if: + - it defines one, or + - the workflow it is a part of (if applicable) defines one + """ + + mapping = settings.get_model_name_to_llm_key() + + if block: + model_name = (block.model or {}).get("model") + + if model_name: + llm_key = mapping.get(model_name) + if llm_key: + return llm_key + + workflow_model_name = (self.model or {}).get("model") + + if workflow_model_name: + llm_key = mapping.get(workflow_model_name) + if llm_key: + return llm_key + + return None + class WorkflowRunStatus(StrEnum): created = "created" diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index 0cd29b3a..772fa065 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -117,6 +117,7 @@ class BlockYAML(BaseModel, abc.ABC): block_type: BlockType label: str continue_on_failure: bool = False + model: dict[str, Any] | None = None class TaskBlockYAML(BlockYAML): @@ -413,6 +414,7 @@ class WorkflowCreateYAMLRequest(BaseModel): totp_verification_url: str | None = None totp_identifier: str | None = None persist_browser_session: bool = False + model: dict[str, Any] | None = None workflow_definition: WorkflowDefinitionYAML is_saved_task: bool = False status: WorkflowStatus = WorkflowStatus.published diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 02811bb7..14314491 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -180,7 +180,7 @@ class WorkflowService: else: raise MissingValueForParameter( parameter_key=workflow_parameter.key, - workflow_id=workflow.workflow_permanent_id, + workflow_id=workflow.workflow_id, workflow_run_id=workflow_run.workflow_run_id, ) except Exception as e: @@ -336,6 +336,7 @@ class WorkflowService: block_idx=block_idx, block_type_var=block.block_type, block_label=block.label, + model=block.model, ) block_result = await block.execute_safe( workflow_run_id=workflow_run_id, @@ -550,6 +551,7 @@ class WorkflowService: totp_verification_url: str | None = None, totp_identifier: str | None = None, persist_browser_session: bool = False, + model: dict[str, Any] | None = None, workflow_permanent_id: str | None = None, version: int | None = None, is_saved_task: bool = False, @@ -565,6 +567,7 @@ class WorkflowService: totp_verification_url=totp_verification_url, totp_identifier=totp_identifier, persist_browser_session=persist_browser_session, + model=model, workflow_permanent_id=workflow_permanent_id, version=version, is_saved_task=is_saved_task, @@ -592,6 +595,7 @@ class WorkflowService: ) if not workflow: raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version) + return workflow async def get_workflows_by_permanent_ids( @@ -1403,6 +1407,7 @@ class WorkflowService: totp_verification_url=request.totp_verification_url, totp_identifier=request.totp_identifier, persist_browser_session=request.persist_browser_session, + model=request.model, workflow_permanent_id=workflow_permanent_id, version=existing_version + 1, is_saved_task=request.is_saved_task, @@ -1419,6 +1424,7 @@ class WorkflowService: totp_verification_url=request.totp_verification_url, totp_identifier=request.totp_identifier, persist_browser_session=request.persist_browser_session, + model=request.model, is_saved_task=request.is_saved_task, status=request.status, ) @@ -1657,6 +1663,7 @@ class WorkflowService: error_code_mapping=block_yaml.error_code_mapping, max_steps_per_run=block_yaml.max_steps_per_run, max_retries=block_yaml.max_retries, + model=block_yaml.model, complete_on_download=block_yaml.complete_on_download, download_suffix=block_yaml.download_suffix, continue_on_failure=block_yaml.continue_on_failure, @@ -1724,6 +1731,7 @@ class WorkflowService: json_schema=block_yaml.json_schema, output_parameter=output_parameter, continue_on_failure=block_yaml.continue_on_failure, + model=block_yaml.model, ) elif block_yaml.block_type == BlockType.DOWNLOAD_TO_S3: return DownloadToS3Block( @@ -1781,6 +1789,7 @@ class WorkflowService: file_url=block_yaml.file_url, json_schema=block_yaml.json_schema, continue_on_failure=block_yaml.continue_on_failure, + model=block_yaml.model, ) elif block_yaml.block_type == BlockType.VALIDATION: validation_block_parameters = ( @@ -1803,6 +1812,7 @@ class WorkflowService: continue_on_failure=block_yaml.continue_on_failure, # only need one step for validation block max_steps_per_run=1, + model=block_yaml.model, ) elif block_yaml.block_type == BlockType.ACTION: @@ -1826,6 +1836,7 @@ class WorkflowService: navigation_goal=block_yaml.navigation_goal, error_code_mapping=block_yaml.error_code_mapping, max_retries=block_yaml.max_retries, + model=block_yaml.model, complete_on_download=block_yaml.complete_on_download, download_suffix=block_yaml.download_suffix, continue_on_failure=block_yaml.continue_on_failure, @@ -1854,6 +1865,7 @@ class WorkflowService: error_code_mapping=block_yaml.error_code_mapping, max_steps_per_run=block_yaml.max_steps_per_run, max_retries=block_yaml.max_retries, + model=block_yaml.model, complete_on_download=block_yaml.complete_on_download, download_suffix=block_yaml.download_suffix, continue_on_failure=block_yaml.continue_on_failure, @@ -1883,6 +1895,7 @@ class WorkflowService: data_schema=block_yaml.data_schema, max_steps_per_run=block_yaml.max_steps_per_run, max_retries=block_yaml.max_retries, + model=block_yaml.model, continue_on_failure=block_yaml.continue_on_failure, cache_actions=block_yaml.cache_actions, complete_verification=False, @@ -1905,6 +1918,7 @@ class WorkflowService: error_code_mapping=block_yaml.error_code_mapping, max_steps_per_run=block_yaml.max_steps_per_run, max_retries=block_yaml.max_retries, + model=block_yaml.model, continue_on_failure=block_yaml.continue_on_failure, totp_verification_url=block_yaml.totp_verification_url, totp_identifier=block_yaml.totp_identifier, @@ -1942,6 +1956,7 @@ class WorkflowService: error_code_mapping=block_yaml.error_code_mapping, max_steps_per_run=block_yaml.max_steps_per_run, max_retries=block_yaml.max_retries, + model=block_yaml.model, download_suffix=block_yaml.download_suffix, continue_on_failure=block_yaml.continue_on_failure, totp_verification_url=block_yaml.totp_verification_url, @@ -1959,6 +1974,7 @@ class WorkflowService: totp_identifier=block_yaml.totp_identifier, max_iterations=block_yaml.max_iterations, max_steps=block_yaml.max_steps, + model=block_yaml.model, output_parameter=output_parameter, ) elif block_yaml.block_type == BlockType.GOTO_URL: