Workflow CodeGen (#2740)

This commit is contained in:
Shuchang Zheng 2025-06-18 00:44:46 -07:00 committed by GitHub
parent 14bc711240
commit f6a0ccd32b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 565 additions and 51 deletions

View file

@ -1,18 +1,15 @@
# skyvern_codegen_cst.py
"""
Generate a runnable Skyvern workflow script **with LibCST**.
Generate a runnable Skyvern workflow script.
Example
-------
from skyvern_codegen_cst import generate_workflow_script
src = generate_workflow_script(
workflow=workflow_dict,
tasks=[task1, task2, ...],
actions_by_task={
task1["task_id"]: task1_actions,
task2["task_id"]: task2_actions,
},
generated_code = generate_workflow_script(
file_name="workflow.py",
workflow_run_request=workflow_run_request,
workflow=workflow,
tasks=tasks,
actions_by_task=actions_by_task,
)
Path("workflow.py").write_text(src)
"""
@ -20,11 +17,14 @@ Path("workflow.py").write_text(src)
from __future__ import annotations
import keyword
from typing import Any, Iterable, Mapping
from enum import StrEnum
from typing import Any
import libcst as cst
from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param
from skyvern.webeye.actions.action_types import ActionType
# --------------------------------------------------------------------- #
# 1. helpers #
# --------------------------------------------------------------------- #
@ -42,6 +42,8 @@ ACTION_MAP = {
"drag": "drag",
"solve_captcha": "solve_captcha",
"verification_code": "verification_code",
"wait": "wait",
"extract": "extract",
}
INDENT = " " * 4
@ -86,12 +88,47 @@ def _value(value: Any) -> cst.BaseExpression:
# --------------------------------------------------------------------- #
def _make_decorator(block: Mapping[str, Any]) -> cst.Decorator:
def _workflow_decorator(wf_req: dict[str, Any]) -> cst.Decorator:
"""
Build @skyvern.workflow(
title="...", totp_url=..., totp_identifier=..., webhook_callback_url=..., max_steps=...
)
"""
# helper that skips “None” so the output is concise
def kw(key: str, value: Any) -> cst.Arg | None:
if value is None:
return None
return cst.Arg(keyword=cst.Name(key), value=_value(value))
args: list = list(
filter(
None,
[
kw("title", wf_req.get("title", "")),
kw("totp_url", wf_req.get("totp_url")),
kw("totp_identifier", wf_req.get("totp_identifier")),
kw("webhook_url", wf_req.get("webhook_url")),
kw("max_steps", wf_req.get("max_steps")),
],
)
)
return cst.Decorator(
decorator=cst.Call(
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("workflow")),
args=args,
)
)
def _make_decorator(block: dict[str, Any]) -> cst.Decorator:
bt = block["block_type"]
deco_name = {
"task": "task_block",
"file_download": "file_download_block",
"send_email": "email_block",
"wait": "wait_block",
}[bt]
kwargs = []
@ -104,12 +141,18 @@ def _make_decorator(block: Mapping[str, Any]) -> cst.Decorator:
"totp_identifier": "totp_identifier",
"webhook_callback_url": "webhook_callback_url",
"max_steps_per_run": "max_steps",
"wait_sec": "seconds",
}
for src_key, kw in field_map.items():
v = block.get(src_key)
if v not in (None, "", [], {}):
kwargs.append(cst.Arg(value=_value(v), keyword=Name(kw)))
if isinstance(v, StrEnum):
v = v.value
try:
kwargs.append(cst.Arg(value=_value(v), keyword=Name(kw)))
except Exception:
raise
# booleans
if block.get("complete_on_download"):
@ -125,7 +168,7 @@ def _make_decorator(block: Mapping[str, Any]) -> cst.Decorator:
)
def _action_to_stmt(act: Mapping[str, Any]) -> cst.BaseStatement:
def _action_to_stmt(act: dict[str, Any]) -> cst.BaseStatement:
"""
Turn one Action dict into:
@ -157,14 +200,16 @@ def _action_to_stmt(act: Mapping[str, Any]) -> cst.BaseStatement:
return cst.SimpleStatementLine([cst.Expr(await_expr)])
def _build_block_fn(block: Mapping[str, Any], actions: Iterable[Mapping[str, Any]]) -> FunctionDef:
name = _safe_name(block["title"])
def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> FunctionDef:
name = _safe_name(block.get("title") or block.get("label") or f"block_{block.get('workflow_run_block_id')}")
body_stmts: list[cst.BaseStatement] = []
if block.get("url"):
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
for act in actions:
if act["action_type"] in [ActionType.COMPLETE]:
continue
body_stmts.append(_action_to_stmt(act))
if not body_stmts:
@ -174,8 +219,8 @@ def _build_block_fn(block: Mapping[str, Any], actions: Iterable[Mapping[str, Any
name=Name(name),
params=cst.Parameters(
params=[
Param(name=Name("page")),
Param(name=Name("context")),
Param(name=Name("page"), annotation=cst.Annotation(cst.Name("SkyvernPage"))),
Param(name=Name("context"), annotation=cst.Annotation(cst.Name("RunContext"))),
]
),
decorators=[_make_decorator(block)],
@ -185,7 +230,7 @@ def _build_block_fn(block: Mapping[str, Any], actions: Iterable[Mapping[str, Any
)
def _build_model(workflow: Mapping[str, Any]) -> cst.ClassDef:
def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
"""
class WorkflowParameters(BaseModel):
ein_info: str
@ -216,31 +261,65 @@ def _build_model(workflow: Mapping[str, Any]) -> cst.ClassDef:
)
def _build_cached_params() -> cst.SimpleStatementLine:
src = "cached_parameters = WorkflowParameters(**{k: f'<{k}>' for k in WorkflowParameters.model_fields})"
return cst.parse_statement(src)
def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine:
"""
Make a CST for:
cached_parameters = WorkflowParameters(ein_info="...", ...)
"""
call = cst.Call(
func=cst.Name("WorkflowParameters"),
args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()],
)
assign = cst.Assign(
targets=[cst.AssignTarget(cst.Name("cached_parameters"))],
value=call,
)
return cst.SimpleStatementLine([assign])
def _build_run_fn(task_fns: list[str]) -> FunctionDef:
body = [cst.parse_statement("page, context = await skyvern.setup(parameters.model_dump())")] + [
cst.parse_statement(f"await {_safe_name(t)}(page, context)") for t in task_fns
def _build_run_fn(task_titles: list[str], wf_req: dict[str, Any]) -> FunctionDef:
body = [
cst.parse_statement("page, context = await skyvern.setup(parameters.model_dump())"),
*[cst.parse_statement(f"await {_safe_name(t)}(page, context)") for t in task_titles],
]
params = cst.Parameters(
params=[
Param(
name=cst.Name("parameters"),
annotation=cst.Annotation(cst.Name("WorkflowParameters")),
default=cst.Name("cached_parameters"),
),
Param(
name=cst.Name("title"),
annotation=cst.Annotation(cst.Name("str")),
default=_value(wf_req.get("title", "")),
),
Param(
name=cst.Name("webhook_url"),
annotation=cst.Annotation(cst.parse_expression("str | None")),
default=_value(wf_req.get("webhook_url")),
),
Param(
name=cst.Name("totp_url"),
annotation=cst.Annotation(cst.parse_expression("str | None")),
default=_value(wf_req.get("totp_url")),
),
Param(
name=cst.Name("totp_identifier"),
annotation=cst.Annotation(cst.parse_expression("str | None")),
default=_value(wf_req.get("totp_identifier")),
),
]
)
return FunctionDef(
name=Name("run_workflow"),
decorators=[cst.Decorator(Attribute(value=Name("skyvern"), attr=Name("workflow")))],
params=cst.Parameters(
params=[
Param(
name=Name("parameters"),
default=Name("cached_parameters"),
annotation=cst.Annotation(Name("WorkflowParameters")),
)
]
),
body=cst.IndentedBlock(body),
returns=None,
name=cst.Name("run_workflow"),
asynchronous=cst.Asynchronous(),
decorators=[_workflow_decorator(wf_req)],
params=params,
body=cst.IndentedBlock(body),
)
@ -251,9 +330,11 @@ def _build_run_fn(task_fns: list[str]) -> FunctionDef:
def generate_workflow_script(
*,
workflow: Mapping[str, Any],
tasks: Iterable[Mapping[str, Any]],
actions_by_task: Mapping[str, Iterable[Mapping[str, Any]]],
file_name: str,
workflow_run_request: dict[str, Any],
workflow: dict[str, Any],
tasks: list[dict[str, Any]],
actions_by_task: dict[str, list[dict[str, Any]]],
) -> str:
"""
Build a LibCST Module and emit .code (PEP-8-formatted source).
@ -285,31 +366,41 @@ def generate_workflow_script(
# --- class + cached params -----------------------------------------
model_cls = _build_model(workflow)
cached_params_stmt = _build_cached_params()
cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {}))
# --- blocks ---------------------------------------------------------
block_fns: list[FunctionDef] = []
task_titles = []
for t in tasks:
fn = _build_block_fn(t, actions_by_task.get(t["task_id"], []))
block_fns.append(fn)
task_titles.append(t["title"])
block_fns = []
length_of_tasks = len(tasks)
for idx, task in enumerate(tasks):
block_fns.append(_build_block_fn(task, actions_by_task.get(task.get("task_id", ""), [])))
if idx < length_of_tasks - 1:
block_fns.append(cst.EmptyLine())
block_fns.append(cst.EmptyLine())
task_titles: list[str] = [
t.get("title") or t.get("label") or t.get("task_id") or f"unknown_title_{idx}" for idx, t in enumerate(tasks)
]
# --- runner ---------------------------------------------------------
run_fn = _build_run_fn(task_titles)
run_fn = _build_run_fn(task_titles, workflow_run_request)
module = cst.Module(
body=[
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
cached_params_stmt,
cst.EmptyLine(),
cst.EmptyLine(),
*block_fns,
cst.EmptyLine(),
run_fn,
cst.EmptyLine(),
run_fn,
]
)
with open(file_name, "w") as f:
f.write(module.code)
return module.code