Skyvern/skyvern/forge/sdk/api/files.py
2024-11-26 12:27:58 +08:00

182 lines
6.6 KiB
Python

import hashlib
import mimetypes
import os
import re
import tempfile
import zipfile
from pathlib import Path
from urllib.parse import urlparse
import aiohttp
import structlog
from multidict import CIMultiDictProxy
from skyvern.constants import REPO_ROOT_DIR
from skyvern.exceptions import DownloadFileMaxSizeExceeded
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
downloaded_bytes = await client.download_file(uri=s3_uri)
file_path = create_named_temporary_file(delete=False)
file_path.write(downloaded_bytes)
return file_path.name
def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
# retrieve it from Content-Disposition
content_disposition = headers.get("Content-Disposition")
if content_disposition:
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
if len(filename) > 0 and Path(filename[0]).suffix:
return Path(filename[0]).suffix
# retrieve it from Content-Type
content_type = headers.get("Content-Type")
if content_type:
if file_extension := mimetypes.guess_extension(content_type):
return file_extension
return ""
async def download_file(url: str, max_size_mb: int | None = None) -> str:
try:
async with aiohttp.ClientSession(raise_for_status=True) as session:
LOG.info("Starting to download file")
async with session.get(url) as response:
# Check the content length if available
if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
# todo: move to root exception.py
raise DownloadFileMaxSizeExceeded(max_size_mb)
# Parse the URL
a = urlparse(url)
# Get the file name
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
file_name = os.path.basename(a.path)
# if no suffix in the URL, we need to parse it from HTTP headers
if not Path(file_name).suffix:
LOG.info("No file extension detected, trying to retrieve it from HTTP headers")
try:
if extension_name := get_file_extension_from_headers(response.headers):
file_name = file_name + extension_name
else:
LOG.warning("No extension name retreived from HTTP headers")
except Exception:
LOG.exception("Failed to retreive the file extension from HTTP headers")
file_path = os.path.join(temp_dir, file_name)
LOG.info(f"Downloading file to {file_path}")
with open(file_path, "wb") as f:
# Write the content of the request into the file
total_bytes_downloaded = 0
async for chunk in response.content.iter_chunked(1024):
f.write(chunk)
total_bytes_downloaded += len(chunk)
if max_size_mb and total_bytes_downloaded > max_size_mb * 1024 * 1024:
raise DownloadFileMaxSizeExceeded(max_size_mb)
LOG.info(f"File downloaded successfully to {file_path}")
return file_path
except aiohttp.ClientResponseError as e:
LOG.error(f"Failed to download file, status code: {e.status}")
raise
except DownloadFileMaxSizeExceeded as e:
LOG.error(f"Failed to download file, max size exceeded: {e.max_size}")
raise
except Exception:
LOG.exception("Failed to download file")
raise
def zip_files(files_path: str, zip_file_path: str) -> str:
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(files_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, files_path) # Relative path within the zip
zipf.write(file_path, arcname)
return zip_file_path
def unzip_files(zip_file_path: str, output_dir: str) -> None:
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(output_dir)
def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path:
return Path(f"{REPO_ROOT_DIR}/downloads/{workflow_run_id}/")
def list_files_in_directory(directory: Path, recursive: bool = False) -> list[str]:
listed_files: list[str] = []
for root, dirs, files in os.walk(directory):
listed_files.extend([os.path.join(root, file) for file in files])
if not recursive:
break
return listed_files
def get_number_of_files_in_directory(directory: Path, recursive: bool = False) -> int:
return len(list_files_in_directory(directory, recursive))
def sanitize_filename(filename: str) -> str:
return "".join(c for c in filename if c.isalnum() or c in ["-", "_", "."])
def rename_file(file_path: str, new_file_name: str) -> str:
try:
new_file_name = sanitize_filename(new_file_name)
new_file_path = os.path.join(os.path.dirname(file_path), new_file_name)
os.rename(file_path, new_file_path)
return new_file_path
except Exception:
LOG.exception(f"Failed to rename file {file_path} to {new_file_name}")
return file_path
def calculate_sha256_for_file(file_path: str) -> str:
"""Helper function to calculate SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def create_folder_if_not_exist(dir: str) -> None:
path = Path(dir)
if path.exists():
return
path.mkdir(parents=True)
def get_skyvern_temp_dir() -> str:
temp_dir = SettingsManager.get_settings().TEMP_PATH
create_folder_if_not_exist(temp_dir)
return temp_dir
def make_temp_directory(
suffix: str | None = None,
prefix: str | None = None,
) -> str:
temp_dir = SettingsManager.get_settings().TEMP_PATH
create_folder_if_not_exist(temp_dir)
return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=temp_dir)
def create_named_temporary_file(delete: bool = True) -> tempfile._TemporaryFileWrapper:
temp_dir = SettingsManager.get_settings().TEMP_PATH
create_folder_if_not_exist(temp_dir)
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete)