mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-16 10:19:42 +00:00
workflow script creation (#3151)
This commit is contained in:
parent
136fa70c48
commit
16596e5c61
4 changed files with 187 additions and 5 deletions
|
@ -129,6 +129,7 @@ def _make_decorator(block: dict[str, Any]) -> cst.Decorator:
|
||||||
"file_download": "file_download_block",
|
"file_download": "file_download_block",
|
||||||
"send_email": "email_block",
|
"send_email": "email_block",
|
||||||
"wait": "wait_block",
|
"wait": "wait_block",
|
||||||
|
"navigation": "navigation_block",
|
||||||
}[bt]
|
}[bt]
|
||||||
|
|
||||||
kwargs = []
|
kwargs = []
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Iterable, Mapping
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
|
@ -13,10 +13,10 @@ LOG = structlog.get_logger(__name__)
|
||||||
@dataclass
|
@dataclass
|
||||||
class CodeGenInput:
|
class CodeGenInput:
|
||||||
file_name: str
|
file_name: str
|
||||||
workflow_run: Mapping[str, Any]
|
workflow_run: dict[str, Any]
|
||||||
workflow: Mapping[str, Any]
|
workflow: dict[str, Any]
|
||||||
workflow_blocks: Iterable[Mapping[str, Any]]
|
workflow_blocks: list[dict[str, Any]]
|
||||||
actions_by_task: Mapping[str, Iterable[Mapping[str, Any]]]
|
actions_by_task: dict[str, list[dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organization_id: str) -> CodeGenInput:
|
async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organization_id: str) -> CodeGenInput:
|
||||||
|
|
|
@ -45,6 +45,7 @@ from skyvern.forge.sdk.db.models import (
|
||||||
WorkflowRunModel,
|
WorkflowRunModel,
|
||||||
WorkflowRunOutputParameterModel,
|
WorkflowRunOutputParameterModel,
|
||||||
WorkflowRunParameterModel,
|
WorkflowRunParameterModel,
|
||||||
|
WorkflowScriptModel,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.db.utils import (
|
from skyvern.forge.sdk.db.utils import (
|
||||||
_custom_json_serializer,
|
_custom_json_serializer,
|
||||||
|
@ -3807,3 +3808,83 @@ class AgentDB:
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
return [convert_to_script_block(record) for record in records]
|
return [convert_to_script_block(record) for record in records]
|
||||||
|
|
||||||
|
async def create_workflow_script(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
organization_id: str,
|
||||||
|
script_id: str,
|
||||||
|
workflow_permanent_id: str,
|
||||||
|
cache_key: str,
|
||||||
|
cache_key_value: str,
|
||||||
|
workflow_id: str | None = None,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a workflow->script cache mapping entry."""
|
||||||
|
try:
|
||||||
|
async with self.Session() as session:
|
||||||
|
record = WorkflowScriptModel(
|
||||||
|
organization_id=organization_id,
|
||||||
|
script_id=script_id,
|
||||||
|
workflow_permanent_id=workflow_permanent_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
cache_key_value=cache_key_value,
|
||||||
|
)
|
||||||
|
session.add(record)
|
||||||
|
await session.commit()
|
||||||
|
except SQLAlchemyError:
|
||||||
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
LOG.error("UnexpectedError", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_workflow_scripts_by_cache_key_value(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
organization_id: str,
|
||||||
|
workflow_permanent_id: str,
|
||||||
|
cache_key_value: str,
|
||||||
|
) -> list[Script]:
|
||||||
|
"""Get latest script versions linked to a workflow by a specific cache_key_value."""
|
||||||
|
try:
|
||||||
|
async with self.Session() as session:
|
||||||
|
# Subquery: script_ids associated with this workflow + cache_key_value
|
||||||
|
ws_script_ids_subquery = (
|
||||||
|
select(WorkflowScriptModel.script_id)
|
||||||
|
.where(WorkflowScriptModel.organization_id == organization_id)
|
||||||
|
.where(WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id)
|
||||||
|
.where(WorkflowScriptModel.cache_key_value == cache_key_value)
|
||||||
|
.where(WorkflowScriptModel.deleted_at.is_(None))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Latest version per script_id within the org and not deleted
|
||||||
|
latest_versions_subquery = (
|
||||||
|
select(
|
||||||
|
ScriptModel.script_id,
|
||||||
|
func.max(ScriptModel.version).label("latest_version"),
|
||||||
|
)
|
||||||
|
.where(ScriptModel.organization_id == organization_id)
|
||||||
|
.where(ScriptModel.deleted_at.is_(None))
|
||||||
|
.where(ScriptModel.script_id.in_(ws_script_ids_subquery))
|
||||||
|
.group_by(ScriptModel.script_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
query = select(ScriptModel).join(
|
||||||
|
latest_versions_subquery,
|
||||||
|
(ScriptModel.script_id == latest_versions_subquery.c.script_id)
|
||||||
|
& (ScriptModel.version == latest_versions_subquery.c.latest_version),
|
||||||
|
)
|
||||||
|
query = query.order_by(ScriptModel.created_at.desc())
|
||||||
|
|
||||||
|
scripts = (await session.scalars(query)).all()
|
||||||
|
return [convert_to_script(script) for script in scripts]
|
||||||
|
except SQLAlchemyError:
|
||||||
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
LOG.error("UnexpectedError", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import structlog
|
import structlog
|
||||||
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
|
|
||||||
from skyvern import analytics
|
from skyvern import analytics
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT
|
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT
|
||||||
|
from skyvern.core.code_generations.generate_code import generate_workflow_script as generate_python_workflow_script
|
||||||
|
from skyvern.core.code_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
|
||||||
from skyvern.exceptions import (
|
from skyvern.exceptions import (
|
||||||
BlockNotFound,
|
BlockNotFound,
|
||||||
BrowserSessionNotFound,
|
BrowserSessionNotFound,
|
||||||
|
@ -99,6 +103,8 @@ from skyvern.forge.sdk.workflow.models.yaml import (
|
||||||
WorkflowDefinitionYAML,
|
WorkflowDefinitionYAML,
|
||||||
)
|
)
|
||||||
from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse
|
from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse
|
||||||
|
from skyvern.schemas.scripts import FileEncoding, ScriptFileCreate
|
||||||
|
from skyvern.services import script_service
|
||||||
from skyvern.webeye.browser_factory import BrowserState
|
from skyvern.webeye.browser_factory import BrowserState
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
@ -615,6 +621,10 @@ class WorkflowService:
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: generate script for workflow if the workflow.use_cache is True AND there's no script cached for the workflow
|
||||||
|
if workflow.use_cache:
|
||||||
|
await self.generate_script_for_workflow(workflow=workflow, workflow_run=workflow_run)
|
||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
async def create_workflow(
|
async def create_workflow(
|
||||||
|
@ -2236,3 +2246,93 @@ class WorkflowService:
|
||||||
break
|
break
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def generate_script_for_workflow(
|
||||||
|
self,
|
||||||
|
workflow: Workflow,
|
||||||
|
workflow_run: WorkflowRun,
|
||||||
|
) -> None:
|
||||||
|
cache_key = workflow.cache_key
|
||||||
|
rendered_cache_key_value = ""
|
||||||
|
# 1) Build cache_key_value from workflow run parameters via jinja
|
||||||
|
if cache_key:
|
||||||
|
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||||
|
workflow_run_id=workflow_run.workflow_run_id
|
||||||
|
)
|
||||||
|
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||||
|
jinja_sandbox_env = SandboxedEnvironment()
|
||||||
|
try:
|
||||||
|
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to render cache key; skip script generation", exc_info=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2) Check existing cached scripts for this workflow + cache_key_value
|
||||||
|
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||||
|
cache_key_value=rendered_cache_key_value,
|
||||||
|
)
|
||||||
|
if existing_scripts:
|
||||||
|
LOG.info(
|
||||||
|
"Found cached script for workflow",
|
||||||
|
workflow_id=workflow.workflow_id,
|
||||||
|
cache_key_value=rendered_cache_key_value,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3) Generate script code from workflow run
|
||||||
|
try:
|
||||||
|
codegen_input = await transform_workflow_run_to_code_gen_input(
|
||||||
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
)
|
||||||
|
python_src = generate_python_workflow_script(
|
||||||
|
file_name=codegen_input.file_name,
|
||||||
|
workflow_run_request=codegen_input.workflow_run,
|
||||||
|
workflow=codegen_input.workflow,
|
||||||
|
tasks=codegen_input.workflow_blocks,
|
||||||
|
actions_by_task=codegen_input.actions_by_task,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.error("Failed to generate workflow script source", exc_info=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4) Persist script and files, then record mapping
|
||||||
|
content_bytes = python_src.encode("utf-8")
|
||||||
|
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
|
||||||
|
files = [
|
||||||
|
ScriptFileCreate(
|
||||||
|
path="main.py",
|
||||||
|
content=content_b64,
|
||||||
|
encoding=FileEncoding.BASE64,
|
||||||
|
mime_type="text/x-python",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
created = await app.DATABASE.create_script(
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
run_id=workflow_run.workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload script file(s) as artifacts and create rows
|
||||||
|
await script_service.build_file_tree(
|
||||||
|
files=files,
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
script_id=created.script_id,
|
||||||
|
script_version=created.version,
|
||||||
|
script_revision_id=created.script_revision_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record the workflow->script mapping for cache lookup
|
||||||
|
await app.DATABASE.create_workflow_script(
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
script_id=created.script_id,
|
||||||
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||||
|
cache_key=cache_key or "",
|
||||||
|
cache_key_value=rendered_cache_key_value,
|
||||||
|
workflow_id=workflow.workflow_id,
|
||||||
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue