Stop MCP model-key hallucinations in workflow creation (#SKY-8302) (#5061)

This commit is contained in:
Marc Kelechava 2026-03-11 16:56:19 -07:00 committed by GitHub
parent af91183d75
commit 87f18f4ee0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 810 additions and 26 deletions

View file

@ -198,6 +198,7 @@ Validate blocks with skyvern_block_validate() before submitting.
Split workflows into multiple blocks one block per logical step rather than cramming everything into a single block.
Use **navigation** blocks for actions (filling forms, clicking buttons) and **extraction** blocks for pulling data.
Do NOT use the deprecated "task" or "task_v2" block types use "navigation" for actions and "extraction" for data extraction. These replacements give clearer semantics and are what the Skyvern UI uses. Existing workflows with task/task_v2 blocks will continue to work do not convert them unless the user asks. New workflows must use navigation/extraction.
For **text_prompt** blocks, default to Skyvern Optimized by omitting both `model` and `llm_key`. If an explicit model is required, use `model: {"model_name": "<value from /models>"}`. Do not invent internal `llm_key` strings.
GOOD (4 blocks, each with clear single responsibility):
Block 1 (navigation): "Select Sole Proprietor and click Continue"

View file

@ -187,6 +187,18 @@ BLOCK_EXAMPLES: dict[str, dict[str, Any]] = {
"label": "wait_for_processing",
"wait_sec": 30,
},
"text_prompt": {
"block_type": "text_prompt",
"label": "summarize_results",
"prompt": "Summarize {{ raw_results }} into a short customer-facing update",
"parameter_keys": ["raw_results"],
"json_schema": {
"type": "object",
"properties": {
"summary": {"type": "string"},
},
},
},
"goto_url": {
"block_type": "goto_url",
"label": "open_cart",

View file

@ -90,6 +90,16 @@ Keep engine 1.0 (default, omit field) when:
When in doubt, split into multiple 1.0 blocks rather than using one 2.0 block it's cheaper and
gives you per-block observability. Only navigation blocks support engine 2.0.
### Model selection for text_prompt blocks
Default to Skyvern Optimized for text_prompt blocks by omitting both `model` and `llm_key`.
If the user explicitly asks for a specific model, use the public `model` field:
`"model": {"model_name": "<one of the values returned by /models>"}`
Do NOT invent internal `llm_key` strings like `ANTHROPIC_CLAUDE_3_5_SONNET`.
Only use `llm_key` when the user explicitly provides an exact internal key and wants that advanced override.
### One block per logical step
Split workflows into small, focused blocks. Each block should do ONE thing.

View file

@ -17,6 +17,7 @@ from pydantic import Field
from skyvern.client.errors import NotFoundError
from skyvern.client.types import WorkflowCreateYamlRequest
from skyvern.schemas.runs import ProxyLocation
from skyvern.schemas.workflows import WorkflowCreateYAMLRequest as WorkflowCreateYAMLRequestSchema
from ._common import ErrorCode, Timer, make_error, make_result
from ._session import get_skyvern
@ -206,6 +207,63 @@ _CODE_V2_DEFAULTS: dict[str, Any] = {
}
def _deep_merge(base: Any, override: Any) -> Any:
"""Recursively merge normalized JSON-like data over the raw payload.
Unknown fields should survive normalization. Lists are merged by index so
overlapping items keep raw unknown keys even if normalization changes the
list length.
"""
if isinstance(base, dict) and isinstance(override, dict):
result = dict(base)
for key, value in override.items():
if key in result:
result[key] = _deep_merge(result[key], value)
else:
result[key] = value
return result
if isinstance(base, list) and isinstance(override, list):
merged: list[Any] = []
for idx in range(max(len(base), len(override))):
if idx < len(base) and idx < len(override):
merged.append(_deep_merge(base[idx], override[idx]))
elif idx < len(override):
merged.append(override[idx])
else:
merged.append(base[idx])
return merged
return override
def _normalize_json_definition(raw: Any) -> WorkflowCreateYamlRequest:
"""Normalize JSON workflow definitions through the shared backend schema."""
if not isinstance(raw, dict):
raise TypeError("Workflow definition JSON must be an object")
try:
normalized = WorkflowCreateYAMLRequestSchema.model_validate(raw)
except Exception as exc:
# Internal schema is stricter than the Fern SDK — skip normalization so
# unknown/future fields are not rejected.
LOG.warning("Skipping text-prompt normalization; internal schema rejected payload", error=str(exc))
return WorkflowCreateYamlRequest(**raw)
merged = _deep_merge(raw, normalized.model_dump(mode="json"))
return WorkflowCreateYamlRequest(**merged)
def _make_invalid_json_definition_error(exc: Exception) -> dict[str, Any]:
return make_error(
ErrorCode.INVALID_INPUT,
f"Invalid JSON definition: {exc}",
"Provide a valid JSON object for the workflow definition",
)
def _inject_code_v2_defaults(definition: str, fmt: str) -> str:
"""Inject Code 2.0 defaults into a JSON definition string when not explicitly set.
@ -237,42 +295,28 @@ def _parse_definition(
Exactly one of the first two will be set on success, or error on failure.
JSON input is parsed into a WorkflowCreateYamlRequest (the type the SDK expects).
"""
if fmt == "json":
try:
raw = json.loads(definition)
return WorkflowCreateYamlRequest(**raw), None, None
except (json.JSONDecodeError, TypeError) as e:
return (
None,
None,
make_error(
ErrorCode.INVALID_INPUT,
f"Invalid JSON definition: {e}",
"Provide a valid JSON object for the workflow definition",
),
)
return None, None, _make_invalid_json_definition_error(e)
try:
return _normalize_json_definition(raw), None, None
except Exception as e:
return (
None,
None,
make_error(
ErrorCode.INVALID_INPUT,
f"Invalid workflow definition: {e}",
"Check the workflow definition fields (title, workflow_definition with blocks)",
),
)
return None, None, _make_invalid_json_definition_error(e)
elif fmt == "yaml":
return None, definition, None
else:
# auto: try JSON first, fall back to YAML
try:
raw = json.loads(definition)
return WorkflowCreateYamlRequest(**raw), None, None
except (json.JSONDecodeError, TypeError):
return None, definition, None
except Exception:
# JSON parsed but failed model validation — treat as YAML
return None, definition, None
try:
return _normalize_json_definition(raw), None, None
except Exception as e:
return None, None, _make_invalid_json_definition_error(e)
# ---------------------------------------------------------------------------

View file

@ -1,4 +1,5 @@
import abc
import functools
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, Any, Literal
@ -7,6 +8,8 @@ import structlog
from pydantic import BaseModel, Field, field_validator, model_validator
from skyvern.config import settings
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType, WorkflowParameterType
from skyvern.schemas.runs import GeoTarget, ProxyLocation, RunEngine
from skyvern.utils.strings import sanitize_identifier
@ -44,6 +47,26 @@ def sanitize_parameter_key(value: str) -> str:
return sanitize_identifier(value, default="parameter")
def _has_jinja_syntax(value: str) -> bool:
return "{{" in value or "{%" in value
@functools.lru_cache(maxsize=1)
def _get_text_prompt_model_name_by_llm_key() -> dict[str, str]:
"""Build a reverse mapping from internal llm_key to public model_name.
Cached because settings don't change at runtime. Tests that monkeypatch
settings must call ``_get_text_prompt_model_name_by_llm_key.cache_clear()``
to avoid cross-test pollution.
"""
reverse_mapping: dict[str, str] = {}
for model_name, metadata in SettingsManager.get_settings().get_model_name_to_llm_key().items():
llm_key = metadata.get("llm_key")
if llm_key and llm_key not in reverse_mapping:
reverse_mapping[llm_key] = model_name
return reverse_mapping
def _replace_references_in_value(value: Any, old_key: str, new_key: str) -> Any:
"""Recursively replaces Jinja references in a value (string, dict, or list)."""
if isinstance(value, str):
@ -679,6 +702,42 @@ class TextPromptBlockYAML(BlockYAML):
parameter_keys: list[str] | None = None
json_schema: dict[str, Any] | None = None
@model_validator(mode="after")
def normalize_llm_selection(self) -> "TextPromptBlockYAML":
raw_llm_key = self.llm_key.strip() if self.llm_key else None
if self.model:
# `model` is the stable public contract; ignore any raw llm_key override
# once a model has been selected.
self.llm_key = None
return self
if not raw_llm_key:
self.llm_key = None
return self
if _has_jinja_syntax(raw_llm_key):
self.llm_key = raw_llm_key
return self
model_name = _get_text_prompt_model_name_by_llm_key().get(raw_llm_key)
if model_name:
self.model = {"model_name": model_name}
self.llm_key = None
return self
if raw_llm_key in LLMConfigRegistry.get_model_names():
self.llm_key = raw_llm_key
return self
LOG.warning(
"Unrecognized text prompt llm_key; defaulting to Skyvern Optimized/default model path",
label=self.label,
llm_key=raw_llm_key,
)
self.llm_key = None
return self
class DownloadToS3BlockYAML(BlockYAML):
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:

View file

@ -495,7 +495,8 @@ _FOR_LOOP_RE = re.compile(
# nested quantifiers would cause exponential backtracking.
_PAGE_CLICK_START_RE = re.compile(r"await page\.click\(")
# Matches a prompt='...' or prompt="..." keyword argument inside a function call.
_PROMPT_KWARG_RE = re.compile(r"""prompt\s*=\s*(['"])(.*?)\1""", re.DOTALL)
# Uses (?:\\.|…) to skip backslash-escaped quotes so apostrophes don't truncate the match.
_PROMPT_KWARG_RE = re.compile(r"""prompt\s*=\s*(['"])((?:\\.|(?!\1).)*?)\1""", re.DOTALL)
def _find_click_calls(text: str) -> list[tuple[int, int, str]]:
@ -622,8 +623,10 @@ def _patch_static_clicks_in_block(body: str) -> str:
if prompt_match:
quote = prompt_match.group(1)
original_prompt = prompt_match.group(2)
# Use an f-string so current_value is evaluated at runtime
new_prompt = f"prompt=f{quote}{original_prompt} Target: {{current_value}}{quote}"
# Use an f-string so current_value is evaluated at runtime.
# Escape existing braces so they are literal in the f-string.
escaped_prompt = original_prompt.replace("{", "{{").replace("}", "}}")
new_prompt = f"prompt=f{quote}{escaped_prompt} Target: {{current_value}}{quote}"
patched = patched[: prompt_match.start()] + new_prompt + patched[prompt_match.end() :]
else:
# No prompt= kwarg — add one with current_value context

View file

@ -378,3 +378,31 @@ class TestFixStaticActionsInForLoops:
break
else:
raise AssertionError("Expected a prompt=f line with current_value")
def test_prompt_with_curly_braces_are_escaped(self) -> None:
"""Existing braces in prompt text must be escaped to avoid f-string evaluation."""
body = textwrap.dedent("""\
await page.click(
selector='a.download-link',
ai='fallback',
prompt='Extract items matching {pattern}',
)""")
result = _patch_static_clicks_in_block(body)
assert "ai='proactive'" in result
# Original braces must be doubled so they're literal in the f-string
assert "{{pattern}}" in result
# The injected Target should still reference current_value
assert "{current_value}" in result
def test_two_level_nested_selector_is_matched(self) -> None:
"""Selectors like tr:has(td:has-text("Report")) must not break the regex."""
body = textwrap.dedent("""\
await page.click(
selector='tr:has(td:has-text("Report"))',
ai='fallback',
prompt='Click row',
)""")
result = _patch_static_clicks_in_block(body)
assert "ai='proactive'" in result
assert "ai='fallback'" not in result
assert "current_value" in result

View file

@ -0,0 +1,29 @@
from __future__ import annotations
import pytest
from skyvern.cli.mcp_tools import mcp
from skyvern.cli.mcp_tools.blocks import skyvern_block_schema
from skyvern.cli.mcp_tools.prompts import BUILD_WORKFLOW_CONTENT
def test_build_workflow_prompt_guides_text_prompt_defaults() -> None:
assert "Default to Skyvern Optimized for text_prompt blocks by omitting both `model` and `llm_key`." in (
BUILD_WORKFLOW_CONTENT
)
assert "Do NOT invent internal `llm_key` strings like `ANTHROPIC_CLAUDE_3_5_SONNET`." in BUILD_WORKFLOW_CONTENT
def test_mcp_instructions_guide_text_prompt_defaults() -> None:
assert "For **text_prompt** blocks, default to Skyvern Optimized by omitting both `model` and `llm_key`." in (
mcp.instructions
)
@pytest.mark.asyncio
async def test_text_prompt_block_schema_example_omits_raw_llm_key() -> None:
result = await skyvern_block_schema(block_type="text_prompt")
assert result["ok"] is True
assert "llm_key" not in result["data"]["example"]
assert "model" not in result["data"]["example"]

View file

@ -0,0 +1,335 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
import skyvern.cli.mcp_tools.workflow as workflow_tools
def _fake_workflow_response() -> SimpleNamespace:
now = datetime.now(timezone.utc)
return SimpleNamespace(
workflow_permanent_id="wpid_test",
workflow_id="wf_test",
title="Example Workflow",
version=1,
status="published",
description=None,
is_saved_task=False,
folder_id=None,
created_at=now,
modified_at=now,
)
@pytest.mark.asyncio
async def test_workflow_create_normalizes_invalid_text_prompt_llm_key(monkeypatch: pytest.MonkeyPatch) -> None:
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Normalize invalid llm_key",
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize the result",
"llm_key": "ANTHROPIC_CLAUDE_3_5_SONNET",
}
],
"parameters": [],
},
}
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
sent_definition = fake_client.create_workflow.await_args.kwargs["json_definition"]
sent_block = sent_definition.workflow_definition.blocks[0]
assert result["ok"] is True
assert sent_block.llm_key is None
assert sent_block.model is None
assert "ANTHROPIC_CLAUDE_3_5_SONNET" not in json.dumps(sent_definition.model_dump(mode="json"))
@pytest.mark.asyncio
async def test_workflow_create_preserves_explicit_internal_text_prompt_llm_key(monkeypatch: pytest.MonkeyPatch) -> None:
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Preserve explicit internal llm_key",
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize the result",
"llm_key": "SPECIAL_INTERNAL_KEY",
}
],
"parameters": [],
},
}
with patch(
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
return_value=["SPECIAL_INTERNAL_KEY"],
):
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
sent_definition = fake_client.create_workflow.await_args.kwargs["json_definition"]
sent_block = sent_definition.workflow_definition.blocks[0]
assert result["ok"] is True
assert sent_block.model is None
assert sent_block.llm_key == "SPECIAL_INTERNAL_KEY"
# ---------------------------------------------------------------------------
# Prove the exact Slack scenario: MCP agent hallucinates various model strings
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hallucinated_key",
[
"ANTHROPIC_CLAUDE_3_5_SONNET", # exact key from the Slack thread
"ANTHROPIC_CLAUDE3.5_SONNET", # the "correct" key Pedro mentioned — still not public
"ANTHROPIC_CLAUDE_3_5_HAIKU",
"OPENAI_GPT4_TURBO",
"VERTEX_GEMINI_2_FLASH",
"claude-3-opus-20240229",
"gpt-4o-mini",
"gemini-pro",
],
)
@pytest.mark.asyncio
async def test_mcp_strips_all_common_hallucinated_llm_keys(
monkeypatch: pytest.MonkeyPatch, hallucinated_key: str
) -> None:
"""MCP workflow creation must strip ANY hallucinated llm_key and default to Skyvern Optimized (null)."""
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Agent-generated workflow",
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "limit_pages",
"prompt": "Extract only the first 10 results",
"llm_key": hallucinated_key,
}
],
"parameters": [],
},
}
with patch(
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
return_value=[], # simulate: none of these are registered
):
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="auto")
sent_def = fake_client.create_workflow.await_args.kwargs["json_definition"]
sent_block = sent_def.workflow_definition.blocks[0]
assert result["ok"] is True
assert sent_block.llm_key is None, f"hallucinated key {hallucinated_key!r} was NOT stripped"
assert sent_block.model is None, "should default to Skyvern Optimized (null model)"
@pytest.mark.asyncio
async def test_workflow_create_preserves_unknown_fields(monkeypatch: pytest.MonkeyPatch) -> None:
"""Fields not in the internal schema should survive normalization.
The Fern-generated WorkflowCreateYamlRequest uses extra='allow', so unknown
fields are accepted. Our normalization deep-merges the original raw dict with
the normalized output so that future SDK fields not yet mirrored in the
internal schema are preserved at any nesting depth.
"""
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Unknown fields test",
"some_future_sdk_field": "should_survive",
"workflow_definition": {
"parameters": [],
"some_nested_future_field": "also_survives",
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize",
}
],
},
}
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
assert result["ok"] is True
sent = fake_client.create_workflow.await_args.kwargs["json_definition"]
# Top-level unknown field preserved
assert sent.some_future_sdk_field == "should_survive"
# Nested unknown field inside workflow_definition also preserved via deep merge
wd = sent.workflow_definition
wd_dict = wd.model_dump(mode="json") if hasattr(wd, "model_dump") else wd.__dict__
assert wd_dict.get("some_nested_future_field") == "also_survives"
@pytest.mark.asyncio
async def test_workflow_create_preserves_block_level_unknown_fields(monkeypatch: pytest.MonkeyPatch) -> None:
"""Unknown fields inside individual block dicts survive normalization via deep merge."""
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Block-level unknown fields test",
"workflow_definition": {
"parameters": [],
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize",
"some_future_block_field": 42,
}
],
},
}
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
assert result["ok"] is True
sent = fake_client.create_workflow.await_args.kwargs["json_definition"]
sent_block = sent.workflow_definition.blocks[0]
block_dict = sent_block.model_dump(mode="json") if hasattr(sent_block, "model_dump") else sent_block.__dict__
assert block_dict.get("some_future_block_field") == 42
@pytest.mark.asyncio
async def test_workflow_create_falls_back_on_schema_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
"""If the internal schema rejects the payload, normalization is skipped and the raw dict is forwarded."""
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Schema rejection test",
"workflow_definition": {
"parameters": [],
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize",
"llm_key": "HALLUCINATED_KEY",
}
],
},
}
with patch(
"skyvern.cli.mcp_tools.workflow.WorkflowCreateYAMLRequestSchema.model_validate",
side_effect=Exception("schema rejected"),
):
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
assert result["ok"] is True
sent = fake_client.create_workflow.await_args.kwargs["json_definition"]
# Normalization was skipped, so the hallucinated key passes through to the SDK
sent_block = sent.workflow_definition.blocks[0]
block_dict = sent_block.model_dump(mode="json") if hasattr(sent_block, "model_dump") else sent_block.__dict__
assert block_dict.get("llm_key") == "HALLUCINATED_KEY"
@pytest.mark.parametrize("format_name", ["json", "auto"])
def test_parse_definition_returns_invalid_input_when_schema_fallback_still_fails(format_name: str) -> None:
"""If both the internal schema and Fern fallback reject the JSON, return a structured INVALID_INPUT error."""
json_def, yaml_def, err = workflow_tools._parse_definition(
json.dumps({"title": "Missing workflow_definition"}),
format_name,
)
assert json_def is None
assert yaml_def is None
assert err is not None
assert err["code"] == workflow_tools.ErrorCode.INVALID_INPUT
assert "Invalid JSON definition" in err["message"]
def test_deep_merge_preserves_block_unknown_fields_when_list_lengths_differ() -> None:
"""Overlapping block data should still merge when normalization changes list length."""
raw = {
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize",
"some_future_block_field": 42,
}
]
}
}
normalized = {
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize",
},
{
"block_type": "text_prompt",
"label": "synthetic_followup",
"prompt": "Follow up",
},
]
}
}
merged = workflow_tools._deep_merge(raw, normalized)
merged_blocks = merged["workflow_definition"]["blocks"]
assert len(merged_blocks) == 2
assert merged_blocks[0]["some_future_block_field"] == 42
assert merged_blocks[1]["label"] == "synthetic_followup"
@pytest.mark.asyncio
async def test_mcp_text_prompt_without_llm_key_stays_null(monkeypatch: pytest.MonkeyPatch) -> None:
"""When MCP correctly omits llm_key (Skyvern Optimized), it stays null through the whole pipeline."""
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
definition = {
"title": "Well-behaved MCP workflow",
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize the extracted data",
}
],
"parameters": [],
},
}
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
sent_block = fake_client.create_workflow.await_args.kwargs["json_definition"].workflow_definition.blocks[0]
assert result["ok"] is True
assert sent_block.llm_key is None
assert sent_block.model is None

View file

@ -9,6 +9,7 @@ from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.models.block import TextPromptBlock
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
from skyvern.schemas.workflows import TextPromptBlockYAML, WorkflowRequest
block_module = sys.modules["skyvern.forge.sdk.workflow.models.block"]
@ -212,3 +213,177 @@ async def test_text_prompt_block_prefers_prompt_type_config_over_secondary(monke
assert captured["default_handler"] == prompt_config_handler
prompt_config_handler.assert_awaited_once()
assert response == {"llm_response": "config"}
@pytest.mark.asyncio
async def test_text_prompt_block_bad_llm_key_uses_same_runtime_path_as_no_override(monkeypatch):
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
now = datetime.now(timezone.utc)
normalized_bad = TextPromptBlockYAML(
label="bad_key",
prompt="Summarize status.",
llm_key="ANTHROPIC_CLAUDE_3_5_SONNET",
)
no_override = TextPromptBlockYAML(
label="no_override",
prompt="Summarize status.",
llm_key=None,
)
blocks = []
for idx, yaml_block in enumerate((normalized_bad, no_override), start=1):
output_parameter = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key=f"text_prompt_output_{idx}",
description=None,
output_parameter_id=f"output-{idx}",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
blocks.append(
TextPromptBlock(
label=yaml_block.label,
llm_key=yaml_block.llm_key,
prompt=yaml_block.prompt,
parameters=[],
json_schema=None,
output_parameter=output_parameter,
model=yaml_block.model,
)
)
captured: list[tuple[str | None, object]] = []
fake_secondary_handler = AsyncMock(return_value={"llm_response": "secondary"})
async def fake_prompt_type_handler(*args, **kwargs):
return None
def fake_get_override_handler(llm_key: str | None, *, default):
captured.append((llm_key, default))
return default
block_module.app.SECONDARY_LLM_API_HANDLER = fake_secondary_handler
block_module.app.LLM_API_HANDLER = AsyncMock()
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
monkeypatch.setattr(
LLMAPIHandlerFactory,
"get_override_llm_api_handler",
fake_get_override_handler,
raising=False,
)
monkeypatch.setattr(
block_module,
"get_llm_handler_for_prompt_type",
fake_prompt_type_handler,
raising=False,
)
monkeypatch.setattr(
prompt_engine,
"load_prompt_from_string",
lambda template, **kwargs: template,
)
for block in blocks:
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
assert response == {"llm_response": "secondary"}
assert captured == [
(None, fake_secondary_handler),
(None, fake_secondary_handler),
]
@pytest.mark.asyncio
async def test_text_prompt_block_uses_explicit_internal_llm_key_override(monkeypatch):
now = datetime.now(timezone.utc)
output_parameter = OutputParameter(
parameter_type=ParameterType.OUTPUT,
key="text_prompt_output_internal",
description=None,
output_parameter_id="output-internal",
workflow_id="workflow-1",
created_at=now,
modified_at=now,
deleted_at=None,
)
block = TextPromptBlock(
label="text-block",
llm_key="SPECIAL_INTERNAL_KEY",
prompt="Summarize status.",
parameters=[],
json_schema=None,
output_parameter=output_parameter,
model=None,
)
captured: dict[str, object] = {}
fake_default_handler = AsyncMock()
fake_override_handler = AsyncMock(return_value={"llm_response": "override"})
async def fake_resolve_default_llm_handler(*args, **kwargs):
return fake_default_handler
def fake_get_override_handler(llm_key: str | None, *, default):
captured["llm_key"] = llm_key
captured["default_handler"] = default
return fake_override_handler
block_module.app.LLM_API_HANDLER = fake_default_handler
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
monkeypatch.setattr(
LLMAPIHandlerFactory,
"get_override_llm_api_handler",
fake_get_override_handler,
raising=False,
)
monkeypatch.setattr(
TextPromptBlock,
"_resolve_default_llm_handler",
fake_resolve_default_llm_handler,
raising=False,
)
monkeypatch.setattr(
prompt_engine,
"load_prompt_from_string",
lambda template, **kwargs: template,
)
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
assert captured["llm_key"] == "SPECIAL_INTERNAL_KEY"
assert captured["default_handler"] == fake_default_handler
fake_override_handler.assert_awaited_once()
assert response == {"llm_response": "override"}
def test_workflow_request_deserialization_strips_invalid_text_prompt_llm_key() -> None:
"""Verify FastAPI deserialization (not just explicit model_validate) strips bad keys.
Moved from tests/scenario/ this test only validates Pydantic behavior and
does not require a database connection.
"""
raw_request = {
"json_definition": {
"title": "Deserialization test",
"workflow_definition": {
"blocks": [
{
"block_type": "text_prompt",
"label": "summarize",
"prompt": "Summarize the result",
"llm_key": "ANTHROPIC_CLAUDE_3_5_SONNET",
}
],
"parameters": [],
},
}
}
workflow_request = WorkflowRequest.model_validate(raw_request)
block = workflow_request.json_definition.workflow_definition.blocks[0]
assert block.llm_key is None
assert block.model is None

View file

@ -0,0 +1,88 @@
from __future__ import annotations
from unittest.mock import patch
import pytest
from skyvern.config import settings as base_settings
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.schemas.workflows import TextPromptBlockYAML, _get_text_prompt_model_name_by_llm_key
@pytest.fixture(autouse=True)
def _clear_llm_key_cache():
"""Clear the lru_cache before each test to prevent cross-test pollution."""
_get_text_prompt_model_name_by_llm_key.cache_clear()
yield
_get_text_prompt_model_name_by_llm_key.cache_clear()
class TestTextPromptBlockYAMLNormalization:
def test_converts_known_llm_key_to_model(self, monkeypatch) -> None:
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
block = TextPromptBlockYAML(
label="summarize",
prompt="Summarize the data.",
llm_key="VERTEX_GEMINI_2.5_FLASH",
)
assert block.model == {"model_name": "gemini-2.5-flash"}
assert block.llm_key is None
def test_clears_invalid_llm_key_to_use_default_model(self, monkeypatch) -> None:
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
with patch(
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
return_value=[],
):
block = TextPromptBlockYAML(
label="summarize",
prompt="Summarize the data.",
llm_key="ANTHROPIC_CLAUDE_3_5_SONNET",
)
assert block.model is None
assert block.llm_key is None
def test_preserves_registered_advanced_llm_key(self, monkeypatch) -> None:
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
with patch(
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
return_value=["SPECIAL_INTERNAL_KEY"],
):
block = TextPromptBlockYAML(
label="summarize",
prompt="Summarize the data.",
llm_key="SPECIAL_INTERNAL_KEY",
)
assert block.model is None
assert block.llm_key == "SPECIAL_INTERNAL_KEY"
def test_preserves_templated_llm_key(self, monkeypatch) -> None:
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
block = TextPromptBlockYAML(
label="summarize",
prompt="Summarize the data.",
llm_key="{{ prompt_block_llm_key }}",
)
assert block.model is None
assert block.llm_key == "{{ prompt_block_llm_key }}"
def test_model_override_clears_raw_llm_key(self, monkeypatch) -> None:
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
block = TextPromptBlockYAML(
label="summarize",
prompt="Summarize the data.",
llm_key="ANTHROPIC_CLAUDE_3_5_SONNET",
model={"model_name": "gemini-3-pro-preview"},
)
assert block.model == {"model_name": "gemini-3-pro-preview"}
assert block.llm_key is None