blt/bytelatent/data/parallel_copy.py

225 lines
7.3 KiB
Python
Raw Normal View History

import logging
import os
import shutil
import time
from enum import Enum
import fsspec
import submitit
import typer
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(
level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
)
logger = logging.getLogger("parallel_copy")
S3_PREFIX = "s3://"
def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
if path.startswith("s3://"):
if s3_profile is None:
return fsspec.filesystem(
"s3", default_block_size=1000000 * 2**20, max_concurrency=1
)
else:
return fsspec.filesystem(
"s3",
profile=s3_profile,
default_block_size=1000000 * 2**20,
max_concurrency=1,
)
else:
return fsspec.filesystem("file")
def strip_s3_prefix(path: str):
if path.startswith(S3_PREFIX):
return path[len(S3_PREFIX) :]
else:
return path
class OverwriteMode(str, Enum):
ALWAYS = "always"
SIZE_MISMATCH = "size_mismatch"
NEVER = "never"
class ParallelMode(str, Enum):
SLURM = "slurm"
MULTIPROCESS = "multiprocess"
def copy_src_to_dst(
src_fs: fsspec.AbstractFileSystem,
dst_fs: fsspec.AbstractFileSystem,
src_file: str,
dst_file: str,
dry_run: bool = False,
):
if dry_run:
logging.info("Dry run copy: %s -> %s", src_file, dst_file)
else:
dst_parent_directory = os.path.dirname(dst_file)
dst_fs.mkdirs(dst_parent_directory, exist_ok=True)
with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open(
dst_file, "wb"
) as dst_pointer:
shutil.copyfileobj(src_pointer, dst_pointer)
class CopyJob(submitit.helpers.Checkpointable):
def __call__(
self,
src_fs_dict: dict,
dst_fs_dict: dict,
src_file: str,
dst_file: str,
dry_run: bool = False,
validate_size: bool = True,
):
src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict)
dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict)
copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run)
if validate_size and not dry_run:
src_size = src_fs.size(src_file)
dst_size = dst_fs.size(dst_file)
if src_size != dst_size:
raise ValueError(
f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}"
)
return True
def main(
src_dir: str,
dst_dir: str,
src_s3_profile: str | None = None,
dst_s3_profile: str | None = None,
n_workers: int = 16,
cpus_per_task: int = 2,
overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH,
validate_size: bool = True,
parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS,
dry_run: bool = False,
job_dir: str = "jobs_parallel-copy",
slurm_qos: str | None = None,
slurm_time_hours: int = 72,
slurm_memory: str = "0",
wait: bool = True,
wait_period: int = 5,
):
logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir)
logging.info("job_dir=%s", job_dir)
logging.info(
"Parallel=%s validate_size=%s overwrite_mode=%s qos=%s",
parallel_mode,
validate_size,
overwrite_mode,
slurm_qos,
)
if parallel_mode == ParallelMode.MULTIPROCESS:
executor = submitit.LocalExecutor(folder=job_dir)
elif parallel_mode == ParallelMode.SLURM:
executor = submitit.SlurmExecutor(folder=job_dir)
executor.update_parameters(
time=slurm_time_hours * 60,
ntasks_per_node=1,
cpus_per_task=cpus_per_task,
array_parallelism=n_workers,
mem=slurm_memory,
gpus_per_node=0,
)
if slurm_qos is not None:
executor.update_parameters(qos=slurm_qos)
else:
raise ValueError("Invalid parallel mode")
assert src_dir.endswith("/"), "src_dir must end with a /"
assert dst_dir.endswith("/"), "dst_dir must end with a /"
src_fs = get_fs(src_dir, s3_profile=src_s3_profile)
dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile)
src_dir = strip_s3_prefix(src_dir)
dst_dir = strip_s3_prefix(dst_dir)
logging.info("src: %s, dst: %s", src_dir, dst_dir)
assert src_fs.isdir(src_dir), "src_dir must be a directory"
if dst_fs.exists(dst_dir):
assert dst_dir, "dst_dir must be a directory if it exists"
else:
dst_fs.mkdirs(dst_dir, exist_ok=True)
files = src_fs.find(src_dir)
logging.info("Files found to check for transfer: %s", len(files))
jobs = []
with executor.batch():
for src_file in files:
relative_src = src_file[len(src_dir) :]
dst_file_path = os.path.join(dst_dir, relative_src)
logging.debug("src: %s -> dst %s", src_file, dst_file_path)
if dst_fs.exists(dst_file_path):
if overwrite_mode == OverwriteMode.NEVER:
pass
elif overwrite_mode == OverwriteMode.ALWAYS:
logging.info("copy: %s -> %s", src_file, dst_file_path)
job = executor.submit(
CopyJob(),
src_fs.to_dict(),
dst_fs.to_dict(),
src_file,
dst_file_path,
dry_run=dry_run,
validate_size=validate_size,
)
jobs.append(job)
elif overwrite_mode == OverwriteMode.SIZE_MISMATCH:
if src_fs.size(src_file) != dst_fs.size(dst_file_path):
logging.info("copy: %s -> %s", src_file, dst_file_path)
job = executor.submit(
CopyJob(),
src_fs.to_dict(),
dst_fs.to_dict(),
src_file,
dst_file_path,
dry_run=dry_run,
validate_size=validate_size,
)
jobs.append(job)
else:
raise ValueError("Unknown overwrite_mode")
else:
logging.info("copy: %s -> %s", src_file, dst_file_path)
job = executor.submit(
CopyJob(),
src_fs.to_dict(),
dst_fs.to_dict(),
src_file,
dst_file_path,
dry_run=dry_run,
validate_size=validate_size,
)
jobs.append(job)
if wait:
while True:
num_finished = sum(job.done() for job in jobs)
logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished)
if num_finished == len(jobs):
break
time.sleep(wait_period)
output = [job.result() for job in jobs]
if all(output):
logging.info("All copies succeeded")
else:
logging.info("Some copies failed")
else:
logging.info("Not waiting for job to complete before exiting submit program")
if __name__ == "__main__":
typer.run(main)