mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
215 lines
9 KiB
Python
215 lines
9 KiB
Python
"""Tests for WorkflowService.setup_workflow_run batch parameter persistence.
|
|
|
|
Verifies that setup_workflow_run collects all parameter values first and
|
|
persists them in a single batch insert, and that validation failures
|
|
(missing params, invalid credentials, DB errors) are handled correctly.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from skyvern.exceptions import InvalidCredentialId, MissingValueForParameter, WorkflowRunParameterPersistenceError
|
|
from skyvern.forge.sdk.core import skyvern_context
|
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
|
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter, WorkflowParameterType
|
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody
|
|
from skyvern.forge.sdk.workflow.service import WorkflowService
|
|
|
|
|
|
def _make_workflow_parameter(
|
|
key: str,
|
|
*,
|
|
workflow_parameter_type: WorkflowParameterType = WorkflowParameterType.STRING,
|
|
default_value: str | int | float | bool | dict | list | None = None,
|
|
) -> WorkflowParameter:
|
|
now = datetime.now(tz=timezone.utc)
|
|
return WorkflowParameter(
|
|
workflow_parameter_id=f"wp_{key}",
|
|
workflow_id="wf_test",
|
|
key=key,
|
|
workflow_parameter_type=workflow_parameter_type,
|
|
default_value=default_value,
|
|
created_at=now,
|
|
modified_at=now,
|
|
)
|
|
|
|
|
|
def _make_service_with_mocks(
|
|
*,
|
|
workflow_parameters: list[WorkflowParameter],
|
|
batch_side_effect: Exception | None = None,
|
|
single_side_effect: Exception | None = None,
|
|
) -> tuple[WorkflowService, SimpleNamespace, SimpleNamespace]:
|
|
"""Helper to build a WorkflowService with mocked internals for setup_workflow_run tests."""
|
|
service = WorkflowService()
|
|
workflow = SimpleNamespace(
|
|
workflow_id="wf_test",
|
|
workflow_permanent_id="wpid_test",
|
|
organization_id="org_test",
|
|
proxy_location=None,
|
|
webhook_callback_url=None,
|
|
extra_http_headers=None,
|
|
run_with="agent",
|
|
code_version=None,
|
|
adaptive_caching=False,
|
|
sequential_key=None,
|
|
)
|
|
workflow_run = SimpleNamespace(workflow_run_id="wr_test", workflow_permanent_id="wpid_test")
|
|
|
|
service.get_workflow_by_permanent_id = AsyncMock(return_value=workflow) # type: ignore[method-assign]
|
|
service.create_workflow_run = AsyncMock(return_value=workflow_run) # type: ignore[method-assign]
|
|
service.get_workflow_parameters = AsyncMock(return_value=workflow_parameters) # type: ignore[method-assign]
|
|
if batch_side_effect:
|
|
service.create_workflow_run_parameters = AsyncMock(side_effect=batch_side_effect) # type: ignore[method-assign]
|
|
else:
|
|
service.create_workflow_run_parameters = AsyncMock(return_value=[]) # type: ignore[method-assign]
|
|
if single_side_effect:
|
|
service.create_workflow_run_parameter = AsyncMock(side_effect=single_side_effect) # type: ignore[method-assign]
|
|
else:
|
|
service.create_workflow_run_parameter = AsyncMock() # type: ignore[method-assign]
|
|
service.mark_workflow_run_as_failed = AsyncMock(return_value=workflow_run) # type: ignore[method-assign]
|
|
|
|
organization = SimpleNamespace(organization_id="org_test", organization_name="Test Org")
|
|
return service, organization, workflow_run
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_context() -> None:
|
|
skyvern_context.reset()
|
|
yield
|
|
skyvern_context.reset()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_workflow_run_raises_on_missing_required_parameters() -> None:
|
|
"""When required parameters have no value and no default, setup should raise MissingValueForParameter."""
|
|
required_param = _make_workflow_parameter("api_key") # no default_value
|
|
service, organization, _ = _make_service_with_mocks(workflow_parameters=[required_param])
|
|
|
|
request = WorkflowRequestBody(data={}) # no data for api_key
|
|
|
|
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
|
|
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
|
|
|
|
with pytest.raises(MissingValueForParameter):
|
|
await service.setup_workflow_run(
|
|
request_id="req_test",
|
|
workflow_request=request,
|
|
workflow_permanent_id="wpid_test",
|
|
organization=organization,
|
|
)
|
|
|
|
service.create_workflow_run_parameters.assert_not_awaited()
|
|
service.mark_workflow_run_as_failed.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_workflow_run_persistence_error_identifies_specific_failing_parameter() -> None:
|
|
"""When batch fails with multiple params, fallback to one-by-one should pinpoint the failing key."""
|
|
params = [
|
|
_make_workflow_parameter(
|
|
"alpha_count", workflow_parameter_type=WorkflowParameterType.INTEGER, default_value="1"
|
|
),
|
|
_make_workflow_parameter("middle_label", default_value="mid"),
|
|
_make_workflow_parameter("zebra_url", default_value="https://zebra.example.com"),
|
|
]
|
|
batch_error = IntegrityError("INSERT", {}, Exception("constraint failed"))
|
|
single_error = IntegrityError("INSERT", {}, Exception("NOT NULL constraint on middle_label"))
|
|
|
|
# Single insert succeeds for alpha_count, fails on middle_label
|
|
async def _single_insert_side_effect(
|
|
*, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: object
|
|
) -> None:
|
|
if workflow_parameter.key == "middle_label":
|
|
raise single_error
|
|
|
|
service, organization, _ = _make_service_with_mocks(
|
|
workflow_parameters=params,
|
|
batch_side_effect=batch_error,
|
|
)
|
|
service.create_workflow_run_parameter = AsyncMock(side_effect=_single_insert_side_effect) # type: ignore[method-assign]
|
|
|
|
request = WorkflowRequestBody(data={"alpha_count": 5, "middle_label": "test", "zebra_url": "https://z.com"})
|
|
|
|
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
|
|
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
|
|
|
|
with pytest.raises(WorkflowRunParameterPersistenceError) as exc_info:
|
|
await service.setup_workflow_run(
|
|
request_id="req_test",
|
|
workflow_request=request,
|
|
workflow_permanent_id="wpid_test",
|
|
organization=organization,
|
|
)
|
|
|
|
error_message = str(exc_info.value)
|
|
# Should identify only the failing parameter, not all three
|
|
assert "middle_label" in error_message
|
|
assert "alpha_count" not in error_message
|
|
assert "zebra_url" not in error_message
|
|
assert exc_info.value.__cause__ is single_error
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_workflow_run_raises_on_non_string_credential_id() -> None:
|
|
"""Credential ID parameters must be strings. Passing an int should raise InvalidCredentialId."""
|
|
cred_param = _make_workflow_parameter(
|
|
"credential",
|
|
workflow_parameter_type=WorkflowParameterType.CREDENTIAL_ID,
|
|
)
|
|
service, organization, _ = _make_service_with_mocks(workflow_parameters=[cred_param])
|
|
|
|
request = WorkflowRequestBody(data={"credential": 12345}) # not a string
|
|
|
|
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
|
|
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
|
|
|
|
with pytest.raises(InvalidCredentialId):
|
|
await service.setup_workflow_run(
|
|
request_id="req_test",
|
|
workflow_request=request,
|
|
workflow_permanent_id="wpid_test",
|
|
organization=organization,
|
|
)
|
|
|
|
service.create_workflow_run_parameters.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_workflow_run_preserves_parent_loop_state_when_replacing_context() -> None:
|
|
service, organization, _ = _make_service_with_mocks(workflow_parameters=[])
|
|
|
|
loop_state = {"downloaded_file_signatures_before_iteration": [("a.pdf", "abc", "https://files/a.pdf")]}
|
|
parent_context = SkyvernContext(
|
|
organization_id="org_test",
|
|
organization_name="Test Org",
|
|
workflow_run_id="wr_parent",
|
|
root_workflow_run_id="wr_root",
|
|
run_id="wr_parent",
|
|
loop_internal_state=loop_state,
|
|
)
|
|
skyvern_context.set(parent_context)
|
|
|
|
with patch("skyvern.forge.sdk.workflow.service.app") as mock_app:
|
|
mock_app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached = AsyncMock(return_value=False)
|
|
|
|
await service.setup_workflow_run(
|
|
request_id="req_test",
|
|
workflow_request=WorkflowRequestBody(data={}),
|
|
workflow_permanent_id="wpid_test",
|
|
organization=organization,
|
|
)
|
|
|
|
current_context = skyvern_context.current()
|
|
assert current_context is not None
|
|
assert current_context.workflow_run_id == "wr_test"
|
|
assert current_context.run_id == "wr_parent"
|
|
assert current_context.root_workflow_run_id == "wr_root"
|
|
assert current_context.loop_internal_state == loop_state
|
|
assert current_context.loop_internal_state is not loop_state
|