From 11cad6c84d083645fbde617c0b125a615ecc094d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:57:06 +0000 Subject: [PATCH] WIP parallel copy script Summary: Test Plan: --- bytelatent/data/parallel_copy.py | 224 +++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 bytelatent/data/parallel_copy.py diff --git a/bytelatent/data/parallel_copy.py b/bytelatent/data/parallel_copy.py new file mode 100644 index 0000000..2820429 --- /dev/null +++ b/bytelatent/data/parallel_copy.py @@ -0,0 +1,224 @@ +import logging +import os +import shutil +import time +from enum import Enum + +import fsspec +import submitit +import typer +from rich.logging import RichHandler + +FORMAT = "%(message)s" +logging.basicConfig( + level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] +) +logger = logging.getLogger("parallel_copy") + + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem( + "s3", default_block_size=1000000 * 2**20, max_concurrency=1 + ) + else: + return fsspec.filesystem( + "s3", + profile=s3_profile, + default_block_size=1000000 * 2**20, + max_concurrency=1, + ) + else: + return fsspec.filesystem("file") + + +def strip_s3_prefix(path: str): + if path.startswith(S3_PREFIX): + return path[len(S3_PREFIX) :] + else: + return path + + +class OverwriteMode(str, Enum): + ALWAYS = "always" + SIZE_MISMATCH = "size_mismatch" + NEVER = "never" + + +class ParallelMode(str, Enum): + SLURM = "slurm" + MULTIPROCESS = "multiprocess" + + +def copy_src_to_dst( + src_fs: fsspec.AbstractFileSystem, + dst_fs: fsspec.AbstractFileSystem, + src_file: str, + dst_file: str, + dry_run: bool = False, +): + if dry_run: + logging.info("Dry run copy: %s -> %s", src_file, dst_file) + else: + dst_parent_directory = os.path.dirname(dst_file) + dst_fs.mkdirs(dst_parent_directory, exist_ok=True) + with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open( + dst_file, "wb" + ) as dst_pointer: + shutil.copyfileobj(src_pointer, dst_pointer) + + +class CopyJob(submitit.helpers.Checkpointable): + def __call__( + self, + src_fs_dict: dict, + dst_fs_dict: dict, + src_file: str, + dst_file: str, + dry_run: bool = False, + validate_size: bool = True, + ): + src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict) + dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict) + copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run) + if validate_size and not dry_run: + src_size = src_fs.size(src_file) + dst_size = dst_fs.size(dst_file) + if src_size != dst_size: + raise ValueError( + f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}" + ) + return True + + +def main( + src_dir: str, + dst_dir: str, + src_s3_profile: str | None = None, + dst_s3_profile: str | None = None, + n_workers: int = 16, + cpus_per_task: int = 2, + overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH, + validate_size: bool = True, + parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS, + dry_run: bool = False, + job_dir: str = "jobs_parallel-copy", + slurm_qos: str | None = None, + slurm_time_hours: int = 72, + slurm_memory: str = "0", + wait: bool = True, + wait_period: int = 5, +): + logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir) + logging.info("job_dir=%s", job_dir) + logging.info( + "Parallel=%s validate_size=%s overwrite_mode=%s qos=%s", + parallel_mode, + validate_size, + overwrite_mode, + slurm_qos, + ) + if parallel_mode == ParallelMode.MULTIPROCESS: + executor = submitit.LocalExecutor(folder=job_dir) + elif parallel_mode == ParallelMode.SLURM: + executor = submitit.SlurmExecutor(folder=job_dir) + executor.update_parameters( + time=slurm_time_hours * 60, + ntasks_per_node=1, + cpus_per_task=cpus_per_task, + array_parallelism=n_workers, + mem=slurm_memory, + gpus_per_node=0, + ) + if slurm_qos is not None: + executor.update_parameters(qos=slurm_qos) + else: + raise ValueError("Invalid parallel mode") + + assert src_dir.endswith("/"), "src_dir must end with a /" + assert dst_dir.endswith("/"), "dst_dir must end with a /" + src_fs = get_fs(src_dir, s3_profile=src_s3_profile) + dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile) + + src_dir = strip_s3_prefix(src_dir) + dst_dir = strip_s3_prefix(dst_dir) + logging.info("src: %s, dst: %s", src_dir, dst_dir) + + assert src_fs.isdir(src_dir), "src_dir must be a directory" + if dst_fs.exists(dst_dir): + assert dst_dir, "dst_dir must be a directory if it exists" + else: + dst_fs.mkdirs(dst_dir, exist_ok=True) + + files = src_fs.find(src_dir) + logging.info("Files found to check for transfer: %s", len(files)) + jobs = [] + with executor.batch(): + for src_file in files: + relative_src = src_file[len(src_dir) :] + dst_file_path = os.path.join(dst_dir, relative_src) + logging.debug("src: %s -> dst %s", src_file, dst_file_path) + if dst_fs.exists(dst_file_path): + if overwrite_mode == OverwriteMode.NEVER: + pass + elif overwrite_mode == OverwriteMode.ALWAYS: + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + elif overwrite_mode == OverwriteMode.SIZE_MISMATCH: + if src_fs.size(src_file) != dst_fs.size(dst_file_path): + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + else: + raise ValueError("Unknown overwrite_mode") + else: + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + if wait: + while True: + num_finished = sum(job.done() for job in jobs) + logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished) + if num_finished == len(jobs): + break + time.sleep(wait_period) + output = [job.result() for job in jobs] + if all(output): + logging.info("All copies succeeded") + else: + logging.info("Some copies failed") + else: + logging.info("Not waiting for job to complete before exiting submit program") + + +if __name__ == "__main__": + typer.run(main)