Skyvern/skyvern/forge/sdk/api/files.py

501 lines
18 KiB
Python

import asyncio
import hashlib
import mimetypes
import os
import re
import shutil
import tempfile
import zipfile
from pathlib import Path
from urllib.parse import parse_qsl, unquote, urlparse
import aiohttp
import structlog
from multidict import CIMultiDictProxy
from yarl import URL
from skyvern.config import settings
from skyvern.constants import BROWSER_DOWNLOAD_TIMEOUT, BROWSER_DOWNLOADING_SUFFIX, REPO_ROOT_DIR
from skyvern.exceptions import DownloadFileMaxSizeExceeded, DownloadFileMaxWaitingTime
from skyvern.forge import app
from skyvern.utils.url_validators import encode_url
LOG = structlog.get_logger()
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str] | dict[str, str]) -> tuple[str, str]:
file_stem = ""
file_suffix: str | None = ""
# retrieve the stem and suffix from Content-Disposition
content_disposition = headers.get("Content-Disposition")
if content_disposition:
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
if len(filename) > 0:
file_stem = Path(filename[0]).stem
file_suffix = Path(filename[0]).suffix
if file_suffix:
return file_stem, file_suffix
# retrieve the suffix from Content-Type
content_type = headers.get("Content-Type")
if content_type:
if file_suffix := mimetypes.guess_extension(content_type):
return file_stem, file_suffix
return file_stem, file_suffix or ""
def extract_google_drive_file_id(url: str) -> str | None:
"""Extract file ID from Google Drive URL."""
# Handle format: https://drive.google.com/file/d/{file_id}/view
match = re.search(r"/file/d/([a-zA-Z0-9_-]+)", url)
if match:
return match.group(1)
return None
def is_valid_mime_type(file_path: str) -> bool:
mime_type, _ = mimetypes.guess_type(file_path)
return mime_type is not None
def _determine_download_filename(
filename: str | None,
response_headers: dict,
url: str,
) -> str:
"""Determine the filename for a downloaded file."""
if filename:
file_name = filename
if not os.path.splitext(file_name)[1]:
content_type = response_headers.get("Content-Type", "")
if content_type:
ext = mimetypes.guess_extension(content_type.split(";")[0].strip())
if ext:
file_name = file_name + ext
return sanitize_filename(file_name)
file_name = ""
file_suffix = ""
try:
file_name, file_suffix = get_file_name_and_suffix_from_headers(response_headers)
if not file_suffix:
LOG.warning("No extension name retrieved from HTTP headers")
except Exception:
LOG.exception("Failed to retrieve the file extension from HTTP headers")
query_params = dict(parse_qsl(urlparse(url).query))
if "download" in query_params:
file_name = query_params["download"]
if not file_name:
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
file_name = os.path.basename(urlparse(url).path) or "download"
if not is_valid_mime_type(file_name) and file_suffix:
LOG.info("No file extension detected, adding the extension from HTTP headers")
file_name = file_name + file_suffix
return sanitize_filename(file_name)
def validate_download_url(url: str, organization_id: str | None = None) -> bool:
"""Validate if a URL is supported for downloading.
Security validation for URL downloads to prevent:
- File system access outside allowed directories
- Access to local file system in non-local environments
- Unsupported or dangerous URL schemes
Args:
url: The URL to validate
Returns:
True if valid, False otherwise.
"""
try:
parsed_url = urlparse(url)
scheme = parsed_url.scheme.lower()
# Allow http/https URLs (includes Google Drive which uses https)
if scheme in ("http", "https"):
return True
if scheme in ("s3", "azure"):
try:
if organization_id is None:
return False
app.STORAGE.assert_managed_file_access(url, organization_id)
return True
except (PermissionError, RuntimeError):
return False
# Allow file:// URLs only in local environment
if scheme == "file":
if settings.ENV != "local":
return False
# Validate the file path is within allowed directories
try:
file_path = parse_uri_to_path(url)
allowed_prefix = f"{REPO_ROOT_DIR}/downloads"
if not file_path.startswith(allowed_prefix):
return False
return True
except ValueError:
return False
# Reject unsupported schemes
return False
except Exception:
return False
async def download_file(
url: str,
max_size_mb: int | None = None,
headers: dict[str, str] | None = None,
output_dir: str | None = None,
filename: str | None = None,
organization_id: str | None = None,
) -> str:
if not url or not url.strip():
raise ValueError("Download URL is empty — no file download was triggered by the browser")
try:
# Check if URL is a Google Drive link
if "drive.google.com" in url:
file_id = extract_google_drive_file_id(url)
if file_id:
# Convert to direct download URL
url = f"https://drive.google.com/uc?export=download&id={file_id}"
LOG.info("Converting Google Drive link to direct download", url=url)
# Check if URL is a cloud storage URI handled by the configured storage backend.
parsed = urlparse(url)
if parsed.scheme in ("s3", "azure"):
if organization_id is None:
raise PermissionError(f"No permission to access storage URI: {url}")
app.STORAGE.assert_managed_file_access(url, organization_id)
LOG.info(
"Downloading managed storage file",
url=url,
organization_id=organization_id,
storage_type=getattr(app.STORAGE, "storage_type", None),
)
data = await app.STORAGE.download_managed_file(url, organization_id)
if data is None:
raise Exception(f"Failed to download managed storage file: {url}")
filename = url.split("/")[-1]
temp_file = create_named_temporary_file(delete=False, file_name=filename)
LOG.info(f"Downloaded file to {temp_file.name}")
temp_file.write(data)
return temp_file.name
# Check if URL is a file:// URI
# we only support to download local files when the environment is local
# and the file is in the skyvern downloads directory
if url.startswith("file://") and settings.ENV == "local":
local_path = parse_uri_to_path(url)
if local_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
LOG.info("Downloading file from local file system", url=url)
return local_path
async with aiohttp.ClientSession(raise_for_status=True) as session:
LOG.info("Starting to download file", url=url)
encoded_url = encode_url(url)
async with session.get(URL(encoded_url, encoded=True), headers=headers) as response:
# Check the content length if available
if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
# todo: move to root exception.py
raise DownloadFileMaxSizeExceeded(max_size_mb)
# Get the file name
if output_dir:
os.makedirs(output_dir, exist_ok=True)
download_dir = output_dir
else:
download_dir = make_temp_directory(prefix="skyvern_downloads_")
# Determine filename - use provided filename or derive from response/URL
file_name = _determine_download_filename(filename, dict(response.headers), url)
file_path = os.path.join(download_dir, file_name)
LOG.info(f"Downloading file to {file_path}")
with open(file_path, "wb") as f:
# Write the content of the request into the file
total_bytes_downloaded = 0
async for chunk in response.content.iter_chunked(1024):
f.write(chunk)
total_bytes_downloaded += len(chunk)
if max_size_mb and total_bytes_downloaded > max_size_mb * 1024 * 1024:
raise DownloadFileMaxSizeExceeded(max_size_mb)
LOG.info(f"File downloaded successfully to {file_path}")
return file_path
except aiohttp.ClientResponseError as e:
LOG.exception(f"Failed to download file, status code: {e.status}")
raise
except DownloadFileMaxSizeExceeded as e:
LOG.exception(f"Failed to download file, max size exceeded: {e.max_size}")
raise
except PermissionError as e:
LOG.warning(
"Rejected storage URI download",
url=url,
organization_id=organization_id,
reason=str(e),
)
raise
except Exception:
LOG.exception("Failed to download file")
raise
def zip_files(files_path: str, zip_file_path: str) -> str:
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(files_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, files_path) # Relative path within the zip
zipf.write(file_path, arcname)
return zip_file_path
def unzip_files(zip_file_path: str, output_dir: str) -> None:
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(output_dir)
_REMOTE_URL_PREFIXES = ("http://", "https://", "s3://", "azure://", "www.")
def is_remote_url(path: str) -> bool:
"""Return True if the path is a remote URL (HTTP, S3, Azure) rather than a local filesystem path."""
return path.startswith(_REMOTE_URL_PREFIXES)
def validate_local_file_path(candidate_path: str, run_id: str | None) -> str:
"""Validate that a local file path is within the workflow's download directory.
Uses os.path.realpath() to resolve symlinks and '..' traversal before checking
containment. Raises PermissionError if the path resolves outside the allowed directory.
Returns the resolved canonical path on success.
"""
if run_id is None:
raise PermissionError("File access denied: no workflow run ID provided")
if not candidate_path:
LOG.warning("Empty path provided for file access validation", run_id=run_id)
raise PermissionError(f"File access denied: path must not be empty for run {run_id}")
allowed_dir = os.path.realpath(os.path.join(settings.DOWNLOAD_PATH, str(run_id)))
resolved = os.path.realpath(candidate_path)
# The resolved path must be the allowed dir itself or a child of it
if resolved != allowed_dir and not resolved.startswith(allowed_dir + os.sep):
LOG.warning(
"Path traversal attempt blocked",
candidate_path=candidate_path,
resolved_path=resolved,
allowed_dir=allowed_dir,
run_id=run_id,
)
raise PermissionError(f"File access denied: path is outside the allowed download directory for run {run_id}")
return resolved
def get_path_for_workflow_download_directory(run_id: str | None) -> Path:
return Path(get_download_dir(run_id=run_id))
def get_download_dir(run_id: str | None) -> str:
download_dir = os.path.join(settings.DOWNLOAD_PATH, str(run_id))
os.makedirs(download_dir, exist_ok=True)
return download_dir
def list_files_in_directory(directory: Path, recursive: bool = False) -> list[str]:
listed_files: list[str] = []
for root, dirs, files in os.walk(directory):
listed_files.extend([os.path.join(root, file) for file in files])
if not recursive:
break
return listed_files
def list_downloading_files_in_directory(
directory: Path, downloading_suffix: str = BROWSER_DOWNLOADING_SUFFIX
) -> list[str]:
# check if there's any file is still downloading
downloading_files: list[str] = []
for file in list_files_in_directory(directory):
path = Path(file)
if path.suffix == downloading_suffix:
downloading_files.append(file)
return downloading_files
async def wait_for_download_finished(downloading_files: list[str], timeout: float = BROWSER_DOWNLOAD_TIMEOUT) -> None:
cur_downloading_files = downloading_files
try:
async with asyncio.timeout(timeout):
while len(cur_downloading_files) > 0:
new_downloading_files: list[str] = []
for path in cur_downloading_files:
# Check for cloud storage URIs (S3 or Azure)
parsed = urlparse(path)
if parsed.scheme in ("s3", "azure"):
if not await app.STORAGE.file_exists(path):
LOG.debug(
"downloading file is not found in cloud storage, means the file finished downloading",
path=path,
)
continue
else:
if not Path(path).exists():
LOG.debug(
"downloading file is not found in the local file system, means the file finished downloading",
path=path,
)
continue
new_downloading_files.append(path)
cur_downloading_files = new_downloading_files
await asyncio.sleep(1)
except asyncio.TimeoutError:
raise DownloadFileMaxWaitingTime(downloading_files=cur_downloading_files)
async def check_downloading_files_and_wait_for_download_to_complete(
download_dir: Path,
organization_id: str,
browser_session_id: str | None = None,
timeout: float = BROWSER_DOWNLOAD_TIMEOUT,
) -> None:
# check if there's any file is still downloading
downloading_files = list_downloading_files_in_directory(download_dir)
if browser_session_id:
files_in_browser_session = await app.STORAGE.list_downloading_files_in_browser_session(
organization_id=organization_id, browser_session_id=browser_session_id
)
downloading_files = downloading_files + files_in_browser_session
if len(downloading_files) == 0:
return
LOG.info(
"File downloading hasn't completed, wait for a while",
downloading_files=downloading_files,
)
try:
await wait_for_download_finished(
downloading_files=downloading_files,
timeout=timeout,
)
except DownloadFileMaxWaitingTime as e:
LOG.warning(
"There're several long-time downloading files, these files might be broken",
downloading_files=e.downloading_files,
)
def get_number_of_files_in_directory(directory: Path, recursive: bool = False) -> int:
return len(list_files_in_directory(directory, recursive))
def sanitize_filename(filename: str) -> str:
return "".join(c for c in filename if c.isalnum() or c in ["-", "_", ".", "%", " "])
def rename_file(file_path: str, new_file_name: str) -> str:
try:
new_file_name = sanitize_filename(new_file_name)
new_file_path = os.path.join(os.path.dirname(file_path), new_file_name)
os.rename(file_path, new_file_path)
return new_file_path
except Exception:
LOG.exception(f"Failed to rename file {file_path} to {new_file_name}")
return file_path
def calculate_sha256_for_file(file_path: str) -> str:
"""Helper function to calculate SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def create_folder_if_not_exist(dir: str) -> None:
path = Path(dir)
path.mkdir(parents=True, exist_ok=True)
def get_skyvern_temp_dir() -> str:
temp_dir = settings.TEMP_PATH
create_folder_if_not_exist(temp_dir)
return temp_dir
def make_temp_directory(
suffix: str | None = None,
prefix: str | None = None,
) -> str:
temp_dir = settings.TEMP_PATH
create_folder_if_not_exist(temp_dir)
return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=temp_dir)
def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper:
temp_dir = settings.TEMP_PATH
create_folder_if_not_exist(temp_dir)
if file_name:
# Sanitize the filename to remove any dangerous characters
safe_file_name = sanitize_filename(file_name)
# Create file with exact name (without random characters)
file_path = os.path.join(temp_dir, safe_file_name)
# Open in binary mode and return a NamedTemporaryFile-like object
file = open(file_path, "wb")
return tempfile._TemporaryFileWrapper(file, file_path, delete=delete)
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete)
def clean_up_dir(dir: str) -> None:
if not os.path.exists(dir):
return
if os.path.isfile(dir):
os.unlink(dir)
return
for item in os.listdir(dir):
item_path = os.path.join(dir, item)
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
return
def clean_up_skyvern_temp_dir() -> None:
return clean_up_dir(get_skyvern_temp_dir())
def parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)
if parsed_uri.scheme != "file":
raise ValueError(f"Invalid URI scheme: {parsed_uri.scheme} expected: file")
path = parsed_uri.netloc + parsed_uri.path
return unquote(path)