# Copyright (c) Meta Platforms, Inc. and affiliates. import argparse import os import subprocess import time import fsspec import requests from huggingface_hub import snapshot_download def run_command(command): print(f"Running: {command}") subprocess.run(command, shell=True, check=True) def download_dataset(repo_id, local_dir, allow_patterns): print(f"Downloading dataset from {repo_id}...") max_retries = 5 retry_delay = 10 # seconds for attempt in range(max_retries): try: snapshot_download( repo_id, repo_type="dataset", local_dir=local_dir, allow_patterns=allow_patterns, resume_download=True, max_workers=16, # Don't hesitate to increase this number to lower the download time ) break except requests.exceptions.ReadTimeout: if attempt < max_retries - 1: print(f"Timeout occurred. Retrying in {retry_delay} seconds...") time.sleep(retry_delay) else: raise print(f"Dataset downloaded to {local_dir}") def parquet_to_jsonl( dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None ): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter if tgt_dir.startswith("s3//"): if s3_profile is None: out_spec = tgt_dir else: out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) else: out_spec = tgt_dir pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( src_dir, file_progress=True, doc_progress=True, glob_pattern="**/*.parquet", ), JsonlWriter( out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), ], tasks=ntasks, logging_dir=os.path.join(work_dir, "datatrove"), ) pipeline_exec.run() def setup_terashuf(work_dir): terashuf_dir = os.path.join(work_dir, "terashuf") terashuf_executable = os.path.join(terashuf_dir, "terashuf") if os.path.exists(terashuf_executable): print("terashuf executable already exists. Skipping setup.") return terashuf_dir print("Setting up terashuf...") run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}") run_command(f"make -C {terashuf_dir}") return terashuf_dir def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", "fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu", "dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0", "dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0", }[dataset] src_dir = f"{data_dir}/{dataset}" out_dir = f"{src_dir}_shuffled" os.makedirs(out_dir, exist_ok=True) work_dir = src_dir # Directory of this Python file prefix = f"{dataset}.chunk." orig_extension = { "fineweb_edu": ".jsonl", "fineweb_edu_10bt": ".jsonl", "dclm_baseline_1.0": ".jsonl.zst", "dclm_baseline_1.0_10prct": ".jsonl.zst", }[dataset] cat_command = { "fineweb_edu": "cat", "fineweb_edu_10bt": "cat", "dclm_baseline_1.0": "zstdcat", "dclm_baseline_1.0_10prct": "zstdcat", }[dataset] allow_patterns = { "fineweb_edu": None, "fineweb_edu_10bt": "sample/10BT/*", "dclm_baseline_1.0": "*.jsonl.zst", "dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst", }[dataset] suffix = ".jsonl" k_validation = 10000 # Number of lines to take from each chunk for validation # Setup terashuf terashuf_dir = setup_terashuf(work_dir) # Download dataset download_dataset(repo_id, src_dir, allow_patterns) if "fineweb" in dataset: parquet_to_jsonl(dataset, work_dir, src_dir, src_dir) # Set up environment variables os.environ["MEMORY"] = f"{memory}" os.environ["SEED"] = f"{seed}" # Run the original shuffling and splitting command terashuf_executable = os.path.join(terashuf_dir, "terashuf") run_command( f"ulimit -n 100000 && " f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | " f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}" "; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;" ) # Create validation set and remove lines from chunks validation_file = f"{out_dir}/{dataset}.val{suffix}" for i in range(nchunks): chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}" run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}") run_command(f"sed -i '1,{k_validation}d' {chunk_file}") print("All tasks completed successfully!") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("dataset", type=str) parser.add_argument("memory", type=float, default=8) parser.add_argument("--data_dir", type=str, default="data") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--nchunks", type=int, default=32) args = parser.parse_args() main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)