Implement cancel workflow run endpoint (#1188)

This commit is contained in:
Shuchang Zheng 2024-11-14 01:32:53 -08:00 committed by GitHub
parent d107c3d4db
commit 28d37545bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 4 deletions

View file

@ -44,7 +44,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import TaskBlock from skyvern.forge.sdk.workflow.models.block import TaskBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.webeye.actions.actions import ( from skyvern.webeye.actions.actions import (
Action, Action,
ActionType, ActionType,
@ -220,10 +220,29 @@ class ForgeAgent:
task: Task, task: Task,
step: Step, step: Step,
api_key: str | None = None, api_key: str | None = None,
workflow_run: WorkflowRun | None = None,
close_browser_on_completion: bool = True, close_browser_on_completion: bool = True,
task_block: TaskBlock | None = None, task_block: TaskBlock | None = None,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
workflow_run: WorkflowRun | None = None
if task.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=task.workflow_run_id)
if workflow_run and workflow_run.status == WorkflowRunStatus.canceled:
LOG.info(
"Workflow run is canceled, stopping execution inside task",
workflow_run_id=workflow_run.workflow_run_id,
step_id=step.step_id,
)
step = await self.update_step(
step,
status=StepStatus.canceled,
is_last=True,
)
task = await self.update_task(
task,
status=TaskStatus.canceled,
)
return step, None, None
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id) refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id)
if refreshed_task: if refreshed_task:
task = refreshed_task task = refreshed_task

View file

@ -345,6 +345,15 @@ async def cancel_task(
await app.agent.update_task(task_obj, status=TaskStatus.canceled) await app.agent.update_task(task_obj, status=TaskStatus.canceled)
@base_router.post("/workflows/runs/{workflow_run_id}/cancel")
@base_router.post("/workflows/runs/{workflow_run_id}/cancel/", include_in_schema=False)
async def cancel_workflow_run(
workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> None:
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
@base_router.post( @base_router.post(
"/tasks/{task_id}/retry_webhook", "/tasks/{task_id}/retry_webhook",
tags=["agent"], tags=["agent"],

View file

@ -335,7 +335,6 @@ class TaskBlock(Block):
organization=organization, organization=organization,
task=task, task=task,
step=step, step=step,
workflow_run=workflow_run,
task_block=self, task_block=self,
) )
except Exception as e: except Exception as e:

View file

@ -188,6 +188,24 @@ class WorkflowService:
for block_idx, block in enumerate(blocks): for block_idx, block in enumerate(blocks):
is_last_block = block_idx + 1 == blocks_cnt is_last_block = block_idx + 1 == blocks_cnt
try: try:
refreshed_workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run.workflow_run_id
)
if refreshed_workflow_run and refreshed_workflow_run.status == WorkflowRunStatus.canceled:
LOG.info(
"Workflow run is canceled, stopping execution inside workflow execution loop",
workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
block_type=block.block_type,
block_label=block.label,
)
await self.clean_up_workflow(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
need_call_webhook=False,
)
return workflow_run
parameters = block.get_all_parameters(workflow_run_id) parameters = block.get_all_parameters(workflow_run_id)
await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run( await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run(
workflow_run_id, parameters, organization workflow_run_id, parameters, organization
@ -197,6 +215,8 @@ class WorkflowService:
block_type=block.block_type, block_type=block.block_type,
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx, block_idx=block_idx,
block_type_var=block.block_type,
block_label=block.label,
) )
block_result = await block.execute_safe(workflow_run_id=workflow_run_id) block_result = await block.execute_safe(workflow_run_id=workflow_run_id)
if block_result.status == BlockStatus.canceled: if block_result.status == BlockStatus.canceled:
@ -206,6 +226,8 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx, block_idx=block_idx,
block_result=block_result, block_result=block_result,
block_type_var=block.block_type,
block_label=block.label,
) )
await self.mark_workflow_run_as_canceled(workflow_run_id=workflow_run.workflow_run_id) await self.mark_workflow_run_as_canceled(workflow_run_id=workflow_run.workflow_run_id)
# We're not sending a webhook here because the workflow run is manually marked as canceled. # We're not sending a webhook here because the workflow run is manually marked as canceled.
@ -223,6 +245,8 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx, block_idx=block_idx,
block_result=block_result, block_result=block_result,
block_type_var=block.block_type,
block_label=block.label,
) )
if block.continue_on_failure and not is_last_block: if block.continue_on_failure and not is_last_block:
LOG.warning( LOG.warning(
@ -232,6 +256,8 @@ class WorkflowService:
block_idx=block_idx, block_idx=block_idx,
block_result=block_result, block_result=block_result,
continue_on_failure=block.continue_on_failure, continue_on_failure=block.continue_on_failure,
block_type_var=block.block_type,
block_label=block.label,
) )
else: else:
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
@ -248,6 +274,8 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx, block_idx=block_idx,
block_result=block_result, block_result=block_result,
block_type_var=block.block_type,
block_label=block.label,
) )
if block.continue_on_failure and not is_last_block: if block.continue_on_failure and not is_last_block:
LOG.warning( LOG.warning(
@ -257,6 +285,8 @@ class WorkflowService:
block_idx=block_idx, block_idx=block_idx,
block_result=block_result, block_result=block_result,
continue_on_failure=block.continue_on_failure, continue_on_failure=block.continue_on_failure,
block_type_var=block.block_type,
block_label=block.label,
) )
else: else:
await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id) await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id)
@ -270,12 +300,27 @@ class WorkflowService:
LOG.exception( LOG.exception(
f"Error while executing workflow run {workflow_run.workflow_run_id}", f"Error while executing workflow run {workflow_run.workflow_run_id}",
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
block_idx=block_idx,
block_type=block.block_type,
block_label=block.label,
) )
await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id)
await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key) await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
return workflow_run return workflow_run
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id) refreshed_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
if refreshed_workflow_run and refreshed_workflow_run.status not in (
WorkflowRunStatus.canceled,
WorkflowRunStatus.failed,
WorkflowRunStatus.terminated,
):
await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id)
else:
LOG.info(
"Workflow run is already canceled, failed, or terminated, not marking as completed",
workflow_run_id=workflow_run.workflow_run_id,
workflow_run_status=refreshed_workflow_run.status if refreshed_workflow_run else None,
)
await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key) await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
return workflow_run return workflow_run