mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
Stop MCP model-key hallucinations in workflow creation (#SKY-8302) (#5061)
This commit is contained in:
parent
af91183d75
commit
87f18f4ee0
11 changed files with 810 additions and 26 deletions
|
|
@ -198,6 +198,7 @@ Validate blocks with skyvern_block_validate() before submitting.
|
|||
Split workflows into multiple blocks — one block per logical step — rather than cramming everything into a single block.
|
||||
Use **navigation** blocks for actions (filling forms, clicking buttons) and **extraction** blocks for pulling data.
|
||||
Do NOT use the deprecated "task" or "task_v2" block types — use "navigation" for actions and "extraction" for data extraction. These replacements give clearer semantics and are what the Skyvern UI uses. Existing workflows with task/task_v2 blocks will continue to work — do not convert them unless the user asks. New workflows must use navigation/extraction.
|
||||
For **text_prompt** blocks, default to Skyvern Optimized by omitting both `model` and `llm_key`. If an explicit model is required, use `model: {"model_name": "<value from /models>"}`. Do not invent internal `llm_key` strings.
|
||||
|
||||
GOOD (4 blocks, each with clear single responsibility):
|
||||
Block 1 (navigation): "Select Sole Proprietor and click Continue"
|
||||
|
|
|
|||
|
|
@ -187,6 +187,18 @@ BLOCK_EXAMPLES: dict[str, dict[str, Any]] = {
|
|||
"label": "wait_for_processing",
|
||||
"wait_sec": 30,
|
||||
},
|
||||
"text_prompt": {
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize_results",
|
||||
"prompt": "Summarize {{ raw_results }} into a short customer-facing update",
|
||||
"parameter_keys": ["raw_results"],
|
||||
"json_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"goto_url": {
|
||||
"block_type": "goto_url",
|
||||
"label": "open_cart",
|
||||
|
|
|
|||
|
|
@ -90,6 +90,16 @@ Keep engine 1.0 (default, omit field) when:
|
|||
When in doubt, split into multiple 1.0 blocks rather than using one 2.0 block — it's cheaper and
|
||||
gives you per-block observability. Only navigation blocks support engine 2.0.
|
||||
|
||||
### Model selection for text_prompt blocks
|
||||
|
||||
Default to Skyvern Optimized for text_prompt blocks by omitting both `model` and `llm_key`.
|
||||
|
||||
If the user explicitly asks for a specific model, use the public `model` field:
|
||||
`"model": {"model_name": "<one of the values returned by /models>"}`
|
||||
|
||||
Do NOT invent internal `llm_key` strings like `ANTHROPIC_CLAUDE_3_5_SONNET`.
|
||||
Only use `llm_key` when the user explicitly provides an exact internal key and wants that advanced override.
|
||||
|
||||
### One block per logical step
|
||||
|
||||
Split workflows into small, focused blocks. Each block should do ONE thing.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from pydantic import Field
|
|||
from skyvern.client.errors import NotFoundError
|
||||
from skyvern.client.types import WorkflowCreateYamlRequest
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
from skyvern.schemas.workflows import WorkflowCreateYAMLRequest as WorkflowCreateYAMLRequestSchema
|
||||
|
||||
from ._common import ErrorCode, Timer, make_error, make_result
|
||||
from ._session import get_skyvern
|
||||
|
|
@ -206,6 +207,63 @@ _CODE_V2_DEFAULTS: dict[str, Any] = {
|
|||
}
|
||||
|
||||
|
||||
def _deep_merge(base: Any, override: Any) -> Any:
|
||||
"""Recursively merge normalized JSON-like data over the raw payload.
|
||||
|
||||
Unknown fields should survive normalization. Lists are merged by index so
|
||||
overlapping items keep raw unknown keys even if normalization changes the
|
||||
list length.
|
||||
"""
|
||||
|
||||
if isinstance(base, dict) and isinstance(override, dict):
|
||||
result = dict(base)
|
||||
for key, value in override.items():
|
||||
if key in result:
|
||||
result[key] = _deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
if isinstance(base, list) and isinstance(override, list):
|
||||
merged: list[Any] = []
|
||||
for idx in range(max(len(base), len(override))):
|
||||
if idx < len(base) and idx < len(override):
|
||||
merged.append(_deep_merge(base[idx], override[idx]))
|
||||
elif idx < len(override):
|
||||
merged.append(override[idx])
|
||||
else:
|
||||
merged.append(base[idx])
|
||||
return merged
|
||||
|
||||
return override
|
||||
|
||||
|
||||
def _normalize_json_definition(raw: Any) -> WorkflowCreateYamlRequest:
|
||||
"""Normalize JSON workflow definitions through the shared backend schema."""
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise TypeError("Workflow definition JSON must be an object")
|
||||
|
||||
try:
|
||||
normalized = WorkflowCreateYAMLRequestSchema.model_validate(raw)
|
||||
except Exception as exc:
|
||||
# Internal schema is stricter than the Fern SDK — skip normalization so
|
||||
# unknown/future fields are not rejected.
|
||||
LOG.warning("Skipping text-prompt normalization; internal schema rejected payload", error=str(exc))
|
||||
return WorkflowCreateYamlRequest(**raw)
|
||||
|
||||
merged = _deep_merge(raw, normalized.model_dump(mode="json"))
|
||||
return WorkflowCreateYamlRequest(**merged)
|
||||
|
||||
|
||||
def _make_invalid_json_definition_error(exc: Exception) -> dict[str, Any]:
|
||||
return make_error(
|
||||
ErrorCode.INVALID_INPUT,
|
||||
f"Invalid JSON definition: {exc}",
|
||||
"Provide a valid JSON object for the workflow definition",
|
||||
)
|
||||
|
||||
|
||||
def _inject_code_v2_defaults(definition: str, fmt: str) -> str:
|
||||
"""Inject Code 2.0 defaults into a JSON definition string when not explicitly set.
|
||||
|
||||
|
|
@ -237,42 +295,28 @@ def _parse_definition(
|
|||
Exactly one of the first two will be set on success, or error on failure.
|
||||
JSON input is parsed into a WorkflowCreateYamlRequest (the type the SDK expects).
|
||||
"""
|
||||
|
||||
if fmt == "json":
|
||||
try:
|
||||
raw = json.loads(definition)
|
||||
return WorkflowCreateYamlRequest(**raw), None, None
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
make_error(
|
||||
ErrorCode.INVALID_INPUT,
|
||||
f"Invalid JSON definition: {e}",
|
||||
"Provide a valid JSON object for the workflow definition",
|
||||
),
|
||||
)
|
||||
return None, None, _make_invalid_json_definition_error(e)
|
||||
try:
|
||||
return _normalize_json_definition(raw), None, None
|
||||
except Exception as e:
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
make_error(
|
||||
ErrorCode.INVALID_INPUT,
|
||||
f"Invalid workflow definition: {e}",
|
||||
"Check the workflow definition fields (title, workflow_definition with blocks)",
|
||||
),
|
||||
)
|
||||
return None, None, _make_invalid_json_definition_error(e)
|
||||
elif fmt == "yaml":
|
||||
return None, definition, None
|
||||
else:
|
||||
# auto: try JSON first, fall back to YAML
|
||||
try:
|
||||
raw = json.loads(definition)
|
||||
return WorkflowCreateYamlRequest(**raw), None, None
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None, definition, None
|
||||
except Exception:
|
||||
# JSON parsed but failed model validation — treat as YAML
|
||||
return None, definition, None
|
||||
try:
|
||||
return _normalize_json_definition(raw), None, None
|
||||
except Exception as e:
|
||||
return None, None, _make_invalid_json_definition_error(e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import abc
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
|
@ -7,6 +8,8 @@ import structlog
|
|||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType, WorkflowParameterType
|
||||
from skyvern.schemas.runs import GeoTarget, ProxyLocation, RunEngine
|
||||
from skyvern.utils.strings import sanitize_identifier
|
||||
|
|
@ -44,6 +47,26 @@ def sanitize_parameter_key(value: str) -> str:
|
|||
return sanitize_identifier(value, default="parameter")
|
||||
|
||||
|
||||
def _has_jinja_syntax(value: str) -> bool:
|
||||
return "{{" in value or "{%" in value
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_text_prompt_model_name_by_llm_key() -> dict[str, str]:
|
||||
"""Build a reverse mapping from internal llm_key to public model_name.
|
||||
|
||||
Cached because settings don't change at runtime. Tests that monkeypatch
|
||||
settings must call ``_get_text_prompt_model_name_by_llm_key.cache_clear()``
|
||||
to avoid cross-test pollution.
|
||||
"""
|
||||
reverse_mapping: dict[str, str] = {}
|
||||
for model_name, metadata in SettingsManager.get_settings().get_model_name_to_llm_key().items():
|
||||
llm_key = metadata.get("llm_key")
|
||||
if llm_key and llm_key not in reverse_mapping:
|
||||
reverse_mapping[llm_key] = model_name
|
||||
return reverse_mapping
|
||||
|
||||
|
||||
def _replace_references_in_value(value: Any, old_key: str, new_key: str) -> Any:
|
||||
"""Recursively replaces Jinja references in a value (string, dict, or list)."""
|
||||
if isinstance(value, str):
|
||||
|
|
@ -679,6 +702,42 @@ class TextPromptBlockYAML(BlockYAML):
|
|||
parameter_keys: list[str] | None = None
|
||||
json_schema: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def normalize_llm_selection(self) -> "TextPromptBlockYAML":
|
||||
raw_llm_key = self.llm_key.strip() if self.llm_key else None
|
||||
|
||||
if self.model:
|
||||
# `model` is the stable public contract; ignore any raw llm_key override
|
||||
# once a model has been selected.
|
||||
self.llm_key = None
|
||||
return self
|
||||
|
||||
if not raw_llm_key:
|
||||
self.llm_key = None
|
||||
return self
|
||||
|
||||
if _has_jinja_syntax(raw_llm_key):
|
||||
self.llm_key = raw_llm_key
|
||||
return self
|
||||
|
||||
model_name = _get_text_prompt_model_name_by_llm_key().get(raw_llm_key)
|
||||
if model_name:
|
||||
self.model = {"model_name": model_name}
|
||||
self.llm_key = None
|
||||
return self
|
||||
|
||||
if raw_llm_key in LLMConfigRegistry.get_model_names():
|
||||
self.llm_key = raw_llm_key
|
||||
return self
|
||||
|
||||
LOG.warning(
|
||||
"Unrecognized text prompt llm_key; defaulting to Skyvern Optimized/default model path",
|
||||
label=self.label,
|
||||
llm_key=raw_llm_key,
|
||||
)
|
||||
self.llm_key = None
|
||||
return self
|
||||
|
||||
|
||||
class DownloadToS3BlockYAML(BlockYAML):
|
||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||
|
|
|
|||
|
|
@ -495,7 +495,8 @@ _FOR_LOOP_RE = re.compile(
|
|||
# nested quantifiers would cause exponential backtracking.
|
||||
_PAGE_CLICK_START_RE = re.compile(r"await page\.click\(")
|
||||
# Matches a prompt='...' or prompt="..." keyword argument inside a function call.
|
||||
_PROMPT_KWARG_RE = re.compile(r"""prompt\s*=\s*(['"])(.*?)\1""", re.DOTALL)
|
||||
# Uses (?:\\.|…) to skip backslash-escaped quotes so apostrophes don't truncate the match.
|
||||
_PROMPT_KWARG_RE = re.compile(r"""prompt\s*=\s*(['"])((?:\\.|(?!\1).)*?)\1""", re.DOTALL)
|
||||
|
||||
|
||||
def _find_click_calls(text: str) -> list[tuple[int, int, str]]:
|
||||
|
|
@ -622,8 +623,10 @@ def _patch_static_clicks_in_block(body: str) -> str:
|
|||
if prompt_match:
|
||||
quote = prompt_match.group(1)
|
||||
original_prompt = prompt_match.group(2)
|
||||
# Use an f-string so current_value is evaluated at runtime
|
||||
new_prompt = f"prompt=f{quote}{original_prompt} Target: {{current_value}}{quote}"
|
||||
# Use an f-string so current_value is evaluated at runtime.
|
||||
# Escape existing braces so they are literal in the f-string.
|
||||
escaped_prompt = original_prompt.replace("{", "{{").replace("}", "}}")
|
||||
new_prompt = f"prompt=f{quote}{escaped_prompt} Target: {{current_value}}{quote}"
|
||||
patched = patched[: prompt_match.start()] + new_prompt + patched[prompt_match.end() :]
|
||||
else:
|
||||
# No prompt= kwarg — add one with current_value context
|
||||
|
|
|
|||
|
|
@ -378,3 +378,31 @@ class TestFixStaticActionsInForLoops:
|
|||
break
|
||||
else:
|
||||
raise AssertionError("Expected a prompt=f line with current_value")
|
||||
|
||||
def test_prompt_with_curly_braces_are_escaped(self) -> None:
|
||||
"""Existing braces in prompt text must be escaped to avoid f-string evaluation."""
|
||||
body = textwrap.dedent("""\
|
||||
await page.click(
|
||||
selector='a.download-link',
|
||||
ai='fallback',
|
||||
prompt='Extract items matching {pattern}',
|
||||
)""")
|
||||
result = _patch_static_clicks_in_block(body)
|
||||
assert "ai='proactive'" in result
|
||||
# Original braces must be doubled so they're literal in the f-string
|
||||
assert "{{pattern}}" in result
|
||||
# The injected Target should still reference current_value
|
||||
assert "{current_value}" in result
|
||||
|
||||
def test_two_level_nested_selector_is_matched(self) -> None:
|
||||
"""Selectors like tr:has(td:has-text("Report")) must not break the regex."""
|
||||
body = textwrap.dedent("""\
|
||||
await page.click(
|
||||
selector='tr:has(td:has-text("Report"))',
|
||||
ai='fallback',
|
||||
prompt='Click row',
|
||||
)""")
|
||||
result = _patch_static_clicks_in_block(body)
|
||||
assert "ai='proactive'" in result
|
||||
assert "ai='fallback'" not in result
|
||||
assert "current_value" in result
|
||||
|
|
|
|||
29
tests/unit/test_mcp_workflow_guidance.py
Normal file
29
tests/unit/test_mcp_workflow_guidance.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.cli.mcp_tools import mcp
|
||||
from skyvern.cli.mcp_tools.blocks import skyvern_block_schema
|
||||
from skyvern.cli.mcp_tools.prompts import BUILD_WORKFLOW_CONTENT
|
||||
|
||||
|
||||
def test_build_workflow_prompt_guides_text_prompt_defaults() -> None:
|
||||
assert "Default to Skyvern Optimized for text_prompt blocks by omitting both `model` and `llm_key`." in (
|
||||
BUILD_WORKFLOW_CONTENT
|
||||
)
|
||||
assert "Do NOT invent internal `llm_key` strings like `ANTHROPIC_CLAUDE_3_5_SONNET`." in BUILD_WORKFLOW_CONTENT
|
||||
|
||||
|
||||
def test_mcp_instructions_guide_text_prompt_defaults() -> None:
|
||||
assert "For **text_prompt** blocks, default to Skyvern Optimized by omitting both `model` and `llm_key`." in (
|
||||
mcp.instructions
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_prompt_block_schema_example_omits_raw_llm_key() -> None:
|
||||
result = await skyvern_block_schema(block_type="text_prompt")
|
||||
|
||||
assert result["ok"] is True
|
||||
assert "llm_key" not in result["data"]["example"]
|
||||
assert "model" not in result["data"]["example"]
|
||||
335
tests/unit/test_mcp_workflow_tools.py
Normal file
335
tests/unit/test_mcp_workflow_tools.py
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import skyvern.cli.mcp_tools.workflow as workflow_tools
|
||||
|
||||
|
||||
def _fake_workflow_response() -> SimpleNamespace:
|
||||
now = datetime.now(timezone.utc)
|
||||
return SimpleNamespace(
|
||||
workflow_permanent_id="wpid_test",
|
||||
workflow_id="wf_test",
|
||||
title="Example Workflow",
|
||||
version=1,
|
||||
status="published",
|
||||
description=None,
|
||||
is_saved_task=False,
|
||||
folder_id=None,
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_create_normalizes_invalid_text_prompt_llm_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Normalize invalid llm_key",
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize the result",
|
||||
"llm_key": "ANTHROPIC_CLAUDE_3_5_SONNET",
|
||||
}
|
||||
],
|
||||
"parameters": [],
|
||||
},
|
||||
}
|
||||
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
|
||||
|
||||
sent_definition = fake_client.create_workflow.await_args.kwargs["json_definition"]
|
||||
sent_block = sent_definition.workflow_definition.blocks[0]
|
||||
|
||||
assert result["ok"] is True
|
||||
assert sent_block.llm_key is None
|
||||
assert sent_block.model is None
|
||||
assert "ANTHROPIC_CLAUDE_3_5_SONNET" not in json.dumps(sent_definition.model_dump(mode="json"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_create_preserves_explicit_internal_text_prompt_llm_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Preserve explicit internal llm_key",
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize the result",
|
||||
"llm_key": "SPECIAL_INTERNAL_KEY",
|
||||
}
|
||||
],
|
||||
"parameters": [],
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
|
||||
return_value=["SPECIAL_INTERNAL_KEY"],
|
||||
):
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
|
||||
|
||||
sent_definition = fake_client.create_workflow.await_args.kwargs["json_definition"]
|
||||
sent_block = sent_definition.workflow_definition.blocks[0]
|
||||
|
||||
assert result["ok"] is True
|
||||
assert sent_block.model is None
|
||||
assert sent_block.llm_key == "SPECIAL_INTERNAL_KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prove the exact Slack scenario: MCP agent hallucinates various model strings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hallucinated_key",
|
||||
[
|
||||
"ANTHROPIC_CLAUDE_3_5_SONNET", # exact key from the Slack thread
|
||||
"ANTHROPIC_CLAUDE3.5_SONNET", # the "correct" key Pedro mentioned — still not public
|
||||
"ANTHROPIC_CLAUDE_3_5_HAIKU",
|
||||
"OPENAI_GPT4_TURBO",
|
||||
"VERTEX_GEMINI_2_FLASH",
|
||||
"claude-3-opus-20240229",
|
||||
"gpt-4o-mini",
|
||||
"gemini-pro",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_strips_all_common_hallucinated_llm_keys(
|
||||
monkeypatch: pytest.MonkeyPatch, hallucinated_key: str
|
||||
) -> None:
|
||||
"""MCP workflow creation must strip ANY hallucinated llm_key and default to Skyvern Optimized (null)."""
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Agent-generated workflow",
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "limit_pages",
|
||||
"prompt": "Extract only the first 10 results",
|
||||
"llm_key": hallucinated_key,
|
||||
}
|
||||
],
|
||||
"parameters": [],
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
|
||||
return_value=[], # simulate: none of these are registered
|
||||
):
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="auto")
|
||||
|
||||
sent_def = fake_client.create_workflow.await_args.kwargs["json_definition"]
|
||||
sent_block = sent_def.workflow_definition.blocks[0]
|
||||
|
||||
assert result["ok"] is True
|
||||
assert sent_block.llm_key is None, f"hallucinated key {hallucinated_key!r} was NOT stripped"
|
||||
assert sent_block.model is None, "should default to Skyvern Optimized (null model)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_create_preserves_unknown_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Fields not in the internal schema should survive normalization.
|
||||
|
||||
The Fern-generated WorkflowCreateYamlRequest uses extra='allow', so unknown
|
||||
fields are accepted. Our normalization deep-merges the original raw dict with
|
||||
the normalized output so that future SDK fields not yet mirrored in the
|
||||
internal schema are preserved at any nesting depth.
|
||||
"""
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Unknown fields test",
|
||||
"some_future_sdk_field": "should_survive",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"some_nested_future_field": "also_survives",
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
|
||||
assert result["ok"] is True
|
||||
|
||||
sent = fake_client.create_workflow.await_args.kwargs["json_definition"]
|
||||
# Top-level unknown field preserved
|
||||
assert sent.some_future_sdk_field == "should_survive"
|
||||
# Nested unknown field inside workflow_definition also preserved via deep merge
|
||||
wd = sent.workflow_definition
|
||||
wd_dict = wd.model_dump(mode="json") if hasattr(wd, "model_dump") else wd.__dict__
|
||||
assert wd_dict.get("some_nested_future_field") == "also_survives"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_create_preserves_block_level_unknown_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Unknown fields inside individual block dicts survive normalization via deep merge."""
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Block-level unknown fields test",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize",
|
||||
"some_future_block_field": 42,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
|
||||
assert result["ok"] is True
|
||||
|
||||
sent = fake_client.create_workflow.await_args.kwargs["json_definition"]
|
||||
sent_block = sent.workflow_definition.blocks[0]
|
||||
block_dict = sent_block.model_dump(mode="json") if hasattr(sent_block, "model_dump") else sent_block.__dict__
|
||||
assert block_dict.get("some_future_block_field") == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_create_falls_back_on_schema_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If the internal schema rejects the payload, normalization is skipped and the raw dict is forwarded."""
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Schema rejection test",
|
||||
"workflow_definition": {
|
||||
"parameters": [],
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize",
|
||||
"llm_key": "HALLUCINATED_KEY",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"skyvern.cli.mcp_tools.workflow.WorkflowCreateYAMLRequestSchema.model_validate",
|
||||
side_effect=Exception("schema rejected"),
|
||||
):
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
|
||||
|
||||
assert result["ok"] is True
|
||||
sent = fake_client.create_workflow.await_args.kwargs["json_definition"]
|
||||
# Normalization was skipped, so the hallucinated key passes through to the SDK
|
||||
sent_block = sent.workflow_definition.blocks[0]
|
||||
block_dict = sent_block.model_dump(mode="json") if hasattr(sent_block, "model_dump") else sent_block.__dict__
|
||||
assert block_dict.get("llm_key") == "HALLUCINATED_KEY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format_name", ["json", "auto"])
|
||||
def test_parse_definition_returns_invalid_input_when_schema_fallback_still_fails(format_name: str) -> None:
|
||||
"""If both the internal schema and Fern fallback reject the JSON, return a structured INVALID_INPUT error."""
|
||||
|
||||
json_def, yaml_def, err = workflow_tools._parse_definition(
|
||||
json.dumps({"title": "Missing workflow_definition"}),
|
||||
format_name,
|
||||
)
|
||||
|
||||
assert json_def is None
|
||||
assert yaml_def is None
|
||||
assert err is not None
|
||||
assert err["code"] == workflow_tools.ErrorCode.INVALID_INPUT
|
||||
assert "Invalid JSON definition" in err["message"]
|
||||
|
||||
|
||||
def test_deep_merge_preserves_block_unknown_fields_when_list_lengths_differ() -> None:
|
||||
"""Overlapping block data should still merge when normalization changes list length."""
|
||||
|
||||
raw = {
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize",
|
||||
"some_future_block_field": 42,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
normalized = {
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize",
|
||||
},
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "synthetic_followup",
|
||||
"prompt": "Follow up",
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
merged = workflow_tools._deep_merge(raw, normalized)
|
||||
merged_blocks = merged["workflow_definition"]["blocks"]
|
||||
|
||||
assert len(merged_blocks) == 2
|
||||
assert merged_blocks[0]["some_future_block_field"] == 42
|
||||
assert merged_blocks[1]["label"] == "synthetic_followup"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_text_prompt_without_llm_key_stays_null(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""When MCP correctly omits llm_key (Skyvern Optimized), it stays null through the whole pipeline."""
|
||||
fake_client = SimpleNamespace(create_workflow=AsyncMock(return_value=_fake_workflow_response()))
|
||||
monkeypatch.setattr(workflow_tools, "get_skyvern", lambda: fake_client)
|
||||
|
||||
definition = {
|
||||
"title": "Well-behaved MCP workflow",
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize the extracted data",
|
||||
}
|
||||
],
|
||||
"parameters": [],
|
||||
},
|
||||
}
|
||||
|
||||
result = await workflow_tools.skyvern_workflow_create(definition=json.dumps(definition), format="json")
|
||||
sent_block = fake_client.create_workflow.await_args.kwargs["json_definition"].workflow_definition.blocks[0]
|
||||
|
||||
assert result["ok"] is True
|
||||
assert sent_block.llm_key is None
|
||||
assert sent_block.model is None
|
||||
|
|
@ -9,6 +9,7 @@ from skyvern.forge.prompts import prompt_engine
|
|||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.models.block import TextPromptBlock
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType
|
||||
from skyvern.schemas.workflows import TextPromptBlockYAML, WorkflowRequest
|
||||
|
||||
block_module = sys.modules["skyvern.forge.sdk.workflow.models.block"]
|
||||
|
||||
|
|
@ -212,3 +213,177 @@ async def test_text_prompt_block_prefers_prompt_type_config_over_secondary(monke
|
|||
assert captured["default_handler"] == prompt_config_handler
|
||||
prompt_config_handler.assert_awaited_once()
|
||||
assert response == {"llm_response": "config"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_prompt_block_bad_llm_key_uses_same_runtime_path_as_no_override(monkeypatch):
|
||||
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
normalized_bad = TextPromptBlockYAML(
|
||||
label="bad_key",
|
||||
prompt="Summarize status.",
|
||||
llm_key="ANTHROPIC_CLAUDE_3_5_SONNET",
|
||||
)
|
||||
no_override = TextPromptBlockYAML(
|
||||
label="no_override",
|
||||
prompt="Summarize status.",
|
||||
llm_key=None,
|
||||
)
|
||||
|
||||
blocks = []
|
||||
for idx, yaml_block in enumerate((normalized_bad, no_override), start=1):
|
||||
output_parameter = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key=f"text_prompt_output_{idx}",
|
||||
description=None,
|
||||
output_parameter_id=f"output-{idx}",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
blocks.append(
|
||||
TextPromptBlock(
|
||||
label=yaml_block.label,
|
||||
llm_key=yaml_block.llm_key,
|
||||
prompt=yaml_block.prompt,
|
||||
parameters=[],
|
||||
json_schema=None,
|
||||
output_parameter=output_parameter,
|
||||
model=yaml_block.model,
|
||||
)
|
||||
)
|
||||
|
||||
captured: list[tuple[str | None, object]] = []
|
||||
fake_secondary_handler = AsyncMock(return_value={"llm_response": "secondary"})
|
||||
|
||||
async def fake_prompt_type_handler(*args, **kwargs):
|
||||
return None
|
||||
|
||||
def fake_get_override_handler(llm_key: str | None, *, default):
|
||||
captured.append((llm_key, default))
|
||||
return default
|
||||
|
||||
block_module.app.SECONDARY_LLM_API_HANDLER = fake_secondary_handler
|
||||
block_module.app.LLM_API_HANDLER = AsyncMock()
|
||||
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
|
||||
monkeypatch.setattr(
|
||||
LLMAPIHandlerFactory,
|
||||
"get_override_llm_api_handler",
|
||||
fake_get_override_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
block_module,
|
||||
"get_llm_handler_for_prompt_type",
|
||||
fake_prompt_type_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_engine,
|
||||
"load_prompt_from_string",
|
||||
lambda template, **kwargs: template,
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
|
||||
assert response == {"llm_response": "secondary"}
|
||||
|
||||
assert captured == [
|
||||
(None, fake_secondary_handler),
|
||||
(None, fake_secondary_handler),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_prompt_block_uses_explicit_internal_llm_key_override(monkeypatch):
|
||||
now = datetime.now(timezone.utc)
|
||||
output_parameter = OutputParameter(
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
key="text_prompt_output_internal",
|
||||
description=None,
|
||||
output_parameter_id="output-internal",
|
||||
workflow_id="workflow-1",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
block = TextPromptBlock(
|
||||
label="text-block",
|
||||
llm_key="SPECIAL_INTERNAL_KEY",
|
||||
prompt="Summarize status.",
|
||||
parameters=[],
|
||||
json_schema=None,
|
||||
output_parameter=output_parameter,
|
||||
model=None,
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_default_handler = AsyncMock()
|
||||
fake_override_handler = AsyncMock(return_value={"llm_response": "override"})
|
||||
|
||||
async def fake_resolve_default_llm_handler(*args, **kwargs):
|
||||
return fake_default_handler
|
||||
|
||||
def fake_get_override_handler(llm_key: str | None, *, default):
|
||||
captured["llm_key"] = llm_key
|
||||
captured["default_handler"] = default
|
||||
return fake_override_handler
|
||||
|
||||
block_module.app.LLM_API_HANDLER = fake_default_handler
|
||||
LLMAPIHandlerFactory = block_module.LLMAPIHandlerFactory
|
||||
monkeypatch.setattr(
|
||||
LLMAPIHandlerFactory,
|
||||
"get_override_llm_api_handler",
|
||||
fake_get_override_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
TextPromptBlock,
|
||||
"_resolve_default_llm_handler",
|
||||
fake_resolve_default_llm_handler,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
prompt_engine,
|
||||
"load_prompt_from_string",
|
||||
lambda template, **kwargs: template,
|
||||
)
|
||||
|
||||
response = await block.send_prompt(block.prompt, {}, workflow_run_id="workflow-run", organization_id="org-1")
|
||||
|
||||
assert captured["llm_key"] == "SPECIAL_INTERNAL_KEY"
|
||||
assert captured["default_handler"] == fake_default_handler
|
||||
fake_override_handler.assert_awaited_once()
|
||||
assert response == {"llm_response": "override"}
|
||||
|
||||
|
||||
def test_workflow_request_deserialization_strips_invalid_text_prompt_llm_key() -> None:
|
||||
"""Verify FastAPI deserialization (not just explicit model_validate) strips bad keys.
|
||||
|
||||
Moved from tests/scenario/ — this test only validates Pydantic behavior and
|
||||
does not require a database connection.
|
||||
"""
|
||||
raw_request = {
|
||||
"json_definition": {
|
||||
"title": "Deserialization test",
|
||||
"workflow_definition": {
|
||||
"blocks": [
|
||||
{
|
||||
"block_type": "text_prompt",
|
||||
"label": "summarize",
|
||||
"prompt": "Summarize the result",
|
||||
"llm_key": "ANTHROPIC_CLAUDE_3_5_SONNET",
|
||||
}
|
||||
],
|
||||
"parameters": [],
|
||||
},
|
||||
}
|
||||
}
|
||||
workflow_request = WorkflowRequest.model_validate(raw_request)
|
||||
|
||||
block = workflow_request.json_definition.workflow_definition.blocks[0]
|
||||
assert block.llm_key is None
|
||||
assert block.model is None
|
||||
|
|
|
|||
88
tests/unit/test_text_prompt_block_yaml.py
Normal file
88
tests/unit/test_text_prompt_block_yaml.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.config import settings as base_settings
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.schemas.workflows import TextPromptBlockYAML, _get_text_prompt_model_name_by_llm_key
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_llm_key_cache():
|
||||
"""Clear the lru_cache before each test to prevent cross-test pollution."""
|
||||
_get_text_prompt_model_name_by_llm_key.cache_clear()
|
||||
yield
|
||||
_get_text_prompt_model_name_by_llm_key.cache_clear()
|
||||
|
||||
|
||||
class TestTextPromptBlockYAMLNormalization:
|
||||
def test_converts_known_llm_key_to_model(self, monkeypatch) -> None:
|
||||
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
|
||||
|
||||
block = TextPromptBlockYAML(
|
||||
label="summarize",
|
||||
prompt="Summarize the data.",
|
||||
llm_key="VERTEX_GEMINI_2.5_FLASH",
|
||||
)
|
||||
|
||||
assert block.model == {"model_name": "gemini-2.5-flash"}
|
||||
assert block.llm_key is None
|
||||
|
||||
def test_clears_invalid_llm_key_to_use_default_model(self, monkeypatch) -> None:
|
||||
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
|
||||
|
||||
with patch(
|
||||
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
|
||||
return_value=[],
|
||||
):
|
||||
block = TextPromptBlockYAML(
|
||||
label="summarize",
|
||||
prompt="Summarize the data.",
|
||||
llm_key="ANTHROPIC_CLAUDE_3_5_SONNET",
|
||||
)
|
||||
|
||||
assert block.model is None
|
||||
assert block.llm_key is None
|
||||
|
||||
def test_preserves_registered_advanced_llm_key(self, monkeypatch) -> None:
|
||||
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
|
||||
|
||||
with patch(
|
||||
"skyvern.schemas.workflows.LLMConfigRegistry.get_model_names",
|
||||
return_value=["SPECIAL_INTERNAL_KEY"],
|
||||
):
|
||||
block = TextPromptBlockYAML(
|
||||
label="summarize",
|
||||
prompt="Summarize the data.",
|
||||
llm_key="SPECIAL_INTERNAL_KEY",
|
||||
)
|
||||
|
||||
assert block.model is None
|
||||
assert block.llm_key == "SPECIAL_INTERNAL_KEY"
|
||||
|
||||
def test_preserves_templated_llm_key(self, monkeypatch) -> None:
|
||||
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
|
||||
|
||||
block = TextPromptBlockYAML(
|
||||
label="summarize",
|
||||
prompt="Summarize the data.",
|
||||
llm_key="{{ prompt_block_llm_key }}",
|
||||
)
|
||||
|
||||
assert block.model is None
|
||||
assert block.llm_key == "{{ prompt_block_llm_key }}"
|
||||
|
||||
def test_model_override_clears_raw_llm_key(self, monkeypatch) -> None:
|
||||
monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings)
|
||||
|
||||
block = TextPromptBlockYAML(
|
||||
label="summarize",
|
||||
prompt="Summarize the data.",
|
||||
llm_key="ANTHROPIC_CLAUDE_3_5_SONNET",
|
||||
model={"model_name": "gemini-3-pro-preview"},
|
||||
)
|
||||
|
||||
assert block.model == {"model_name": "gemini-3-pro-preview"}
|
||||
assert block.llm_key is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue