mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 06:14:35 +00:00
Initial commit
This commit is contained in:
commit
bcc039bb75
86 changed files with 12203 additions and 0 deletions
156
setup/download_prepare_hf_data.py
Normal file
156
setup/download_prepare_hf_data.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def run_command(command):
|
||||
print(f"Running: {command}")
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
|
||||
|
||||
def download_dataset(repo_id, local_dir, allow_patterns):
|
||||
print(f"Downloading dataset from {repo_id}...")
|
||||
max_retries = 5
|
||||
retry_delay = 10 # seconds
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
allow_patterns=allow_patterns,
|
||||
resume_download=True,
|
||||
max_workers=16, # Don't hesitate to increase this number to lower the download time
|
||||
)
|
||||
break
|
||||
except requests.exceptions.ReadTimeout:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"Timeout occurred. Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
raise
|
||||
print(f"Dataset downloaded to {local_dir}")
|
||||
|
||||
|
||||
def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.pipeline.readers import ParquetReader
|
||||
from datatrove.pipeline.writers import JsonlWriter
|
||||
|
||||
pipeline_exec = LocalPipelineExecutor(
|
||||
pipeline=[
|
||||
ParquetReader(
|
||||
src_dir,
|
||||
file_progress=True,
|
||||
doc_progress=True,
|
||||
glob_pattern="**/*.parquet",
|
||||
),
|
||||
JsonlWriter(
|
||||
tgt_dir,
|
||||
output_filename=dataset + ".chunk.${rank}.jsonl",
|
||||
compression=None,
|
||||
),
|
||||
],
|
||||
tasks=ntasks,
|
||||
logging_dir=os.path.join(work_dir, "datatrove"),
|
||||
)
|
||||
pipeline_exec.run()
|
||||
|
||||
|
||||
def setup_terashuf(work_dir):
|
||||
terashuf_dir = os.path.join(work_dir, "terashuf")
|
||||
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
|
||||
|
||||
if os.path.exists(terashuf_executable):
|
||||
print("terashuf executable already exists. Skipping setup.")
|
||||
return terashuf_dir
|
||||
|
||||
print("Setting up terashuf...")
|
||||
run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}")
|
||||
run_command(f"make -C {terashuf_dir}")
|
||||
return terashuf_dir
|
||||
|
||||
|
||||
def main(dataset, memory, data_dir, seed=42, nchunks=32):
|
||||
# Configuration
|
||||
repo_id = {
|
||||
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
|
||||
"fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu",
|
||||
"dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0",
|
||||
"dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0",
|
||||
}[dataset]
|
||||
src_dir = f"{data_dir}/{dataset}"
|
||||
out_dir = f"{src_dir}_shuffled"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
work_dir = src_dir # Directory of this Python file
|
||||
prefix = f"{dataset}.chunk."
|
||||
orig_extension = {
|
||||
"fineweb_edu": ".jsonl",
|
||||
"fineweb_edu_10bt": ".jsonl",
|
||||
"dclm_baseline_1.0": ".jsonl.zst",
|
||||
"dclm_baseline_1.0_10prct": ".jsonl.zst",
|
||||
}[dataset]
|
||||
cat_command = {
|
||||
"fineweb_edu": "cat",
|
||||
"fineweb_edu_10bt": "cat",
|
||||
"dclm_baseline_1.0": "zstdcat",
|
||||
"dclm_baseline_1.0_10prct": "zstdcat",
|
||||
}[dataset]
|
||||
allow_patterns = {
|
||||
"fineweb_edu": None,
|
||||
"fineweb_edu_10bt": "sample/10BT/*",
|
||||
"dclm_baseline_1.0": "*.jsonl.zst",
|
||||
"dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst",
|
||||
}[dataset]
|
||||
suffix = ".jsonl"
|
||||
k_validation = 10000 # Number of lines to take from each chunk for validation
|
||||
|
||||
# Setup terashuf
|
||||
terashuf_dir = setup_terashuf(work_dir)
|
||||
|
||||
# Download dataset
|
||||
download_dataset(repo_id, src_dir, allow_patterns)
|
||||
|
||||
if "fineweb" in dataset:
|
||||
parquet_to_jsonl(dataset, work_dir, src_dir, src_dir)
|
||||
|
||||
# Set up environment variables
|
||||
os.environ["MEMORY"] = f"{memory}"
|
||||
os.environ["SEED"] = f"{seed}"
|
||||
|
||||
# Run the original shuffling and splitting command
|
||||
terashuf_executable = os.path.join(terashuf_dir, "terashuf")
|
||||
run_command(
|
||||
f"ulimit -n 100000 && "
|
||||
f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 {cat_command} | {terashuf_executable} | "
|
||||
f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}"
|
||||
"; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;"
|
||||
)
|
||||
|
||||
# Create validation set and remove lines from chunks
|
||||
validation_file = f"{out_dir}/{dataset}.val{suffix}"
|
||||
for i in range(nchunks):
|
||||
chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}"
|
||||
run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}")
|
||||
run_command(f"sed -i '1,{k_validation}d' {chunk_file}")
|
||||
|
||||
print("All tasks completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("dataset", type=str)
|
||||
parser.add_argument("memory", type=float, default=8)
|
||||
parser.add_argument("--data_dir", type=str, default="data")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--nchunks", type=int, default=32)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks)
|
Loading…
Add table
Add a link
Reference in a new issue