workflow apis (#326)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-05-16 10:51:22 -07:00 committed by GitHub
parent 50026f33c2
commit 72d25cd37d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 364 additions and 19 deletions

View file

@ -0,0 +1,49 @@
"""add proxy_location and webhook_callback_url to workflows table
Revision ID: 04bf06540db6
Revises: baec12642d77
Create Date: 2024-05-16 17:29:55.083124+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "04bf06540db6"
down_revision: Union[str, None] = "baec12642d77"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"workflows",
sa.Column(
"proxy_location",
sa.Enum(
"US_CA",
"US_NY",
"US_TX",
"US_FL",
"US_WA",
"RESIDENTIAL",
"RESIDENTIAL_ES",
"NONE",
name="proxylocation",
),
nullable=True,
),
)
op.add_column("workflows", sa.Column("webhook_callback_url", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflows", "webhook_callback_url")
op.drop_column("workflows", "proxy_location")
# ### end Alembic commands ###

View file

@ -107,8 +107,22 @@ class UnknownBlockType(SkyvernException):
class WorkflowNotFound(SkyvernHTTPException): class WorkflowNotFound(SkyvernHTTPException):
def __init__(self, workflow_id: str) -> None: def __init__(
super().__init__(f"Workflow {workflow_id} not found", status_code=status.HTTP_404_NOT_FOUND) self,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> None:
workflow_repr = ""
if workflow_id:
workflow_repr = f"workflow_id={workflow_id}"
if workflow_permanent_id:
if version:
workflow_repr = f"workflow_permanent_id={workflow_permanent_id}, version={version}"
else:
workflow_repr = f"workflow_permanent_id={workflow_permanent_id}"
super().__init__(f"Workflow not found. {workflow_repr}", status_code=status.HTTP_404_NOT_FOUND)
class WorkflowRunNotFound(SkyvernException): class WorkflowRunNotFound(SkyvernException):

View file

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Any, Sequence from typing import Any, Sequence
import structlog import structlog
from sqlalchemy import and_, delete, select from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
@ -723,10 +723,14 @@ class AgentDB:
async def create_workflow( async def create_workflow(
self, self,
organization_id: str,
title: str, title: str,
workflow_definition: dict[str, Any], workflow_definition: dict[str, Any],
organization_id: str | None = None,
description: str | None = None, description: str | None = None,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> Workflow: ) -> Workflow:
async with self.Session() as session: async with self.Session() as session:
workflow = WorkflowModel( workflow = WorkflowModel(
@ -734,7 +738,13 @@ class AgentDB:
title=title, title=title,
description=description, description=description,
workflow_definition=workflow_definition, workflow_definition=workflow_definition,
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
) )
if workflow_permanent_id:
workflow.workflow_permanent_id = workflow_permanent_id
if version:
workflow.version = version
session.add(workflow) session.add(workflow)
await session.commit() await session.commit()
await session.refresh(workflow) await session.refresh(workflow)
@ -743,7 +753,9 @@ class AgentDB:
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None: async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
try: try:
async with self.Session() as session: async with self.Session() as session:
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id) get_workflow_query = (
select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id: if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first(): if workflow := (await session.scalars(get_workflow_query)).first():
@ -753,6 +765,74 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True) LOG.error("SQLAlchemyError", exc_info=True)
raise raise
async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
) -> Workflow | None:
try:
get_workflow_query = (
select(WorkflowModel)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if version:
get_workflow_query = get_workflow_query.filter_by(version=version)
get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc())
async with self.Session() as session:
if workflow := (await session.scalars(get_workflow_query)).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_workflows_by_organization_id(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
try:
async with self.Session() as session:
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(WorkflowModel.organization_id, WorkflowModel.workflow_permanent_id)
.subquery()
)
main_query = (
select(WorkflowModel)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.order_by(WorkflowModel.created_at.desc()) # Example ordering by creation date
.limit(page_size)
.offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def update_workflow( async def update_workflow(
self, self,
workflow_id: str, workflow_id: str,
@ -760,10 +840,13 @@ class AgentDB:
title: str | None = None, title: str | None = None,
description: str | None = None, description: str | None = None,
workflow_definition: dict[str, Any] | None = None, workflow_definition: dict[str, Any] | None = None,
version: int | None = None,
) -> Workflow: ) -> Workflow:
try: try:
async with self.Session() as session: async with self.Session() as session:
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id) get_workflow_query = (
select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id: if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first(): if workflow := (await session.scalars(get_workflow_query)).first():
@ -773,6 +856,8 @@ class AgentDB:
workflow.description = description workflow.description = description
if workflow_definition: if workflow_definition:
workflow.workflow_definition = workflow_definition workflow.workflow_definition = workflow_definition
if version:
workflow.version = version
await session.commit() await session.commit()
await session.refresh(workflow) await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled) return convert_to_workflow(workflow, self.debug_enabled)
@ -789,8 +874,29 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True) LOG.error("UnexpectedError", exc_info=True)
raise raise
async def soft_delete_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> None:
async with self.Session() as session:
# soft delete the workflow by setting the deleted_at field
update_deleted_at_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
update_deleted_at_query = update_deleted_at_query.filter_by(organization_id=organization_id)
update_deleted_at_query = update_deleted_at_query.values(deleted_at=datetime.utcnow())
await session.execute(update_deleted_at_query)
await session.commit()
async def create_workflow_run( async def create_workflow_run(
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None self,
workflow_id: str,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
) -> WorkflowRun: ) -> WorkflowRun:
try: try:
async with self.Session() as session: async with self.Session() as session:

View file

@ -145,6 +145,8 @@ class WorkflowModel(Base):
title = Column(String, nullable=False) title = Column(String, nullable=False)
description = Column(String, nullable=True) description = Column(String, nullable=True)
workflow_definition = Column(JSON, nullable=False) workflow_definition = Column(JSON, nullable=False)
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View file

@ -148,6 +148,10 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
workflow_id=workflow_model.workflow_id, workflow_id=workflow_model.workflow_id,
organization_id=workflow_model.organization_id, organization_id=workflow_model.organization_id,
title=workflow_model.title, title=workflow_model.title,
workflow_permanent_id=workflow_model.workflow_permanent_id,
webhook_callback_url=workflow_model.webhook_callback_url,
proxy_location=ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None,
version=workflow_model.version,
description=workflow_model.description, description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition), workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
created_at=workflow_model.created_at, created_at=workflow_model.created_at,

View file

@ -532,3 +532,88 @@ async def create_workflow(
return await app.WORKFLOW_SERVICE.create_workflow_from_request( return await app.WORKFLOW_SERVICE.create_workflow_from_request(
organization_id=current_org.organization_id, request=workflow_create_request organization_id=current_org.organization_id, request=workflow_create_request
) )
@base_router.put(
"/workflows/{workflow_permanent_id}",
openapi_extra={
"requestBody": {
"content": {"application/x-yaml": {"schema": WorkflowCreateYAMLRequest.model_json_schema()}},
"required": True,
},
},
response_model=Workflow,
)
@base_router.put(
"/workflows/{workflow_permanent_id}/",
openapi_extra={
"requestBody": {
"content": {"application/x-yaml": {"schema": WorkflowCreateYAMLRequest.model_json_schema()}},
"required": True,
},
},
response_model=Workflow,
include_in_schema=False,
)
async def update_workflow(
workflow_permanent_id: str,
request: Request,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Workflow:
analytics.capture("skyvern-oss-agent-workflow-update")
# validate the workflow
raw_yaml = await request.body()
try:
workflow_yaml = yaml.safe_load(raw_yaml)
except yaml.YAMLError:
raise HTTPException(status_code=422, detail="Invalid YAML")
workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml)
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
organization_id=current_org.organization_id,
request=workflow_create_request,
workflow_permanent_id=workflow_permanent_id,
)
@base_router.delete("/workflows/{workflow_permanent_id}")
@base_router.delete("/workflows/{workflow_permanent_id}/", include_in_schema=False)
async def delete_workflow(
workflow_permanent_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> None:
analytics.capture("skyvern-oss-agent-workflow-delete")
await app.WORKFLOW_SERVICE.delete_workflow_by_permanent_id(workflow_permanent_id, current_org.organization_id)
@base_router.get("/workflows", response_model=list[Workflow])
@base_router.get("/workflows/", response_model=list[Workflow])
async def get_workflows(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
analytics.capture("skyvern-oss-agent-workflows-get")
return await app.WORKFLOW_SERVICE.get_workflows_by_organization_id(
organization_id=current_org.organization_id,
page=page,
page_size=page_size,
)
@base_router.get("/workflows/{workflow_permanent_id}", response_model=Workflow)
@base_router.get("/workflows/{workflow_permanent_id}/", response_model=Workflow)
async def get_workflow(
workflow_permanent_id: str,
version: int | None = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Workflow:
analytics.capture("skyvern-oss-agent-workflows-get")
return await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=current_org.organization_id,
version=version,
)

View file

@ -42,8 +42,12 @@ class Workflow(BaseModel):
workflow_id: str workflow_id: str
organization_id: str organization_id: str
title: str title: str
workflow_permanent_id: str
version: int
description: str | None = None description: str | None = None
workflow_definition: WorkflowDefinition workflow_definition: WorkflowDefinition
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime

View file

@ -3,6 +3,7 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.models.block import BlockType from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.forge.sdk.workflow.models.parameter import ParameterType, WorkflowParameterType from skyvern.forge.sdk.workflow.models.parameter import ParameterType, WorkflowParameterType
@ -187,4 +188,6 @@ class WorkflowDefinitionYAML(BaseModel):
class WorkflowCreateYAMLRequest(BaseModel): class WorkflowCreateYAMLRequest(BaseModel):
title: str title: str
description: str | None = None description: str | None = None
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
workflow_definition: WorkflowDefinitionYAML workflow_definition: WorkflowDefinitionYAML

View file

@ -19,7 +19,7 @@ from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined, ContextParameterSourceNotDefined,
WorkflowDefinitionHasDuplicateParameterKeys, WorkflowDefinitionHasDuplicateParameterKeys,
@ -89,6 +89,10 @@ class WorkflowService:
if workflow.organization_id != organization_id: if workflow.organization_id != organization_id:
LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}") LOG.error(f"Workflow {workflow_id} does not belong to organization {organization_id}")
raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id) raise WorkflowOrganizationMismatch(workflow_id=workflow_id, organization_id=organization_id)
if workflow_request.proxy_location is None and workflow.proxy_location is not None:
workflow_request.proxy_location = workflow.proxy_location
if workflow_request.webhook_callback_url is None and workflow.webhook_callback_url is not None:
workflow_request.webhook_callback_url = workflow.webhook_callback_url
# Create the workflow run and set skyvern context # Create the workflow run and set skyvern context
workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id) workflow_run = await self.create_workflow_run(workflow_request=workflow_request, workflow_id=workflow_id)
LOG.info( LOG.info(
@ -97,6 +101,7 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id, workflow_id=workflow.workflow_id,
proxy_location=workflow_request.proxy_location, proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url,
) )
skyvern_context.set( skyvern_context.set(
SkyvernContext( SkyvernContext(
@ -266,20 +271,58 @@ class WorkflowService:
title: str, title: str,
workflow_definition: WorkflowDefinition, workflow_definition: WorkflowDefinition,
description: str | None = None, description: str | None = None,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> Workflow: ) -> Workflow:
return await app.DATABASE.create_workflow( return await app.DATABASE.create_workflow(
organization_id=organization_id,
title=title, title=title,
description=description,
workflow_definition=workflow_definition.model_dump(), workflow_definition=workflow_definition.model_dump(),
organization_id=organization_id,
description=description,
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
workflow_permanent_id=workflow_permanent_id,
version=version,
) )
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow: async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id) workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if not workflow: if not workflow:
raise WorkflowNotFound(workflow_id) raise WorkflowNotFound(workflow_id=workflow_id)
return workflow return workflow
async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
) -> Workflow:
workflow = await app.DATABASE.get_workflow_by_permanent_id(
workflow_permanent_id,
organization_id=organization_id,
version=version,
)
if not workflow:
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version)
return workflow
async def get_workflows_by_organization_id(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
return await app.DATABASE.get_workflows_by_organization_id(
organization_id=organization_id,
page=page,
page_size=page_size,
)
async def update_workflow( async def update_workflow(
self, self,
workflow_id: str, workflow_id: str,
@ -290,14 +333,25 @@ class WorkflowService:
) -> Workflow: ) -> Workflow:
if workflow_definition: if workflow_definition:
workflow_definition.validate() workflow_definition.validate()
return await app.DATABASE.update_workflow( return await app.DATABASE.update_workflow(
workflow_id=workflow_id, workflow_id=workflow_id,
organization_id=organization_id,
title=title, title=title,
organization_id=organization_id,
description=description, description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None, workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
) )
async def delete_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> None:
await app.DATABASE.soft_delete_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
)
async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun: async def create_workflow_run(self, workflow_request: WorkflowRequestBody, workflow_id: str) -> WorkflowRun:
return await app.DATABASE.create_workflow_run( return await app.DATABASE.create_workflow_run(
workflow_id=workflow_id, workflow_id=workflow_id,
@ -669,15 +723,39 @@ class WorkflowService:
await self.persist_har_data(browser_state, last_step, workflow, workflow_run) await self.persist_har_data(browser_state, last_step, workflow, workflow_run)
await self.persist_tracing_data(browser_state, last_step, workflow_run) await self.persist_tracing_data(browser_state, last_step, workflow_run)
async def create_workflow_from_request(self, organization_id: str, request: WorkflowCreateYAMLRequest) -> Workflow: async def create_workflow_from_request(
self,
organization_id: str,
request: WorkflowCreateYAMLRequest,
workflow_permanent_id: str | None = None,
) -> Workflow:
LOG.info("Creating workflow from request", organization_id=organization_id, title=request.title) LOG.info("Creating workflow from request", organization_id=organization_id, title=request.title)
try: try:
workflow = await self.create_workflow( if workflow_permanent_id:
organization_id=organization_id, existing_latest_workflow = await self.get_workflow_by_permanent_id(
title=request.title, workflow_permanent_id=workflow_permanent_id,
description=request.description, organization_id=organization_id,
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]), )
) existing_version = existing_latest_workflow.version
workflow = await self.create_workflow(
title=request.title,
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
description=request.description,
organization_id=organization_id,
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
workflow_permanent_id=workflow_permanent_id,
version=existing_version + 1,
)
else:
workflow = await self.create_workflow(
title=request.title,
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
description=request.description,
organization_id=organization_id,
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
)
# Create parameters from the request # Create parameters from the request
parameters: dict[str, PARAMETER_TYPE] = {} parameters: dict[str, PARAMETER_TYPE] = {}
duplicate_parameter_keys = set() duplicate_parameter_keys = set()