ModelSelector: add 'Skyvern Optimized' as hard choice (#2558)

This commit is contained in:
Shuchang Zheng 2025-05-31 18:42:57 -07:00 committed by GitHub
parent 800a26d323
commit b4d5837196
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 59 additions and 21 deletions

View file

@ -21,6 +21,10 @@ type Props = {
onChange: (value: WorkflowModel | null) => void; onChange: (value: WorkflowModel | null) => void;
}; };
const constants = {
SkyvernOptimized: "Skyvern Optimized",
} as const;
function ModelSelector({ function ModelSelector({
clearable = true, clearable = true,
value, value,
@ -38,6 +42,7 @@ function ModelSelector({
}); });
const models = availableModels?.models ?? []; const models = availableModels?.models ?? [];
const choices = [constants.SkyvernOptimized, ...models];
return ( return (
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
@ -49,18 +54,23 @@ function ModelSelector({
<Select <Select
value={value?.model ?? ""} value={value?.model ?? ""}
onValueChange={(v) => { onValueChange={(v) => {
onChange({ model: v }); const newValue = v === constants.SkyvernOptimized ? null : v;
onChange(newValue ? { model: newValue } : null);
}} }}
> >
<SelectTrigger <SelectTrigger
className={(className || "") + (value && clearable ? " pr-10" : "")} className={(className || "") + (value && clearable ? " pr-10" : "")}
> >
<SelectValue placeholder="Skyvern Optimized" /> <SelectValue placeholder={constants.SkyvernOptimized} />
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
{models.map((m) => ( {choices.map((m) => (
<SelectItem key={m} value={m}> <SelectItem key={m} value={m}>
{m} {m === constants.SkyvernOptimized ? (
<span>Skyvern Optimized </span>
) : (
m
)}
</SelectItem> </SelectItem>
))} ))}
</SelectContent> </SelectContent>

View file

@ -272,22 +272,22 @@ class Settings(BaseSettings):
if self.is_cloud_environment(): if self.is_cloud_environment():
return { return {
"gemini-2.5-pro-preview-05-06": "VERTEX_GEMINI_2.5_PRO_PREVIEW", "Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW",
"gemini-2.5-flash-preview-05-20": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
"azure/gpt-4.1": "AZURE_OPENAI_GPT4_1", "GPT 4.1": "OPENAI_GPT4_1",
"azure/o3-mini": "AZURE_OPENAI_O3_MINI", "GPT o3-mini": "OPENAI_O3_MINI",
"us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE", "bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
"us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE", "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
} }
else: else:
# TODO: apparently the list for OSS is to be much larger # TODO: apparently the list for OSS is to be much larger
return { return {
"gemini-2.5-pro-preview-05-06": "VERTEX_GEMINI_2.5_PRO_PREVIEW", "Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW",
"gemini-2.5-flash-preview-05-20": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
"azure/gpt-4.1": "AZURE_OPENAI_GPT4_1", "GPT 4.1": "OPENAI_GPT4_1",
"azure/o3-mini": "AZURE_OPENAI_O3_MINI", "GPT o3-mini": "OPENAI_O3_MINI",
"us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE", "bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
"us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE", "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
} }
def is_cloud_environment(self) -> bool: def is_cloud_environment(self) -> bool:

View file

@ -867,13 +867,12 @@ class ForgeAgent:
else: else:
if engine in CUA_ENGINES: if engine in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm) self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER( json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt, prompt=extract_action_prompt,
prompt_name="extract-actions", prompt_name="extract-actions",
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
llm_key_override=task.llm_key, llm_key_override=llm_caller.llm_key if llm_caller else None,
) )
try: try:
json_response = await self.handle_potential_verification_code( json_response = await self.handle_potential_verification_code(

View file

@ -57,7 +57,7 @@ class TaskV2(BaseModel):
""" """
if self.model: if self.model:
model_name = self.model.get("model_name") model_name = self.model.get("name")
if model_name: if model_name:
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
llm_key = mapping.get(model_name) llm_key = mapping.get(model_name)

View file

@ -248,7 +248,7 @@ class Task(TaskBase):
Otherwise return `None`. Otherwise return `None`.
""" """
if self.model: if self.model:
model_name = self.model.get("model_name") model_name = self.model.get("name")
if model_name: if model_name:
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
return mapping.get(model_name) return mapping.get(model_name)

View file

@ -5,10 +5,11 @@ from typing import Any, List
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from typing_extensions import deprecated from typing_extensions import deprecated
from skyvern.config import settings
from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import TaskV2 from skyvern.forge.sdk.schemas.task_v2 import TaskV2
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar from skyvern.forge.sdk.workflow.models.block import Block, BlockTypeVar
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE
from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.runs import ProxyLocation
from skyvern.utils.url_validators import validate_url from skyvern.utils.url_validators import validate_url
@ -81,6 +82,34 @@ class Workflow(BaseModel):
modified_at: datetime modified_at: datetime
deleted_at: datetime | None = None deleted_at: datetime | None = None
def determine_llm_key(self, *, block: Block | None = None) -> str | None:
"""
Determine the LLM key override to use for a block, if it has one.
It has one if:
- it defines one, or
- the workflow it is a part of (if applicable) defines one
"""
mapping = settings.get_model_name_to_llm_key()
if block:
model_name = (block.model or {}).get("name")
if model_name:
llm_key = mapping.get(model_name)
if llm_key:
return llm_key
workflow_model_name = (self.model or {}).get("name")
if workflow_model_name:
llm_key = mapping.get(workflow_model_name)
if llm_key:
return llm_key
return None
class WorkflowRunStatus(StrEnum): class WorkflowRunStatus(StrEnum):
created = "created" created = "created"