workflow apis (#326)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-05-16 10:51:22 -07:00 committed by GitHub
parent 50026f33c2
commit 72d25cd37d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 364 additions and 19 deletions

View file

@ -19,7 +19,7 @@ from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined,
WorkflowDefinitionHasDuplicateParameterKeys,
@ -89,6 +89,10 @@ class WorkflowService:
if workflow.organization_id != organization_id:
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
if workflow_request.proxy_location is None and workflow.proxy_location is not None:
workflow_request.proxy_location = workflow.proxy_location
if workflow_request.webhook_callback_url is None and workflow.webhook_callback_url is not None:
workflow_request.webhook_callback_url = workflow.webhook_callback_url
# Create the workflow run and set skyvern context
workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id)
LOG.info(
@ -97,6 +101,7 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url,
)
skyvern_context.set(
SkyvernContext(
@ -266,20 +271,58 @@ class WorkflowService:
title: str,
workflow_definition: WorkflowDefinition,
description: str | None = None,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> Workflow:
return await app.DATABASE.create_workflow(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump(),
organization_id=organization_id,
description=description,
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
workflow_permanent_id=workflow_permanent_id,
version=version,
)
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if not workflow:
raise WorkflowNotFound(workflow_id)
raise WorkflowNotFound(workflow_id=workflow_id)
return workflow
async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
) -> Workflow:
workflow = await app.DATABASE.get_workflow_by_permanent_id(
workflow_permanent_id,
organization_id=organization_id,
version=version,
)
if not workflow:
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version)
return workflow
async def get_workflows_by_organization_id(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
return await app.DATABASE.get_workflows_by_organization_id(
organization_id=organization_id,
page=page,
page_size=page_size,
)
async def update_workflow(
self,
workflow_id: str,
@ -290,14 +333,25 @@ class WorkflowService:
) -> Workflow:
if workflow_definition:
workflow_definition.validate()
return await app.DATABASE.update_workflow(
workflow_id=workflow_id,
organization_id=organization_id,
title=title,
organization_id=organization_id,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
)
async def delete_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> None:
await app.DATABASE.soft_delete_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
)
async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun:
return await app.DATABASE.create_workflow_run(
workflow_id=workflow_id,
@ -669,15 +723,39 @@ class WorkflowService:
await self.persist_har_data(browser_state, last_step, workflow, workflow_run)
await self.persist_tracing_data(browser_state, last_step, workflow_run)
async def create_workflow_from_request(self, organization_id: str, request: WorkflowCreateYAMLRequest) -> Workflow:
async def create_workflow_from_request(
self,
organization_id: str,
request: WorkflowCreateYAMLRequest,
workflow_permanent_id: str | None = None,
) -> Workflow:
LOG.info("Creating workflow from request", organization_id=organization_id, title=request.title)
try:
workflow = await self.create_workflow(
organization_id=organization_id,
title=request.title,
description=request.description,
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
)
if workflow_permanent_id:
existing_latest_workflow = await self.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
)
existing_version = existing_latest_workflow.version
workflow = await self.create_workflow(
title=request.title,
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
description=request.description,
organization_id=organization_id,
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
workflow_permanent_id=workflow_permanent_id,
version=existing_version + 1,
)
else:
workflow = await self.create_workflow(
title=request.title,
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
description=request.description,
organization_id=organization_id,
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
)
# Create parameters from the request
parameters: dict[str, PARAMETER_TYPE] = {}
duplicate_parameter_keys = set()