Skyvern/tests/unit/test_workflow_run_parameter_batching.py
2026-04-13 14:12:12 -07:00

215 lines
9 KiB
Python

"""Tests for WorkflowService.setup_workflow_run batch parameter persistence.
Verifies that setup_workflow_run collects all parameter values first and
persists them in a single batch insert, and that validation failures
(missing params, invalid credentials, DB errors) are handled correctly.
"""
from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy.exc import IntegrityError
from skyvern.exceptions import InvalidCredentialId, MissingValueForParameter, WorkflowRunParameterPersistenceError
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter, WorkflowParameterType
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody
from skyvern.forge.sdk.workflow.service import WorkflowService
def _make_workflow_parameter(
key: str,
*,
workflow_parameter_type: WorkflowParameterType = WorkflowParameterType.STRING,
default_value: str | int | float | bool | dict | list | None = None,
) -> WorkflowParameter:
now = datetime.now(tz=timezone.utc)
return WorkflowParameter(
workflow_parameter_id=f"wp_{key}",
workflow_id="wf_test",
key=key,
workflow_parameter_type=workflow_parameter_type,
default_value=default_value,
created_at=now,
modified_at=now,
)
def _make_service_with_mocks(
*,
workflow_parameters: list[WorkflowParameter],
batch_side_effect: Exception | None = None,
single_side_effect: Exception | None = None,
) -> tuple[WorkflowService, SimpleNamespace, SimpleNamespace]:
"""Helper to build a WorkflowService with mocked internals for setup_workflow_run tests."""
service = WorkflowService()
workflow = SimpleNamespace(
workflow_id="wf_test",
workflow_permanent_id="wpid_test",
organization_id="org_test",
proxy_location=None,
webhook_callback_url=None,
extra_http_headers=None,
run_with="agent",
code_version=None,
adaptive_caching=False,
sequential_key=None,
)
workflow_run = SimpleNamespace(workflow_run_id="wr_test", workflow_permanent_id="wpid_test")
service.get_workflow_by_permanent_id = AsyncMock(return_value=workflow) # type: ignore[method-assign]
service.create_workflow_run = AsyncMock(return_value=workflow_run) # type: ignore[method-assign]
service.get_workflow_parameters = AsyncMock(return_value=workflow_parameters) # type: ignore[method-assign]
if batch_side_effect:
service.create_workflow_run_parameters = AsyncMock(side_effect=batch_side_effect) # type: ignore[method-assign]
else:
service.create_workflow_run_parameters = AsyncMock(return_value=[]) # type: ignore[method-assign]
if single_side_effect:
service.create_workflow_run_parameter = AsyncMock(side_effect=single_side_effect) # type: ignore[method-assign]
else:
service.create_workflow_run_parameter = AsyncMock() # type: ignore[method-assign]
service.mark_workflow_run_as_failed = AsyncMock(return_value=workflow_run) # type: ignore[method-assign]
organization = SimpleNamespace(organization_id="org_test", organization_name="Test Org")
return service, organization, workflow_run
@pytest.fixture(autouse=True)
def reset_context() -> None:
skyvern_context.reset()
yield
skyvern_context.reset()
@pytest.mark.asyncio
async def test_setup_workflow_run_raises_on_missing_required_parameters() -> None:
"""When required parameters have no value and no default, setup should raise MissingValueForParameter."""
required_param = _make_workflow_parameter("api_key") # no default_value
service, organization, _ = _make_service_with_mocks(workflow_parameters=[required_param])
request = WorkflowRequestBody(data={}) # no data for api_key
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
with pytest.raises(MissingValueForParameter):
await service.setup_workflow_run(
request_id="req_test",
workflow_request=request,
workflow_permanent_id="wpid_test",
organization=organization,
)
service.create_workflow_run_parameters.assert_not_awaited()
service.mark_workflow_run_as_failed.assert_awaited_once()
@pytest.mark.asyncio
async def test_setup_workflow_run_persistence_error_identifies_specific_failing_parameter() -> None:
"""When batch fails with multiple params, fallback to one-by-one should pinpoint the failing key."""
params = [
_make_workflow_parameter(
"alpha_count", workflow_parameter_type=WorkflowParameterType.INTEGER, default_value="1"
),
_make_workflow_parameter("middle_label", default_value="mid"),
_make_workflow_parameter("zebra_url", default_value="https://zebra.example.com"),
]
batch_error = IntegrityError("INSERT", {}, Exception("constraint failed"))
single_error = IntegrityError("INSERT", {}, Exception("NOT NULL constraint on middle_label"))
# Single insert succeeds for alpha_count, fails on middle_label
async def _single_insert_side_effect(
*, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: object
) -> None:
if workflow_parameter.key == "middle_label":
raise single_error
service, organization, _ = _make_service_with_mocks(
workflow_parameters=params,
batch_side_effect=batch_error,
)
service.create_workflow_run_parameter = AsyncMock(side_effect=_single_insert_side_effect) # type: ignore[method-assign]
request = WorkflowRequestBody(data={"alpha_count": 5, "middle_label": "test", "zebra_url": "https://z.com"})
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
with pytest.raises(WorkflowRunParameterPersistenceError) as exc_info:
await service.setup_workflow_run(
request_id="req_test",
workflow_request=request,
workflow_permanent_id="wpid_test",
organization=organization,
)
error_message = str(exc_info.value)
# Should identify only the failing parameter, not all three
assert "middle_label" in error_message
assert "alpha_count" not in error_message
assert "zebra_url" not in error_message
assert exc_info.value.__cause__ is single_error
@pytest.mark.asyncio
async def test_setup_workflow_run_raises_on_non_string_credential_id() -> None:
"""Credential ID parameters must be strings. Passing an int should raise InvalidCredentialId."""
cred_param = _make_workflow_parameter(
"credential",
workflow_parameter_type=WorkflowParameterType.CREDENTIAL_ID,
)
service, organization, _ = _make_service_with_mocks(workflow_parameters=[cred_param])
request = WorkflowRequestBody(data={"credential": 12345}) # not a string
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
with pytest.raises(InvalidCredentialId):
await service.setup_workflow_run(
request_id="req_test",
workflow_request=request,
workflow_permanent_id="wpid_test",
organization=organization,
)
service.create_workflow_run_parameters.assert_not_awaited()
@pytest.mark.asyncio
async def test_setup_workflow_run_preserves_parent_loop_state_when_replacing_context() -> None:
service, organization, _ = _make_service_with_mocks(workflow_parameters=[])
loop_state = {"downloaded_file_signatures_before_iteration": [("a.pdf", "abc", "https://files/a.pdf")]}
parent_context = SkyvernContext(
organization_id="org_test",
organization_name="Test Org",
workflow_run_id="wr_parent",
root_workflow_run_id="wr_root",
run_id="wr_parent",
loop_internal_state=loop_state,
)
skyvern_context.set(parent_context)
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
await service.setup_workflow_run(
request_id="req_test",
workflow_request=WorkflowRequestBody(data={}),
workflow_permanent_id="wpid_test",
organization=organization,
)
current_context = skyvern_context.current()
assert current_context is not None
assert current_context.workflow_run_id == "wr_test"
assert current_context.run_id == "wr_parent"
assert current_context.root_workflow_run_id == "wr_root"
assert current_context.loop_internal_state == loop_state
assert current_context.loop_internal_state is not loop_state