mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
Invalidate cached script on workflow parameter set change (SKY-9254) (#5652)
This commit is contained in:
parent
91d967f5d7
commit
6abb0c619f
3 changed files with 293 additions and 0 deletions
|
|
@ -521,6 +521,35 @@ class ScriptsRepository(BaseRepository):
|
|||
workflow_script_model = (await session.scalars(query)).first()
|
||||
return WorkflowScript.model_validate(workflow_script_model) if workflow_script_model else None
|
||||
|
||||
@db_operation("get_workflow_script_source_workflow_id")
|
||||
async def get_workflow_script_source_workflow_id(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
script_id: str,
|
||||
cache_key_value: str,
|
||||
) -> str | None:
|
||||
"""Return the workflow version (w_*) that produced a given cached script row.
|
||||
|
||||
Used to detect when the workflow definition has changed since the cached
|
||||
script was generated (SKY-9254).
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowScriptModel.workflow_id)
|
||||
.where(
|
||||
WorkflowScriptModel.organization_id == organization_id,
|
||||
WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id,
|
||||
WorkflowScriptModel.script_id == script_id,
|
||||
WorkflowScriptModel.cache_key_value == cache_key_value,
|
||||
WorkflowScriptModel.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(WorkflowScriptModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return (await session.scalars(query)).first()
|
||||
|
||||
@db_operation("get_workflow_script_by_cache_key_value")
|
||||
async def get_workflow_script_by_cache_key_value(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -246,6 +246,60 @@ async def generate_or_update_pending_workflow_script(
|
|||
)
|
||||
|
||||
|
||||
async def _invalidate_if_parameters_changed(
|
||||
workflow: Workflow,
|
||||
existing_script: Script,
|
||||
cache_key_value: str,
|
||||
workflow_run_id: str,
|
||||
) -> bool:
|
||||
"""Return True if the cached script should be invalidated because the
|
||||
workflow's parameter key set has changed since the script was generated.
|
||||
|
||||
Only fires when the workflow version id differs from the version that
|
||||
produced the cached script, so steady-state cache hits pay no extra DB
|
||||
work. A missing prior workflow row (hard-deleted) is treated as a cache
|
||||
miss for safety.
|
||||
"""
|
||||
cache_workflow_id = await app.DATABASE.scripts.get_workflow_script_source_workflow_id(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
script_id=existing_script.script_id,
|
||||
cache_key_value=cache_key_value,
|
||||
)
|
||||
if not cache_workflow_id or cache_workflow_id == workflow.workflow_id:
|
||||
return False
|
||||
|
||||
old_workflow = await app.DATABASE.workflows.get_workflow(
|
||||
workflow_id=cache_workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if old_workflow is None:
|
||||
LOG.info(
|
||||
"Cached script invalidated: prior workflow version not found",
|
||||
workflow_id=workflow.workflow_id,
|
||||
cache_workflow_id=cache_workflow_id,
|
||||
script_id=existing_script.script_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return True
|
||||
|
||||
old_param_keys = {p.key for p in old_workflow.workflow_definition.parameters}
|
||||
new_param_keys = {p.key for p in workflow.workflow_definition.parameters}
|
||||
if old_param_keys != new_param_keys:
|
||||
LOG.info(
|
||||
"Cached script invalidated: workflow parameter set changed",
|
||||
workflow_id=workflow.workflow_id,
|
||||
cache_workflow_id=cache_workflow_id,
|
||||
script_id=existing_script.script_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
added_params=sorted(new_param_keys - old_param_keys),
|
||||
removed_params=sorted(old_param_keys - new_param_keys),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def get_workflow_script(
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
|
|
@ -313,6 +367,22 @@ async def get_workflow_script(
|
|||
)
|
||||
|
||||
if existing_script:
|
||||
# SKY-9254: invalidate the cached script when the workflow's parameter
|
||||
# set has changed since it was generated. Cache lookup keys on
|
||||
# (org, wpid, cache_key_value) — none of which change when a user
|
||||
# edits the workflow to add/remove a parameter. Without this check
|
||||
# the old cached code (which has no reference to the new param)
|
||||
# keeps getting served, and the new param ends up injected wherever
|
||||
# the agent guesses.
|
||||
invalidated = await _invalidate_if_parameters_changed(
|
||||
workflow=workflow,
|
||||
existing_script=existing_script,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
if invalidated:
|
||||
return None, rendered_cache_key_value, False
|
||||
|
||||
LOG.info(
|
||||
"Found cached script for workflow (cache hit)",
|
||||
workflow_id=workflow.workflow_id,
|
||||
|
|
|
|||
194
tests/unit/workflow/test_cache_invalidation_on_param_change.py
Normal file
194
tests/unit/workflow/test_cache_invalidation_on_param_change.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
"""Tests for cache invalidation when a workflow's parameter set changes (SKY-9254)."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowDefinition
|
||||
from skyvern.schemas.scripts import Script
|
||||
from skyvern.services.workflow_script_service import _invalidate_if_parameters_changed
|
||||
|
||||
|
||||
def _workflow_param(key: str) -> WorkflowParameter:
|
||||
now = datetime.now(timezone.utc)
|
||||
return WorkflowParameter(
|
||||
key=key,
|
||||
description="",
|
||||
workflow_parameter_id=f"wp_{key}",
|
||||
workflow_parameter_type=WorkflowParameterType.STRING,
|
||||
workflow_id="w_current",
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _workflow(workflow_id: str, param_keys: list[str]) -> Workflow:
|
||||
now = datetime.now(timezone.utc)
|
||||
return Workflow(
|
||||
workflow_id=workflow_id,
|
||||
organization_id="o_test",
|
||||
title="Test",
|
||||
workflow_permanent_id="wpid_test",
|
||||
version=1,
|
||||
is_saved_task=False,
|
||||
workflow_definition=WorkflowDefinition(
|
||||
blocks=[],
|
||||
parameters=[_workflow_param(k) for k in param_keys],
|
||||
),
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _script() -> Script:
|
||||
now = datetime.now(timezone.utc)
|
||||
return Script(
|
||||
script_id="s_test",
|
||||
script_revision_id="sr_test",
|
||||
organization_id="o_test",
|
||||
run_id="wr_old",
|
||||
version=1,
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_invalidate_when_workflow_version_matches() -> None:
|
||||
"""No DB work past the source lookup if the cached row was produced by the
|
||||
current workflow version — this is the hot path on every cache hit."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_current")
|
||||
mock_db.workflows.get_workflow = AsyncMock()
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["name"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
mock_db.workflows.get_workflow.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_invalidate_when_source_workflow_id_missing() -> None:
|
||||
"""Legacy rows without a workflow_id stored can't be diffed — keep serving."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value=None)
|
||||
mock_db.workflows.get_workflow = AsyncMock()
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["name"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
mock_db.workflows.get_workflow.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidates_when_prior_workflow_hard_deleted() -> None:
|
||||
"""If the old workflow row is gone, we can't verify param set — play it safe."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old")
|
||||
mock_db.workflows.get_workflow = AsyncMock(return_value=None)
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["name"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_invalidate_when_param_keys_identical() -> None:
|
||||
"""Cosmetic edits (title, description, webhook, proxy) bump workflow_id
|
||||
but don't change the parameter set — cache stays warm."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old")
|
||||
mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["name", "email"]))
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["name", "email"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidates_when_parameter_added() -> None:
|
||||
"""This is the SKY-9254 case: a phone parameter was added post-cache."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old")
|
||||
mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["name"]))
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["name", "phone"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidates_when_parameter_removed() -> None:
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old")
|
||||
mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["name", "phone"]))
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["name"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidates_when_parameter_renamed() -> None:
|
||||
"""Rename = remove old key + add new key. Both directions in the symmetric diff."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.scripts.get_workflow_script_source_workflow_id = AsyncMock(return_value="w_old")
|
||||
mock_db.workflows.get_workflow = AsyncMock(return_value=_workflow("w_old", ["phone_number"]))
|
||||
|
||||
with patch("skyvern.services.workflow_script_service.app") as mock_app:
|
||||
mock_app.DATABASE = mock_db
|
||||
result = await _invalidate_if_parameters_changed(
|
||||
workflow=_workflow("w_current", ["phone"]),
|
||||
existing_script=_script(),
|
||||
cache_key_value="default:v2",
|
||||
workflow_run_id="wr_new",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
Loading…
Add table
Add a link
Reference in a new issue