mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-15 09:49:46 +00:00
Jon/model name massage (#2559)
This commit is contained in:
parent
b4d5837196
commit
2167d88c20
11 changed files with 67 additions and 60 deletions
|
@ -393,7 +393,7 @@ export type CreditCardCredential = {
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelsResponse = {
|
export type ModelsResponse = {
|
||||||
models: string[];
|
models: Record<string, string>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const RunEngine = {
|
export const RunEngine = {
|
||||||
|
|
|
@ -41,8 +41,17 @@ function ModelSelector({
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const models = availableModels?.models ?? [];
|
const models = availableModels?.models ?? {};
|
||||||
const choices = [constants.SkyvernOptimized, ...models];
|
const reverseMap = Object.entries(models).reduce(
|
||||||
|
(acc, [key, value]) => {
|
||||||
|
acc[value] = key;
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{} as Record<string, string>,
|
||||||
|
);
|
||||||
|
const labels = Object.keys(reverseMap);
|
||||||
|
const chosen = value ? models[value.model_name] : constants.SkyvernOptimized;
|
||||||
|
const choices = [constants.SkyvernOptimized, ...labels];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
|
@ -52,10 +61,13 @@ function ModelSelector({
|
||||||
</div>
|
</div>
|
||||||
<div className="relative flex items-center">
|
<div className="relative flex items-center">
|
||||||
<Select
|
<Select
|
||||||
value={value?.model ?? ""}
|
value={chosen}
|
||||||
onValueChange={(v) => {
|
onValueChange={(v) => {
|
||||||
const newValue = v === constants.SkyvernOptimized ? null : v;
|
const newValue = v === constants.SkyvernOptimized ? null : v;
|
||||||
onChange(newValue ? { model: newValue } : null);
|
const modelName = newValue ? reverseMap[newValue] : null;
|
||||||
|
const value = modelName ? { model_name: modelName } : null;
|
||||||
|
console.log({ v, newValue, modelName, value });
|
||||||
|
onChange(value);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<SelectTrigger
|
<SelectTrigger
|
||||||
|
|
|
@ -29,7 +29,7 @@ export const navigationNodeDefaultData: NavigationNodeData = {
|
||||||
completeCriterion: "",
|
completeCriterion: "",
|
||||||
terminateCriterion: "",
|
terminateCriterion: "",
|
||||||
errorCodeMapping: "null",
|
errorCodeMapping: "null",
|
||||||
model: { model: "" },
|
model: { model_name: "" },
|
||||||
engine: RunEngine.SkyvernV1,
|
engine: RunEngine.SkyvernV1,
|
||||||
maxRetries: null,
|
maxRetries: null,
|
||||||
maxStepsOverride: null,
|
maxStepsOverride: null,
|
||||||
|
|
|
@ -19,6 +19,7 @@ import { Switch } from "@/components/ui/switch";
|
||||||
import { Separator } from "@/components/ui/separator";
|
import { Separator } from "@/components/ui/separator";
|
||||||
import { ModelsResponse } from "@/api/types";
|
import { ModelsResponse } from "@/api/types";
|
||||||
import { ModelSelector } from "@/components/ModelSelector";
|
import { ModelSelector } from "@/components/ModelSelector";
|
||||||
|
import { WorkflowModel } from "@/routes/workflows/types/workflowTypes";
|
||||||
|
|
||||||
function StartNode({ id, data }: NodeProps<StartNode>) {
|
function StartNode({ id, data }: NodeProps<StartNode>) {
|
||||||
const credentialGetter = useCredentialGetter();
|
const credentialGetter = useCredentialGetter();
|
||||||
|
@ -33,7 +34,11 @@ function StartNode({ id, data }: NodeProps<StartNode>) {
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const models = availableModels?.models ?? [];
|
const modelNames = availableModels?.models ?? {};
|
||||||
|
const firstKey = Object.keys(modelNames)[0];
|
||||||
|
const workflowModel: WorkflowModel | null = firstKey
|
||||||
|
? { model_name: modelNames[firstKey] || "" }
|
||||||
|
: null;
|
||||||
|
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
webhookCallbackUrl: data.withWorkflowSettings
|
webhookCallbackUrl: data.withWorkflowSettings
|
||||||
|
@ -45,7 +50,7 @@ function StartNode({ id, data }: NodeProps<StartNode>) {
|
||||||
persistBrowserSession: data.withWorkflowSettings
|
persistBrowserSession: data.withWorkflowSettings
|
||||||
? data.persistBrowserSession
|
? data.persistBrowserSession
|
||||||
: false,
|
: false,
|
||||||
model: data.withWorkflowSettings ? data.model : { model: models[0] || "" },
|
model: data.withWorkflowSettings ? data.model : workflowModel,
|
||||||
});
|
});
|
||||||
|
|
||||||
function handleChange(key: string, value: unknown) {
|
function handleChange(key: string, value: unknown) {
|
||||||
|
|
|
@ -466,7 +466,7 @@ export type WorkflowSettings = {
|
||||||
model: WorkflowModel | null;
|
model: WorkflowModel | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type WorkflowModel = JsonObjectExtendable<{ model: string }>;
|
export type WorkflowModel = JsonObjectExtendable<{ model_name: string }>;
|
||||||
|
|
||||||
export function isOutputParameter(
|
export function isOutputParameter(
|
||||||
parameter: Parameter,
|
parameter: Parameter,
|
||||||
|
|
|
@ -264,7 +264,7 @@ class Settings(BaseSettings):
|
||||||
SKYVERN_BASE_URL: str = "https://api.skyvern.com"
|
SKYVERN_BASE_URL: str = "https://api.skyvern.com"
|
||||||
SKYVERN_API_KEY: str = "PLACEHOLDER"
|
SKYVERN_API_KEY: str = "PLACEHOLDER"
|
||||||
|
|
||||||
def get_model_name_to_llm_key(self) -> dict[str, str]:
|
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Keys are model names available to blocks in the frontend. These map to key names
|
Keys are model names available to blocks in the frontend. These map to key names
|
||||||
in LLMConfigRegistry._configs.
|
in LLMConfigRegistry._configs.
|
||||||
|
@ -272,22 +272,40 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
if self.is_cloud_environment():
|
if self.is_cloud_environment():
|
||||||
return {
|
return {
|
||||||
"Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW",
|
"gemini-2.5-pro-preview-05-06": {"llm_key": "VERTEX_GEMINI_2.5_PRO_PREVIEW", "label": "Gemini 2.5 Pro"},
|
||||||
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
|
"gemini-2.5-flash-preview-05-20": {
|
||||||
"GPT 4.1": "OPENAI_GPT4_1",
|
"llm_key": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
|
||||||
"GPT o3-mini": "OPENAI_O3_MINI",
|
"label": "Gemini 2.5 Flash",
|
||||||
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
|
},
|
||||||
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
|
"azure/gpt-4.1": {"llm_key": "AZURE_OPENAI_GPT4_1", "label": "GPT 4.1"},
|
||||||
|
"azure/o3-mini": {"llm_key": "AZURE_OPENAI_O3_MINI", "label": "GPT O3 Mini"},
|
||||||
|
"us.anthropic.claude-opus-4-20250514-v1:0": {
|
||||||
|
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
|
||||||
|
"label": "Anthropic Claude 4 Opus",
|
||||||
|
},
|
||||||
|
"us.anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||||
|
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
|
||||||
|
"label": "Anthropic Claude 4 Sonnet",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# TODO: apparently the list for OSS is to be much larger
|
# TODO: apparently the list for OSS is to be much larger
|
||||||
return {
|
return {
|
||||||
"Gemini 2.5": "GEMINI_2.5_PRO_PREVIEW",
|
"gemini-2.5-pro-preview-05-06": {"llm_key": "VERTEX_GEMINI_2.5_PRO_PREVIEW", "label": "Gemini 2.5 Pro"},
|
||||||
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
|
"gemini-2.5-flash-preview-05-20": {
|
||||||
"GPT 4.1": "OPENAI_GPT4_1",
|
"llm_key": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
|
||||||
"GPT o3-mini": "OPENAI_O3_MINI",
|
"label": "Gemini 2.5 Flash",
|
||||||
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
|
},
|
||||||
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
|
"azure/gpt-4.1": {"llm_key": "AZURE_OPENAI_GPT4_1", "label": "GPT 4.1"},
|
||||||
|
"azure/o3-mini": {"llm_key": "AZURE_OPENAI_O3_MINI", "label": "GPT O3 Mini"},
|
||||||
|
"us.anthropic.claude-opus-4-20250514-v1:0": {
|
||||||
|
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
|
||||||
|
"label": "Anthropic Claude 4 Opus",
|
||||||
|
},
|
||||||
|
"us.anthropic.claude-sonnet-4-20250514-v1:0": {
|
||||||
|
"llm_key": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
|
||||||
|
"label": "Anthropic Claude 4 Sonnet",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def is_cloud_environment(self) -> bool:
|
def is_cloud_environment(self) -> bool:
|
||||||
|
|
|
@ -867,12 +867,13 @@ class ForgeAgent:
|
||||||
else:
|
else:
|
||||||
if engine in CUA_ENGINES:
|
if engine in CUA_ENGINES:
|
||||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
||||||
|
|
||||||
json_response = await app.LLM_API_HANDLER(
|
json_response = await app.LLM_API_HANDLER(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
prompt_name="extract-actions",
|
prompt_name="extract-actions",
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
llm_key_override=llm_caller.llm_key if llm_caller else None,
|
llm_key_override=task.llm_key,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
json_response = await self.handle_potential_verification_code(
|
json_response = await self.handle_potential_verification_code(
|
||||||
|
|
|
@ -810,9 +810,9 @@ async def models() -> ModelsResponse:
|
||||||
Get a list of available models.
|
Get a list of available models.
|
||||||
"""
|
"""
|
||||||
mapping = settings.get_model_name_to_llm_key()
|
mapping = settings.get_model_name_to_llm_key()
|
||||||
models = list(mapping.keys())
|
just_labels = {k: v["label"] for k, v in mapping.items() if "anthropic" not in k.lower()}
|
||||||
|
|
||||||
return ModelsResponse(models=models)
|
return ModelsResponse(models=just_labels)
|
||||||
|
|
||||||
|
|
||||||
@legacy_base_router.post(
|
@legacy_base_router.post(
|
||||||
|
|
|
@ -57,10 +57,10 @@ class TaskV2(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.model:
|
if self.model:
|
||||||
model_name = self.model.get("name")
|
model_name = self.model.get("model_name")
|
||||||
if model_name:
|
if model_name:
|
||||||
mapping = settings.get_model_name_to_llm_key()
|
mapping = settings.get_model_name_to_llm_key()
|
||||||
llm_key = mapping.get(model_name)
|
llm_key = mapping.get(model_name, {}).get("llm_key")
|
||||||
if llm_key:
|
if llm_key:
|
||||||
return llm_key
|
return llm_key
|
||||||
|
|
||||||
|
|
|
@ -248,10 +248,10 @@ class Task(TaskBase):
|
||||||
Otherwise return `None`.
|
Otherwise return `None`.
|
||||||
"""
|
"""
|
||||||
if self.model:
|
if self.model:
|
||||||
model_name = self.model.get("name")
|
model_name = self.model.get("model_name")
|
||||||
if model_name:
|
if model_name:
|
||||||
mapping = settings.get_model_name_to_llm_key()
|
mapping = settings.get_model_name_to_llm_key()
|
||||||
return mapping.get(model_name)
|
return mapping.get(model_name, {}).get("llm_key")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -370,4 +370,4 @@ class SortDirection(StrEnum):
|
||||||
|
|
||||||
|
|
||||||
class ModelsResponse(BaseModel):
|
class ModelsResponse(BaseModel):
|
||||||
models: list[str]
|
models: dict[str, str]
|
||||||
|
|
|
@ -5,11 +5,10 @@ from typing import Any, List
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from skyvern.config import settings
|
|
||||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2
|
||||||
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
|
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
|
||||||
from skyvern.forge.sdk.workflow.models.block import Block, BlockTypeVar
|
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE
|
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE
|
||||||
from skyvern.schemas.runs import ProxyLocation
|
from skyvern.schemas.runs import ProxyLocation
|
||||||
from skyvern.utils.url_validators import validate_url
|
from skyvern.utils.url_validators import validate_url
|
||||||
|
@ -82,34 +81,6 @@ class Workflow(BaseModel):
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
deleted_at: datetime | None = None
|
deleted_at: datetime | None = None
|
||||||
|
|
||||||
def determine_llm_key(self, *, block: Block | None = None) -> str | None:
|
|
||||||
"""
|
|
||||||
Determine the LLM key override to use for a block, if it has one.
|
|
||||||
|
|
||||||
It has one if:
|
|
||||||
- it defines one, or
|
|
||||||
- the workflow it is a part of (if applicable) defines one
|
|
||||||
"""
|
|
||||||
|
|
||||||
mapping = settings.get_model_name_to_llm_key()
|
|
||||||
|
|
||||||
if block:
|
|
||||||
model_name = (block.model or {}).get("name")
|
|
||||||
|
|
||||||
if model_name:
|
|
||||||
llm_key = mapping.get(model_name)
|
|
||||||
if llm_key:
|
|
||||||
return llm_key
|
|
||||||
|
|
||||||
workflow_model_name = (self.model or {}).get("name")
|
|
||||||
|
|
||||||
if workflow_model_name:
|
|
||||||
llm_key = mapping.get(workflow_model_name)
|
|
||||||
if llm_key:
|
|
||||||
return llm_key
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunStatus(StrEnum):
|
class WorkflowRunStatus(StrEnum):
|
||||||
created = "created"
|
created = "created"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue