Ykeremy/context parameter source parameters (#200)

This commit is contained in:
Kerem Yilmaz 2024-04-16 15:41:44 -07:00 committed by GitHub
parent 02cf2a1e87
commit 4a3e897dad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 159 additions and 27 deletions

View file

@ -20,7 +20,10 @@ from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys
from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined,
WorkflowDefinitionHasDuplicateParameterKeys,
)
from skyvern.forge.sdk.workflow.models.block import (
BlockResult,
BlockType,
@ -34,6 +37,7 @@ from skyvern.forge.sdk.workflow.models.block import (
UploadToS3Block,
)
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
AWSSecretParameter,
ContextParameter,
OutputParameter,
@ -145,11 +149,17 @@ class WorkflowService:
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
# Get all context parameters from the workflow definition
context_parameters = [
parameter
for parameter in workflow.workflow_definition.parameters
if isinstance(parameter, ContextParameter)
]
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(
workflow_run_id, wp_wps_tuples, workflow_output_parameters
workflow_run_id, wp_wps_tuples, workflow_output_parameters, context_parameters
)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
@ -649,10 +659,10 @@ class WorkflowService:
organization_id=organization_id,
title=request.title,
description=request.description,
workflow_definition=WorkflowDefinition(blocks=[]),
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
)
# Create parameters from the request
parameters = {}
parameters: dict[str, PARAMETER_TYPE] = {}
duplicate_parameter_keys = set()
# We're going to process context parameters after other parameters since they depend on the other parameters
@ -701,10 +711,23 @@ class WorkflowService:
# Now we can process the context parameters since all other parameters have been created
for context_parameter in context_parameter_yamls:
if context_parameter.source_parameter_key not in parameters:
raise ContextParameterSourceNotDefined(
context_parameter_key=context_parameter.key, source_key=context_parameter.source_parameter_key
)
if context_parameter.key in parameters:
LOG.error(f"Duplicate parameter key {context_parameter.key}")
duplicate_parameter_keys.add(context_parameter.key)
continue
# We're only adding the context parameter to the parameters dict, we're not creating it in the database
# It'll only be stored in the `workflow.workflow_definition`
# todo (kerem): should we have a database table for context parameters?
parameters[context_parameter.key] = ContextParameter(
key=context_parameter.key,
description=context_parameter.description,
source=parameters[context_parameter.source_workflow_parameter_key],
source=parameters[context_parameter.source_parameter_key],
# Context parameters don't have a default value, the value always depends on the source parameter
value=None,
)
@ -720,7 +743,7 @@ class WorkflowService:
block_label_mapping[block.label] = block
# Set the blocks for the workflow definition
workflow_definition = WorkflowDefinition(blocks=blocks)
workflow_definition = WorkflowDefinition(parameters=parameters.values(), blocks=blocks)
workflow = await self.update_workflow(
workflow_id=workflow.workflow_id,
workflow_definition=workflow_definition,