mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
import logging
|
|
import os
|
|
import shutil
|
|
import time
|
|
from enum import Enum
|
|
|
|
import fsspec
|
|
import submitit
|
|
import typer
|
|
from rich.logging import RichHandler
|
|
|
|
FORMAT = "%(message)s"
|
|
logging.basicConfig(
|
|
level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
|
|
)
|
|
logger = logging.getLogger("parallel_copy")
|
|
|
|
|
|
S3_PREFIX = "s3://"
|
|
|
|
|
|
def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
|
|
if path.startswith("s3://"):
|
|
if s3_profile is None:
|
|
return fsspec.filesystem(
|
|
"s3", default_block_size=1000000 * 2**20, max_concurrency=1
|
|
)
|
|
else:
|
|
return fsspec.filesystem(
|
|
"s3",
|
|
profile=s3_profile,
|
|
default_block_size=1000000 * 2**20,
|
|
max_concurrency=1,
|
|
)
|
|
else:
|
|
return fsspec.filesystem("file")
|
|
|
|
|
|
def strip_s3_prefix(path: str):
|
|
if path.startswith(S3_PREFIX):
|
|
return path[len(S3_PREFIX) :]
|
|
else:
|
|
return path
|
|
|
|
|
|
class OverwriteMode(str, Enum):
|
|
ALWAYS = "always"
|
|
SIZE_MISMATCH = "size_mismatch"
|
|
NEVER = "never"
|
|
|
|
|
|
class ParallelMode(str, Enum):
|
|
SLURM = "slurm"
|
|
MULTIPROCESS = "multiprocess"
|
|
|
|
|
|
def copy_src_to_dst(
|
|
src_fs: fsspec.AbstractFileSystem,
|
|
dst_fs: fsspec.AbstractFileSystem,
|
|
src_file: str,
|
|
dst_file: str,
|
|
dry_run: bool = False,
|
|
):
|
|
if dry_run:
|
|
logging.info("Dry run copy: %s -> %s", src_file, dst_file)
|
|
else:
|
|
dst_parent_directory = os.path.dirname(dst_file)
|
|
dst_fs.mkdirs(dst_parent_directory, exist_ok=True)
|
|
with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open(
|
|
dst_file, "wb"
|
|
) as dst_pointer:
|
|
shutil.copyfileobj(src_pointer, dst_pointer)
|
|
|
|
|
|
class CopyJob(submitit.helpers.Checkpointable):
|
|
def __call__(
|
|
self,
|
|
src_fs_dict: dict,
|
|
dst_fs_dict: dict,
|
|
src_file: str,
|
|
dst_file: str,
|
|
dry_run: bool = False,
|
|
validate_size: bool = True,
|
|
):
|
|
src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict)
|
|
dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict)
|
|
copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run)
|
|
if validate_size and not dry_run:
|
|
src_size = src_fs.size(src_file)
|
|
dst_size = dst_fs.size(dst_file)
|
|
if src_size != dst_size:
|
|
raise ValueError(
|
|
f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}"
|
|
)
|
|
return True
|
|
|
|
|
|
def main(
|
|
src_dir: str,
|
|
dst_dir: str,
|
|
src_s3_profile: str | None = None,
|
|
dst_s3_profile: str | None = None,
|
|
n_workers: int = 16,
|
|
cpus_per_task: int = 2,
|
|
overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH,
|
|
validate_size: bool = True,
|
|
parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS,
|
|
dry_run: bool = False,
|
|
job_dir: str = "jobs_parallel-copy",
|
|
slurm_qos: str | None = None,
|
|
slurm_time_hours: int = 72,
|
|
slurm_memory: str = "0",
|
|
wait: bool = True,
|
|
wait_period: int = 5,
|
|
):
|
|
logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir)
|
|
logging.info("job_dir=%s", job_dir)
|
|
logging.info(
|
|
"Parallel=%s validate_size=%s overwrite_mode=%s qos=%s",
|
|
parallel_mode,
|
|
validate_size,
|
|
overwrite_mode,
|
|
slurm_qos,
|
|
)
|
|
if parallel_mode == ParallelMode.MULTIPROCESS:
|
|
executor = submitit.LocalExecutor(folder=job_dir)
|
|
elif parallel_mode == ParallelMode.SLURM:
|
|
executor = submitit.SlurmExecutor(folder=job_dir)
|
|
executor.update_parameters(
|
|
time=slurm_time_hours * 60,
|
|
ntasks_per_node=1,
|
|
cpus_per_task=cpus_per_task,
|
|
array_parallelism=n_workers,
|
|
mem=slurm_memory,
|
|
gpus_per_node=0,
|
|
)
|
|
if slurm_qos is not None:
|
|
executor.update_parameters(qos=slurm_qos)
|
|
else:
|
|
raise ValueError("Invalid parallel mode")
|
|
|
|
assert src_dir.endswith("/"), "src_dir must end with a /"
|
|
assert dst_dir.endswith("/"), "dst_dir must end with a /"
|
|
src_fs = get_fs(src_dir, s3_profile=src_s3_profile)
|
|
dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile)
|
|
|
|
src_dir = strip_s3_prefix(src_dir)
|
|
dst_dir = strip_s3_prefix(dst_dir)
|
|
logging.info("src: %s, dst: %s", src_dir, dst_dir)
|
|
|
|
assert src_fs.isdir(src_dir), "src_dir must be a directory"
|
|
if dst_fs.exists(dst_dir):
|
|
assert dst_dir, "dst_dir must be a directory if it exists"
|
|
else:
|
|
dst_fs.mkdirs(dst_dir, exist_ok=True)
|
|
|
|
files = src_fs.find(src_dir)
|
|
logging.info("Files found to check for transfer: %s", len(files))
|
|
jobs = []
|
|
with executor.batch():
|
|
for src_file in files:
|
|
relative_src = src_file[len(src_dir) :]
|
|
dst_file_path = os.path.join(dst_dir, relative_src)
|
|
logging.debug("src: %s -> dst %s", src_file, dst_file_path)
|
|
if dst_fs.exists(dst_file_path):
|
|
if overwrite_mode == OverwriteMode.NEVER:
|
|
pass
|
|
elif overwrite_mode == OverwriteMode.ALWAYS:
|
|
logging.info("copy: %s -> %s", src_file, dst_file_path)
|
|
job = executor.submit(
|
|
CopyJob(),
|
|
src_fs.to_dict(),
|
|
dst_fs.to_dict(),
|
|
src_file,
|
|
dst_file_path,
|
|
dry_run=dry_run,
|
|
validate_size=validate_size,
|
|
)
|
|
jobs.append(job)
|
|
elif overwrite_mode == OverwriteMode.SIZE_MISMATCH:
|
|
if src_fs.size(src_file) != dst_fs.size(dst_file_path):
|
|
logging.info("copy: %s -> %s", src_file, dst_file_path)
|
|
job = executor.submit(
|
|
CopyJob(),
|
|
src_fs.to_dict(),
|
|
dst_fs.to_dict(),
|
|
src_file,
|
|
dst_file_path,
|
|
dry_run=dry_run,
|
|
validate_size=validate_size,
|
|
)
|
|
jobs.append(job)
|
|
else:
|
|
raise ValueError("Unknown overwrite_mode")
|
|
else:
|
|
logging.info("copy: %s -> %s", src_file, dst_file_path)
|
|
job = executor.submit(
|
|
CopyJob(),
|
|
src_fs.to_dict(),
|
|
dst_fs.to_dict(),
|
|
src_file,
|
|
dst_file_path,
|
|
dry_run=dry_run,
|
|
validate_size=validate_size,
|
|
)
|
|
jobs.append(job)
|
|
if wait:
|
|
while True:
|
|
num_finished = sum(job.done() for job in jobs)
|
|
logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished)
|
|
if num_finished == len(jobs):
|
|
break
|
|
time.sleep(wait_period)
|
|
output = [job.result() for job in jobs]
|
|
if all(output):
|
|
logging.info("All copies succeeded")
|
|
else:
|
|
logging.info("Some copies failed")
|
|
else:
|
|
logging.info("Not waiting for job to complete before exiting submit program")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
typer.run(main)
|