diff --git a/skyvern/forge/sdk/db/repositories/scripts.py b/skyvern/forge/sdk/db/repositories/scripts.py index 0f1b82b87..527bfbc22 100644 --- a/skyvern/forge/sdk/db/repositories/scripts.py +++ b/skyvern/forge/sdk/db/repositories/scripts.py @@ -521,6 +521,35 @@ class ScriptsRepository(BaseRepository): workflow_script_model = (await session.scalars(query)).first() return WorkflowScript.model_validate(workflow_script_model) if workflow_script_model else None + @db_operation("get_workflow_script_source_workflow_id") + async def get_workflow_script_source_workflow_id( + self, + *, + organization_id: str, + workflow_permanent_id: str, + script_id: str, + cache_key_value: str, + ) -> str | None: + """Return the workflow version (w_*) that produced a given cached script row. + + Used to detect when the workflow definition has changed since the cached + script was generated (SKY-9254). + """ + async with self.Session() as session: + query = ( + select(WorkflowScriptModel.workflow_id) + .where( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScriptModel.script_id == script_id, + WorkflowScriptModel.cache_key_value == cache_key_value, + WorkflowScriptModel.deleted_at.is_(None), + ) + .order_by(WorkflowScriptModel.created_at.desc()) + .limit(1) + ) + return (await session.scalars(query)).first() + @db_operation("get_workflow_script_by_cache_key_value") async def get_workflow_script_by_cache_key_value( self, diff --git a/skyvern/services/workflow_script_service.py b/skyvern/services/workflow_script_service.py index 0e7e62ba9..d1ac694fc 100644 --- a/skyvern/services/workflow_script_service.py +++ b/skyvern/services/workflow_script_service.py @@ -246,6 +246,60 @@ async def generate_or_update_pending_workflow_script( ) +async def _invalidate_if_parameters_changed( + workflow: Workflow, + existing_script: Script, + cache_key_value: str, + workflow_run_id: str, +) -> bool: + """Return True if the cached script should be invalidated because the + workflow's parameter key set has changed since the script was generated. + + Only fires when the workflow version id differs from the version that + produced the cached script, so steady-state cache hits pay no extra DB + work. A missing prior workflow row (hard-deleted) is treated as a cache + miss for safety. + """ + cache_workflow_id = await app.DATABASE.scripts.get_workflow_script_source_workflow_id( + organization_id=workflow.organization_id, + workflow_permanent_id=workflow.workflow_permanent_id, + script_id=existing_script.script_id, + cache_key_value=cache_key_value, + ) + if not cache_workflow_id or cache_workflow_id == workflow.workflow_id: + return False + + old_workflow = await app.DATABASE.workflows.get_workflow( + workflow_id=cache_workflow_id, + organization_id=workflow.organization_id, + ) + if old_workflow is None: + LOG.info( + "Cached script invalidated: prior workflow version not found", + workflow_id=workflow.workflow_id, + cache_workflow_id=cache_workflow_id, + script_id=existing_script.script_id, + workflow_run_id=workflow_run_id, + ) + return True + + old_param_keys = {p.key for p in old_workflow.workflow_definition.parameters} + new_param_keys = {p.key for p in workflow.workflow_definition.parameters} + if old_param_keys != new_param_keys: + LOG.info( + "Cached script invalidated: workflow parameter set changed", + workflow_id=workflow.workflow_id, + cache_workflow_id=cache_workflow_id, + script_id=existing_script.script_id, + workflow_run_id=workflow_run_id, + added_params=sorted(new_param_keys - old_param_keys), + removed_params=sorted(old_param_keys - new_param_keys), + ) + return True + + return False + + async def get_workflow_script( workflow: Workflow, workflow_run: WorkflowRun, @@ -313,6 +367,22 @@ async def get_workflow_script( ) if existing_script: + # SKY-9254: invalidate the cached script when the workflow's parameter + # set has changed since it was generated. Cache lookup keys on + # (org, wpid, cache_key_value) — none of which change when a user + # edits the workflow to add/remove a parameter. Without this check + # the old cached code (which has no reference to the new param) + # keeps getting served, and the new param ends up injected wherever + # the agent guesses. + invalidated = await _invalidate_if_parameters_changed( + workflow=workflow, + existing_script=existing_script, + cache_key_value=rendered_cache_key_value, + workflow_run_id=workflow_run.workflow_run_id, + ) + if invalidated: + return None, rendered_cache_key_value, False + LOG.info( "Found cached script for workflow (cache hit)", workflow_id=workflow.workflow_id, diff --git a/tests/unit/workflow/test_cache_invalidation_on_param_change.py b/tests/unit/workflow/test_cache_invalidation_on_param_change.py new file mode 100644 index 000000000..6241238f7 --- /dev/null +++ b/tests/unit/workflow/test_cache_invalidation_on_param_change.py @@ -0,0 +1,194 @@ +"""Tests for cache invalidation when a workflow's parameter set changes (SKY-9254).""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from skyvern.forge.sdk.workflow.models.parameter import ( + WorkflowParameter, + WorkflowParameterType, +) +from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowDefinition +from skyvern.schemas.scripts import Script +from skyvern.services.workflow_script_service import _invalidate_if_parameters_changed + + +def _workflow_param(key: str) -> WorkflowParameter: + now = datetime.now(timezone.utc) + return WorkflowParameter( + key=key, + description="", + workflow_parameter_id=f"wp_{key}", + workflow_parameter_type=WorkflowParameterType.STRING, + workflow_id="w_current", + created_at=now, + modified_at=now, + ) + + +def _workflow(workflow_id: str, param_keys: list[str]) -> Workflow: + now = datetime.now(timezone.utc) + return Workflow( + workflow_id=workflow_id, + organization_id="o_test", + title="Test", + workflow_permanent_id="wpid_test", + version=1, + is_saved_task=False, + workflow_definition=WorkflowDefinition( + blocks=[], + parameters=[_workflow_param(k) for k in param_keys], + ), + created_at=now, + modified_at=now, + ) + + +def _script() -> Script: + now = datetime.now(timezone.utc) + return Script( + script_id="s_test", + script_revision_id="sr_test", + organization_id="o_test", + run_id="wr_old", + version=1, + created_at=now, + modified_at=now, + ) + + +@pytest.mark.asyncio +async def test_does_not_invalidate_when_workflow_version_matches() -> None: + """No DB work past the source lookup if the cached row was produced by the + current workflow version — this is the hot path on every cache hit.""" + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_current") + mock_db.workflows.get_workflow = AsyncMock() + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["name"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is False + mock_db.workflows.get_workflow.assert_not_called() + + +@pytest.mark.asyncio +async def test_does_not_invalidate_when_source_workflow_id_missing() -> None: + """Legacy rows without a workflow_id stored can't be diffed — keep serving.""" + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value=None) + mock_db.workflows.get_workflow = AsyncMock() + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["name"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is False + mock_db.workflows.get_workflow.assert_not_called() + + +@pytest.mark.asyncio +async def test_invalidates_when_prior_workflow_hard_deleted() -> None: + """If the old workflow row is gone, we can't verify param set — play it safe.""" + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old") + mock_db.workflows.get_workflow = AsyncMock(return_value=None) + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["name"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is True + + +@pytest.mark.asyncio +async def test_does_not_invalidate_when_param_keys_identical() -> None: + """Cosmetic edits (title, description, webhook, proxy) bump workflow_id + but don't change the parameter set — cache stays warm.""" + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old") + mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["name", "email"])) + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["name", "email"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is False + + +@pytest.mark.asyncio +async def test_invalidates_when_parameter_added() -> None: + """This is the SKY-9254 case: a phone parameter was added post-cache.""" + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old") + mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["name"])) + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["name", "phone"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is True + + +@pytest.mark.asyncio +async def test_invalidates_when_parameter_removed() -> None: + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old") + mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["name", "phone"])) + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["name"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is True + + +@pytest.mark.asyncio +async def test_invalidates_when_parameter_renamed() -> None: + """Rename = remove old key + add new key. Both directions in the symmetric diff.""" + mock_db = MagicMock() + mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old") + mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["phone_number"])) + + with patch("skyvern.services.workflow_script_service.app") as mock_app: + mock_app.DATABASE = mock_db + result = await _invalidate_if_parameters_changed( + workflow=_workflow("w_current", ["phone"]), + existing_script=_script(), + cache_key_value="default:v2", + workflow_run_id="wr_new", + ) + + assert result is True