Workflow: Output Parameters & Code Blocks (#117)

This commit is contained in:
Kerem Yilmaz 2024-03-21 17:16:56 -07:00 committed by GitHub
parent d2ca6ca792
commit 066c2302b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 556 additions and 44 deletions

View file

@ -14,7 +14,12 @@ from skyvern.exceptions import (
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
ContextParameter,
OutputParameter,
WorkflowParameter,
)
LOG = structlog.get_logger()
@ -22,12 +27,14 @@ LOG = structlog.get_logger()
class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
CODE = "code"
class Block(BaseModel, abc.ABC):
# Must be unique within workflow definition
label: str
block_type: BlockType
parent_block_id: str | None = None
next_block_id: str | None = None
output_parameter: OutputParameter | None = None
@classmethod
def get_subclasses(cls) -> tuple[type["Block"], ...]:
@ -38,7 +45,7 @@ class Block(BaseModel, abc.ABC):
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
@abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
pass
@abc.abstractmethod
@ -96,7 +103,7 @@ class TaskBlock(Block):
return order, retry + 1
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
@ -158,6 +165,32 @@ class TaskBlock(Block):
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
if updated_task.status == TaskStatus.completed:
will_retry = False
LOG.info(
f"Task completed",
task_id=updated_task.task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
organization_id=workflow.organization_id,
)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=updated_task.extracted_information,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=updated_task.extracted_information,
)
LOG.info(
f"Registered output parameter value",
output_parameter_id=self.output_parameter.output_parameter_id,
value=updated_task.extracted_information,
workflow_run_id=workflow_run_id,
workflow_id=workflow.workflow_id,
task_id=updated_task.task_id,
)
return self.output_parameter
else:
current_retry += 1
will_retry = current_retry <= self.max_retries
@ -172,6 +205,7 @@ class TaskBlock(Block):
current_retry=current_retry,
max_retries=self.max_retries,
)
return None
class ForLoopBlock(Block):
@ -216,9 +250,10 @@ class ForLoopBlock(Block):
return [parameter_value]
else:
# TODO (kerem): Implement this for context parameters
# TODO (kerem): Implement this for output parameters
raise NotImplementedError
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
LOG.info(
@ -227,14 +262,77 @@ class ForLoopBlock(Block):
workflow_run_id=workflow_run_id,
num_loop_over_values=len(loop_over_values),
)
outputs_with_loop_values = []
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id)
if self.loop_block.output_parameter:
outputs_with_loop_values.append(
{
"loop_value": loop_over_value,
"output_parameter": self.loop_block.output_parameter,
"output_value": workflow_run_context.get_value(self.loop_block.output_parameter.key),
}
)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=outputs_with_loop_values,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=outputs_with_loop_values,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
class CodeBlock(Block):
block_type: Literal[BlockType.CODE] = BlockType.CODE
code: str
parameters: list[PARAMETER_TYPE] = []
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
value = workflow_run_context.get_value(parameter.key)
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
if secret_value is not None:
parameter_values[parameter.key] = secret_value
else:
parameter_values[parameter.key] = value
local_variables: dict[str, Any] = {}
exec(self.code, parameter_values, local_variables)
result = {"result": local_variables.get("result")}
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=result,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=result,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]