Jon/model name massage (#2559)

This commit is contained in:
Shuchang Zheng 2025-05-31 19:34:30 -07:00 committed by GitHub
parent b4d5837196
commit 2167d88c20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 67 additions and 60 deletions

View file

@ -393,7 +393,7 @@ export type CreditCardCredential = {
}; };
export type ModelsResponse = { export type ModelsResponse = {
models: string[]; models: Record<string, string>;
}; };
export const RunEngine = { export const RunEngine = {

View file

@ -41,8 +41,17 @@ function ModelSelector({
}, },
}); });
const models = availableModels?.models ?? []; const models = availableModels?.models ?? {};
const choices = [constants.SkyvernOptimized, ...models]; const reverseMap = Object.entries(models).reduce(
(acc, [key, value]) => {
acc[value] = key;
return acc;
},
{} as Record<string, string>,
);
const labels = Object.keys(reverseMap);
const chosen = value ? models[value.model_name] : constants.SkyvernOptimized;
const choices = [constants.SkyvernOptimized, ...labels];
return ( return (
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
@ -52,10 +61,13 @@ function ModelSelector({
</div> </div>
<div className="relative flex items-center"> <div className="relative flex items-center">
<Select <Select
value={value?.model ?? ""} value={chosen}
onValueChange={(v) => { onValueChange={(v) => {
const newValue = v === constants.SkyvernOptimized ? null : v; const newValue = v === constants.SkyvernOptimized ? null : v;
onChange(newValue ? { model: newValue } : null); const modelName = newValue ? reverseMap[newValue] : null;
const value = modelName ? { model_name: modelName } : null;
console.log({ v, newValue, modelName, value });
onChange(value);
}} }}
> >
<SelectTrigger <SelectTrigger

View file

@ -29,7 +29,7 @@ export const navigationNodeDefaultData: NavigationNodeData = {
completeCriterion: "", completeCriterion: "",
terminateCriterion: "", terminateCriterion: "",
errorCodeMapping: "null", errorCodeMapping: "null",
model: { model: "" }, model: { model_name: "" },
engine: RunEngine.SkyvernV1, engine: RunEngine.SkyvernV1,
maxRetries: null, maxRetries: null,
maxStepsOverride: null, maxStepsOverride: null,

View file

@ -19,6 +19,7 @@ import { Switch } from "@/components/ui/switch";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { ModelsResponse } from "@/api/types"; import { ModelsResponse } from "@/api/types";
import { ModelSelector } from "@/components/ModelSelector"; import { ModelSelector } from "@/components/ModelSelector";
import { WorkflowModel } from "@/routes/workflows/types/workflowTypes";
function StartNode({ id, data }: NodeProps<StartNode>) { function StartNode({ id, data }: NodeProps<StartNode>) {
const credentialGetter = useCredentialGetter(); const credentialGetter = useCredentialGetter();
@ -33,7 +34,11 @@ function StartNode({ id, data }: NodeProps<StartNode>) {
}, },
}); });
const models = availableModels?.models ?? []; const modelNames = availableModels?.models ?? {};
const firstKey = Object.keys(modelNames)[0];
const workflowModel: WorkflowModel | null = firstKey
? { model_name: modelNames[firstKey] || "" }
: null;
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
webhookCallbackUrl: data.withWorkflowSettings webhookCallbackUrl: data.withWorkflowSettings
@ -45,7 +50,7 @@ function StartNode({ id, data }: NodeProps<StartNode>) {
persistBrowserSession: data.withWorkflowSettings persistBrowserSession: data.withWorkflowSettings
? data.persistBrowserSession ? data.persistBrowserSession
: false, : false,
model: data.withWorkflowSettings ? data.model : { model: models[0] || "" }, model: data.withWorkflowSettings ? data.model : workflowModel,
}); });
function handleChange(key: string, value: unknown) { function handleChange(key: string, value: unknown) {

View file

@ -466,7 +466,7 @@ export type WorkflowSettings = {
model: WorkflowModel | null; model: WorkflowModel | null;
}; };
export type WorkflowModel = JsonObjectExtendable<{ model: string }>; export type WorkflowModel = JsonObjectExtendable<{ model_name: string }>;
export function isOutputParameter( export function isOutputParameter(
parameter: Parameter, parameter: Parameter,

View file

@ -264,7 +264,7 @@ class Settings(BaseSettings):
SKYVERN_BASE_URL: str = "https://api.skyvern.com" SKYVERN_BASE_URL: str = "https://api.skyvern.com"
SKYVERN_API_KEY: str = "PLACEHOLDER" SKYVERN_API_KEY: str = "PLACEHOLDER"
def get_model_name_to_llm_key(self) -> dict[str, str]: def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
""" """
Keys are model names available to blocks in the frontend. These map to key names Keys are model names available to blocks in the frontend. These map to key names
in LLMConfigRegistry._configs. in LLMConfigRegistry._configs.
@ -272,22 +272,40 @@ class Settings(BaseSettings):
if self.is_cloud_environment(): if self.is_cloud_environment():
return { return {
"Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW", "gemini-2.5-pro-preview-05-06": {"llm_key": "VERTEX_GEMINI_2.5_PRO_PREVIEW", "label": "Gemini 2.5 Pro"},
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "gemini-2.5-flash-preview-05-20": {
"GPT 4.1": "OPENAI_GPT4_1", "llm_key": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
"GPT o3-mini": "OPENAI_O3_MINI", "label": "Gemini 2.5 Flash",
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE", },
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE", "azure/gpt-4.1": {"llm_key": "AZURE_OPENAI_GPT4_1", "label": "GPT 4.1"},
"azure/o3-mini": {"llm_key": "AZURE_OPENAI_O3_MINI", "label": "GPT O3 Mini"},
"us.anthropic.claude-opus-4-20250514-v1:0": {
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
"label": "Anthropic Claude 4 Opus",
},
"us.anthropic.claude-sonnet-4-20250514-v1:0": {
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
"label": "Anthropic Claude 4 Sonnet",
},
} }
else: else:
# TODO: apparently the list for OSS is to be much larger # TODO: apparently the list for OSS is to be much larger
return { return {
"Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW", "gemini-2.5-pro-preview-05-06": {"llm_key": "VERTEX_GEMINI_2.5_PRO_PREVIEW", "label": "Gemini 2.5 Pro"},
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "gemini-2.5-flash-preview-05-20": {
"GPT 4.1": "OPENAI_GPT4_1", "llm_key": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
"GPT o3-mini": "OPENAI_O3_MINI", "label": "Gemini 2.5 Flash",
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE", },
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE", "azure/gpt-4.1": {"llm_key": "AZURE_OPENAI_GPT4_1", "label": "GPT 4.1"},
"azure/o3-mini": {"llm_key": "AZURE_OPENAI_O3_MINI", "label": "GPT O3 Mini"},
"us.anthropic.claude-opus-4-20250514-v1:0": {
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
"label": "Anthropic Claude 4 Opus",
},
"us.anthropic.claude-sonnet-4-20250514-v1:0": {
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
"label": "Anthropic Claude 4 Sonnet",
},
} }
def is_cloud_environment(self) -> bool: def is_cloud_environment(self) -> bool:

View file

@ -867,12 +867,13 @@ class ForgeAgent:
else: else:
if engine in CUA_ENGINES: if engine in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm) self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER( json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt, prompt=extract_action_prompt,
prompt_name="extract-actions", prompt_name="extract-actions",
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
llm_key_override=llm_caller.llm_key if llm_caller else None, llm_key_override=task.llm_key,
) )
try: try:
json_response = await self.handle_potential_verification_code( json_response = await self.handle_potential_verification_code(

View file

@ -810,9 +810,9 @@ async def models() -> ModelsResponse:
Get a list of available models. Get a list of available models.
""" """
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
models = list(mapping.keys()) just_labels = {k: v["label"] for k, v in mapping.items() if "anthropic" not in k.lower()}
return ModelsResponse(models=models) return ModelsResponse(models=just_labels)
@legacy_base_router.post( @legacy_base_router.post(

View file

@ -57,10 +57,10 @@ class TaskV2(BaseModel):
""" """
if self.model: if self.model:
model_name = self.model.get("name") model_name = self.model.get("model_name")
if model_name: if model_name:
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
llm_key = mapping.get(model_name) llm_key = mapping.get(model_name, {}).get("llm_key")
if llm_key: if llm_key:
return llm_key return llm_key

View file

@ -248,10 +248,10 @@ class Task(TaskBase):
Otherwise return `None`. Otherwise return `None`.
""" """
if self.model: if self.model:
model_name = self.model.get("name") model_name = self.model.get("model_name")
if model_name: if model_name:
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
return mapping.get(model_name) return mapping.get(model_name, {}).get("llm_key")
return None return None
@ -370,4 +370,4 @@ class SortDirection(StrEnum):
class ModelsResponse(BaseModel): class ModelsResponse(BaseModel):
models: list[str] models: dict[str, str]

View file

@ -5,11 +5,10 @@ from typing import Any, List
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from typing_extensions import deprecated from typing_extensions import deprecated
from skyvern.config import settings
from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import TaskV2 from skyvern.forge.sdk.schemas.task_v2 import TaskV2
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
from skyvern.forge.sdk.workflow.models.block import Block, BlockTypeVar from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE
from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.runs import ProxyLocation
from skyvern.utils.url_validators import validate_url from skyvern.utils.url_validators import validate_url
@ -82,34 +81,6 @@ class Workflow(BaseModel):
modified_at: datetime modified_at: datetime
deleted_at: datetime | None = None deleted_at: datetime | None = None
def determine_llm_key(self, *, block: Block | None = None) -> str | None:
"""
Determine the LLM key override to use for a block, if it has one.
It has one if:
- it defines one, or
- the workflow it is a part of (if applicable) defines one
"""
mapping = settings.get_model_name_to_llm_key()
if block:
model_name = (block.model or {}).get("name")
if model_name:
llm_key = mapping.get(model_name)
if llm_key:
return llm_key
workflow_model_name = (self.model or {}).get("name")
if workflow_model_name:
llm_key = mapping.get(workflow_model_name)
if llm_key:
return llm_key
return None
class WorkflowRunStatus(StrEnum): class WorkflowRunStatus(StrEnum):
created = "created" created = "created"