From 63913e4dbab3fb57097434570caae8728befbf46 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <par@meta.com> Date: Wed, 5 Mar 2025 15:03:42 -0800 Subject: [PATCH] Reduce per file resources arrow uses (#77) Summary: Test Plan: --- bytelatent/args.py | 3 +- bytelatent/data/iterators/arrow_iterator.py | 8 ++++- .../data/iterators/sequence_iterator.py | 27 ++++++++++++++++- bytelatent/iterate_data.py | 30 +++++++++++++++++++ bytelatent/train.py | 3 ++ 5 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 bytelatent/iterate_data.py diff --git a/bytelatent/args.py b/bytelatent/args.py index 4a67231..bad4d17 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel): preprocess_dir: str | None = None dataset_files: list[str] | None = None entropy_model_name: str | None = "transformer_100m" - arrow_batch_size: int = 100 + # Be very careful with increasing, increases memory usage by that factor per rank, per data source + arrow_batch_size: int = 20 buffer_size: int = 64 file_format: str = "arrow" diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index ee57472..bfb1c17 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator): if (self.row_num - 1) % self.num_workers == self.worker_id: yield out - self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size) + self.batch_iterator = self.dataset.to_batches( + batch_size=self.arrow_batch_size, + # We have large files in GBs, no need to readahead + fragment_readahead=1, + # Don't readahead in case batches are huge (e.g., books) + batch_readahead=1, + ) for batch in self.batch_iterator: batch_columns = batch.to_pydict() if self.file_format == "arrow": diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 0a492be..3b5214d 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import ( PydanticIteratorState, StatefulIterator, ) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.iterators.limit_iterator import LimitIterator +from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.preprocess_iterator import ( PreprocessIterator, PreprocessIteratorState, @@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState): ) +def get_datafile( + iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator, +): + if isinstance(iterator, ArrowFileIterator): + return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}" + elif isinstance(iterator, PreprocessIterator): + return get_datafile(iterator.arrow_iterator) + elif isinstance(iterator, LoopingIterator): + return get_datafile(iterator.file_iterator) + elif isinstance(iterator, LimitIterator): + return get_datafile(iterator.base_iterator) + else: + raise NotImplementedError() + + class SequenceIterator(StatefulIterator): def __init__( self, @@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator): tokens: list[int] = [] mask: list[bool] = [] first = True + logger.info( + "Starting first buffer for: %s", + get_datafile(self.preprocess_iterator), + ) for example in example_iter: assert example.tokens is not None assert example.mask is not None @@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator): while len(patch_lengths) >= n_buffer_patches: if first: first = False - logger.info("First buffer complete") + logger.info( + "First buffer complete for: %s", + get_datafile(self.preprocess_iterator), + ) x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape( self.buffer_size, self.output_seq_len diff --git a/bytelatent/iterate_data.py b/bytelatent/iterate_data.py new file mode 100644 index 0000000..bdb8f22 --- /dev/null +++ b/bytelatent/iterate_data.py @@ -0,0 +1,30 @@ +import json + +import pyarrow +import typer +from rich.progress import track + +from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState +from bytelatent.logger import init_logger + + +def main(state_file: str): + init_logger() + pyarrow.set_io_thread_count(4) + pyarrow.set_cpu_count(4) + with open(state_file) as f: + train_state = json.load(f) + dl_state = MultiprocessIteratorState(**train_state["data_loader_state"]) + packing_iterator_state = dl_state.base_iterator_state + print("building") + packing_iterator = packing_iterator_state.build() + print("iter") + batch_iter = packing_iterator.create_iter() + batch = None + print("looping") + for i in track(range(1_000)): + batch = next(batch_iter) + + +if __name__ == "__main__": + typer.run(main) diff --git a/bytelatent/train.py b/bytelatent/train.py index ad2fd9b..5a8f937 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -13,6 +13,7 @@ from timeit import default_timer as timer from typing import Any, TypeVar import numpy as np +import pyarrow import torch import torch.distributed import torch.nn.functional @@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale): def train(args: TrainArgs): with ExitStack() as context_stack: + pyarrow.set_io_thread_count(4) + pyarrow.set_cpu_count(4) tokenizer = args.data.tokenizer_args.build() validate_train_args( args,