mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 08:27:45 +00:00
b0120da72f
Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3
168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import time
|
|
|
|
import fsspec
|
|
import requests
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
def run_command(command):
|
|
print(f"Running: {command}")
|
|
subprocess.run(command, shell=True, check=True)
|
|
|
|
|
|
def download_dataset(repo_id, local_dir, allow_patterns):
|
|
print(f"Downloading dataset from {repo_id}...")
|
|
max_retries = 5
|
|
retry_delay = 10 # seconds
|
|
for attempt in range(max_retries):
|
|
try:
|
|
snapshot_download(
|
|
repo_id,
|
|
repo_type="dataset",
|
|
local_dir=local_dir,
|
|
allow_patterns=allow_patterns,
|
|
resume_download=True,
|
|
max_workers=16, # Don't hesitate to increase this number to lower the download time
|
|
)
|
|
break
|
|
except requests.exceptions.ReadTimeout:
|
|
if attempt < max_retries - 1:
|
|
print(f"Timeout occurred. Retrying in {retry_delay} seconds...")
|
|
time.sleep(retry_delay)
|
|
else:
|
|
raise
|
|
print(f"Dataset downloaded to {local_dir}")
|
|
|
|
|
|
def parquet_to_jsonl(
|
|
dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None
|
|
):
|
|
from datatrove.executor import LocalPipelineExecutor
|
|
from datatrove.pipeline.readers import ParquetReader
|
|
from datatrove.pipeline.writers import JsonlWriter
|
|
|
|
if tgt_dir.startswith("s3//"):
|
|
if s3_profile is None:
|
|
out_spec = tgt_dir
|
|
else:
|
|
out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile))
|
|
else:
|
|
out_spec = tgt_dir
|
|
|
|
pipeline_exec = LocalPipelineExecutor(
|
|
pipeline=[
|
|
ParquetReader(
|
|
src_dir,
|
|
file_progress=True,
|
|
doc_progress=True,
|
|
glob_pattern="**/*.parquet",
|
|
),
|
|
JsonlWriter(
|
|
out_spec,
|
|
output_filename=dataset + ".chunk.${rank}.jsonl",
|
|
compression=None,
|
|
),
|
|
],
|
|
tasks=ntasks,
|
|
logging_dir=os.path.join(work_dir, "datatrove"),
|
|
)
|
|
pipeline_exec.run()
|
|
|
|
|
|
def setup_terashuf(work_dir):
|
|
terashuf_dir = os.path.join(work_dir, "terashuf")
|
|
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
|
|
|
|
if os.path.exists(terashuf_executable):
|
|
print("terashuf executable already exists. Skipping setup.")
|
|
return terashuf_dir
|
|
|
|
print("Setting up terashuf...")
|
|
run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}")
|
|
run_command(f"make -C {terashuf_dir}")
|
|
return terashuf_dir
|
|
|
|
|
|
def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None):
|
|
# Configuration
|
|
repo_id = {
|
|
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
|
|
"fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu",
|
|
"dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0",
|
|
"dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0",
|
|
}[dataset]
|
|
src_dir = f"{data_dir}/{dataset}"
|
|
out_dir = f"{src_dir}_shuffled"
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
work_dir = src_dir # Directory of this Python file
|
|
prefix = f"{dataset}.chunk."
|
|
orig_extension = {
|
|
"fineweb_edu": ".jsonl",
|
|
"fineweb_edu_10bt": ".jsonl",
|
|
"dclm_baseline_1.0": ".jsonl.zst",
|
|
"dclm_baseline_1.0_10prct": ".jsonl.zst",
|
|
}[dataset]
|
|
cat_command = {
|
|
"fineweb_edu": "cat",
|
|
"fineweb_edu_10bt": "cat",
|
|
"dclm_baseline_1.0": "zstdcat",
|
|
"dclm_baseline_1.0_10prct": "zstdcat",
|
|
}[dataset]
|
|
allow_patterns = {
|
|
"fineweb_edu": None,
|
|
"fineweb_edu_10bt": "sample/10BT/*",
|
|
"dclm_baseline_1.0": "*.jsonl.zst",
|
|
"dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst",
|
|
}[dataset]
|
|
suffix = ".jsonl"
|
|
k_validation = 10000 # Number of lines to take from each chunk for validation
|
|
|
|
# Setup terashuf
|
|
terashuf_dir = setup_terashuf(work_dir)
|
|
|
|
# Download dataset
|
|
download_dataset(repo_id, src_dir, allow_patterns)
|
|
|
|
if "fineweb" in dataset:
|
|
parquet_to_jsonl(dataset, work_dir, src_dir, src_dir)
|
|
|
|
# Set up environment variables
|
|
os.environ["MEMORY"] = f"{memory}"
|
|
os.environ["SEED"] = f"{seed}"
|
|
|
|
# Run the original shuffling and splitting command
|
|
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
|
|
run_command(
|
|
f"ulimit -n 100000 && "
|
|
f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | "
|
|
f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}"
|
|
"; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;"
|
|
)
|
|
|
|
# Create validation set and remove lines from chunks
|
|
validation_file = f"{out_dir}/{dataset}.val{suffix}"
|
|
for i in range(nchunks):
|
|
chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}"
|
|
run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}")
|
|
run_command(f"sed -i '1,{k_validation}d' {chunk_file}")
|
|
|
|
print("All tasks completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("dataset", type=str)
|
|
parser.add_argument("memory", type=float, default=8)
|
|
parser.add_argument("--data_dir", type=str, default="data")
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--nchunks", type=int, default=32)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)
|