create and update workflow run block inworkflow execution (#1419)

This commit is contained in:
Shuchang Zheng 2024-12-22 11:16:23 -08:00 committed by GitHub
parent 8b75586fb1
commit b256bace6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 279 additions and 54 deletions

View file

@ -128,13 +128,27 @@ class Block(BaseModel, abc.ABC):
workflow_run_id=workflow_run_id,
)
def build_block_result(
async def build_block_result(
self,
success: bool,
failure_reason: str | None,
output_parameter_value: dict[str, Any] | list | str | None = None,
status: BlockStatus | None = None,
workflow_run_block_id: str | None = None,
organization_id: str | None = None,
) -> BlockResult:
# TODO: update workflow run block status and failure reason
if isinstance(output_parameter_value, str):
output_parameter_value = {"value": output_parameter_value}
if workflow_run_block_id:
await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
output=output_parameter_value,
status=status,
failure_reason=failure_reason,
organization_id=organization_id,
)
return BlockResult(
success=success,
failure_reason=failure_reason,
@ -167,12 +181,30 @@ class Block(BaseModel, abc.ABC):
return app.WORKFLOW_CONTEXT_MANAGER.aws_client
@abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
pass
async def execute_safe(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute_safe(
self,
workflow_run_id: str,
parent_workflow_run_block_id: str | None = None,
organization_id: str | None = None,
**kwargs: dict,
) -> BlockResult:
workflow_run_block_id = None
try:
return await self.execute(workflow_run_id, **kwargs)
workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
parent_workflow_run_block_id=parent_workflow_run_block_id,
label=self.label,
block_type=self.block_type,
continue_on_failure=self.continue_on_failure,
)
workflow_run_block_id = workflow_run_block.workflow_run_block_id
return await self.execute(workflow_run_id, workflow_run_block_id, organization_id=organization_id, **kwargs)
except Exception as e:
LOG.exception(
"Block execution failed",
@ -188,7 +220,14 @@ class Block(BaseModel, abc.ABC):
failure_reason = "unexpected exception"
if isinstance(e, SkyvernException):
failure_reason = f"unexpected SkyvernException({e.__class__.__name__})"
return self.build_block_result(success=False, failure_reason=failure_reason, status=BlockStatus.failed)
return await self.build_block_result(
success=False,
failure_reason=failure_reason,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@abc.abstractmethod
def get_all_parameters(
@ -304,7 +343,9 @@ class BaseTaskBlock(Block):
return order, retry + 1
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
@ -350,11 +391,13 @@ class BaseTaskBlock(Block):
try:
self.format_potential_template_parameters(workflow_run_context=workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
# TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and
@ -370,6 +413,11 @@ class BaseTaskBlock(Block):
task_order=task_order,
task_retry=task_retry,
)
await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
task_id=task.task_id,
organization_id=workflow.organization_id,
)
current_running_task = task
organization = await app.DATABASE.get_organization(organization_id=workflow.organization_id)
if not organization:
@ -475,11 +523,13 @@ class BaseTaskBlock(Block):
task_output = TaskOutput.from_task(updated_task)
output_parameter_value = task_output.model_dump()
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value)
return self.build_block_result(
return await self.build_block_result(
success=success,
failure_reason=updated_task.failure_reason,
output_parameter_value=output_parameter_value,
status=block_status_mapping[updated_task.status],
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
elif updated_task.status == TaskStatus.canceled:
LOG.info(
@ -490,11 +540,13 @@ class BaseTaskBlock(Block):
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
)
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=updated_task.failure_reason,
output_parameter_value=None,
status=block_status_mapping[updated_task.status],
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
else:
current_retry += 1
@ -517,18 +569,22 @@ class BaseTaskBlock(Block):
await self.record_output_parameter_value(
workflow_run_context, workflow_run_id, output_parameter_value
)
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=updated_task.failure_reason,
output_parameter_value=output_parameter_value,
status=block_status_mapping[updated_task.status],
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
await self.record_output_parameter_value(workflow_run_context, workflow_run_id)
return self.build_block_result(
return await self.build_block_result(
success=False,
status=BlockStatus.failed,
failure_reason=current_running_task.failure_reason if current_running_task else None,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -675,7 +731,12 @@ class ForLoopBlock(Block):
return [parameter_value]
async def execute_loop_helper(
self, workflow_run_id: str, workflow_run_context: WorkflowRunContext, loop_over_values: list[Any]
self,
workflow_run_id: str,
workflow_run_block_id: str,
workflow_run_context: WorkflowRunContext,
loop_over_values: list[Any],
organization_id: str | None = None,
) -> LoopBlockExecutedResult:
outputs_with_loop_values: list[list[dict[str, Any]]] = []
block_outputs: list[BlockResult] = []
@ -698,7 +759,11 @@ class ForLoopBlock(Block):
loop_block = loop_block.copy()
current_block = loop_block
block_output = await loop_block.execute_safe(workflow_run_id=workflow_run_id)
block_output = await loop_block.execute_safe(
workflow_run_id=workflow_run_id,
parent_workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
each_loop_output_values.append(
{
"loop_value": loop_over_value,
@ -747,15 +812,19 @@ class ForLoopBlock(Block):
last_block=current_block,
)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
try:
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"failed to get loop values: {str(e)}",
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
LOG.info(
@ -772,10 +841,12 @@ class ForLoopBlock(Block):
num_loop_over_values=len(loop_over_values),
)
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, [])
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason="No iterable value found for the loop block",
status=BlockStatus.terminated,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
if not self.loop_blocks or len(self.loop_blocks) == 0:
@ -786,14 +857,20 @@ class ForLoopBlock(Block):
num_loop_blocks=len(self.loop_blocks),
)
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, [])
return self.build_block_result(
success=False, failure_reason="No defined blocks to loop", status=BlockStatus.terminated
return await self.build_block_result(
success=False,
failure_reason="No defined blocks to loop",
status=BlockStatus.terminated,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
loop_executed_result = await self.execute_loop_helper(
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
workflow_run_context=workflow_run_context,
loop_over_values=loop_over_values,
organization_id=organization_id,
)
await self.record_output_parameter_value(
workflow_run_context, workflow_run_id, loop_executed_result.outputs_with_loop_values
@ -811,11 +888,13 @@ class ForLoopBlock(Block):
else:
block_status = BlockStatus.failed
return self.build_block_result(
return await self.build_block_result(
success=success,
failure_reason=loop_executed_result.get_failure_reason(),
output_parameter_value=loop_executed_result.outputs_with_loop_values,
status=block_status,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -834,18 +913,22 @@ class CodeBlock(Block):
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
self.code = self.format_block_parameter_template_from_workflow_run_context(self.code, workflow_run_context)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
raise DisabledBlockExecutionError("CodeBlock is disabled")
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
# get all parameters into a dictionary
@ -887,7 +970,13 @@ async def user_code():
result = {"result": result_container.get("result")}
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result)
return self.build_block_result(success=True, output_parameter_value=result, status=BlockStatus.completed)
return await self.build_block_result(
success=True,
output_parameter_value=result,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
DEFAULT_TEXT_PROMPT_LLM_KEY = settings.SECONDARY_LLM_KEY or settings.LLM_KEY
@ -944,17 +1033,21 @@ class TextPromptBlock(Block):
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
# get all parameters into a dictionary
parameter_values = {}
@ -968,8 +1061,13 @@ class TextPromptBlock(Block):
response = await self.send_prompt(self.prompt, parameter_values)
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response)
return self.build_block_result(
success=True, failure_reason=None, output_parameter_value=response, status=BlockStatus.completed
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=response,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -1000,7 +1098,9 @@ class DownloadToS3Block(Block):
# Clean up the temporary file since it's created with delete=False
os.unlink(file_path)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
@ -1017,11 +1117,13 @@ class DownloadToS3Block(Block):
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
try:
@ -1040,8 +1142,13 @@ class DownloadToS3Block(Block):
LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri)
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, uri)
return self.build_block_result(
success=True, failure_reason=None, output_parameter_value=uri, status=BlockStatus.completed
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=uri,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -1072,7 +1179,9 @@ class UploadToS3Block(Block):
s3_key = f"{settings.ENV}/{workflow_run_id}/{uuid.uuid4()}_{Path(path).name}"
return f"s3://{s3_bucket}/{s3_key}"
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
@ -1092,11 +1201,13 @@ class UploadToS3Block(Block):
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
if not self.path or not os.path.exists(self.path):
@ -1130,8 +1241,13 @@ class UploadToS3Block(Block):
LOG.info("UploadToS3Block: File(s) uploaded to S3", file_path=self.path)
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, s3_uris)
return self.build_block_result(
success=True, failure_reason=None, output_parameter_value=s3_uris, status=BlockStatus.completed
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=s3_uris,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -1386,16 +1502,20 @@ class SendEmailBlock(Block):
return msg
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
smtp_host_value, smtp_port_value, smtp_username_value, smtp_password_value = self._decrypt_smtp_parameters(
workflow_run_context
@ -1415,8 +1535,13 @@ class SendEmailBlock(Block):
LOG.error("SendEmailBlock: Failed to send email", exc_info=True)
result_dict = {"success": False, "error": str(e)}
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict)
return self.build_block_result(
success=False, failure_reason=str(e), output_parameter_value=result_dict, status=BlockStatus.failed
return await self.build_block_result(
success=False,
failure_reason=str(e),
output_parameter_value=result_dict,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
finally:
if smtp_host:
@ -1424,8 +1549,13 @@ class SendEmailBlock(Block):
result_dict = {"success": True}
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict)
return self.build_block_result(
success=True, failure_reason=None, output_parameter_value=result_dict, status=BlockStatus.completed
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=result_dict,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -1461,7 +1591,9 @@ class FileParserBlock(Block):
except csv.Error as e:
raise InvalidFileType(file_url=file_url_used, file_type=self.file_type, error=str(e))
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
if (
self.file_url
@ -1480,11 +1612,13 @@ class FileParserBlock(Block):
try:
self.format_potential_template_parameters(workflow_run_context)
except Exception as e:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason=f"Failed to format jinja template: {str(e)}",
output_parameter_value=None,
status=BlockStatus.failed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
# Download the file
@ -1503,8 +1637,13 @@ class FileParserBlock(Block):
parsed_data.append(row)
# Record the parsed data
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, parsed_data)
return self.build_block_result(
success=True, failure_reason=None, output_parameter_value=parsed_data, status=BlockStatus.completed
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=parsed_data,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -1520,7 +1659,9 @@ class WaitBlock(Block):
) -> list[PARAMETER_TYPE]:
return self.parameters
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
# TODO: we need to support to interrupt the sleep when the workflow run failed/cancelled/terminated
LOG.info(
"Going to pause the workflow for a while",
@ -1531,8 +1672,13 @@ class WaitBlock(Block):
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
result_dict = {"success": True}
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict)
return self.build_block_result(
success=True, failure_reason=None, output_parameter_value=result_dict, status=BlockStatus.completed
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=result_dict,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
@ -1545,18 +1691,27 @@ class ValidationBlock(BaseTaskBlock):
) -> list[PARAMETER_TYPE]:
return self.parameters
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
async def execute(
self, workflow_run_id: str, workflow_run_block_id: str, organization_id: str | None = None, **kwargs: dict
) -> BlockResult:
task_order, _ = await self.get_task_order(workflow_run_id, 0)
is_first_task = task_order == 0
if is_first_task:
return self.build_block_result(
return await self.build_block_result(
success=False,
failure_reason="Validation block should not be the first block",
output_parameter_value=None,
status=BlockStatus.terminated,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
return await super().execute(workflow_run_id=workflow_run_id, kwargs=kwargs)
return await super().execute(
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
kwargs=kwargs,
)
class ActionBlock(BaseTaskBlock):