From a1d05403b42cdb111be2cf3034a044e63f45b91b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 117 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 217 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..d67b6db --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,117 @@ +import os + +import fsspec +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +import typer +from pyarrow.lib import ArrowInvalid +from rich.progress import track + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..4e7b99e 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,17 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator +import fsspec import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore +import s3fs from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu",