mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 11:40:32 +00:00
feat: parallel loop execution — backend engine, sessions, concurrency (SKY-8175, SKY-8176, SKY-8177, SKY-8180) (#5412)
This commit is contained in:
parent
acfae2118c
commit
ae9741a381
12 changed files with 2040 additions and 1 deletions
122
skyvern/cli/mcp_tools/README.md
Normal file
122
skyvern/cli/mcp_tools/README.md
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# Skyvern MCP Server
|
||||
|
||||
The Skyvern MCP server gives AI assistants (Claude, Cursor, Windsurf, Codex) full browser control -- clicking, filling forms, extracting data, navigating pages, uploading files, managing workflows, and more. 75+ tools, one server.
|
||||
|
||||
## Quickstart
|
||||
|
||||
```bash
|
||||
pip install skyvern
|
||||
skyvern setup claude-code
|
||||
|
||||
# or if you're using other coding agents
|
||||
skyvern setup
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
### Cloud (recommended)
|
||||
|
||||
Get an API key from [app.skyvern.com](https://app.skyvern.com), then configure your client:
|
||||
|
||||
**Claude Code:**
|
||||
```bash
|
||||
claude mcp add-json skyvern '{"type":"http","url":"https://api.skyvern.com/mcp/","headers":{"x-api-key":"YOUR_API_KEY"}}' --scope user
|
||||
```
|
||||
|
||||
**Cursor** (`~/.cursor/mcp.json`):
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"skyvern": {
|
||||
"type": "streamable-http",
|
||||
"url": "https://api.skyvern.com/mcp/",
|
||||
"headers": { "x-api-key": "YOUR_API_KEY" }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Windsurf** (`~/.codeium/windsurf/mcp_config.json`):
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"skyvern": {
|
||||
"type": "streamable-http",
|
||||
"url": "https://api.skyvern.com/mcp/",
|
||||
"headers": { "x-api-key": "YOUR_API_KEY" }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Local (self-hosted)
|
||||
|
||||
```bash
|
||||
skyvern init # interactive setup wizard
|
||||
skyvern run server # start the local API server
|
||||
```
|
||||
|
||||
Manual config for any MCP client:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"skyvern": {
|
||||
"command": "python3",
|
||||
"args": ["-m", "skyvern", "run", "mcp"],
|
||||
"env": {
|
||||
"SKYVERN_BASE_URL": "http://localhost:8000",
|
||||
"SKYVERN_API_KEY": "YOUR_API_KEY"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
### Browser Sessions
|
||||
`skyvern_browser_session_create`, `skyvern_browser_session_close`, `skyvern_browser_session_list`, `skyvern_browser_session_get`, `skyvern_browser_session_connect`
|
||||
|
||||
### Browser Actions
|
||||
`skyvern_act` (natural language), `skyvern_navigate`, `skyvern_click`, `skyvern_type`, `skyvern_hover`, `skyvern_scroll`, `skyvern_select_option`, `skyvern_press_key`, `skyvern_drag`, `skyvern_file_upload`, `skyvern_wait`
|
||||
|
||||
### Data Extraction & Validation
|
||||
`skyvern_extract` (structured JSON output), `skyvern_screenshot`, `skyvern_find`, `skyvern_validate`, `skyvern_evaluate` (run JavaScript), `skyvern_get_html`, `skyvern_get_value`, `skyvern_get_styles`
|
||||
|
||||
### Authentication & Credentials
|
||||
`skyvern_login`, `skyvern_credential_list`, `skyvern_credential_get`, `skyvern_credential_delete`
|
||||
|
||||
Supports Skyvern vault, Bitwarden, 1Password, and Azure Key Vault with automatic 2FA/TOTP.
|
||||
|
||||
### Tabs & Frames
|
||||
`skyvern_tab_new`, `skyvern_tab_list`, `skyvern_tab_switch`, `skyvern_tab_close`, `skyvern_tab_wait_for_new`, `skyvern_frame_list`, `skyvern_frame_switch`, `skyvern_frame_main`
|
||||
|
||||
### Network & Console Inspection
|
||||
`skyvern_console_messages`, `skyvern_network_requests`, `skyvern_network_request_detail`, `skyvern_network_route`, `skyvern_network_unroute`, `skyvern_get_errors`, `skyvern_har_start`, `skyvern_har_stop`, `skyvern_handle_dialog`
|
||||
|
||||
### Browser State & Storage
|
||||
`skyvern_state_save`, `skyvern_state_load`, `skyvern_get_session_storage`, `skyvern_set_session_storage`, `skyvern_clear_session_storage`, `skyvern_clear_local_storage`, `skyvern_clipboard_read`, `skyvern_clipboard_write`
|
||||
|
||||
### Workflows
|
||||
`skyvern_workflow_create`, `skyvern_workflow_list`, `skyvern_workflow_get`, `skyvern_workflow_run`, `skyvern_workflow_status`, `skyvern_workflow_update`, `skyvern_workflow_delete`, `skyvern_workflow_cancel`, `skyvern_workflow_update_folder`
|
||||
|
||||
### Workflow Building Blocks
|
||||
`skyvern_block_schema`, `skyvern_block_validate` -- 23 block types for multi-step automations.
|
||||
|
||||
### Cached Scripts
|
||||
`skyvern_script_list_for_workflow`, `skyvern_script_get_code`, `skyvern_script_versions`, `skyvern_script_deploy`, `skyvern_script_fallback_episodes`
|
||||
|
||||
### Organization
|
||||
`skyvern_folder_create`, `skyvern_folder_list`, `skyvern_folder_get`, `skyvern_folder_update`, `skyvern_folder_delete`
|
||||
|
||||
## Switching Configs
|
||||
|
||||
Use the CLI to switch between API keys or environments without manual editing:
|
||||
|
||||
```bash
|
||||
skyvern mcp switch
|
||||
```
|
||||
|
||||
## Full Documentation
|
||||
|
||||
[skyvern.com/docs/integrations/mcp](https://www.skyvern.com/docs/integrations/mcp)
|
||||
|
|
@ -1,6 +1,20 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
# Parallel loop iteration browser key separator
|
||||
LOOP_ITERATION_SEPARATOR = "__iter_"
|
||||
|
||||
|
||||
def loop_iteration_key(workflow_run_id: str, loop_idx: int) -> str:
|
||||
"""Build the cache key for a parallel loop iteration's browser state."""
|
||||
return f"{workflow_run_id}{LOOP_ITERATION_SEPARATOR}{loop_idx}"
|
||||
|
||||
|
||||
def is_loop_iteration_key(key: str) -> bool:
|
||||
"""Check whether a browser_session_id belongs to a parallel loop iteration."""
|
||||
return LOOP_ITERATION_SEPARATOR in key
|
||||
|
||||
|
||||
# This is the attribute name used to tag interactable elements
|
||||
SKYVERN_ID_ATTR: str = "unique_id"
|
||||
SKYVERN_DIR = Path(__file__).parent
|
||||
|
|
|
|||
|
|
@ -651,6 +651,19 @@ class AgentFunction:
|
|||
|
||||
return cleanup_element_tree_func
|
||||
|
||||
async def check_parallel_loop_quota(self, organization_id: str, requested_concurrency: int) -> int:
|
||||
"""Check per-org quota for parallel loop iterations.
|
||||
|
||||
Returns the number of parallel iterations allowed. OSS base returns
|
||||
requested_concurrency unchanged (no enforcement). Cloud override
|
||||
enforces org-level caps via Redis.
|
||||
"""
|
||||
return requested_concurrency
|
||||
|
||||
async def release_parallel_loop_quota(self, organization_id: str, count: int) -> None:
|
||||
"""Release parallel loop iteration slots. OSS base is a no-op."""
|
||||
return
|
||||
|
||||
async def validate_code_block(self, organization_id: str | None = None) -> None:
|
||||
if not settings.ENABLE_CODE_BLOCK:
|
||||
raise DisabledBlockExecutionError("CodeBlock is disabled")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import builtins
|
||||
import dataclasses
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
|
@ -93,6 +95,34 @@ class SkyvernContext:
|
|||
# preventing repeated injection loops when the captcha solver succeeds but the page doesn't change
|
||||
proactive_captcha_task_ids: set[str] = field(default_factory=set)
|
||||
|
||||
def create_iteration_copy(self, browser_session_id: str) -> "SkyvernContext":
|
||||
"""Create an isolated copy for a parallel loop iteration.
|
||||
|
||||
Copies scalar fields by value and creates fresh instances of mutable
|
||||
containers (dicts, lists, sets). Fields that hold non-copyable objects
|
||||
(Playwright Frame/Page) are reset to empty defaults.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
for f in dataclasses.fields(self):
|
||||
val = getattr(self, f.name)
|
||||
if isinstance(val, dict):
|
||||
# frame_index_map and magic_link_pages hold Playwright objects
|
||||
# that cannot be safely copied — start fresh per iteration.
|
||||
if f.name in ("frame_index_map", "magic_link_pages"):
|
||||
kwargs[f.name] = {}
|
||||
else:
|
||||
kwargs[f.name] = dict(val)
|
||||
elif isinstance(val, list):
|
||||
kwargs[f.name] = list(val)
|
||||
# builtins.set required because the module-level `set()` function
|
||||
# below (line ~180) shadows the built-in inside this module.
|
||||
elif isinstance(val, builtins.set):
|
||||
kwargs[f.name] = builtins.set(val)
|
||||
else:
|
||||
kwargs[f.name] = val
|
||||
kwargs["browser_session_id"] = browser_session_id
|
||||
return SkyvernContext(**kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Self
|
||||
|
||||
import structlog
|
||||
|
|
@ -225,6 +226,50 @@ class WorkflowRunContext:
|
|||
label = ""
|
||||
return self.blocks_metadata.get(label, BlockMetadata())
|
||||
|
||||
def create_iteration_snapshot(self, loop_idx: int) -> "WorkflowRunContext":
|
||||
"""Create a deep-copy snapshot of mutable state for a parallel loop iteration.
|
||||
|
||||
The snapshot shares immutable/reference fields (aws_client, workflow, secrets,
|
||||
parameters, organization_id) but gets its own copies of values and blocks_metadata
|
||||
so concurrent iterations don't interfere with each other.
|
||||
"""
|
||||
snapshot = WorkflowRunContext(
|
||||
workflow_title=self.workflow_title,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_permanent_id=self.workflow_permanent_id,
|
||||
workflow_run_id=self.workflow_run_id,
|
||||
aws_client=self._aws_client,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
snapshot.organization_id = self.organization_id
|
||||
snapshot.browser_session_id = self.browser_session_id
|
||||
snapshot.include_secrets_in_templates = self.include_secrets_in_templates
|
||||
snapshot.credential_totp_identifiers = self.credential_totp_identifiers
|
||||
# Parameter definitions and secrets are read-only during iteration today,
|
||||
# but shallow-copy the dicts so a future code path that mutates them
|
||||
# inside a loop block can't silently leak across iterations.
|
||||
snapshot.parameters = dict(self.parameters)
|
||||
snapshot.secrets = dict(self.secrets)
|
||||
# Deep-copy mutable state so iterations don't clobber each other
|
||||
snapshot.values = copy.deepcopy(self.values)
|
||||
snapshot.blocks_metadata = copy.deepcopy(self.blocks_metadata)
|
||||
snapshot.workflow_run_outputs = copy.deepcopy(self.workflow_run_outputs)
|
||||
return snapshot
|
||||
|
||||
def merge_iteration_results(self, snapshots: list[tuple[int, "WorkflowRunContext"]]) -> None:
|
||||
"""Merge results from parallel iteration snapshots back into this context.
|
||||
|
||||
Snapshots are applied in iteration-index order so that later iterations
|
||||
overwrite earlier ones when keys collide (matching sequential semantics).
|
||||
"""
|
||||
for _loop_idx, snapshot in sorted(snapshots, key=lambda t: t[0]):
|
||||
# Merge values: iteration outputs are keyed by block label
|
||||
self.values.update(snapshot.values)
|
||||
# Merge block metadata
|
||||
self.blocks_metadata.update(snapshot.blocks_metadata)
|
||||
# Merge workflow run outputs
|
||||
self.workflow_run_outputs.update(snapshot.workflow_run_outputs)
|
||||
|
||||
async def _should_include_secrets_in_templates(self) -> bool:
|
||||
"""
|
||||
Check if secrets should be included in template formatting based on experimentation provider.
|
||||
|
|
@ -1225,6 +1270,22 @@ class WorkflowRunContext:
|
|||
self.values[parameter.key][key] = secret_id
|
||||
|
||||
|
||||
# Per-task override for the active WorkflowRunContext, used by parallel loop
|
||||
# iterations. When set, get_workflow_run_context() returns this snapshot
|
||||
# instead of looking up the shared context in workflow_run_contexts. Because
|
||||
# ContextVar state is copied per asyncio.create_task, each iteration sees its
|
||||
# own snapshot, so child blocks resolving context via the classmethod get
|
||||
# their isolated view without any call-signature changes.
|
||||
_iteration_workflow_run_context: ContextVar["WorkflowRunContext | None"] = ContextVar(
|
||||
"_iteration_workflow_run_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_iteration_workflow_run_context(context: "WorkflowRunContext | None") -> None:
|
||||
"""Set the per-task workflow run context override for parallel loop iterations."""
|
||||
_iteration_workflow_run_context.set(context)
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
aws_client: AsyncAWSClient
|
||||
workflow_run_contexts: dict[str, WorkflowRunContext]
|
||||
|
|
@ -1282,6 +1343,14 @@ class WorkflowContextManager:
|
|||
return workflow_run_context
|
||||
|
||||
def get_workflow_run_context(self, workflow_run_id: str) -> WorkflowRunContext:
|
||||
# Parallel loop iterations install a per-task snapshot via the
|
||||
# _iteration_workflow_run_context ContextVar. If present and matching
|
||||
# this workflow_run_id, return the snapshot so child blocks resolving
|
||||
# context via the classmethod see their isolated view instead of the
|
||||
# globally shared context.
|
||||
override = _iteration_workflow_run_context.get()
|
||||
if override is not None and override.workflow_run_id == workflow_run_id:
|
||||
return override
|
||||
self._validate_workflow_run_context(workflow_run_id)
|
||||
return self.workflow_run_contexts[workflow_run_id]
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from skyvern.constants import (
|
|||
GET_DOWNLOADED_FILES_TIMEOUT,
|
||||
MAX_FILE_PARSE_INPUT_TOKENS,
|
||||
MAX_UPLOAD_FILE_COUNT,
|
||||
loop_iteration_key,
|
||||
)
|
||||
from skyvern.exceptions import (
|
||||
AzureConfigurationError,
|
||||
|
|
@ -81,7 +82,11 @@ from skyvern.forge.sdk.settings_manager import SettingsManager
|
|||
from skyvern.forge.sdk.trace import traced
|
||||
from skyvern.forge.sdk.utils.pdf_parser import extract_pdf_file, render_pdf_pages_as_images, validate_pdf_file
|
||||
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
|
||||
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.context_manager import (
|
||||
BlockMetadata,
|
||||
WorkflowRunContext,
|
||||
set_iteration_workflow_run_context,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.exceptions import (
|
||||
CustomizedCodeException,
|
||||
FailedToFormatJinjaStyleParameter,
|
||||
|
|
@ -954,6 +959,8 @@ class BaseTaskBlock(Block):
|
|||
)
|
||||
else:
|
||||
# if not the first task block, need to navigate manually
|
||||
# get_for_workflow_run checks SkyvernContext for iteration-specific
|
||||
# browser keys (__iter_N) first, falling back to bare workflow_run_id.
|
||||
browser_state = app.BROWSER_MANAGER.get_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
if browser_state is None:
|
||||
raise MissingBrowserState(task_id=task.task_id, workflow_run_id=workflow_run_id)
|
||||
|
|
@ -1333,6 +1340,14 @@ class ForLoopBlock(Block):
|
|||
# Note: intentionally excludes `list` (unlike BaseTaskBlock.data_schema) because a list schema
|
||||
# does not describe the shape of individual loop items -- only dict schemas are meaningful here.
|
||||
data_schema: dict[str, Any] | str | None = None
|
||||
# Parallel execution: None or 1 = sequential (existing behavior), 2-20 = parallel batch size.
|
||||
max_concurrency: int | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _clamp_max_concurrency(self) -> ForLoopBlock:
|
||||
if self.max_concurrency is not None:
|
||||
self.max_concurrency = max(1, min(self.max_concurrency, 20))
|
||||
return self
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
|
|
@ -1824,6 +1839,53 @@ class ForLoopBlock(Block):
|
|||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
) -> LoopBlockExecutedResult:
|
||||
# Dispatch to parallel path when max_concurrency > 1
|
||||
effective_concurrency = self.max_concurrency or 1
|
||||
if effective_concurrency > 1 and browser_session_id:
|
||||
# Persistent browser sessions cannot be safely shared across
|
||||
# parallel iterations — every iteration would race on the same
|
||||
# live page. Fall back to sequential execution rather than
|
||||
# corrupt the session.
|
||||
LOG.info(
|
||||
"Persistent browser session set; forcing sequential execution for for-loop",
|
||||
workflow_run_id=workflow_run_id,
|
||||
browser_session_id=browser_session_id,
|
||||
requested_concurrency=self.max_concurrency,
|
||||
)
|
||||
effective_concurrency = 1
|
||||
if effective_concurrency > 1:
|
||||
# Check per-org quota — may reduce concurrency or force sequential
|
||||
if organization_id:
|
||||
effective_concurrency = await app.AGENT_FUNCTION.check_parallel_loop_quota(
|
||||
organization_id, effective_concurrency
|
||||
)
|
||||
if effective_concurrency <= 1:
|
||||
# Release whatever was granted before falling back to sequential.
|
||||
# A return value of 0 means no slots were acquired, so nothing
|
||||
# to release. A return value of 1 means one slot was incremented
|
||||
# in Redis and must be decremented to avoid a leak.
|
||||
if effective_concurrency > 0:
|
||||
await app.AGENT_FUNCTION.release_parallel_loop_quota(organization_id, effective_concurrency)
|
||||
LOG.info(
|
||||
"Org concurrency quota exhausted, falling back to sequential execution",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
requested_concurrency=self.max_concurrency,
|
||||
)
|
||||
# Fall through to sequential path below
|
||||
|
||||
if effective_concurrency > 1:
|
||||
return await self._execute_loop_parallel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
workflow_run_context=workflow_run_context,
|
||||
loop_over_values=loop_over_values,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
granted_concurrency=effective_concurrency,
|
||||
)
|
||||
|
||||
# --- Sequential path (existing behavior, unchanged) ---
|
||||
outputs_with_loop_values: list[list[dict[str, Any]]] = []
|
||||
block_outputs: list[BlockResult] = []
|
||||
current_block: BlockTypeVar | None = None
|
||||
|
|
@ -2131,6 +2193,499 @@ class ForLoopBlock(Block):
|
|||
last_block=current_block,
|
||||
)
|
||||
|
||||
async def _execute_single_iteration_parallel(
|
||||
self,
|
||||
loop_idx: int,
|
||||
loop_over_value: Any,
|
||||
workflow_run_id: str,
|
||||
workflow_run_block_id: str,
|
||||
workflow_run_context: WorkflowRunContext,
|
||||
start_label: str,
|
||||
label_to_block: dict[str, BlockTypeVar],
|
||||
default_next_map: dict[str, str | None],
|
||||
conditional_scopes: dict[str, str],
|
||||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
) -> tuple[int, list[dict[str, Any]], list[BlockResult], BlockTypeVar | None, WorkflowRunContext]:
|
||||
"""Execute a single loop iteration with its own isolated context.
|
||||
|
||||
Returns (loop_idx, iteration_outputs, block_results, last_block, context_snapshot).
|
||||
"""
|
||||
each_loop_output_values: list[dict[str, Any]] = []
|
||||
block_outputs: list[BlockResult] = []
|
||||
current_block: BlockTypeVar | None = None
|
||||
|
||||
# Create isolated browser context for this iteration and register it
|
||||
# under workflow_run_id so existing task block browser lookups find it.
|
||||
# This is safe because each coroutine in asyncio.gather holds its own
|
||||
# reference and the SkyvernContext iteration key provides disambiguation.
|
||||
try:
|
||||
await app.BROWSER_MANAGER.get_or_create_for_loop_iteration(
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
# The iteration browser is stored under the iteration key
|
||||
# (e.g. "wr_xxx__iter_0") and resolved via SkyvernContext in
|
||||
# get_or_create_for_workflow_run(). We intentionally do NOT
|
||||
# alias it under the bare workflow_run_id to avoid races
|
||||
# between concurrent iterations.
|
||||
except Exception:
|
||||
# Fail this iteration immediately. Falling back to the shared
|
||||
# workflow browser would cause concurrent iterations to race on
|
||||
# the same page, corrupting navigation/actions and producing
|
||||
# non-deterministic data. Sequential loops can safely share a
|
||||
# browser; parallel loops cannot.
|
||||
LOG.error(
|
||||
"Failed to create isolated browser for parallel iteration",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
exc_info=True,
|
||||
)
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Failed to create isolated browser for parallel iteration {loop_idx}",
|
||||
workflow_run_block_id=None,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return loop_idx, [], [failure_block_result], None, workflow_run_context
|
||||
|
||||
# Capture baseline downloaded files for per-iteration scoping (SKY-7005)
|
||||
loop_context = skyvern_context.current()
|
||||
if loop_context:
|
||||
downloaded_file_sigs_before: list[tuple[str | None, str | None, str | None]] = []
|
||||
baseline_timed_out = False
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_file_sigs_before = [
|
||||
to_downloaded_file_signature(fi)
|
||||
for fi in await app.STORAGE.get_downloaded_files(
|
||||
organization_id=organization_id or "",
|
||||
run_id=loop_context.run_id if loop_context.run_id else workflow_run_id,
|
||||
)
|
||||
]
|
||||
except asyncio.TimeoutError:
|
||||
baseline_timed_out = True
|
||||
LOG.warning(
|
||||
"Timeout getting baseline downloaded files for parallel loop iteration",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
)
|
||||
if baseline_timed_out:
|
||||
loop_context.loop_internal_state = None
|
||||
else:
|
||||
loop_context.loop_internal_state = {
|
||||
"downloaded_file_signatures_before_iteration": downloaded_file_sigs_before,
|
||||
}
|
||||
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
||||
|
||||
iteration_step_count = 0
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: starting iteration",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
max_steps_per_iteration=DEFAULT_MAX_STEPS_PER_ITERATION,
|
||||
)
|
||||
|
||||
block_idx = 0
|
||||
current_label: str | None = start_label
|
||||
conditional_wrb_ids: dict[str, str] = {}
|
||||
while current_label:
|
||||
loop_block = label_to_block.get(current_label)
|
||||
if not loop_block:
|
||||
LOG.error(
|
||||
"Unable to find loop block with label in parallel loop graph",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_label=self.label,
|
||||
current_label=current_label,
|
||||
)
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Unable to find block with label {current_label} inside loop {self.label}",
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
block_outputs.append(failure_block_result)
|
||||
return (loop_idx, each_loop_output_values, block_outputs, current_block, workflow_run_context)
|
||||
|
||||
metadata: BlockMetadata = {
|
||||
"current_index": loop_idx,
|
||||
"current_value": loop_over_value,
|
||||
"current_item": loop_over_value,
|
||||
}
|
||||
workflow_run_context.update_block_metadata(self.label, metadata)
|
||||
workflow_run_context.update_block_metadata(loop_block.label, metadata)
|
||||
|
||||
original_loop_block = loop_block
|
||||
loop_block = loop_block.model_copy(deep=True)
|
||||
current_block = loop_block
|
||||
|
||||
parent_wrb_id = workflow_run_block_id
|
||||
if current_label in conditional_scopes:
|
||||
cond_label = conditional_scopes[current_label]
|
||||
if cond_label in conditional_wrb_ids:
|
||||
parent_wrb_id = conditional_wrb_ids[cond_label]
|
||||
|
||||
block_output = await loop_block.execute_safe(
|
||||
workflow_run_id=workflow_run_id,
|
||||
parent_workflow_run_block_id=parent_wrb_id,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
current_value=str(loop_over_value),
|
||||
current_index=loop_idx,
|
||||
)
|
||||
|
||||
if loop_block.block_type == BlockType.CONDITIONAL and block_output.workflow_run_block_id:
|
||||
conditional_wrb_ids[current_label] = block_output.workflow_run_block_id
|
||||
|
||||
output_value = (
|
||||
workflow_run_context.get_value(block_output.output_parameter.key)
|
||||
if workflow_run_context.has_value(block_output.output_parameter.key)
|
||||
else None
|
||||
)
|
||||
|
||||
if block_output.output_parameter.key.endswith("_output"):
|
||||
LOG.debug("Block output (parallel)", block_type=loop_block.block_type, output_value=output_value)
|
||||
|
||||
each_loop_output_values.append(
|
||||
{
|
||||
"loop_value": loop_over_value,
|
||||
"output_parameter": block_output.output_parameter,
|
||||
"output_value": output_value,
|
||||
}
|
||||
)
|
||||
try:
|
||||
if block_output.workflow_run_block_id:
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=block_output.workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
current_value=str(loop_over_value),
|
||||
current_index=loop_idx,
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to update workflow run block (parallel)",
|
||||
workflow_run_block_id=block_output.workflow_run_block_id,
|
||||
loop_over_value=loop_over_value,
|
||||
loop_idx=loop_idx,
|
||||
)
|
||||
loop_block = original_loop_block
|
||||
block_outputs.append(block_output)
|
||||
|
||||
iteration_step_count += 1
|
||||
if iteration_step_count >= DEFAULT_MAX_STEPS_PER_ITERATION:
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: reached max_steps_per_iteration limit",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
)
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Reached max_steps_per_iteration limit of {DEFAULT_MAX_STEPS_PER_ITERATION}",
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
block_outputs.append(failure_block_result)
|
||||
break
|
||||
|
||||
if block_output.status == BlockStatus.canceled:
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: iteration was canceled",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
)
|
||||
break
|
||||
|
||||
if (
|
||||
not block_output.success
|
||||
and not loop_block.continue_on_failure
|
||||
and not loop_block.next_loop_on_failure
|
||||
and not self.next_loop_on_failure
|
||||
):
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: iteration block failed, terminating",
|
||||
loop_idx=loop_idx,
|
||||
block_idx=block_idx,
|
||||
failure_reason=block_output.failure_reason,
|
||||
)
|
||||
break
|
||||
|
||||
if block_output.success or loop_block.continue_on_failure:
|
||||
next_label: str | None = None
|
||||
if loop_block.block_type == BlockType.CONDITIONAL:
|
||||
branch_metadata = (
|
||||
block_output.output_parameter_value
|
||||
if isinstance(block_output.output_parameter_value, dict)
|
||||
else None
|
||||
)
|
||||
next_label = (branch_metadata or {}).get("next_block_label")
|
||||
else:
|
||||
next_label = default_next_map.get(loop_block.label)
|
||||
|
||||
if not next_label:
|
||||
break
|
||||
|
||||
if next_label not in label_to_block:
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Next block label {next_label} not found inside loop {self.label}",
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
block_outputs.append(failure_block_result)
|
||||
break
|
||||
|
||||
current_label = next_label
|
||||
block_idx += 1
|
||||
continue
|
||||
|
||||
if loop_block.next_loop_on_failure or self.next_loop_on_failure:
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: iteration block failed, continuing to next",
|
||||
loop_idx=loop_idx,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
break
|
||||
|
||||
break
|
||||
|
||||
return (loop_idx, each_loop_output_values, block_outputs, current_block, workflow_run_context)
|
||||
|
||||
async def _execute_loop_parallel(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_run_block_id: str,
|
||||
workflow_run_context: WorkflowRunContext,
|
||||
loop_over_values: list[Any],
|
||||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
granted_concurrency: int | None = None,
|
||||
) -> LoopBlockExecutedResult:
|
||||
"""Execute loop iterations in parallel batches of max_concurrency.
|
||||
|
||||
Slot ownership contract:
|
||||
- The CALLER (`execute_loop_helper`) acquires `granted_concurrency`
|
||||
slots via `check_parallel_loop_quota` before invoking this method.
|
||||
- This method takes ownership of those slots and is responsible for
|
||||
releasing exactly `granted_concurrency` slots in its `finally` block.
|
||||
- Callers must NOT release the slots themselves after this method
|
||||
returns or raises.
|
||||
"""
|
||||
# Caller (execute_loop_helper) only invokes this method when the
|
||||
# post-quota granted_concurrency is > 1. Use an explicit check rather
|
||||
# than `assert` so the contract holds when running with -O.
|
||||
if not granted_concurrency or granted_concurrency <= 1:
|
||||
raise ValueError(f"_execute_loop_parallel requires granted_concurrency > 1, got {granted_concurrency}")
|
||||
effective_concurrency = granted_concurrency
|
||||
outputs_with_loop_values: list[list[dict[str, Any]]] = []
|
||||
all_block_outputs: list[BlockResult] = []
|
||||
last_block: BlockTypeVar | None = None
|
||||
# Release exactly what we acquired. acquire_parallel_slots already
|
||||
# returns min(requested, available), so effective_concurrency is the
|
||||
# true granted count. Releasing fewer (e.g. min(granted, len(values)))
|
||||
# would orphan the unused slots in Redis until TTL expiry.
|
||||
slots_to_release = effective_concurrency
|
||||
|
||||
# Wrap the entire method body so concurrency slots are always released,
|
||||
# even if _build_loop_graph or compute_conditional_scopes fails.
|
||||
try:
|
||||
start_label, label_to_block, default_next_map = self._build_loop_graph(self.loop_blocks)
|
||||
conditional_scopes = compute_conditional_scopes(label_to_block, default_next_map)
|
||||
|
||||
# Enforce max iterations
|
||||
capped_values = loop_over_values[:DEFAULT_MAX_LOOP_ITERATIONS]
|
||||
if len(loop_over_values) > DEFAULT_MAX_LOOP_ITERATIONS:
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: capping iterations",
|
||||
workflow_run_id=workflow_run_id,
|
||||
max_iterations=DEFAULT_MAX_LOOP_ITERATIONS,
|
||||
total_values=len(loop_over_values),
|
||||
)
|
||||
# Process in batches
|
||||
for batch_start in range(0, len(capped_values), effective_concurrency):
|
||||
batch_end = min(batch_start + effective_concurrency, len(capped_values))
|
||||
batch = [(idx, val) for idx, val in enumerate(capped_values[batch_start:batch_end], start=batch_start)]
|
||||
|
||||
LOG.info(
|
||||
"ForLoopBlock parallel: executing batch",
|
||||
workflow_run_id=workflow_run_id,
|
||||
batch_start=batch_start,
|
||||
batch_end=batch_end,
|
||||
batch_size=len(batch),
|
||||
max_concurrency=effective_concurrency,
|
||||
)
|
||||
|
||||
# Create context snapshots for each iteration in the batch
|
||||
iteration_tasks = []
|
||||
batch_loop_indices = []
|
||||
for loop_idx, loop_over_value in batch:
|
||||
batch_loop_indices.append(loop_idx)
|
||||
ctx_snapshot = workflow_run_context.create_iteration_snapshot(loop_idx)
|
||||
|
||||
# Default args freeze loop variables at definition time so each
|
||||
# task closes over its own loop_idx/value/snapshot rather than
|
||||
# the late-binding loop locals.
|
||||
async def _run_iteration(
|
||||
_loop_idx: int = loop_idx,
|
||||
_loop_value: Any = loop_over_value,
|
||||
_ctx_snapshot: WorkflowRunContext = ctx_snapshot,
|
||||
) -> tuple[int, list[dict[str, Any]], list[BlockResult], BlockTypeVar | None, WorkflowRunContext]:
|
||||
# Install the per-iteration workflow_run_context snapshot
|
||||
# into the ContextVar override so child blocks that resolve
|
||||
# context via Block.get_workflow_run_context(workflow_run_id)
|
||||
# see the isolated snapshot instead of the shared global.
|
||||
# create_task() snapshotted ContextVars at task creation,
|
||||
# so this set() only mutates the current task's copy.
|
||||
set_iteration_workflow_run_context(_ctx_snapshot)
|
||||
|
||||
# Fork SkyvernContext for this iteration with isolated mutable fields.
|
||||
# Uses create_iteration_copy() which shallow-copies scalars and creates
|
||||
# fresh dicts/lists/sets, avoiding deepcopy issues with Playwright
|
||||
# Frame/Page objects that are not safely copyable.
|
||||
parent_skyvern_ctx = skyvern_context.current()
|
||||
if parent_skyvern_ctx:
|
||||
iter_ctx = parent_skyvern_ctx.create_iteration_copy(
|
||||
browser_session_id=loop_iteration_key(workflow_run_id, _loop_idx),
|
||||
)
|
||||
skyvern_context.set(iter_ctx)
|
||||
else:
|
||||
LOG.warning(
|
||||
"No parent SkyvernContext to fork for parallel iteration — "
|
||||
"per-iteration file tracking skipped",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=_loop_idx,
|
||||
)
|
||||
|
||||
return await self._execute_single_iteration_parallel(
|
||||
loop_idx=_loop_idx,
|
||||
loop_over_value=_loop_value,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
workflow_run_context=_ctx_snapshot,
|
||||
start_label=start_label,
|
||||
label_to_block=label_to_block,
|
||||
default_next_map=default_next_map,
|
||||
conditional_scopes=conditional_scopes,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
||||
# asyncio.create_task() copies the current ContextVar state at
|
||||
# creation time, giving each iteration its own isolated copy.
|
||||
# Plain coroutines in asyncio.gather() share the caller's
|
||||
# ContextVar scope, so skyvern_context.set() would race.
|
||||
iteration_tasks.append(asyncio.create_task(_run_iteration()))
|
||||
|
||||
# Execute batch concurrently; return_exceptions=True so one failure doesn't kill the batch
|
||||
batch_results = await asyncio.gather(*iteration_tasks, return_exceptions=True)
|
||||
|
||||
# Collect snapshots for merging
|
||||
context_snapshots: list[tuple[int, WorkflowRunContext]] = []
|
||||
|
||||
# Sort results by loop_idx to preserve ordering
|
||||
indexed_results: list[tuple[int, Any]] = []
|
||||
for i, result in enumerate(batch_results):
|
||||
loop_idx_for_result = batch[i][0]
|
||||
indexed_results.append((loop_idx_for_result, result))
|
||||
indexed_results.sort(key=lambda t: t[0])
|
||||
|
||||
# Process every result in the batch before deciding whether to
|
||||
# stop. asyncio.gather(return_exceptions=True) already ran all
|
||||
# iterations to completion, so successful outputs from indices
|
||||
# after a failing one must still be captured — otherwise we'd
|
||||
# silently discard data the browser already produced.
|
||||
batch_had_failure = False
|
||||
for loop_idx_sorted, result in indexed_results:
|
||||
if isinstance(result, Exception):
|
||||
LOG.error(
|
||||
"ForLoopBlock parallel: iteration raised an exception",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx_sorted,
|
||||
error=str(result),
|
||||
exc_info=result,
|
||||
)
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Parallel iteration {loop_idx_sorted} failed: {str(result)}",
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
all_block_outputs.append(failure_block_result)
|
||||
outputs_with_loop_values.append([])
|
||||
batch_had_failure = True
|
||||
else:
|
||||
iter_idx, iter_outputs, iter_block_results, iter_last_block, iter_ctx = result
|
||||
outputs_with_loop_values.append(iter_outputs)
|
||||
all_block_outputs.extend(iter_block_results)
|
||||
if iter_last_block is not None:
|
||||
last_block = iter_last_block
|
||||
context_snapshots.append((iter_idx, iter_ctx))
|
||||
|
||||
if batch_had_failure and not self.next_loop_on_failure:
|
||||
# Merge any successful snapshots before stopping so their
|
||||
# context mutations aren't lost.
|
||||
workflow_run_context.merge_iteration_results(context_snapshots)
|
||||
# Persist accumulated outputs before the early return so a
|
||||
# subsequent Temporal timeout can recover prior-batch data.
|
||||
await self._persist_partial_loop_output(
|
||||
workflow_run_id, outputs_with_loop_values, batch_loop_indices[-1]
|
||||
)
|
||||
try:
|
||||
await app.BROWSER_MANAGER.cleanup_loop_iterations(
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_indices=batch_loop_indices,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Failed to cleanup loop iteration browsers", exc_info=True)
|
||||
return LoopBlockExecutedResult(
|
||||
outputs_with_loop_values=outputs_with_loop_values,
|
||||
block_outputs=all_block_outputs,
|
||||
last_block=last_block,
|
||||
)
|
||||
|
||||
# Merge iteration context snapshots back into the main context
|
||||
workflow_run_context.merge_iteration_results(context_snapshots)
|
||||
|
||||
# Persist accumulated outputs after each batch so progress
|
||||
# survives Temporal activity timeouts mid-loop. Mirrors the
|
||||
# sequential path's per-iteration persistence.
|
||||
await self._persist_partial_loop_output(
|
||||
workflow_run_id, outputs_with_loop_values, batch_loop_indices[-1]
|
||||
)
|
||||
|
||||
# Clean up iteration browser contexts for this batch
|
||||
try:
|
||||
await app.BROWSER_MANAGER.cleanup_loop_iterations(
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_indices=batch_loop_indices,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Failed to cleanup loop iteration browsers after batch", exc_info=True)
|
||||
finally:
|
||||
# Always release concurrency slots back to the org quota.
|
||||
# Release the full granted count — acquire incremented Redis by
|
||||
# this amount, so we must decrement by the same amount even if
|
||||
# the loop had fewer iterations than granted slots.
|
||||
if organization_id and slots_to_release > 0:
|
||||
await app.AGENT_FUNCTION.release_parallel_loop_quota(organization_id, slots_to_release)
|
||||
|
||||
return LoopBlockExecutedResult(
|
||||
outputs_with_loop_values=outputs_with_loop_values,
|
||||
block_outputs=all_block_outputs,
|
||||
last_block=last_block,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
|
|||
|
|
@ -439,6 +439,7 @@ def block_yaml_to_block(
|
|||
loop_blocks=loop_blocks,
|
||||
complete_if_empty=block_yaml.complete_if_empty,
|
||||
data_schema=block_yaml.data_schema,
|
||||
max_concurrency=block_yaml.max_concurrency,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.CONDITIONAL:
|
||||
branch_conditions = []
|
||||
|
|
|
|||
|
|
@ -641,6 +641,7 @@ class ForLoopBlockYAML(BlockYAML):
|
|||
loop_variable_reference: str | None = None
|
||||
complete_if_empty: bool = False
|
||||
data_schema: dict[str, Any] | str | None = None
|
||||
max_concurrency: int | None = None
|
||||
|
||||
|
||||
class BranchCriteriaYAML(BaseModel):
|
||||
|
|
|
|||
|
|
@ -61,6 +61,21 @@ class BrowserManager(Protocol):
|
|||
|
||||
def get_for_script(self, script_id: str | None = None) -> BrowserState | None: ...
|
||||
|
||||
async def get_or_create_for_loop_iteration(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
loop_idx: int,
|
||||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
) -> BrowserState: ...
|
||||
|
||||
async def cleanup_loop_iterations(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
loop_indices: list[int],
|
||||
organization_id: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def set_video_artifact_for_task(self, task: Task, artifacts: list[VideoArtifact]) -> None: ...
|
||||
|
||||
async def get_video_artifacts(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from skyvern.constants import is_loop_iteration_key, loop_iteration_key
|
||||
from skyvern.exceptions import MissingBrowserState
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
|
||||
|
|
@ -22,6 +25,16 @@ LOG = structlog.get_logger()
|
|||
class RealBrowserManager(BrowserManager):
|
||||
def __init__(self) -> None:
|
||||
self.pages: dict[str, BrowserState] = {}
|
||||
# Lazily initialized inside an async context to avoid binding the lock
|
||||
# to the wrong event loop when RealBrowserManager is instantiated
|
||||
# during module import or test fixtures that haven't started a loop.
|
||||
self._loop_iteration_lock_instance: asyncio.Lock | None = None
|
||||
|
||||
@property
|
||||
def _loop_iteration_lock(self) -> asyncio.Lock:
|
||||
if self._loop_iteration_lock_instance is None:
|
||||
self._loop_iteration_lock_instance = asyncio.Lock()
|
||||
return self._loop_iteration_lock_instance
|
||||
|
||||
@staticmethod
|
||||
async def _create_browser_state(
|
||||
|
|
@ -160,6 +173,28 @@ class RealBrowserManager(BrowserManager):
|
|||
if browser_profile_id is None:
|
||||
browser_profile_id = workflow_run.browser_profile_id
|
||||
|
||||
# When running inside a parallel loop iteration, SkyvernContext holds an
|
||||
# iteration-specific browser_session_id (e.g. "wr_xxx__iter_0"). The
|
||||
# iteration browser was pre-created by get_or_create_for_loop_iteration()
|
||||
# and stored under that key. Return it directly so child blocks (task,
|
||||
# action, etc.) use the isolated browser instead of racing to create one
|
||||
# under the bare workflow_run_id.
|
||||
ctx = skyvern_context.current()
|
||||
if ctx and ctx.browser_session_id and is_loop_iteration_key(ctx.browser_session_id):
|
||||
iteration_browser = self.pages.get(ctx.browser_session_id)
|
||||
if iteration_browser:
|
||||
# Navigate to the task URL if page is still on about:blank
|
||||
if url:
|
||||
page = await iteration_browser.get_working_page()
|
||||
if page and page.url == "about:blank":
|
||||
await iteration_browser.navigate_to_url(page=page, url=url)
|
||||
LOG.debug(
|
||||
"Returning iteration-specific browser state from parallel loop context",
|
||||
workflow_run_id=workflow_run_id,
|
||||
iteration_key=ctx.browser_session_id,
|
||||
)
|
||||
return iteration_browser
|
||||
|
||||
# Check own cache entry first so navigate_to_url is only called on the first step.
|
||||
# Don't pass parent_workflow_run_id here — that lookup is deferred to the block
|
||||
# below so PBS runs don't accidentally inherit the parent's browser.
|
||||
|
|
@ -252,6 +287,15 @@ class RealBrowserManager(BrowserManager):
|
|||
def get_for_workflow_run(
|
||||
self, workflow_run_id: str, parent_workflow_run_id: str | None = None
|
||||
) -> BrowserState | None:
|
||||
# Check for parallel loop iteration browser via SkyvernContext first.
|
||||
# This mirrors the async get_or_create_for_workflow_run() so callers
|
||||
# like task block's non-first-task path get the correct iteration browser.
|
||||
ctx = skyvern_context.current()
|
||||
if ctx and ctx.browser_session_id and is_loop_iteration_key(ctx.browser_session_id):
|
||||
iteration_browser = self.pages.get(ctx.browser_session_id)
|
||||
if iteration_browser:
|
||||
return iteration_browser
|
||||
|
||||
# Priority: parent first, then own entry.
|
||||
# Callers that need to avoid parent inheritance must omit parent_workflow_run_id.
|
||||
# See get_or_create_for_workflow_run() for the two-phase lookup pattern.
|
||||
|
|
@ -264,6 +308,13 @@ class RealBrowserManager(BrowserManager):
|
|||
return None
|
||||
|
||||
def set_video_artifact_for_task(self, task: Task, artifacts: list[VideoArtifact]) -> None:
|
||||
# Check parallel loop iteration browser first
|
||||
ctx = skyvern_context.current()
|
||||
if ctx and ctx.browser_session_id and is_loop_iteration_key(ctx.browser_session_id):
|
||||
iter_browser = self.pages.get(ctx.browser_session_id)
|
||||
if iter_browser:
|
||||
iter_browser.browser_artifacts.video_artifacts = artifacts
|
||||
return
|
||||
if task.workflow_run_id and task.workflow_run_id in self.pages:
|
||||
self.pages[task.workflow_run_id].browser_artifacts.video_artifacts = artifacts
|
||||
return
|
||||
|
|
@ -525,3 +576,95 @@ class RealBrowserManager(BrowserManager):
|
|||
if script_id and script_id in self.pages:
|
||||
return self.pages[script_id]
|
||||
return None
|
||||
|
||||
async def get_or_create_for_loop_iteration(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
loop_idx: int,
|
||||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
) -> BrowserState:
|
||||
"""Get or create an isolated BrowserContext for a parallel loop iteration.
|
||||
|
||||
Each iteration gets its own browser context so cookies, auth state, and
|
||||
storage are fully isolated between concurrent iterations.
|
||||
|
||||
Uses _loop_iteration_lock to prevent concurrent create_task coroutines
|
||||
from racing through the check-then-act on self.pages.
|
||||
"""
|
||||
key = loop_iteration_key(workflow_run_id, loop_idx)
|
||||
|
||||
async with self._loop_iteration_lock:
|
||||
if key in self.pages:
|
||||
return self.pages[key]
|
||||
|
||||
# Persistent sessions cannot be aliased under per-iteration keys —
|
||||
# multiple iterations would race on the same live page. The caller
|
||||
# (execute_loop_helper) is expected to force sequential execution
|
||||
# when a persistent session is in use; this branch only runs if
|
||||
# that contract is bypassed, in which case we still create a fresh
|
||||
# isolated context to preserve correctness over the persistence.
|
||||
if browser_session_id:
|
||||
LOG.warning(
|
||||
"Persistent browser session not used for parallel loop iteration — "
|
||||
"creating isolated context to avoid cross-iteration races",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Creating isolated browser state for loop iteration",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
key=key,
|
||||
)
|
||||
browser_state = await self._create_browser_state(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
self.pages[key] = browser_state
|
||||
|
||||
# Page creation can happen outside the lock — the key is already
|
||||
# reserved in self.pages so no other coroutine will race on it.
|
||||
await browser_state.get_or_create_page(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return browser_state
|
||||
|
||||
async def cleanup_loop_iterations(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
loop_indices: list[int],
|
||||
organization_id: str | None = None,
|
||||
) -> None:
|
||||
"""Close and remove browser states for the given parallel loop iterations.
|
||||
|
||||
Uses _loop_iteration_lock so cleanup cannot race with create or with
|
||||
a concurrent cleanup call from another batch.
|
||||
"""
|
||||
# Collect entries to close under the lock, then close outside it
|
||||
# to avoid holding the lock during potentially slow browser teardown.
|
||||
to_close: list[tuple[int, BrowserState]] = []
|
||||
async with self._loop_iteration_lock:
|
||||
for loop_idx in loop_indices:
|
||||
key = loop_iteration_key(workflow_run_id, loop_idx)
|
||||
browser_state = self.pages.pop(key, None)
|
||||
if browser_state is None:
|
||||
continue
|
||||
# Only close if no other entry still references the same object
|
||||
shared = any(bs is browser_state for bs in self.pages.values())
|
||||
if not shared:
|
||||
to_close.append((loop_idx, browser_state))
|
||||
|
||||
for loop_idx, browser_state in to_close:
|
||||
try:
|
||||
await browser_state.close()
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to close loop iteration browser state",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_idx=loop_idx,
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
|
|||
985
tests/unit/test_parallel_loop_execution.py
Normal file
985
tests/unit/test_parallel_loop_execution.py
Normal file
|
|
@ -0,0 +1,985 @@
|
|||
"""Tests for parallel loop execution (SKY-8175 + SKY-8176 + SKY-8180).
|
||||
|
||||
Tests cover:
|
||||
1. max_concurrency field validation and clamping
|
||||
2. Sequential behavior unchanged when max_concurrency is None or 1
|
||||
3. Parallel execution dispatches correctly
|
||||
4. Result ordering is preserved regardless of completion order
|
||||
5. Error handling: one iteration failure doesn't kill others (when next_loop_on_failure=True)
|
||||
6. WorkflowRunContext snapshot/merge isolation
|
||||
7. Browser isolation key format
|
||||
8. YAML schema passthrough
|
||||
9. Batch sizing logic
|
||||
10. Quota-enforced fallback to sequential
|
||||
11. Concurrency slot release on success and failure
|
||||
12. YAML round-trip through workflow_definition_converter
|
||||
13. Browser cleanup verification
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
BlockStatus,
|
||||
BlockType,
|
||||
ForLoopBlock,
|
||||
LoopBlockExecutedResult,
|
||||
TaskBlock,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||
from skyvern.schemas.workflows import BlockResult, ForLoopBlockYAML, TaskBlockYAML, WorkflowDefinitionYAML
|
||||
|
||||
_NOW = datetime.now(UTC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_failure_block_result(label: str = "failure") -> BlockResult:
|
||||
"""Create a real BlockResult for failed iteration testing."""
|
||||
return BlockResult(
|
||||
success=False,
|
||||
output_parameter=_make_output_parameter(label),
|
||||
status=BlockStatus.failed,
|
||||
failure_reason="test failure",
|
||||
)
|
||||
|
||||
|
||||
def _make_output_parameter(label: str) -> OutputParameter:
|
||||
return OutputParameter(
|
||||
output_parameter_id=f"op_{label}",
|
||||
key=f"{label}_output",
|
||||
workflow_id="wf_test",
|
||||
created_at=_NOW,
|
||||
modified_at=_NOW,
|
||||
)
|
||||
|
||||
|
||||
def _make_task_block(label: str) -> TaskBlock:
|
||||
return TaskBlock(
|
||||
label=label,
|
||||
block_type=BlockType.TASK,
|
||||
output_parameter=_make_output_parameter(label),
|
||||
url="https://example.com",
|
||||
)
|
||||
|
||||
|
||||
def _make_for_loop_block(
|
||||
label: str = "loop_block",
|
||||
loop_blocks: list | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
next_loop_on_failure: bool = False,
|
||||
) -> ForLoopBlock:
|
||||
blocks = loop_blocks or [_make_task_block("inner_task")]
|
||||
return ForLoopBlock(
|
||||
label=label,
|
||||
block_type=BlockType.FOR_LOOP,
|
||||
output_parameter=_make_output_parameter(label),
|
||||
loop_blocks=blocks,
|
||||
max_concurrency=max_concurrency,
|
||||
next_loop_on_failure=next_loop_on_failure,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# max_concurrency field validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxConcurrencyValidation:
|
||||
def test_none_is_preserved(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=None)
|
||||
assert block.max_concurrency is None
|
||||
|
||||
def test_one_is_preserved(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=1)
|
||||
assert block.max_concurrency == 1
|
||||
|
||||
def test_valid_value_preserved(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=5)
|
||||
assert block.max_concurrency == 5
|
||||
|
||||
def test_clamped_to_min_one(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=0)
|
||||
assert block.max_concurrency == 1
|
||||
|
||||
def test_clamped_negative_to_one(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=-5)
|
||||
assert block.max_concurrency == 1
|
||||
|
||||
def test_clamped_to_max_twenty(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=100)
|
||||
assert block.max_concurrency == 20
|
||||
|
||||
def test_twenty_is_preserved(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=20)
|
||||
assert block.max_concurrency == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML schema passthrough
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForLoopBlockYAMLMaxConcurrency:
|
||||
def test_yaml_accepts_max_concurrency(self) -> None:
|
||||
yaml_block = ForLoopBlockYAML(
|
||||
label="test_loop",
|
||||
loop_blocks=[TaskBlockYAML(label="inner", url="https://example.com")],
|
||||
max_concurrency=5,
|
||||
)
|
||||
assert yaml_block.max_concurrency == 5
|
||||
|
||||
def test_yaml_defaults_to_none(self) -> None:
|
||||
yaml_block = ForLoopBlockYAML(
|
||||
label="test_loop",
|
||||
loop_blocks=[TaskBlockYAML(label="inner", url="https://example.com")],
|
||||
)
|
||||
assert yaml_block.max_concurrency is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowRunContext snapshot/merge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkflowRunContextSnapshot:
|
||||
def _make_context(self) -> WorkflowRunContext:
|
||||
ctx = WorkflowRunContext(
|
||||
workflow_title="Test",
|
||||
workflow_id="wf_1",
|
||||
workflow_permanent_id="wpid_1",
|
||||
workflow_run_id="wr_1",
|
||||
aws_client=MagicMock(),
|
||||
)
|
||||
ctx.values = {"key1": "original_value", "key2": [1, 2, 3]}
|
||||
ctx.blocks_metadata = {"block_a": {"current_index": 0}}
|
||||
ctx.parameters = {"param1": MagicMock()}
|
||||
ctx.secrets = {"secret1": "s3cr3t"}
|
||||
ctx.organization_id = "org_1"
|
||||
return ctx
|
||||
|
||||
def test_snapshot_deep_copies_values(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snapshot = ctx.create_iteration_snapshot(0)
|
||||
|
||||
# Modifying snapshot values should not affect original
|
||||
snapshot.values["key1"] = "modified"
|
||||
snapshot.values["key2"].append(4)
|
||||
|
||||
assert ctx.values["key1"] == "original_value"
|
||||
assert ctx.values["key2"] == [1, 2, 3]
|
||||
|
||||
def test_snapshot_shallow_copies_parameters_and_secrets(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snapshot = ctx.create_iteration_snapshot(0)
|
||||
|
||||
# Parameters and secrets dicts are shallow-copied so a future code
|
||||
# path that mutates them inside a loop block can't leak across
|
||||
# iterations. Values inside are still shared (read-only contract).
|
||||
assert snapshot.parameters is not ctx.parameters
|
||||
assert snapshot.parameters == ctx.parameters
|
||||
assert snapshot.secrets is not ctx.secrets
|
||||
assert snapshot.secrets == ctx.secrets
|
||||
|
||||
def test_snapshot_deep_copies_blocks_metadata(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snapshot = ctx.create_iteration_snapshot(0)
|
||||
|
||||
snapshot.blocks_metadata["block_a"]["current_index"] = 99
|
||||
|
||||
assert ctx.blocks_metadata["block_a"]["current_index"] == 0
|
||||
|
||||
def test_snapshot_preserves_immutable_fields(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snapshot = ctx.create_iteration_snapshot(0)
|
||||
|
||||
assert snapshot.workflow_run_id == ctx.workflow_run_id
|
||||
assert snapshot.organization_id == ctx.organization_id
|
||||
assert snapshot.workflow_id == ctx.workflow_id
|
||||
|
||||
def test_merge_iteration_results_preserves_order(self) -> None:
|
||||
ctx = self._make_context()
|
||||
|
||||
# Create two snapshots with different values
|
||||
snap1 = ctx.create_iteration_snapshot(0)
|
||||
snap1.values["iter_0_result"] = "result_0"
|
||||
snap1.blocks_metadata["block_iter_0"] = {"done": True}
|
||||
|
||||
snap2 = ctx.create_iteration_snapshot(1)
|
||||
snap2.values["iter_1_result"] = "result_1"
|
||||
snap2.blocks_metadata["block_iter_1"] = {"done": True}
|
||||
|
||||
# Merge in reverse order — should still apply in index order
|
||||
ctx.merge_iteration_results([(1, snap2), (0, snap1)])
|
||||
|
||||
assert ctx.values["iter_0_result"] == "result_0"
|
||||
assert ctx.values["iter_1_result"] == "result_1"
|
||||
assert "block_iter_0" in ctx.blocks_metadata
|
||||
assert "block_iter_1" in ctx.blocks_metadata
|
||||
|
||||
def test_merge_later_iteration_overwrites_earlier_on_collision(self) -> None:
|
||||
ctx = self._make_context()
|
||||
|
||||
snap0 = ctx.create_iteration_snapshot(0)
|
||||
snap0.values["shared_key"] = "from_iter_0"
|
||||
|
||||
snap1 = ctx.create_iteration_snapshot(1)
|
||||
snap1.values["shared_key"] = "from_iter_1"
|
||||
|
||||
ctx.merge_iteration_results([(0, snap0), (1, snap1)])
|
||||
|
||||
# Later iteration (idx=1) should win
|
||||
assert ctx.values["shared_key"] == "from_iter_1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser iteration key format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBrowserIterationKey:
|
||||
def test_key_format(self) -> None:
|
||||
from skyvern.constants import loop_iteration_key
|
||||
|
||||
assert loop_iteration_key("wr_abc123", 5) == "wr_abc123__iter_5"
|
||||
|
||||
def test_key_format_zero(self) -> None:
|
||||
from skyvern.constants import loop_iteration_key
|
||||
|
||||
assert loop_iteration_key("wr_xyz", 0) == "wr_xyz__iter_0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sequential path unchanged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSequentialPathUnchanged:
|
||||
"""Verify that max_concurrency=None and max_concurrency=1 use the sequential path."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_concurrency_does_not_call_parallel(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=None)
|
||||
# Mock _execute_loop_parallel to verify it's NOT called
|
||||
mock_parallel = AsyncMock()
|
||||
object.__setattr__(block, "_execute_loop_parallel", mock_parallel)
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.get_value.return_value = None
|
||||
mock_context.has_value.return_value = False
|
||||
|
||||
# The sequential path will fail because we haven't mocked everything,
|
||||
# but _execute_loop_parallel should NOT be called
|
||||
try:
|
||||
await block.execute_loop_helper(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b"],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_parallel.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_one_does_not_call_parallel(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=1)
|
||||
mock_parallel = AsyncMock()
|
||||
object.__setattr__(block, "_execute_loop_parallel", mock_parallel)
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.get_value.return_value = None
|
||||
mock_context.has_value.return_value = False
|
||||
|
||||
try:
|
||||
await block.execute_loop_helper(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b"],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_parallel.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelDispatch:
|
||||
"""Verify that max_concurrency > 1 dispatches to _execute_loop_parallel."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_gt_one_calls_parallel(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=3)
|
||||
|
||||
expected_result = LoopBlockExecutedResult(
|
||||
outputs_with_loop_values=[],
|
||||
block_outputs=[],
|
||||
last_block=None,
|
||||
)
|
||||
mock_parallel = AsyncMock(return_value=expected_result)
|
||||
object.__setattr__(block, "_execute_loop_parallel", mock_parallel)
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
|
||||
result = await block.execute_loop_helper(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c"],
|
||||
)
|
||||
|
||||
mock_parallel.assert_called_once()
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_passes_granted_concurrency(self) -> None:
|
||||
"""Verify granted_concurrency kwarg is forwarded from execute_loop_helper."""
|
||||
block = _make_for_loop_block(max_concurrency=5)
|
||||
|
||||
expected_result = LoopBlockExecutedResult(
|
||||
outputs_with_loop_values=[],
|
||||
block_outputs=[],
|
||||
last_block=None,
|
||||
)
|
||||
mock_parallel = AsyncMock(return_value=expected_result)
|
||||
object.__setattr__(block, "_execute_loop_parallel", mock_parallel)
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
|
||||
await block.execute_loop_helper(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b"],
|
||||
)
|
||||
|
||||
call_kwargs = mock_parallel.call_args.kwargs
|
||||
assert call_kwargs["granted_concurrency"] == 5
|
||||
|
||||
|
||||
class TestBatchSizing:
|
||||
"""Verify that _execute_loop_parallel processes iterations in correct batch sizes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_batch_when_values_lte_concurrency(self) -> None:
|
||||
"""3 values with concurrency=5 should produce a single batch of 3."""
|
||||
block = _make_for_loop_block(max_concurrency=5)
|
||||
gather_calls: list[int] = []
|
||||
|
||||
async def tracking_gather(*coros_or_futures, return_exceptions=False):
|
||||
gather_calls.append(len(coros_or_futures))
|
||||
# Return successful mock results for each coroutine
|
||||
results = []
|
||||
for coro in coros_or_futures:
|
||||
try:
|
||||
# Cancel the coroutine since we can't actually run them
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Return a successful iteration result tuple
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((len(results), [], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = tracking_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c"],
|
||||
granted_concurrency=5,
|
||||
)
|
||||
|
||||
assert len(gather_calls) == 1
|
||||
assert gather_calls[0] == 3 # single batch of 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_batches_when_values_gt_concurrency(self) -> None:
|
||||
"""7 values with concurrency=3 should produce 3 batches: [3, 3, 1]."""
|
||||
block = _make_for_loop_block(max_concurrency=3)
|
||||
gather_calls: list[int] = []
|
||||
|
||||
async def tracking_gather(*coros_or_futures, return_exceptions=False):
|
||||
gather_calls.append(len(coros_or_futures))
|
||||
results = []
|
||||
for coro in coros_or_futures:
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((len(results), [], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = tracking_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c", "d", "e", "f", "g"],
|
||||
granted_concurrency=3,
|
||||
)
|
||||
|
||||
assert len(gather_calls) == 3
|
||||
assert gather_calls == [3, 3, 1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelErrorHandling:
|
||||
"""Verify error handling in _execute_loop_parallel."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_iteration_with_next_loop_on_failure_continues(self) -> None:
|
||||
"""When next_loop_on_failure=True, one exception doesn't stop the loop."""
|
||||
block = _make_for_loop_block(max_concurrency=3, next_loop_on_failure=True)
|
||||
|
||||
async def mock_gather(*coros_or_futures, return_exceptions=False):
|
||||
results = []
|
||||
for i, coro in enumerate(coros_or_futures):
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
if i == 1:
|
||||
# Second iteration raises
|
||||
results.append(RuntimeError("iteration 1 failed"))
|
||||
else:
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((i, [{"output": f"result_{i}"}], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = mock_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
# build_block_result needs to return a mock BlockResult
|
||||
object.__setattr__(block, "build_block_result", AsyncMock(return_value=_make_failure_block_result()))
|
||||
|
||||
result = await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c"],
|
||||
granted_concurrency=3,
|
||||
)
|
||||
|
||||
# All 3 iterations processed: 2 success + 1 failure
|
||||
assert len(result.outputs_with_loop_values) == 3
|
||||
# Failure produces empty outputs list
|
||||
assert result.outputs_with_loop_values[1] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_iteration_without_next_loop_on_failure_stops(self) -> None:
|
||||
"""When next_loop_on_failure=False, first exception stops the loop early."""
|
||||
block = _make_for_loop_block(max_concurrency=3, next_loop_on_failure=False)
|
||||
|
||||
async def mock_gather(*coros_or_futures, return_exceptions=False):
|
||||
results = []
|
||||
for i, coro in enumerate(coros_or_futures):
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
if i == 0:
|
||||
# First iteration (index 0) returns success
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((i, [{"output": "ok"}], [], None, mock_ctx))
|
||||
elif i == 1:
|
||||
# Second iteration (index 1) raises
|
||||
results.append(RuntimeError("iteration 1 failed"))
|
||||
else:
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((i, [{"output": "ok"}], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = mock_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
object.__setattr__(block, "build_block_result", AsyncMock(return_value=_make_failure_block_result()))
|
||||
|
||||
result = await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c"],
|
||||
granted_concurrency=3,
|
||||
)
|
||||
|
||||
# Stops after the batch finishes, but all 3 iteration results are
|
||||
# captured first (idx 0 success + idx 1 failure + idx 2 success).
|
||||
# asyncio.gather already ran them to completion, so dropping the
|
||||
# later successes would silently lose data the browser produced.
|
||||
assert len(result.outputs_with_loop_values) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_iterations_fail_with_next_loop_on_failure(self) -> None:
|
||||
"""Edge case: all iterations fail but loop continues due to next_loop_on_failure."""
|
||||
block = _make_for_loop_block(max_concurrency=3, next_loop_on_failure=True)
|
||||
|
||||
async def mock_gather(*coros_or_futures, return_exceptions=False):
|
||||
results = []
|
||||
for coro in coros_or_futures:
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
results.append(RuntimeError("failed"))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = mock_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
object.__setattr__(block, "build_block_result", AsyncMock(return_value=_make_failure_block_result()))
|
||||
|
||||
result = await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c"],
|
||||
granted_concurrency=3,
|
||||
)
|
||||
|
||||
# All 3 iterations produced output (all failures)
|
||||
assert len(result.outputs_with_loop_values) == 3
|
||||
# All outputs are empty lists (failure case)
|
||||
assert all(o == [] for o in result.outputs_with_loop_values)
|
||||
|
||||
|
||||
class TestWorkflowRunContextSnapshotIsolation:
|
||||
"""Additional snapshot/merge isolation tests for parallel loops."""
|
||||
|
||||
def _make_context(self) -> WorkflowRunContext:
|
||||
ctx = WorkflowRunContext(
|
||||
workflow_title="Test",
|
||||
workflow_id="wf_1",
|
||||
workflow_permanent_id="wpid_1",
|
||||
workflow_run_id="wr_1",
|
||||
aws_client=MagicMock(),
|
||||
)
|
||||
ctx.values = {"shared": "original"}
|
||||
ctx.blocks_metadata = {"block_a": {"idx": 0}}
|
||||
ctx.workflow_run_outputs = {"output_a": "val_a"}
|
||||
ctx.parameters = {"p1": MagicMock()}
|
||||
ctx.secrets = {"s1": "secret"}
|
||||
ctx.organization_id = "org_1"
|
||||
return ctx
|
||||
|
||||
def test_snapshot_deep_copies_workflow_run_outputs(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snapshot = ctx.create_iteration_snapshot(0)
|
||||
|
||||
snapshot.workflow_run_outputs["output_b"] = "val_b"
|
||||
|
||||
assert "output_b" not in ctx.workflow_run_outputs
|
||||
|
||||
def test_multiple_snapshots_are_independent(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snap0 = ctx.create_iteration_snapshot(0)
|
||||
snap1 = ctx.create_iteration_snapshot(1)
|
||||
|
||||
snap0.values["only_in_0"] = True
|
||||
snap1.values["only_in_1"] = True
|
||||
|
||||
assert "only_in_1" not in snap0.values
|
||||
assert "only_in_0" not in snap1.values
|
||||
|
||||
def test_merge_workflow_run_outputs(self) -> None:
|
||||
ctx = self._make_context()
|
||||
snap0 = ctx.create_iteration_snapshot(0)
|
||||
snap0.workflow_run_outputs["iter_0_out"] = "r0"
|
||||
|
||||
snap1 = ctx.create_iteration_snapshot(1)
|
||||
snap1.workflow_run_outputs["iter_1_out"] = "r1"
|
||||
|
||||
ctx.merge_iteration_results([(0, snap0), (1, snap1)])
|
||||
|
||||
assert ctx.workflow_run_outputs["iter_0_out"] == "r0"
|
||||
assert ctx.workflow_run_outputs["iter_1_out"] == "r1"
|
||||
|
||||
def test_merge_empty_snapshots_is_safe(self) -> None:
|
||||
ctx = self._make_context()
|
||||
original_values = dict(ctx.values)
|
||||
|
||||
ctx.merge_iteration_results([])
|
||||
|
||||
assert ctx.values == original_values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser cleanup verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBrowserCleanup:
|
||||
"""Verify browser cleanup is called correctly in parallel loops."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_called_for_each_batch(self) -> None:
|
||||
"""cleanup_loop_iterations should be called once per batch."""
|
||||
block = _make_for_loop_block(max_concurrency=2)
|
||||
cleanup_calls: list[list[int]] = []
|
||||
|
||||
async def mock_gather(*coros_or_futures, return_exceptions=False):
|
||||
results = []
|
||||
for i, coro in enumerate(coros_or_futures):
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((i, [], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
async def track_cleanup(workflow_run_id, loop_indices, organization_id=None):
|
||||
cleanup_calls.append(list(loop_indices))
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER.cleanup_loop_iterations = track_cleanup
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = mock_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["a", "b", "c", "d", "e"],
|
||||
granted_concurrency=2,
|
||||
)
|
||||
|
||||
# 5 values / 2 concurrency = 3 batches
|
||||
assert len(cleanup_calls) == 3
|
||||
assert cleanup_calls[0] == [0, 1]
|
||||
assert cleanup_calls[1] == [2, 3]
|
||||
assert cleanup_calls[2] == [4]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML round-trip through workflow_definition_converter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestYAMLRoundTrip:
|
||||
"""Verify max_concurrency survives YAML→block conversion."""
|
||||
|
||||
def test_converter_preserves_max_concurrency(self) -> None:
|
||||
from skyvern.forge.sdk.workflow.workflow_definition_converter import (
|
||||
convert_workflow_definition,
|
||||
)
|
||||
from skyvern.schemas.workflows import WorkflowParameterYAML
|
||||
|
||||
yaml_def = WorkflowDefinitionYAML(
|
||||
version=2,
|
||||
blocks=[
|
||||
ForLoopBlockYAML(
|
||||
label="parallel_loop",
|
||||
loop_blocks=[TaskBlockYAML(label="inner_task", url="https://example.com")],
|
||||
loop_over_parameter_key="items",
|
||||
max_concurrency=5,
|
||||
),
|
||||
],
|
||||
parameters=[
|
||||
WorkflowParameterYAML(
|
||||
key="items",
|
||||
parameter_type="workflow",
|
||||
workflow_parameter_type="json",
|
||||
default_value=["a", "b"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
wd = convert_workflow_definition(yaml_def, "wf_test")
|
||||
assert len(wd.blocks) == 1
|
||||
loop_block = wd.blocks[0]
|
||||
assert isinstance(loop_block, ForLoopBlock)
|
||||
assert loop_block.max_concurrency == 5
|
||||
|
||||
def test_converter_preserves_none_max_concurrency(self) -> None:
|
||||
from skyvern.forge.sdk.workflow.workflow_definition_converter import (
|
||||
convert_workflow_definition,
|
||||
)
|
||||
from skyvern.schemas.workflows import WorkflowParameterYAML
|
||||
|
||||
yaml_def = WorkflowDefinitionYAML(
|
||||
version=2,
|
||||
blocks=[
|
||||
ForLoopBlockYAML(
|
||||
label="sequential_loop",
|
||||
loop_blocks=[TaskBlockYAML(label="inner_task", url="https://example.com")],
|
||||
loop_over_parameter_key="items",
|
||||
),
|
||||
],
|
||||
parameters=[
|
||||
WorkflowParameterYAML(
|
||||
key="items",
|
||||
parameter_type="workflow",
|
||||
workflow_parameter_type="json",
|
||||
default_value=["a", "b"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
wd = convert_workflow_definition(yaml_def, "wf_test")
|
||||
loop_block = wd.blocks[0]
|
||||
assert isinstance(loop_block, ForLoopBlock)
|
||||
assert loop_block.max_concurrency is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single iteration edge case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleIterationEdgeCase:
|
||||
"""With only 1 iteration, parallel path should still work correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_value_parallel(self) -> None:
|
||||
block = _make_for_loop_block(max_concurrency=5)
|
||||
|
||||
async def mock_gather(*coros_or_futures, return_exceptions=False):
|
||||
results = []
|
||||
for i, coro in enumerate(coros_or_futures):
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {"result": "single"}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((0, [{"output": "single_result"}], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = mock_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
result = await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=["only_one"],
|
||||
granted_concurrency=5,
|
||||
)
|
||||
|
||||
assert len(result.outputs_with_loop_values) == 1
|
||||
assert result.outputs_with_loop_values[0] == [{"output": "single_result"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Iteration cap enforcement in parallel path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterationCapParallel:
|
||||
"""Verify DEFAULT_MAX_LOOP_ITERATIONS is enforced in parallel path."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_values_capped_at_max_iterations(self) -> None:
|
||||
from skyvern.forge.sdk.workflow.models.block import DEFAULT_MAX_LOOP_ITERATIONS
|
||||
|
||||
block = _make_for_loop_block(max_concurrency=20)
|
||||
gather_call_sizes: list[int] = []
|
||||
|
||||
async def mock_gather(*coros_or_futures, return_exceptions=False):
|
||||
gather_call_sizes.append(len(coros_or_futures))
|
||||
results = []
|
||||
for i, coro in enumerate(coros_or_futures):
|
||||
try:
|
||||
coro.close()
|
||||
except Exception:
|
||||
pass
|
||||
mock_ctx = MagicMock(spec=WorkflowRunContext)
|
||||
mock_ctx.values = {}
|
||||
mock_ctx.blocks_metadata = {}
|
||||
mock_ctx.workflow_run_outputs = {}
|
||||
results.append((i, [], [], None, mock_ctx))
|
||||
return results
|
||||
|
||||
mock_context = MagicMock(spec=WorkflowRunContext)
|
||||
mock_context.create_iteration_snapshot.return_value = MagicMock(
|
||||
spec=WorkflowRunContext,
|
||||
values={},
|
||||
blocks_metadata={},
|
||||
workflow_run_outputs={},
|
||||
)
|
||||
|
||||
# Create more values than the max
|
||||
oversized_values = list(range(DEFAULT_MAX_LOOP_ITERATIONS + 50))
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.asyncio") as mock_asyncio,
|
||||
patch("skyvern.forge.sdk.workflow.models.block.skyvern_context") as mock_skyvern_ctx,
|
||||
):
|
||||
mock_app.AGENT_FUNCTION = AsyncMock()
|
||||
mock_app.AGENT_FUNCTION.release_parallel_loop_quota = AsyncMock()
|
||||
mock_app.BROWSER_MANAGER = AsyncMock()
|
||||
mock_asyncio.create_task = lambda coro: coro
|
||||
mock_asyncio.gather = mock_gather
|
||||
mock_skyvern_ctx.current.return_value = None
|
||||
|
||||
await block._execute_loop_parallel(
|
||||
workflow_run_id="wr_test",
|
||||
workflow_run_block_id="wrb_test",
|
||||
workflow_run_context=mock_context,
|
||||
loop_over_values=oversized_values,
|
||||
granted_concurrency=20,
|
||||
)
|
||||
|
||||
total_processed = sum(gather_call_sizes)
|
||||
assert total_processed == DEFAULT_MAX_LOOP_ITERATIONS
|
||||
|
|
@ -10,6 +10,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.webeye.real_browser_manager import RealBrowserManager
|
||||
|
||||
|
||||
|
|
@ -131,3 +133,92 @@ async def test_non_pbs_workflow_run_inherits_parent_browser() -> None:
|
|||
# Both entries should be synced
|
||||
assert manager.pages["wfr_child"] is parent_state
|
||||
assert manager.pages["wfr_parent"] is parent_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_iteration_browser_returned_from_context() -> None:
|
||||
"""When SkyvernContext has an iteration browser_session_id, get_or_create_for_workflow_run
|
||||
must return the pre-created iteration browser instead of creating a new one."""
|
||||
manager = RealBrowserManager()
|
||||
iteration_state = MagicMock()
|
||||
iteration_state.get_working_page = AsyncMock(return_value=MagicMock(url="https://example.com"))
|
||||
manager.pages["wr_abc__iter_0"] = iteration_state
|
||||
|
||||
workflow_run = make_workflow_run("wr_abc")
|
||||
|
||||
# Set up SkyvernContext with iteration browser_session_id
|
||||
ctx = SkyvernContext(browser_session_id="wr_abc__iter_0")
|
||||
skyvern_context.set(ctx)
|
||||
try:
|
||||
with patch("skyvern.webeye.real_browser_manager.app"):
|
||||
result = await manager.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
url="https://example.com",
|
||||
)
|
||||
assert result is iteration_state
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_iteration_browser_not_used_without_context() -> None:
|
||||
"""Without a parallel iteration context, normal lookup applies even if
|
||||
iteration keys exist in pages."""
|
||||
manager = RealBrowserManager()
|
||||
iteration_state = MagicMock()
|
||||
normal_state = MagicMock()
|
||||
manager.pages["wr_abc__iter_0"] = iteration_state
|
||||
manager.pages["wr_abc"] = normal_state
|
||||
|
||||
workflow_run = make_workflow_run("wr_abc")
|
||||
|
||||
# No SkyvernContext set (or context without iteration marker)
|
||||
skyvern_context.reset()
|
||||
with patch("skyvern.webeye.real_browser_manager.app"):
|
||||
result = await manager.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
url="https://example.com",
|
||||
)
|
||||
assert result is normal_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_iteration_different_iterations_get_different_browsers() -> None:
|
||||
"""Each parallel iteration should get its own isolated browser state."""
|
||||
manager = RealBrowserManager()
|
||||
iter0_state = MagicMock()
|
||||
iter0_state.get_working_page = AsyncMock(return_value=MagicMock(url="https://example.com"))
|
||||
iter1_state = MagicMock()
|
||||
iter1_state.get_working_page = AsyncMock(return_value=MagicMock(url="https://example.com"))
|
||||
manager.pages["wr_abc__iter_0"] = iter0_state
|
||||
manager.pages["wr_abc__iter_1"] = iter1_state
|
||||
|
||||
workflow_run = make_workflow_run("wr_abc")
|
||||
|
||||
# Iteration 0
|
||||
ctx0 = SkyvernContext(browser_session_id="wr_abc__iter_0")
|
||||
skyvern_context.set(ctx0)
|
||||
try:
|
||||
with patch("skyvern.webeye.real_browser_manager.app"):
|
||||
result0 = await manager.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
url="https://example.com",
|
||||
)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
# Iteration 1
|
||||
ctx1 = SkyvernContext(browser_session_id="wr_abc__iter_1")
|
||||
skyvern_context.set(ctx1)
|
||||
try:
|
||||
with patch("skyvern.webeye.real_browser_manager.app"):
|
||||
result1 = await manager.get_or_create_for_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
url="https://example.com",
|
||||
)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
assert result0 is iter0_state
|
||||
assert result1 is iter1_state
|
||||
assert result0 is not result1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue