mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Merge 11cad6c84d
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
4ff8341738
224
bytelatent/data/parallel_copy.py
Normal file
224
bytelatent/data/parallel_copy.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
import fsspec
|
||||
import submitit
|
||||
import typer
|
||||
from rich.logging import RichHandler
|
||||
|
||||
FORMAT = "%(message)s"
|
||||
logging.basicConfig(
|
||||
level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
|
||||
)
|
||||
logger = logging.getLogger("parallel_copy")
|
||||
|
||||
|
||||
S3_PREFIX = "s3://"
|
||||
|
||||
|
||||
def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
|
||||
if path.startswith("s3://"):
|
||||
if s3_profile is None:
|
||||
return fsspec.filesystem(
|
||||
"s3", default_block_size=1000000 * 2**20, max_concurrency=1
|
||||
)
|
||||
else:
|
||||
return fsspec.filesystem(
|
||||
"s3",
|
||||
profile=s3_profile,
|
||||
default_block_size=1000000 * 2**20,
|
||||
max_concurrency=1,
|
||||
)
|
||||
else:
|
||||
return fsspec.filesystem("file")
|
||||
|
||||
|
||||
def strip_s3_prefix(path: str):
|
||||
if path.startswith(S3_PREFIX):
|
||||
return path[len(S3_PREFIX) :]
|
||||
else:
|
||||
return path
|
||||
|
||||
|
||||
class OverwriteMode(str, Enum):
|
||||
ALWAYS = "always"
|
||||
SIZE_MISMATCH = "size_mismatch"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class ParallelMode(str, Enum):
|
||||
SLURM = "slurm"
|
||||
MULTIPROCESS = "multiprocess"
|
||||
|
||||
|
||||
def copy_src_to_dst(
|
||||
src_fs: fsspec.AbstractFileSystem,
|
||||
dst_fs: fsspec.AbstractFileSystem,
|
||||
src_file: str,
|
||||
dst_file: str,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
if dry_run:
|
||||
logging.info("Dry run copy: %s -> %s", src_file, dst_file)
|
||||
else:
|
||||
dst_parent_directory = os.path.dirname(dst_file)
|
||||
dst_fs.mkdirs(dst_parent_directory, exist_ok=True)
|
||||
with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open(
|
||||
dst_file, "wb"
|
||||
) as dst_pointer:
|
||||
shutil.copyfileobj(src_pointer, dst_pointer)
|
||||
|
||||
|
||||
class CopyJob(submitit.helpers.Checkpointable):
|
||||
def __call__(
|
||||
self,
|
||||
src_fs_dict: dict,
|
||||
dst_fs_dict: dict,
|
||||
src_file: str,
|
||||
dst_file: str,
|
||||
dry_run: bool = False,
|
||||
validate_size: bool = True,
|
||||
):
|
||||
src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict)
|
||||
dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict)
|
||||
copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run)
|
||||
if validate_size and not dry_run:
|
||||
src_size = src_fs.size(src_file)
|
||||
dst_size = dst_fs.size(dst_file)
|
||||
if src_size != dst_size:
|
||||
raise ValueError(
|
||||
f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main(
|
||||
src_dir: str,
|
||||
dst_dir: str,
|
||||
src_s3_profile: str | None = None,
|
||||
dst_s3_profile: str | None = None,
|
||||
n_workers: int = 16,
|
||||
cpus_per_task: int = 2,
|
||||
overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH,
|
||||
validate_size: bool = True,
|
||||
parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS,
|
||||
dry_run: bool = False,
|
||||
job_dir: str = "jobs_parallel-copy",
|
||||
slurm_qos: str | None = None,
|
||||
slurm_time_hours: int = 72,
|
||||
slurm_memory: str = "0",
|
||||
wait: bool = True,
|
||||
wait_period: int = 5,
|
||||
):
|
||||
logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir)
|
||||
logging.info("job_dir=%s", job_dir)
|
||||
logging.info(
|
||||
"Parallel=%s validate_size=%s overwrite_mode=%s qos=%s",
|
||||
parallel_mode,
|
||||
validate_size,
|
||||
overwrite_mode,
|
||||
slurm_qos,
|
||||
)
|
||||
if parallel_mode == ParallelMode.MULTIPROCESS:
|
||||
executor = submitit.LocalExecutor(folder=job_dir)
|
||||
elif parallel_mode == ParallelMode.SLURM:
|
||||
executor = submitit.SlurmExecutor(folder=job_dir)
|
||||
executor.update_parameters(
|
||||
time=slurm_time_hours * 60,
|
||||
ntasks_per_node=1,
|
||||
cpus_per_task=cpus_per_task,
|
||||
array_parallelism=n_workers,
|
||||
mem=slurm_memory,
|
||||
gpus_per_node=0,
|
||||
)
|
||||
if slurm_qos is not None:
|
||||
executor.update_parameters(qos=slurm_qos)
|
||||
else:
|
||||
raise ValueError("Invalid parallel mode")
|
||||
|
||||
assert src_dir.endswith("/"), "src_dir must end with a /"
|
||||
assert dst_dir.endswith("/"), "dst_dir must end with a /"
|
||||
src_fs = get_fs(src_dir, s3_profile=src_s3_profile)
|
||||
dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile)
|
||||
|
||||
src_dir = strip_s3_prefix(src_dir)
|
||||
dst_dir = strip_s3_prefix(dst_dir)
|
||||
logging.info("src: %s, dst: %s", src_dir, dst_dir)
|
||||
|
||||
assert src_fs.isdir(src_dir), "src_dir must be a directory"
|
||||
if dst_fs.exists(dst_dir):
|
||||
assert dst_dir, "dst_dir must be a directory if it exists"
|
||||
else:
|
||||
dst_fs.mkdirs(dst_dir, exist_ok=True)
|
||||
|
||||
files = src_fs.find(src_dir)
|
||||
logging.info("Files found to check for transfer: %s", len(files))
|
||||
jobs = []
|
||||
with executor.batch():
|
||||
for src_file in files:
|
||||
relative_src = src_file[len(src_dir) :]
|
||||
dst_file_path = os.path.join(dst_dir, relative_src)
|
||||
logging.debug("src: %s -> dst %s", src_file, dst_file_path)
|
||||
if dst_fs.exists(dst_file_path):
|
||||
if overwrite_mode == OverwriteMode.NEVER:
|
||||
pass
|
||||
elif overwrite_mode == OverwriteMode.ALWAYS:
|
||||
logging.info("copy: %s -> %s", src_file, dst_file_path)
|
||||
job = executor.submit(
|
||||
CopyJob(),
|
||||
src_fs.to_dict(),
|
||||
dst_fs.to_dict(),
|
||||
src_file,
|
||||
dst_file_path,
|
||||
dry_run=dry_run,
|
||||
validate_size=validate_size,
|
||||
)
|
||||
jobs.append(job)
|
||||
elif overwrite_mode == OverwriteMode.SIZE_MISMATCH:
|
||||
if src_fs.size(src_file) != dst_fs.size(dst_file_path):
|
||||
logging.info("copy: %s -> %s", src_file, dst_file_path)
|
||||
job = executor.submit(
|
||||
CopyJob(),
|
||||
src_fs.to_dict(),
|
||||
dst_fs.to_dict(),
|
||||
src_file,
|
||||
dst_file_path,
|
||||
dry_run=dry_run,
|
||||
validate_size=validate_size,
|
||||
)
|
||||
jobs.append(job)
|
||||
else:
|
||||
raise ValueError("Unknown overwrite_mode")
|
||||
else:
|
||||
logging.info("copy: %s -> %s", src_file, dst_file_path)
|
||||
job = executor.submit(
|
||||
CopyJob(),
|
||||
src_fs.to_dict(),
|
||||
dst_fs.to_dict(),
|
||||
src_file,
|
||||
dst_file_path,
|
||||
dry_run=dry_run,
|
||||
validate_size=validate_size,
|
||||
)
|
||||
jobs.append(job)
|
||||
if wait:
|
||||
while True:
|
||||
num_finished = sum(job.done() for job in jobs)
|
||||
logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished)
|
||||
if num_finished == len(jobs):
|
||||
break
|
||||
time.sleep(wait_period)
|
||||
output = [job.result() for job in jobs]
|
||||
if all(output):
|
||||
logging.info("All copies succeeded")
|
||||
else:
|
||||
logging.info("Some copies failed")
|
||||
else:
|
||||
logging.info("Not waiting for job to complete before exiting submit program")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
Loading…
Reference in a new issue