mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-10 15:35:51 +00:00
generate GeneratedWorkflowParameters (#3264)
This commit is contained in:
parent
988416829f
commit
2a62dc08aa
7 changed files with 504 additions and 64 deletions
|
@ -25,6 +25,10 @@ import structlog
|
|||
from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param
|
||||
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.generate_workflow_parameters import (
|
||||
generate_workflow_parameters_schema,
|
||||
hydrate_input_text_actions_with_field_names,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.webeye.actions.action_types import ActionType
|
||||
|
||||
|
@ -61,6 +65,7 @@ ACTIONS_WITH_XPATH = [
|
|||
]
|
||||
|
||||
INDENT = " " * 4
|
||||
DOUBLE_INDENT = " " * 8
|
||||
|
||||
|
||||
def _safe_name(label: str) -> str:
|
||||
|
@ -97,6 +102,57 @@ def _value(value: Any) -> cst.BaseExpression:
|
|||
return cst.SimpleString(repr(str(value)))
|
||||
|
||||
|
||||
def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression:
|
||||
"""Create a generate_text function call CST expression."""
|
||||
return cst.Await(
|
||||
expression=cst.Call(
|
||||
func=cst.Name("generate_text"),
|
||||
whitespace_before_args=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
args=[
|
||||
# First positional argument: context.generated_parameters['parameter_key']
|
||||
cst.Arg(
|
||||
value=cst.Subscript(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("generated_parameters"),
|
||||
),
|
||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
|
||||
),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
),
|
||||
# intention keyword argument
|
||||
cst.Arg(
|
||||
keyword=cst.Name("intention"),
|
||||
value=_value(intention),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
),
|
||||
# data keyword argument
|
||||
cst.Arg(
|
||||
keyword=cst.Name("data"),
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("parameters"),
|
||||
),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
comma=cst.Comma(),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# 2. utility builders #
|
||||
# --------------------------------------------------------------------- #
|
||||
|
@ -177,10 +233,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
|
|||
)
|
||||
|
||||
if method in ["type", "fill"]:
|
||||
# Get intention from action
|
||||
intention = act.get("intention") or act.get("reasoning") or ""
|
||||
|
||||
# Use generate_text call if field_name is available, otherwise fallback to direct value
|
||||
if act.get("field_name"):
|
||||
text_value = _generate_text_call(
|
||||
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
|
||||
)
|
||||
else:
|
||||
text_value = _value(act["text"])
|
||||
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("text"),
|
||||
value=_value(act["text"]),
|
||||
value=text_value,
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
|
@ -212,7 +279,7 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
|
|||
elif method == "extract":
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("data_extraction_goal"),
|
||||
keyword=cst.Name("prompt"),
|
||||
value=_value(act["data_extraction_goal"]),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
|
@ -309,8 +376,8 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
|||
def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
|
||||
"""
|
||||
class WorkflowParameters(BaseModel):
|
||||
ein_info: str
|
||||
company_name: str
|
||||
param1: str
|
||||
param2: str
|
||||
...
|
||||
"""
|
||||
ann_lines: list[cst.BaseStatement] = []
|
||||
|
@ -319,7 +386,6 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
|
|||
if p["parameter_type"] != "workflow":
|
||||
continue
|
||||
|
||||
# ein_info: str
|
||||
ann = cst.AnnAssign(
|
||||
target=cst.Name(p["key"]),
|
||||
annotation=cst.Annotation(cst.Name("str")),
|
||||
|
@ -337,21 +403,24 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
|
|||
)
|
||||
|
||||
|
||||
def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine:
|
||||
def _build_generated_model_from_schema(schema_code: str) -> cst.ClassDef | None:
|
||||
"""
|
||||
Make a CST for:
|
||||
cached_parameters = WorkflowParameters(ein_info="...", ...)
|
||||
Parse the generated schema code and return a ClassDef, or None if parsing fails.
|
||||
"""
|
||||
call = cst.Call(
|
||||
func=cst.Name("WorkflowParameters"),
|
||||
args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()],
|
||||
)
|
||||
try:
|
||||
# Parse the schema code and extract just the class definition
|
||||
parsed_module = cst.parse_module(schema_code)
|
||||
|
||||
assign = cst.Assign(
|
||||
targets=[cst.AssignTarget(cst.Name("cached_parameters"))],
|
||||
value=call,
|
||||
)
|
||||
return cst.SimpleStatementLine([assign])
|
||||
# Find the GeneratedWorkflowParameters class in the parsed module
|
||||
for node in parsed_module.body:
|
||||
if isinstance(node, cst.ClassDef) and node.name.value == "GeneratedWorkflowParameters":
|
||||
return node
|
||||
|
||||
# If no class found, return None
|
||||
return None
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to parse generated schema code", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
|
@ -804,7 +873,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
|
|||
cst.parse_statement(
|
||||
"parameters = parameters.model_dump() if isinstance(parameters, WorkflowParameters) else parameters"
|
||||
),
|
||||
cst.parse_statement("page, context = await skyvern.setup(parameters)"),
|
||||
cst.parse_statement("page, context = await skyvern.setup(parameters, GeneratedWorkflowParameters)"),
|
||||
]
|
||||
|
||||
for block in blocks:
|
||||
|
@ -867,8 +936,27 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
|
|||
params=[
|
||||
Param(
|
||||
name=cst.Name("parameters"),
|
||||
annotation=cst.Annotation(cst.Name("WorkflowParameters")),
|
||||
default=cst.Name("cached_parameters"),
|
||||
annotation=cst.Annotation(
|
||||
cst.BinaryOperation(
|
||||
left=cst.Name("WorkflowParameters"),
|
||||
operator=cst.BitOr(
|
||||
whitespace_before=cst.SimpleWhitespace(" "),
|
||||
whitespace_after=cst.SimpleWhitespace(" "),
|
||||
),
|
||||
right=cst.Subscript(
|
||||
value=cst.Name("dict"),
|
||||
slice=[
|
||||
cst.SubscriptElement(
|
||||
slice=cst.Index(value=cst.Name("str")),
|
||||
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
|
||||
),
|
||||
cst.SubscriptElement(
|
||||
slice=cst.Index(value=cst.Name("Any")),
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
),
|
||||
whitespace_after_param=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
|
@ -948,11 +1036,24 @@ async def generate_workflow_script(
|
|||
imports: list[cst.BaseStatement] = [
|
||||
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]),
|
||||
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("pydantic"))])]),
|
||||
cst.SimpleStatementLine(
|
||||
[
|
||||
cst.ImportFrom(
|
||||
module=cst.Name("typing"),
|
||||
names=[
|
||||
cst.ImportAlias(cst.Name("Any")),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
cst.SimpleStatementLine(
|
||||
[
|
||||
cst.ImportFrom(
|
||||
module=cst.Name("pydantic"),
|
||||
names=[cst.ImportAlias(cst.Name("BaseModel"))],
|
||||
names=[
|
||||
cst.ImportAlias(cst.Name("BaseModel")),
|
||||
cst.ImportAlias(cst.Name("Field")),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
|
@ -964,15 +1065,20 @@ async def generate_workflow_script(
|
|||
names=[
|
||||
cst.ImportAlias(cst.Name("RunContext")),
|
||||
cst.ImportAlias(cst.Name("SkyvernPage")),
|
||||
cst.ImportAlias(cst.Name("generate_text")),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
# --- generate schema and hydrate actions ---------------------------
|
||||
generated_schema, field_mappings = await generate_workflow_parameters_schema(actions_by_task)
|
||||
actions_by_task = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
|
||||
|
||||
# --- class + cached params -----------------------------------------
|
||||
model_cls = _build_model(workflow)
|
||||
cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {}))
|
||||
generated_model_cls = _build_generated_model_from_schema(generated_schema)
|
||||
|
||||
# --- blocks ---------------------------------------------------------
|
||||
block_fns = []
|
||||
|
@ -1008,17 +1114,29 @@ async def generate_workflow_script(
|
|||
# --- runner ---------------------------------------------------------
|
||||
run_fn = _build_run_fn(blocks, workflow_run_request)
|
||||
|
||||
module = cst.Module(
|
||||
body=[
|
||||
*imports,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
model_cls,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
cached_params_stmt,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
# Build module body with optional generated model class
|
||||
module_body = [
|
||||
*imports,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
model_cls,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
]
|
||||
|
||||
# Add generated model class if available
|
||||
if generated_model_cls:
|
||||
module_body.extend(
|
||||
[
|
||||
generated_model_cls,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
]
|
||||
)
|
||||
|
||||
# Continue with the rest of the module
|
||||
module_body.extend(
|
||||
[
|
||||
*block_fns,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
|
@ -1029,6 +1147,8 @@ async def generate_workflow_script(
|
|||
]
|
||||
)
|
||||
|
||||
module = cst.Module(body=module_body)
|
||||
|
||||
with open(file_name, "w") as f:
|
||||
f.write(module.code)
|
||||
return module.code
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue