generate GeneratedWorkflowParameters (#3264)

This commit is contained in:
Shuchang Zheng 2025-08-21 15:42:34 -07:00 committed by GitHub
parent 988416829f
commit 2a62dc08aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 504 additions and 64 deletions

View file

@ -25,6 +25,10 @@ import structlog
from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_workflow_parameters import (
generate_workflow_parameters_schema,
hydrate_input_text_actions_with_field_names,
)
from skyvern.forge import app
from skyvern.webeye.actions.action_types import ActionType
@ -61,6 +65,7 @@ ACTIONS_WITH_XPATH = [
]
INDENT = " " * 4
DOUBLE_INDENT = " " * 8
def _safe_name(label: str) -> str:
@ -97,6 +102,57 @@ def _value(value: Any) -> cst.BaseExpression:
return cst.SimpleString(repr(str(value)))
def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression:
"""Create a generate_text function call CST expression."""
return cst.Await(
expression=cst.Call(
func=cst.Name("generate_text"),
whitespace_before_args=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
args=[
# First positional argument: context.generated_parameters['parameter_key']
cst.Arg(
value=cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("generated_parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
),
# intention keyword argument
cst.Arg(
keyword=cst.Name("intention"),
value=_value(intention),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
),
# data keyword argument
cst.Arg(
keyword=cst.Name("data"),
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
comma=cst.Comma(),
),
],
)
)
# --------------------------------------------------------------------- #
# 2. utility builders #
# --------------------------------------------------------------------- #
@ -177,10 +233,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
)
if method in ["type", "fill"]:
# Get intention from action
intention = act.get("intention") or act.get("reasoning") or ""
# Use generate_text call if field_name is available, otherwise fallback to direct value
if act.get("field_name"):
text_value = _generate_text_call(
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
)
else:
text_value = _value(act["text"])
args.append(
cst.Arg(
keyword=cst.Name("text"),
value=_value(act["text"]),
value=text_value,
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
@ -212,7 +279,7 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
elif method == "extract":
args.append(
cst.Arg(
keyword=cst.Name("data_extraction_goal"),
keyword=cst.Name("prompt"),
value=_value(act["data_extraction_goal"]),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
@ -309,8 +376,8 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
"""
class WorkflowParameters(BaseModel):
ein_info: str
company_name: str
param1: str
param2: str
...
"""
ann_lines: list[cst.BaseStatement] = []
@ -319,7 +386,6 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
if p["parameter_type"] != "workflow":
continue
# ein_info: str
ann = cst.AnnAssign(
target=cst.Name(p["key"]),
annotation=cst.Annotation(cst.Name("str")),
@ -337,21 +403,24 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
)
def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine:
def _build_generated_model_from_schema(schema_code: str) -> cst.ClassDef | None:
"""
Make a CST for:
cached_parameters = WorkflowParameters(ein_info="...", ...)
Parse the generated schema code and return a ClassDef, or None if parsing fails.
"""
call = cst.Call(
func=cst.Name("WorkflowParameters"),
args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()],
)
try:
# Parse the schema code and extract just the class definition
parsed_module = cst.parse_module(schema_code)
assign = cst.Assign(
targets=[cst.AssignTarget(cst.Name("cached_parameters"))],
value=call,
)
return cst.SimpleStatementLine([assign])
# Find the GeneratedWorkflowParameters class in the parsed module
for node in parsed_module.body:
if isinstance(node, cst.ClassDef) and node.name.value == "GeneratedWorkflowParameters":
return node
# If no class found, return None
return None
except Exception as e:
LOG.warning("Failed to parse generated schema code", error=str(e))
return None
# --------------------------------------------------------------------- #
@ -804,7 +873,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
cst.parse_statement(
"parameters = parameters.model_dump() if isinstance(parameters, WorkflowParameters) else parameters"
),
cst.parse_statement("page, context = await skyvern.setup(parameters)"),
cst.parse_statement("page, context = await skyvern.setup(parameters, GeneratedWorkflowParameters)"),
]
for block in blocks:
@ -867,8 +936,27 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
params=[
Param(
name=cst.Name("parameters"),
annotation=cst.Annotation(cst.Name("WorkflowParameters")),
default=cst.Name("cached_parameters"),
annotation=cst.Annotation(
cst.BinaryOperation(
left=cst.Name("WorkflowParameters"),
operator=cst.BitOr(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
),
right=cst.Subscript(
value=cst.Name("dict"),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("str")),
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("Any")),
),
],
),
)
),
whitespace_after_param=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
@ -948,11 +1036,24 @@ async def generate_workflow_script(
imports: list[cst.BaseStatement] = [
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]),
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("pydantic"))])]),
cst.SimpleStatementLine(
[
cst.ImportFrom(
module=cst.Name("typing"),
names=[
cst.ImportAlias(cst.Name("Any")),
],
)
]
),
cst.SimpleStatementLine(
[
cst.ImportFrom(
module=cst.Name("pydantic"),
names=[cst.ImportAlias(cst.Name("BaseModel"))],
names=[
cst.ImportAlias(cst.Name("BaseModel")),
cst.ImportAlias(cst.Name("Field")),
],
)
]
),
@ -964,15 +1065,20 @@ async def generate_workflow_script(
names=[
cst.ImportAlias(cst.Name("RunContext")),
cst.ImportAlias(cst.Name("SkyvernPage")),
cst.ImportAlias(cst.Name("generate_text")),
],
)
]
),
]
# --- generate schema and hydrate actions ---------------------------
generated_schema, field_mappings = await generate_workflow_parameters_schema(actions_by_task)
actions_by_task = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
# --- class + cached params -----------------------------------------
model_cls = _build_model(workflow)
cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {}))
generated_model_cls = _build_generated_model_from_schema(generated_schema)
# --- blocks ---------------------------------------------------------
block_fns = []
@ -1008,17 +1114,29 @@ async def generate_workflow_script(
# --- runner ---------------------------------------------------------
run_fn = _build_run_fn(blocks, workflow_run_request)
module = cst.Module(
body=[
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
cached_params_stmt,
cst.EmptyLine(),
cst.EmptyLine(),
# Build module body with optional generated model class
module_body = [
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
]
# Add generated model class if available
if generated_model_cls:
module_body.extend(
[
generated_model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
]
)
# Continue with the rest of the module
module_body.extend(
[
*block_fns,
cst.EmptyLine(),
cst.EmptyLine(),
@ -1029,6 +1147,8 @@ async def generate_workflow_script(
]
)
module = cst.Module(body=module_body)
with open(file_name, "w") as f:
f.write(module.code)
return module.code