blt/setup/download_prepare_hf_data.py
Pedro Rodriguez b0120da72f
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled
Replace regular filesystem calls with fsspec + add s3 support (#18)
Summary:

For compatibility with either local/nfs or S3 datasets, swap to fsspec.

Add a tool to compare local and remote filesystems

Test Plan:

- Ran regular train script
- Ran with config with data in S3
2025-01-10 11:04:41 -08:00

168 lines
5.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
import os
import subprocess
import time
import fsspec
import requests
from huggingface_hub import snapshot_download
def run_command(command):
print(f"Running: {command}")
subprocess.run(command, shell=True, check=True)
def download_dataset(repo_id, local_dir, allow_patterns):
print(f"Downloading dataset from {repo_id}...")
max_retries = 5
retry_delay = 10 # seconds
for attempt in range(max_retries):
try:
snapshot_download(
repo_id,
repo_type="dataset",
local_dir=local_dir,
allow_patterns=allow_patterns,
resume_download=True,
max_workers=16, # Don't hesitate to increase this number to lower the download time
)
break
except requests.exceptions.ReadTimeout:
if attempt < max_retries - 1:
print(f"Timeout occurred. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
raise
print(f"Dataset downloaded to {local_dir}")
def parquet_to_jsonl(
dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None
):
from datatrove.executor import LocalPipelineExecutor
from datatrove.pipeline.readers import ParquetReader
from datatrove.pipeline.writers import JsonlWriter
if tgt_dir.startswith("s3//"):
if s3_profile is None:
out_spec = tgt_dir
else:
out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile))
else:
out_spec = tgt_dir
pipeline_exec = LocalPipelineExecutor(
pipeline=[
ParquetReader(
src_dir,
file_progress=True,
doc_progress=True,
glob_pattern="**/*.parquet",
),
JsonlWriter(
out_spec,
output_filename=dataset + ".chunk.${rank}.jsonl",
compression=None,
),
],
tasks=ntasks,
logging_dir=os.path.join(work_dir, "datatrove"),
)
pipeline_exec.run()
def setup_terashuf(work_dir):
terashuf_dir = os.path.join(work_dir, "terashuf")
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
if os.path.exists(terashuf_executable):
print("terashuf executable already exists. Skipping setup.")
return terashuf_dir
print("Setting up terashuf...")
run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}")
run_command(f"make -C {terashuf_dir}")
return terashuf_dir
def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None):
# Configuration
repo_id = {
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
"fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu",
"dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0",
"dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0",
}[dataset]
src_dir = f"{data_dir}/{dataset}"
out_dir = f"{src_dir}_shuffled"
os.makedirs(out_dir, exist_ok=True)
work_dir = src_dir # Directory of this Python file
prefix = f"{dataset}.chunk."
orig_extension = {
"fineweb_edu": ".jsonl",
"fineweb_edu_10bt": ".jsonl",
"dclm_baseline_1.0": ".jsonl.zst",
"dclm_baseline_1.0_10prct": ".jsonl.zst",
}[dataset]
cat_command = {
"fineweb_edu": "cat",
"fineweb_edu_10bt": "cat",
"dclm_baseline_1.0": "zstdcat",
"dclm_baseline_1.0_10prct": "zstdcat",
}[dataset]
allow_patterns = {
"fineweb_edu": None,
"fineweb_edu_10bt": "sample/10BT/*",
"dclm_baseline_1.0": "*.jsonl.zst",
"dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst",
}[dataset]
suffix = ".jsonl"
k_validation = 10000 # Number of lines to take from each chunk for validation
# Setup terashuf
terashuf_dir = setup_terashuf(work_dir)
# Download dataset
download_dataset(repo_id, src_dir, allow_patterns)
if "fineweb" in dataset:
parquet_to_jsonl(dataset, work_dir, src_dir, src_dir)
# Set up environment variables
os.environ["MEMORY"] = f"{memory}"
os.environ["SEED"] = f"{seed}"
# Run the original shuffling and splitting command
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
run_command(
f"ulimit -n 100000 && "
f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | "
f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}"
"; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;"
)
# Create validation set and remove lines from chunks
validation_file = f"{out_dir}/{dataset}.val{suffix}"
for i in range(nchunks):
chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}"
run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}")
run_command(f"sed -i '1,{k_validation}d' {chunk_file}")
print("All tasks completed successfully!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str)
parser.add_argument("memory", type=float, default=8)
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nchunks", type=int, default=32)
args = parser.parse_args()
main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)