Merge 11cad6c84d into sapling-pr-archive-EntilZha
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-03 18:29:37 -08:00 committed by GitHub
commit 4ff8341738
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -0,0 +1,224 @@
import logging
import os
import shutil
import time
from enum import Enum
import fsspec
import submitit
import typer
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(
level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
)
logger = logging.getLogger("parallel_copy")
S3_PREFIX = "s3://"
def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
if path.startswith("s3://"):
if s3_profile is None:
return fsspec.filesystem(
"s3", default_block_size=1000000 * 2**20, max_concurrency=1
)
else:
return fsspec.filesystem(
"s3",
profile=s3_profile,
default_block_size=1000000 * 2**20,
max_concurrency=1,
)
else:
return fsspec.filesystem("file")
def strip_s3_prefix(path: str):
if path.startswith(S3_PREFIX):
return path[len(S3_PREFIX) :]
else:
return path
class OverwriteMode(str, Enum):
ALWAYS = "always"
SIZE_MISMATCH = "size_mismatch"
NEVER = "never"
class ParallelMode(str, Enum):
SLURM = "slurm"
MULTIPROCESS = "multiprocess"
def copy_src_to_dst(
src_fs: fsspec.AbstractFileSystem,
dst_fs: fsspec.AbstractFileSystem,
src_file: str,
dst_file: str,
dry_run: bool = False,
):
if dry_run:
logging.info("Dry run copy: %s -> %s", src_file, dst_file)
else:
dst_parent_directory = os.path.dirname(dst_file)
dst_fs.mkdirs(dst_parent_directory, exist_ok=True)
with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open(
dst_file, "wb"
) as dst_pointer:
shutil.copyfileobj(src_pointer, dst_pointer)
class CopyJob(submitit.helpers.Checkpointable):
def __call__(
self,
src_fs_dict: dict,
dst_fs_dict: dict,
src_file: str,
dst_file: str,
dry_run: bool = False,
validate_size: bool = True,
):
src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict)
dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict)
copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run)
if validate_size and not dry_run:
src_size = src_fs.size(src_file)
dst_size = dst_fs.size(dst_file)
if src_size != dst_size:
raise ValueError(
f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}"
)
return True
def main(
src_dir: str,
dst_dir: str,
src_s3_profile: str | None = None,
dst_s3_profile: str | None = None,
n_workers: int = 16,
cpus_per_task: int = 2,
overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH,
validate_size: bool = True,
parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS,
dry_run: bool = False,
job_dir: str = "jobs_parallel-copy",
slurm_qos: str | None = None,
slurm_time_hours: int = 72,
slurm_memory: str = "0",
wait: bool = True,
wait_period: int = 5,
):
logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir)
logging.info("job_dir=%s", job_dir)
logging.info(
"Parallel=%s validate_size=%s overwrite_mode=%s qos=%s",
parallel_mode,
validate_size,
overwrite_mode,
slurm_qos,
)
if parallel_mode == ParallelMode.MULTIPROCESS:
executor = submitit.LocalExecutor(folder=job_dir)
elif parallel_mode == ParallelMode.SLURM:
executor = submitit.SlurmExecutor(folder=job_dir)
executor.update_parameters(
time=slurm_time_hours * 60,
ntasks_per_node=1,
cpus_per_task=cpus_per_task,
array_parallelism=n_workers,
mem=slurm_memory,
gpus_per_node=0,
)
if slurm_qos is not None:
executor.update_parameters(qos=slurm_qos)
else:
raise ValueError("Invalid parallel mode")
assert src_dir.endswith("/"), "src_dir must end with a /"
assert dst_dir.endswith("/"), "dst_dir must end with a /"
src_fs = get_fs(src_dir, s3_profile=src_s3_profile)
dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile)
src_dir = strip_s3_prefix(src_dir)
dst_dir = strip_s3_prefix(dst_dir)
logging.info("src: %s, dst: %s", src_dir, dst_dir)
assert src_fs.isdir(src_dir), "src_dir must be a directory"
if dst_fs.exists(dst_dir):
assert dst_dir, "dst_dir must be a directory if it exists"
else:
dst_fs.mkdirs(dst_dir, exist_ok=True)
files = src_fs.find(src_dir)
logging.info("Files found to check for transfer: %s", len(files))
jobs = []
with executor.batch():
for src_file in files:
relative_src = src_file[len(src_dir) :]
dst_file_path = os.path.join(dst_dir, relative_src)
logging.debug("src: %s -> dst %s", src_file, dst_file_path)
if dst_fs.exists(dst_file_path):
if overwrite_mode == OverwriteMode.NEVER:
pass
elif overwrite_mode == OverwriteMode.ALWAYS:
logging.info("copy: %s -> %s", src_file, dst_file_path)
job = executor.submit(
CopyJob(),
src_fs.to_dict(),
dst_fs.to_dict(),
src_file,
dst_file_path,
dry_run=dry_run,
validate_size=validate_size,
)
jobs.append(job)
elif overwrite_mode == OverwriteMode.SIZE_MISMATCH:
if src_fs.size(src_file) != dst_fs.size(dst_file_path):
logging.info("copy: %s -> %s", src_file, dst_file_path)
job = executor.submit(
CopyJob(),
src_fs.to_dict(),
dst_fs.to_dict(),
src_file,
dst_file_path,
dry_run=dry_run,
validate_size=validate_size,
)
jobs.append(job)
else:
raise ValueError("Unknown overwrite_mode")
else:
logging.info("copy: %s -> %s", src_file, dst_file_path)
job = executor.submit(
CopyJob(),
src_fs.to_dict(),
dst_fs.to_dict(),
src_file,
dst_file_path,
dry_run=dry_run,
validate_size=validate_size,
)
jobs.append(job)
if wait:
while True:
num_finished = sum(job.done() for job in jobs)
logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished)
if num_finished == len(jobs):
break
time.sleep(wait_period)
output = [job.result() for job in jobs]
if all(output):
logging.info("All copies succeeded")
else:
logging.info("Some copies failed")
else:
logging.info("Not waiting for job to complete before exiting submit program")
if __name__ == "__main__":
typer.run(main)