For loop block updates (#176)

This commit is contained in:
Kerem Yilmaz 2024-04-10 13:47:25 -07:00 committed by GitHub
parent 39d7d91938
commit 8c12e2bc20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 35 additions and 15 deletions

View file

@ -282,20 +282,25 @@ class ForLoopBlock(Block):
# TODO (kerem): Add support for ContextParameter # TODO (kerem): Add support for ContextParameter
loop_over: PARAMETER_TYPE loop_over: PARAMETER_TYPE
loop_block: "BlockTypeVar" loop_blocks: list["BlockTypeVar"]
def get_all_parameters( def get_all_parameters(
self, self,
workflow_run_id: str, workflow_run_id: str,
) -> list[PARAMETER_TYPE]: ) -> list[PARAMETER_TYPE]:
return self.loop_block.get_all_parameters(workflow_run_id) + [self.loop_over] parameters = {self.loop_over}
for loop_block in self.loop_blocks:
for parameter in loop_block.get_all_parameters(workflow_run_id):
parameters.add(parameter)
return list(parameters)
def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]: def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]:
if not isinstance(loop_data, dict): if not isinstance(loop_data, dict):
# TODO (kerem): Should we add support for other types? # TODO (kerem): Should we add support for other types?
raise ValueError("loop_data should be a dictionary") raise ValueError("loop_data should be a dict")
loop_block_parameters = self.loop_block.get_all_parameters(workflow_run_id) loop_block_parameters = self.get_all_parameters(workflow_run_id)
context_parameters = [ context_parameters = [
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter) parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
] ]
@ -332,28 +337,37 @@ class ForLoopBlock(Block):
num_loop_over_values=len(loop_over_values), num_loop_over_values=len(loop_over_values),
) )
outputs_with_loop_values = [] outputs_with_loop_values = []
block_outputs = []
for loop_over_value in loop_over_values: for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value) context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value: for context_parameter in context_parameters_with_value:
workflow_run_context.set_value(context_parameter.key, context_parameter.value) workflow_run_context.set_value(context_parameter.key, context_parameter.value)
try: try:
block_output = await self.loop_block.execute(workflow_run_id=workflow_run_id) block_outputs = [
block_outputs.append(block_output) await loop_block.execute(workflow_run_id=workflow_run_id) for loop_block in self.loop_blocks
]
except Exception as e: except Exception as e:
LOG.error("ForLoopBlock: Failed to execute loop block", exc_info=True) LOG.error("ForLoopBlock: Failed to execute loop block", exc_info=True)
raise e raise e
if block_output.output_parameter:
outputs_with_loop_values.append( outputs_with_loop_values.append(
[
{ {
"loop_value": loop_over_value, "loop_value": loop_over_value,
"output_parameter": block_output.output_parameter, "output_parameter": block_output.output_parameter,
"output_value": workflow_run_context.get_value(block_output.output_parameter.key), "output_value": workflow_run_context.get_value(block_output.output_parameter.key),
} }
for block_output in block_outputs
if block_output.output_parameter
]
) )
# If all block outputs are successful, the loop is successful # If all block outputs are successful, the loop is successful
success = all([block_output.success for block_output in block_outputs]) success = all([block_output.success for block_output in block_outputs])
if not success:
LOG.info(
"ForLoopBlock: Encountered an failure processing block, terminating early",
block_outputs=block_outputs,
)
break
if self.output_parameter: if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution( await workflow_run_context.register_output_parameter_value_post_execution(

View file

@ -21,6 +21,9 @@ class Parameter(BaseModel, abc.ABC):
key: str key: str
description: str | None = None description: str | None = None
def __hash__(self) -> int:
return hash(self.key)
@classmethod @classmethod
def get_subclasses(cls) -> tuple[type["Parameter"], ...]: def get_subclasses(cls) -> tuple[type["Parameter"], ...]:
return tuple(cls.__subclasses__()) return tuple(cls.__subclasses__())

View file

@ -95,7 +95,7 @@ class ForLoopBlockYAML(BlockYAML):
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP # type: ignore block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP # type: ignore
loop_over_parameter_key: str loop_over_parameter_key: str
loop_block: "BLOCK_YAML_SUBCLASSES" loop_blocks: list["BLOCK_YAML_SUBCLASSES"]
class CodeBlockYAML(BlockYAML): class CodeBlockYAML(BlockYAML):

View file

@ -760,12 +760,15 @@ class WorkflowService:
max_retries=block_yaml.max_retries, max_retries=block_yaml.max_retries,
) )
elif block_yaml.block_type == BlockType.FOR_LOOP: elif block_yaml.block_type == BlockType.FOR_LOOP:
loop_block = await WorkflowService.block_yaml_to_block(block_yaml.loop_block, parameters) loop_blocks = [
await WorkflowService.block_yaml_to_block(loop_block, parameters)
for loop_block in block_yaml.loop_blocks
]
loop_over_parameter = parameters[block_yaml.loop_over_parameter_key] loop_over_parameter = parameters[block_yaml.loop_over_parameter_key]
return ForLoopBlock( return ForLoopBlock(
label=block_yaml.label, label=block_yaml.label,
loop_over=loop_over_parameter, loop_over=loop_over_parameter,
loop_block=loop_block, loop_blocks=loop_blocks,
output_parameter=output_parameter, output_parameter=output_parameter,
) )
elif block_yaml.block_type == BlockType.CODE: elif block_yaml.block_type == BlockType.CODE: