# Copyright (c) Meta Platforms, Inc. and affiliates. import subprocess from pathlib import Path import submitit import typer class PreprocessEntropiesJob(submitit.helpers.Checkpointable): def __init__(self) -> None: pass def __call__(self, shard_file: str, output_filename: str): subprocess.run( [ "python", "-u", "-m", "bytelatent.preprocess.preprocess_entropies", str(shard_file), str(output_filename), ], check=True, ) return True def chunk(items, size): for i in range(0, len(items), size): yield items[i : i + size] def main( job_folder: str, input_dir: str, output_dir: str, qos: str = "explore", slurm_batch_size: int = 1000, check_only: bool = False, wait: bool = False, ): input_dir = Path(input_dir) output_dir = Path(output_dir) shard_files = [ p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name ] if check_only: exist = [] missing = [] for shard_file in shard_files: shard_file = Path(shard_file) complete_file = output_dir / f"{shard_file.name}.arrow.complete" if complete_file.exists(): exist.append(complete_file) else: missing.append(complete_file) print("Checked for output files for input_dir=", input_dir) print("Exist:", len(exist)) print("Missing:", len(missing)) print(missing) return print("Running parallel job over N files=", len(shard_files)) print("Input Directory:", input_dir) print("Output Directory:", output_dir) output_dir.mkdir(exist_ok=True, parents=True) executor = submitit.SlurmExecutor(job_folder) executor.update_parameters( # 12 hours in minutes time=60 * 12, qos=qos, exclusive="user", cpus_per_task=4, num_gpus=1, mem_per_gpu="80G", array_parallelism=slurm_batch_size, ) jobs = [] n_batches = 0 n_skipped = 0 n_launched = 0 for file_batch in chunk(shard_files, slurm_batch_size): with executor.batch(): for shard_file in file_batch: output_filename = Path(output_dir) / f"{shard_file.name}.arrow" complete_output_filename = ( Path(output_dir) / f"{shard_file.name}.arrow.complete" ) if complete_output_filename.exists(): n_skipped += 1 else: job = executor.submit( PreprocessEntropiesJob(), str(shard_file), str(output_filename) ) n_launched += 1 jobs.append(job) n_batches += 1 print("launched array jobs n=", n_launched) print("skipped (completed) array jobs n=", n_skipped) print("number of slurm batches=", n_batches) if wait: output = [job.result() for job in jobs] assert all(output) if __name__ == "__main__": typer.run(main)