Reduce per file resources arrow uses

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-03-05 23:03:14 +00:00
parent 8f2cf8899d
commit 880493e742
5 changed files with 68 additions and 3 deletions

View file

@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
preprocess_dir: str | None = None
dataset_files: list[str] | None = None
entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100
# Be very careful with increasing, increases memory usage by that factor per rank, per data source
arrow_batch_size: int = 20
buffer_size: int = 64
file_format: str = "arrow"

View file

@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out
self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size,
# We have large files in GBs, no need to readahead
fragment_readahead=1,
# Don't readahead in case batches are huge (e.g., books)
batch_readahead=1,
)
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
if self.file_format == "arrow":

View file

@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.preprocess_iterator import (
PreprocessIterator,
PreprocessIteratorState,
@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
)
def get_datafile(
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
):
if isinstance(iterator, ArrowFileIterator):
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
elif isinstance(iterator, PreprocessIterator):
return get_datafile(iterator.arrow_iterator)
elif isinstance(iterator, LoopingIterator):
return get_datafile(iterator.file_iterator)
elif isinstance(iterator, LimitIterator):
return get_datafile(iterator.base_iterator)
else:
raise NotImplementedError()
class SequenceIterator(StatefulIterator):
def __init__(
self,
@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
tokens: list[int] = []
mask: list[bool] = []
first = True
logger.info(
"Starting first buffer for: %s",
get_datafile(self.preprocess_iterator),
)
for example in example_iter:
assert example.tokens is not None
assert example.mask is not None
@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
while len(patch_lengths) >= n_buffer_patches:
if first:
first = False
logger.info("First buffer complete")
logger.info(
"First buffer complete for: %s",
get_datafile(self.preprocess_iterator),
)
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
self.buffer_size, self.output_seq_len

View file

@ -0,0 +1,30 @@
import json
import pyarrow
import typer
from rich.progress import track
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
from bytelatent.logger import init_logger
def main(state_file: str):
init_logger()
pyarrow.set_io_thread_count(4)
pyarrow.set_cpu_count(4)
with open(state_file) as f:
train_state = json.load(f)
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
packing_iterator_state = dl_state.base_iterator_state
print("building")
packing_iterator = packing_iterator_state.build()
print("iter")
batch_iter = packing_iterator.create_iter()
batch = None
print("looping")
for i in track(range(1_000)):
batch = next(batch_iter)
if __name__ == "__main__":
typer.run(main)

View file

@ -13,6 +13,7 @@ from timeit import default_timer as timer
from typing import Any, TypeVar
import numpy as np
import pyarrow
import torch
import torch.distributed
import torch.nn.functional
@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
def train(args: TrainArgs):
with ExitStack() as context_stack:
pyarrow.set_io_thread_count(4)
pyarrow.set_cpu_count(4)
tokenizer = args.data.tokenizer_args.build()
validate_train_args(
args,