Reduce per file resources arrow uses ()

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-03-05 15:03:42 -08:00 committed by GitHub
parent 8f2cf8899d
commit 63913e4dba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 68 additions and 3 deletions

View file

@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
preprocess_dir: str | None = None preprocess_dir: str | None = None
dataset_files: list[str] | None = None dataset_files: list[str] | None = None
entropy_model_name: str | None = "transformer_100m" entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100 # Be very careful with increasing, increases memory usage by that factor per rank, per data source
arrow_batch_size: int = 20
buffer_size: int = 64 buffer_size: int = 64
file_format: str = "arrow" file_format: str = "arrow"

View file

@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
if (self.row_num - 1) % self.num_workers == self.worker_id: if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out yield out
self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size) self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size,
# We have large files in GBs, no need to readahead
fragment_readahead=1,
# Don't readahead in case batches are huge (e.g., books)
batch_readahead=1,
)
for batch in self.batch_iterator: for batch in self.batch_iterator:
batch_columns = batch.to_pydict() batch_columns = batch.to_pydict()
if self.file_format == "arrow": if self.file_format == "arrow":

View file

@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState, PydanticIteratorState,
StatefulIterator, StatefulIterator,
) )
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.preprocess_iterator import ( from bytelatent.data.iterators.preprocess_iterator import (
PreprocessIterator, PreprocessIterator,
PreprocessIteratorState, PreprocessIteratorState,
@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
) )
def get_datafile(
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
):
if isinstance(iterator, ArrowFileIterator):
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
elif isinstance(iterator, PreprocessIterator):
return get_datafile(iterator.arrow_iterator)
elif isinstance(iterator, LoopingIterator):
return get_datafile(iterator.file_iterator)
elif isinstance(iterator, LimitIterator):
return get_datafile(iterator.base_iterator)
else:
raise NotImplementedError()
class SequenceIterator(StatefulIterator): class SequenceIterator(StatefulIterator):
def __init__( def __init__(
self, self,
@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
tokens: list[int] = [] tokens: list[int] = []
mask: list[bool] = [] mask: list[bool] = []
first = True first = True
logger.info(
"Starting first buffer for: %s",
get_datafile(self.preprocess_iterator),
)
for example in example_iter: for example in example_iter:
assert example.tokens is not None assert example.tokens is not None
assert example.mask is not None assert example.mask is not None
@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
while len(patch_lengths) >= n_buffer_patches: while len(patch_lengths) >= n_buffer_patches:
if first: if first:
first = False first = False
logger.info("First buffer complete") logger.info(
"First buffer complete for: %s",
get_datafile(self.preprocess_iterator),
)
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape( x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
self.buffer_size, self.output_seq_len self.buffer_size, self.output_seq_len

View file

@ -0,0 +1,30 @@
import json
import pyarrow
import typer
from rich.progress import track
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
from bytelatent.logger import init_logger
def main(state_file: str):
init_logger()
pyarrow.set_io_thread_count(4)
pyarrow.set_cpu_count(4)
with open(state_file) as f:
train_state = json.load(f)
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
packing_iterator_state = dl_state.base_iterator_state
print("building")
packing_iterator = packing_iterator_state.build()
print("iter")
batch_iter = packing_iterator.create_iter()
batch = None
print("looping")
for i in track(range(1_000)):
batch = next(batch_iter)
if __name__ == "__main__":
typer.run(main)

View file

@ -13,6 +13,7 @@ from timeit import default_timer as timer
from typing import Any, TypeVar from typing import Any, TypeVar
import numpy as np import numpy as np
import pyarrow
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional import torch.nn.functional
@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
def train(args: TrainArgs): def train(args: TrainArgs):
with ExitStack() as context_stack: with ExitStack() as context_stack:
pyarrow.set_io_thread_count(4)
pyarrow.set_cpu_count(4)
tokenizer = args.data.tokenizer_args.build() tokenizer = args.data.tokenizer_args.build()
validate_train_args( validate_train_args(
args, args,