Sync cloud skyvern to oss skyvern (#55)

This commit is contained in:
Kerem Yilmaz 2024-03-12 22:28:16 -07:00 committed by GitHub
parent 647ea2ac0f
commit 15d78d7b08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 554 additions and 163 deletions

View file

@ -1,7 +1,9 @@
import uuid
from typing import TYPE_CHECKING, Any
import structlog
from skyvern.exceptions import WorkflowRunContextNotInitialized
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, Parameter, ParameterType, WorkflowParameter
@ -12,15 +14,15 @@ if TYPE_CHECKING:
LOG = structlog.get_logger()
class ContextManager:
aws_client: AsyncAWSClient
class WorkflowRunContext:
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]) -> None:
self.aws_client = AsyncAWSClient()
self.parameters = {}
self.values = {}
self.secrets = {}
for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key]
@ -32,8 +34,33 @@ class ContextManager:
self.parameters[parameter.key] = parameter
self.values[parameter.key] = run_parameter.value
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
"""
Get the value of a parameter. If the parameter is an AWS secret, the value will be the random secret id, not
the actual secret value. This will be used when building the navigation payload since we don't want to expose
the actual secret value in the payload.
"""
return self.values[key]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
def get_original_secret_value_or_none(self, secret_id: str) -> Any:
"""
Get the original secret value from the secrets dict. If the secret id is not found, return None.
"""
return self.secrets.get(secret_id)
@staticmethod
def generate_random_secret_id() -> str:
return f"secret_{uuid.uuid4()}"
async def register_parameter_value(
self,
aws_client: AsyncAWSClient,
parameter: PARAMETER_TYPE,
) -> None:
if parameter.parameter_type == ParameterType.WORKFLOW:
@ -42,15 +69,21 @@ class ContextManager:
f"Workflow parameters are set while initializing context manager. Parameter key: {parameter.key}"
)
elif parameter.parameter_type == ParameterType.AWS_SECRET:
secret_value = await self.aws_client.get_secret(parameter.aws_key)
# If the parameter is an AWS secret, fetch the secret value and store it in the secrets dict
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
# We'll replace the random secret id with the actual secret value when we need to use it.
secret_value = await aws_client.get_secret(parameter.aws_key)
if secret_value is not None:
self.values[parameter.key] = secret_value
random_secret_id = self.generate_random_secret_id()
self.secrets[random_secret_id] = secret_value
self.values[parameter.key] = random_secret_id
else:
# ContextParameter values will be set within the blocks
return None
async def register_block_parameters(
self,
aws_client: AsyncAWSClient,
parameters: list[PARAMETER_TYPE],
) -> None:
for parameter in parameters:
@ -67,13 +100,41 @@ class ContextManager:
)
self.parameters[parameter.key] = parameter
await self.register_parameter_value(parameter)
await self.register_parameter_value(aws_client, parameter)
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
def get_value(self, key: str) -> Any:
return self.values[key]
class WorkflowContextManager:
aws_client: AsyncAWSClient
workflow_run_contexts: dict[str, WorkflowRunContext]
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
secrets: dict[str, Any]
def __init__(self) -> None:
self.aws_client = AsyncAWSClient()
self.workflow_run_contexts = {}
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
if workflow_run_id not in self.workflow_run_contexts:
LOG.error(f"WorkflowRunContext not initialized for workflow run {workflow_run_id}")
raise WorkflowRunContextNotInitialized(workflow_run_id=workflow_run_id)
def initialize_workflow_run_context(
self, workflow_run_id: str, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]]
) -> WorkflowRunContext:
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples)
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
return workflow_run_context
def get_workflow_run_context(self, workflow_run_id: str) -> WorkflowRunContext:
self._validate_workflow_run_context(workflow_run_id)
return self.workflow_run_contexts[workflow_run_id]
async def register_block_parameters_for_workflow_run(
self,
workflow_run_id: str,
parameters: list[PARAMETER_TYPE],
) -> None:
self._validate_workflow_run_context(workflow_run_id)
await self.workflow_run_contexts[workflow_run_id].register_block_parameters(self.aws_client, parameters)