Replace regular filesystem calls with fsspec + add s3 support

Summary:

For compatibility with either local/nfs or S3 datasets, swap to fsspec.

Add a tool to compare local and remote filesystems

Test Plan:

- Ran regular train script
- Ran with config with data in S3
This commit is contained in:
Pedro Rodriguez 2025-01-10 01:02:25 +00:00
parent d4ddb95322
commit a1d05403b4
7 changed files with 217 additions and 19 deletions

2
.gitignore vendored
View file

@ -165,4 +165,4 @@ cython_debug/
figures/ figures/
.vscode/ .vscode/
.DS_Store .DS_Store
internal/

View file

@ -46,8 +46,11 @@ def distribute_data_to_rank(
arrow_batch_size: int, arrow_batch_size: int,
rank: int, rank: int,
world_size: int, world_size: int,
s3_profile: str | None = None,
) -> ArrowFileIterator: ) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks) n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = [] rank_to_arrow_iterator_params = []
for chunk_path in dataset_chunks: for chunk_path in dataset_chunks:
@ -61,6 +64,7 @@ def distribute_data_to_rank(
dataset_files=None, dataset_files=None,
entropy_model_name=entropy_model_name, entropy_model_name=entropy_model_name,
arrow_batch_size=arrow_batch_size, arrow_batch_size=arrow_batch_size,
s3_profile=s3_profile,
) )
) )
return rank_to_arrow_iterator_params[rank] return rank_to_arrow_iterator_params[rank]
@ -68,6 +72,7 @@ def distribute_data_to_rank(
class DataloaderArgs(BaseModel): class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
s3_profile: str | None = None
root_dir: str | None = None root_dir: str | None = None
sources: dict[str, float] = {} sources: dict[str, float] = {}
batch_size: int = 2 batch_size: int = 2
@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel):
arrow_batch_size=self.arrow_batch_size, arrow_batch_size=self.arrow_batch_size,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
s3_profile=self.s3_profile,
) )
looping_iterator = LoopingIterator(arrow_iterator) looping_iterator = LoopingIterator(arrow_iterator)
preprocess_iterator = PreprocessIterator( preprocess_iterator = PreprocessIterator(

View file

@ -0,0 +1,117 @@
import os
import fsspec
import pyarrow as pa
# pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore
import typer
from pyarrow.lib import ArrowInvalid
from rich.progress import track
def is_valid_arrow_file(path: str):
try:
dataset = pa.dataset.dataset(path, format="arrow")
return True
except ArrowInvalid:
return False
app = typer.Typer()
S3_PREFIX = "s3://"
def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
if path.startswith("s3://"):
if s3_profile is None:
return fsspec.filesystem("s3")
else:
return fsspec.filesystem("s3", profile=s3_profile)
else:
return fsspec.filesystem("file")
@app.command()
def print_local_to_delete(
blob_dir: str, local_dirs: list[str], s3_profile: str = "blt"
):
for s in local_dirs:
assert s.endswith("/"), "Dirs must end with /"
assert blob_dir.endswith("/"), "Dirs must end with /"
blob_fs = fsspec.filesystem("s3", profile=s3_profile)
blob_files = blob_fs.find(blob_dir)
for f in track(blob_files):
size = blob_fs.info(f)["Size"]
if not f.lower().endswith(".complete"):
assert size != 0, f"Size was invalidly zero for {f}"
blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files}
local_fs = fsspec.filesystem("file")
files_to_delete = []
for local_dir in local_dirs:
local_files = local_fs.find(local_dir)
for f in local_files:
relative_path = f[len(local_dir) :]
if relative_path in blob_relative_paths and not os.path.islink(f):
files_to_delete.append(f)
print(len(files_to_delete))
with open("/tmp/files_to_delete.txt", "w") as f:
for file in files_to_delete:
f.write(f"{file}\n")
@app.command()
def compare_local_to_blob(
source_dirs: list[str], dst_dir: str, s3_profile: str = "blt"
):
for s in source_dirs:
assert s.endswith("/"), "Dirs must end with /"
assert dst_dir.endswith("/"), "Dirs must end with /"
assert len(source_dirs) != 0
assert dst_dir.startswith("s3://")
local_fs = fsspec.filesystem("file")
dst_fs = fsspec.filesystem("s3", profile=s3_profile)
source_to_files = {}
all_local_files = set()
for s in source_dirs:
skipped = []
if s not in source_to_files:
source_to_files[s] = []
for f in local_fs.find(s):
if os.path.islink(f):
continue
if f.endswith(".COMPLETE") or f.endswith(".complete"):
is_complete_file = True
assert os.path.getsize(f) == 0, ".COMPLETE files should be empty"
else:
is_complete_file = False
if not is_complete_file and os.path.getsize(f) == 0:
skipped.append(f)
continue
if f.endswith(".arrow"):
if not is_valid_arrow_file(f):
skipped.append(f)
continue
source_to_files[s].append(f)
all_local_files.add(f[len(s) :])
print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10])
dst_files = dst_fs.find(dst_dir)
print(dst_dir, len(dst_files))
dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files}
diff = all_local_files.symmetric_difference(dst_file_set)
print("Local files", len(all_local_files))
print("DST Files", len(dst_file_set))
print("Symmetric difference", len(diff))
dst_only_files = dst_file_set - all_local_files
print("DST only", len(dst_only_files), list(dst_only_files)[:10])
if __name__ == "__main__":
app()

View file

@ -1,17 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import os
import re import re
from logging import getLogger from logging import getLogger
from pathlib import Path
from typing import Any, Generator from typing import Any, Generator
import fsspec
import pyarrow as pa import pyarrow as pa
# pyarrow needs the initialization from this import # pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore import pyarrow.dataset # pyright: ignore
import s3fs
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent import ByteLatentError from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
logger = getLogger(__name__) logger = getLogger(__name__)
@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
dataset_files: list[str] | None dataset_files: list[str] | None
entropy_model_name: str | None entropy_model_name: str | None
arrow_batch_size: int = 100 arrow_batch_size: int = 100
s3_profile: str | None
filesystem_type: str | None = None
def build(self) -> "ArrowFileIterator": def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator( arrow_file = ArrowFileIterator(
@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
entropy_model_name=self.entropy_model_name, entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size, arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files, dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
) )
if self.row_num != 0: if self.row_num != 0:
arrow_file._set_row_num(self.row_num) arrow_file._set_row_num(self.row_num)
return arrow_file return arrow_file
def shard_sort_key(file: str | Path): def shard_sort_key(file: str):
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) assert isinstance(file, str)
match = re.search(r".+\.shard_([0-9]+)\.arrow", file)
shard_number = int(match.group(1)) shard_number = int(match.group(1))
return shard_number return shard_number
@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator):
entropy_model_name: str | None, entropy_model_name: str | None,
arrow_batch_size: int, arrow_batch_size: int,
dataset_files: list[str] | None = None, dataset_files: list[str] | None = None,
s3_profile: str | None = None,
filesystem_type: str | None = None,
): ):
assert 0 <= worker_id < num_workers, (worker_id, num_workers) assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None: if file_path is None and dataset_files is None:
@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator):
self.preprocess_dir = preprocess_dir self.preprocess_dir = preprocess_dir
self.entropy_model_name = entropy_model_name self.entropy_model_name = entropy_model_name
self.arrow_batch_size = arrow_batch_size self.arrow_batch_size = arrow_batch_size
self.s3_profile = s3_profile
self.filesystem_type = filesystem_type
self.fs = None
if self.filesystem_type is not None:
if self.filesystem_type == "file":
self.fs = fsspec.filesystem("file")
elif self.filesystem_type == "s3":
self.fs = fsspec.filesystem("s3", profile=s3_profile)
if dataset_files is None: if dataset_files is None:
# Prepare arrow shards # Prepare arrow shards
jsonl_file = Path(file_path) jsonl_file = file_path
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) parts = re.match(
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
)
assert parts is not None assert parts is not None
dataset = parts.group(1) dataset = parts.group(1)
data_dir = Path(preprocess_dir) / dataset / entropy_model_name data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) data_dir_with_glob = os.path.join(
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
)
if self.fs is None:
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
if isinstance(self.fs, s3fs.S3FileSystem):
self.filesystem_type = "s3"
else:
self.filesystem_type = "file"
shard_files = self.fs.glob(data_dir_with_glob)
for s in shard_files: for s in shard_files:
if not (data_dir / f"{s.name}.complete").exists(): complete_file = os.path.join(
data_dir, f"{os.path.basename(s)}.complete"
)
if not self.fs.exists(complete_file):
raise ValueError(f"Missing .complete for input file: {s}") raise ValueError(f"Missing .complete for input file: {s}")
shard_files = sorted(shard_files, key=shard_sort_key) shard_files = sorted(shard_files, key=shard_sort_key)
@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator):
raise ByteLatentError( raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
) )
self.dataset_files = [str(f) for f in shard_files] self.dataset_files = [f for f in shard_files]
else: else:
self.preprocess_dir = None self.preprocess_dir = None
self.dataset_files = dataset_files self.dataset_files = dataset_files
if dataset_files[0].startswith("s3://"):
for f in dataset_files:
assert f.startswith("s3://")
if self.fs is None:
self.fs = get_fs(dataset_files[0], s3_profile=s3_profile)
if isinstance(self.fs, s3fs.S3FileSystem):
self.filesystem_type = "s3"
else:
self.filesystem_type = "file"
def get_state(self) -> ArrowFileIteratorState: def get_state(self) -> ArrowFileIteratorState:
return ArrowFileIteratorState( return ArrowFileIteratorState(
@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator):
entropy_model_name=self.entropy_model_name, entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size, arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files, dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
) )
def create_iter( def create_iter(
self, self,
) -> Generator[BltExample, Any, None]: ) -> Generator[BltExample, Any, None]:
if self.dataset is None: if self.dataset is None:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") if isinstance(self.fs, s3fs.core.S3FileSystem):
filesystem = self.fs
else:
filesystem = None
self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem
)
self.batch_iterator = self.dataset.to_batches( self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size batch_size=self.arrow_batch_size
) )
@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator):
self.batch_iterator = None self.batch_iterator = None
self.batch_to_consume = None self.batch_to_consume = None
else: else:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") if isinstance(self.fs, s3fs.core.S3FileSystem):
filesystem = self.fs
else:
filesystem = None
self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem
)
self.batch_iterator = self.dataset.to_batches( self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size batch_size=self.arrow_batch_size
) )
@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks( def find_and_sanitize_chunks(
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN dataset_path: str,
world_size: int,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
s3_profile: str | None = None,
): ):
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks) n_chunks = len(dataset_chunks)
if n_chunks > world_size: if n_chunks > world_size:

View file

@ -91,7 +91,7 @@ def init_logger(
log_file: str | None = None, log_file: str | None = None,
*, *,
name: str | None = None, name: str | None = None,
level: str = "NOTSET", level: str = "INFO",
): ):
""" """
Setup logging. Setup logging.

View file

@ -20,3 +20,4 @@ altair
submitit submitit
typer typer
rich rich
fsspec[full]

View file

@ -5,6 +5,7 @@ import os
import subprocess import subprocess
import time import time
import fsspec
import requests import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns):
print(f"Dataset downloaded to {local_dir}") print(f"Dataset downloaded to {local_dir}")
def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): def parquet_to_jsonl(
dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None
):
from datatrove.executor import LocalPipelineExecutor from datatrove.executor import LocalPipelineExecutor
from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.readers import ParquetReader
from datatrove.pipeline.writers import JsonlWriter from datatrove.pipeline.writers import JsonlWriter
if tgt_dir.startswith("s3//"):
if s3_profile is None:
out_spec = tgt_dir
else:
out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile))
else:
out_spec = tgt_dir
pipeline_exec = LocalPipelineExecutor( pipeline_exec = LocalPipelineExecutor(
pipeline=[ pipeline=[
ParquetReader( ParquetReader(
@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
glob_pattern="**/*.parquet", glob_pattern="**/*.parquet",
), ),
JsonlWriter( JsonlWriter(
tgt_dir, out_spec,
output_filename=dataset + ".chunk.${rank}.jsonl", output_filename=dataset + ".chunk.${rank}.jsonl",
compression=None, compression=None,
), ),
@ -77,7 +88,7 @@ def setup_terashuf(work_dir):
return terashuf_dir return terashuf_dir
def main(dataset, memory, data_dir, seed=42, nchunks=32): def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None):
# Configuration # Configuration
repo_id = { repo_id = {
"fineweb_edu": "HuggingFaceFW/fineweb-edu", "fineweb_edu": "HuggingFaceFW/fineweb-edu",