Include downloaded files in the task block outputs within workflows so subsequent blocks can use them (#1797)

This commit is contained in:
Shuchang Zheng 2025-02-20 01:19:03 -08:00 committed by GitHub
parent 367473f930
commit 167f219a3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 50 additions and 18 deletions

View file

@ -7,7 +7,7 @@ import shutil
import tempfile import tempfile
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import unquote, urlparse
import aiohttp import aiohttp
import structlog import structlog
@ -72,6 +72,15 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
client = AsyncAWSClient() client = AsyncAWSClient()
return await download_from_s3(client, url) return await download_from_s3(client, url)
# Check if URL is a file:// URI
# we only support to download local files when the environment is local
# and the file is in the skyvern downloads directory
if url.startswith("file://") and settings.ENV == "local":
file_path = parse_uri_to_path(url)
if file_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
LOG.info("Downloading file from local file system", url=url)
return file_path
async with aiohttp.ClientSession(raise_for_status=True) as session: async with aiohttp.ClientSession(raise_for_status=True) as session:
LOG.info("Starting to download file", url=url) LOG.info("Starting to download file", url=url)
async with session.get(url) as response: async with session.get(url) as response:
@ -273,3 +282,11 @@ def clean_up_dir(dir: str) -> None:
def clean_up_skyvern_temp_dir() -> None: def clean_up_skyvern_temp_dir() -> None:
return clean_up_dir(get_skyvern_temp_dir()) return clean_up_dir(get_skyvern_temp_dir())
def parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)
if parsed_uri.scheme != "file":
raise ValueError(f"Invalid URI scheme: {parsed_uri.scheme} expected: file")
path = parsed_uri.netloc + parsed_uri.path
return unquote(path)

View file

@ -2,12 +2,11 @@ import os
import shutil import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from urllib.parse import unquote, urlparse
import structlog import structlog
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir, parse_uri_to_path
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
@ -68,7 +67,7 @@ class LocalStorage(BaseStorage):
async def store_artifact(self, artifact: Artifact, data: bytes) -> None: async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
file_path = None file_path = None
try: try:
file_path = Path(self._parse_uri_to_path(artifact.uri)) file_path = Path(parse_uri_to_path(artifact.uri))
self._create_directories_if_not_exists(file_path) self._create_directories_if_not_exists(file_path)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(data) f.write(data)
@ -82,7 +81,7 @@ class LocalStorage(BaseStorage):
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None: async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
file_path = None file_path = None
try: try:
file_path = Path(self._parse_uri_to_path(artifact.uri)) file_path = Path(parse_uri_to_path(artifact.uri))
self._create_directories_if_not_exists(file_path) self._create_directories_if_not_exists(file_path)
Path(path).replace(file_path) Path(path).replace(file_path)
except Exception: except Exception:
@ -95,7 +94,7 @@ class LocalStorage(BaseStorage):
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None: async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
file_path = None file_path = None
try: try:
file_path = self._parse_uri_to_path(artifact.uri) file_path = parse_uri_to_path(artifact.uri)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
return f.read() return f.read()
except Exception: except Exception:
@ -170,14 +169,6 @@ class LocalStorage(BaseStorage):
files.append(f"file://{path}") files.append(f"file://{path}")
return files return files
@staticmethod
def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)
if parsed_uri.scheme != "file":
raise ValueError("Invalid URI scheme: {parsed_uri.scheme} expected: file")
path = parsed_uri.netloc + parsed_uri.path
return unquote(path)
@staticmethod @staticmethod
def _create_directories_if_not_exists(path_including_file_name: Path) -> None: def _create_directories_if_not_exists(path_including_file_name: Path) -> None:
path = path_including_file_name.parent path = path_including_file_name.parent

View file

@ -352,15 +352,17 @@ class TaskOutput(BaseModel):
extracted_information: list | dict[str, Any] | str | None = None extracted_information: list | dict[str, Any] | str | None = None
failure_reason: str | None = None failure_reason: str | None = None
errors: list[dict[str, Any]] = [] errors: list[dict[str, Any]] = []
downloaded_file_urls: list[str] | None = None
@staticmethod @staticmethod
def from_task(task: Task) -> TaskOutput: def from_task(task: Task, downloaded_file_urls: list[str] | None = None) -> TaskOutput:
return TaskOutput( return TaskOutput(
task_id=task.task_id, task_id=task.task_id,
status=task.status, status=task.status,
extracted_information=task.extracted_information, extracted_information=task.extracted_information,
failure_reason=task.failure_reason, failure_reason=task.failure_reason,
errors=task.errors, errors=task.errors,
downloaded_file_urls=downloaded_file_urls,
) )

View file

@ -25,7 +25,7 @@ from pypdf import PdfReader
from pypdf.errors import PdfReadError from pypdf.errors import PdfReadError
from skyvern.config import settings from skyvern.config import settings
from skyvern.constants import MAX_UPLOAD_FILE_COUNT from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, MAX_UPLOAD_FILE_COUNT
from skyvern.exceptions import ( from skyvern.exceptions import (
ContextParameterValueNotFound, ContextParameterValueNotFound,
DisabledBlockExecutionError, DisabledBlockExecutionError,
@ -633,7 +633,18 @@ class BaseTaskBlock(Block):
organization_id=workflow_run.organization_id, organization_id=workflow_run.organization_id,
) )
success = updated_task.status == TaskStatus.completed success = updated_task.status == TaskStatus.completed
task_output = TaskOutput.from_task(updated_task)
downloaded_file_urls = []
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
organization_id=workflow_run.organization_id,
task_id=updated_task.task_id,
workflow_run_id=workflow_run_id,
)
except asyncio.TimeoutError:
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
task_output = TaskOutput.from_task(updated_task, downloaded_file_urls)
output_parameter_value = task_output.model_dump() output_parameter_value = task_output.model_dump()
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value) await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value)
return await self.build_block_result( return await self.build_block_result(
@ -682,7 +693,18 @@ class BaseTaskBlock(Block):
current_retry += 1 current_retry += 1
will_retry = current_retry <= self.max_retries will_retry = current_retry <= self.max_retries
retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else "" retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else ""
task_output = TaskOutput.from_task(updated_task) downloaded_file_urls = []
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
organization_id=workflow_run.organization_id,
task_id=updated_task.task_id,
workflow_run_id=workflow_run_id,
)
except asyncio.TimeoutError:
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
task_output = TaskOutput.from_task(updated_task, downloaded_file_urls)
LOG.warning( LOG.warning(
f"Task failed with status {updated_task.status}{retry_message}", f"Task failed with status {updated_task.status}{retry_message}",
task_id=updated_task.task_id, task_id=updated_task.task_id,