all blocks support jinja template (#1288)

This commit is contained in:
LawyZheng 2024-11-29 15:24:35 +08:00 committed by GitHub
parent f491b017d1
commit d697023994
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -154,10 +154,24 @@ class Block(BaseModel, abc.ABC):
def get_async_aws_client() -> AsyncAWSClient:
return app.WORKFLOW_CONTEXT_MANAGER.aws_client
@staticmethod
def format_block_parameter_template_from_workflow_run_context(
potential_template: str, workflow_run_context: WorkflowRunContext
) -> str:
if not potential_template:
return potential_template
template = Template(potential_template)
return template.render(workflow_run_context.values)
@abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
pass
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
self.label = self.format_block_parameter_template_from_workflow_run_context(
potential_template=self.label, workflow_run_context=workflow_run_context
)
async def execute_safe(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
try:
return await self.execute(workflow_run_id, **kwargs)
@ -219,14 +233,48 @@ class BaseTaskBlock(Block):
return parameters
@staticmethod
def format_task_block_parameter_template_from_workflow_run_context(
potential_template: str | None, workflow_run_context: WorkflowRunContext
) -> str | None:
if not potential_template:
return potential_template
template = Template(potential_template)
return template.render(workflow_run_context.values)
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context=workflow_run_context)
self.title = self.format_block_parameter_template_from_workflow_run_context(self.title, workflow_run_context)
if self.url:
self.url = self.format_block_parameter_template_from_workflow_run_context(self.url, workflow_run_context)
if self.totp_identifier:
self.totp_identifier = self.format_block_parameter_template_from_workflow_run_context(
self.totp_identifier, workflow_run_context
)
if self.totp_verification_url:
self.totp_verification_url = self.format_block_parameter_template_from_workflow_run_context(
self.totp_verification_url, workflow_run_context
)
if self.download_suffix:
self.download_suffix = self.format_block_parameter_template_from_workflow_run_context(
self.download_suffix, workflow_run_context
)
if self.navigation_goal:
self.navigation_goal = self.format_block_parameter_template_from_workflow_run_context(
self.navigation_goal, workflow_run_context
)
if self.data_extraction_goal:
self.data_extraction_goal = self.format_block_parameter_template_from_workflow_run_context(
self.data_extraction_goal, workflow_run_context
)
if self.complete_criterion:
self.complete_criterion = self.format_block_parameter_template_from_workflow_run_context(
self.complete_criterion, workflow_run_context
)
if self.terminate_criterion:
self.terminate_criterion = self.format_block_parameter_template_from_workflow_run_context(
self.terminate_criterion, workflow_run_context
)
@staticmethod
async def get_task_order(workflow_run_id: str, current_retry: int) -> tuple[int, int]:
@ -301,26 +349,7 @@ class BaseTaskBlock(Block):
)
self.download_suffix = download_suffix_parameter_value
self.url = self.format_task_block_parameter_template_from_workflow_run_context(self.url, workflow_run_context)
self.totp_identifier = self.format_task_block_parameter_template_from_workflow_run_context(
self.totp_identifier, workflow_run_context
)
self.download_suffix = self.format_task_block_parameter_template_from_workflow_run_context(
self.download_suffix, workflow_run_context
)
self.navigation_goal = self.format_task_block_parameter_template_from_workflow_run_context(
self.navigation_goal, workflow_run_context
)
self.data_extraction_goal = self.format_task_block_parameter_template_from_workflow_run_context(
self.data_extraction_goal, workflow_run_context
)
self.complete_criterion = self.format_task_block_parameter_template_from_workflow_run_context(
self.complete_criterion, workflow_run_context
)
self.terminate_criterion = self.format_task_block_parameter_template_from_workflow_run_context(
self.terminate_criterion, workflow_run_context
)
self.format_potential_template_parameters(workflow_run_context=workflow_run_context)
# TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and
# non-retryable terminations
while will_retry:
@ -698,6 +727,7 @@ class ForLoopBlock(Block):
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
self.format_potential_template_parameters(workflow_run_context)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
@ -772,10 +802,15 @@ class CodeBlock(Block):
) -> list[PARAMETER_TYPE]:
return self.parameters
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context)
self.code = self.format_block_parameter_template_from_workflow_run_context(self.code, workflow_run_context)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
raise DisabledBlockExecutionError("CodeBlock is disabled")
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
self.format_potential_template_parameters(workflow_run_context)
# get all parameters into a dictionary
parameter_values = {}
@ -836,6 +871,13 @@ class TextPromptBlock(Block):
) -> list[PARAMETER_TYPE]:
return self.parameters
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context)
self.llm_key = self.format_block_parameter_template_from_workflow_run_context(
self.llm_key, workflow_run_context
)
self.prompt = self.format_block_parameter_template_from_workflow_run_context(self.prompt, workflow_run_context)
async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
llm_key = self.llm_key or DEFAULT_TEXT_PROMPT_LLM_KEY
llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(llm_key)
@ -870,6 +912,7 @@ class TextPromptBlock(Block):
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
self.format_potential_template_parameters(workflow_run_context)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
@ -903,6 +946,10 @@ class DownloadToS3Block(Block):
return []
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context)
self.url = self.format_block_parameter_template_from_workflow_run_context(self.url, workflow_run_context)
async def _upload_file_to_s3(self, uri: str, file_path: str) -> None:
try:
client = self.get_async_aws_client()
@ -925,6 +972,8 @@ class DownloadToS3Block(Block):
)
self.url = task_url_parameter_value
self.format_potential_template_parameters(workflow_run_context)
try:
file_path = await download_file(self.url, max_size_mb=10)
except Exception as e:
@ -963,6 +1012,11 @@ class UploadToS3Block(Block):
return []
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context)
if self.path:
self.path = self.format_block_parameter_template_from_workflow_run_context(self.path, workflow_run_context)
@staticmethod
def _get_s3_uri(workflow_run_id: str, path: str) -> str:
s3_bucket = SettingsManager.get_settings().AWS_S3_BUCKET_UPLOADS
@ -986,6 +1040,7 @@ class UploadToS3Block(Block):
elif self.path == SettingsManager.get_settings().WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY:
self.path = str(get_path_for_workflow_download_directory(workflow_run_id).absolute())
self.format_potential_template_parameters(workflow_run_context)
if not self.path or not os.path.exists(self.path):
raise FileNotFoundError(f"UploadToS3Block: File not found at path: {self.path}")
@ -1061,6 +1116,16 @@ class SendEmailBlock(Block):
return parameters
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context)
self.sender = self.format_block_parameter_template_from_workflow_run_context(self.sender, workflow_run_context)
self.subject = self.format_block_parameter_template_from_workflow_run_context(
self.subject, workflow_run_context
)
self.body = self.format_block_parameter_template_from_workflow_run_context(self.body, workflow_run_context)
# file_attachments are formatted in _get_file_paths()
# recipients are formatted in get_real_email_recipients()
def _decrypt_smtp_parameters(self, workflow_run_context: WorkflowRunContext) -> tuple[str, int, str, str]:
obfuscated_smtp_host_value = workflow_run_context.get_value(self.smtp_host.key)
obfuscated_smtp_port_value = workflow_run_context.get_value(self.smtp_port.key)
@ -1117,6 +1182,7 @@ class SendEmailBlock(Block):
file_path=path,
)
path = self.format_block_parameter_template_from_workflow_run_context(path, workflow_run_context)
# if the file path is a directory, add all files in the directory, skip directories, limit to 10 files
if os.path.exists(path):
if os.path.isdir(path):
@ -1157,6 +1223,7 @@ class SendEmailBlock(Block):
else:
maybe_recipient = recipient
recipient = self.format_block_parameter_template_from_workflow_run_context(recipient, workflow_run_context)
# check if maybe_recipient is a valid email address
try:
validate_email(maybe_recipient)
@ -1269,6 +1336,7 @@ class SendEmailBlock(Block):
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
self.format_potential_template_parameters(workflow_run_context)
smtp_host_value, smtp_port_value, smtp_username_value, smtp_password_value = self._decrypt_smtp_parameters(
workflow_run_context
)
@ -1320,6 +1388,12 @@ class FileParserBlock(Block):
return [workflow_run_context.get_parameter(self.file_url)]
return []
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
super().format_potential_template_parameters(workflow_run_context)
self.file_url = self.format_block_parameter_template_from_workflow_run_context(
self.file_url, workflow_run_context
)
def validate_file_type(self, file_url_used: str, file_path: str) -> None:
if self.file_type == FileType.CSV:
try:
@ -1330,7 +1404,6 @@ class FileParserBlock(Block):
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
file_url_to_use = self.file_url
if (
self.file_url
and workflow_run_context.has_parameter(self.file_url)
@ -1343,15 +1416,17 @@ class FileParserBlock(Block):
file_url_parameter_value=file_url_parameter_value,
file_url_parameter_key=self.file_url,
)
file_url_to_use = file_url_parameter_value
self.file_url = file_url_parameter_value
self.format_potential_template_parameters(workflow_run_context)
# Download the file
if file_url_to_use.startswith("s3://"):
file_path = await download_from_s3(self.get_async_aws_client(), file_url_to_use)
if self.file_url.startswith("s3://"):
file_path = await download_from_s3(self.get_async_aws_client(), self.file_url)
else:
file_path = await download_file(file_url_to_use)
file_path = await download_file(self.file_url)
# Validate the file type
self.validate_file_type(file_url_to_use, file_path)
self.validate_file_type(self.file_url, file_path)
# Parse the file into a list of dictionaries where each dictionary represents a row in the file
parsed_data = []
with open(file_path, "r") as file: