mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-08 10:49:09 +00:00
Reduce per file resources arrow uses
Summary: Test Plan:
This commit is contained in:
parent
8f2cf8899d
commit
880493e742
5 changed files with 68 additions and 3 deletions
bytelatent
|
@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
|
|||
preprocess_dir: str | None = None
|
||||
dataset_files: list[str] | None = None
|
||||
entropy_model_name: str | None = "transformer_100m"
|
||||
arrow_batch_size: int = 100
|
||||
# Be very careful with increasing, increases memory usage by that factor per rank, per data source
|
||||
arrow_batch_size: int = 20
|
||||
buffer_size: int = 64
|
||||
file_format: str = "arrow"
|
||||
|
||||
|
|
|
@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
|
|||
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
||||
yield out
|
||||
|
||||
self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
|
||||
self.batch_iterator = self.dataset.to_batches(
|
||||
batch_size=self.arrow_batch_size,
|
||||
# We have large files in GBs, no need to readahead
|
||||
fragment_readahead=1,
|
||||
# Don't readahead in case batches are huge (e.g., books)
|
||||
batch_readahead=1,
|
||||
)
|
||||
for batch in self.batch_iterator:
|
||||
batch_columns = batch.to_pydict()
|
||||
if self.file_format == "arrow":
|
||||
|
|
|
@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
|
|||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||
from bytelatent.data.iterators.preprocess_iterator import (
|
||||
PreprocessIterator,
|
||||
PreprocessIteratorState,
|
||||
|
@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
|
|||
)
|
||||
|
||||
|
||||
def get_datafile(
|
||||
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
|
||||
):
|
||||
if isinstance(iterator, ArrowFileIterator):
|
||||
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
|
||||
elif isinstance(iterator, PreprocessIterator):
|
||||
return get_datafile(iterator.arrow_iterator)
|
||||
elif isinstance(iterator, LoopingIterator):
|
||||
return get_datafile(iterator.file_iterator)
|
||||
elif isinstance(iterator, LimitIterator):
|
||||
return get_datafile(iterator.base_iterator)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SequenceIterator(StatefulIterator):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
|
|||
tokens: list[int] = []
|
||||
mask: list[bool] = []
|
||||
first = True
|
||||
logger.info(
|
||||
"Starting first buffer for: %s",
|
||||
get_datafile(self.preprocess_iterator),
|
||||
)
|
||||
for example in example_iter:
|
||||
assert example.tokens is not None
|
||||
assert example.mask is not None
|
||||
|
@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
|
|||
while len(patch_lengths) >= n_buffer_patches:
|
||||
if first:
|
||||
first = False
|
||||
logger.info("First buffer complete")
|
||||
logger.info(
|
||||
"First buffer complete for: %s",
|
||||
get_datafile(self.preprocess_iterator),
|
||||
)
|
||||
|
||||
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
||||
self.buffer_size, self.output_seq_len
|
||||
|
|
30
bytelatent/iterate_data.py
Normal file
30
bytelatent/iterate_data.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import json
|
||||
|
||||
import pyarrow
|
||||
import typer
|
||||
from rich.progress import track
|
||||
|
||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
|
||||
from bytelatent.logger import init_logger
|
||||
|
||||
|
||||
def main(state_file: str):
|
||||
init_logger()
|
||||
pyarrow.set_io_thread_count(4)
|
||||
pyarrow.set_cpu_count(4)
|
||||
with open(state_file) as f:
|
||||
train_state = json.load(f)
|
||||
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
||||
packing_iterator_state = dl_state.base_iterator_state
|
||||
print("building")
|
||||
packing_iterator = packing_iterator_state.build()
|
||||
print("iter")
|
||||
batch_iter = packing_iterator.create_iter()
|
||||
batch = None
|
||||
print("looping")
|
||||
for i in track(range(1_000)):
|
||||
batch = next(batch_iter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
|
@ -13,6 +13,7 @@ from timeit import default_timer as timer
|
|||
from typing import Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pyarrow
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional
|
||||
|
@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
|
|||
|
||||
def train(args: TrainArgs):
|
||||
with ExitStack() as context_stack:
|
||||
pyarrow.set_io_thread_count(4)
|
||||
pyarrow.set_cpu_count(4)
|
||||
tokenizer = args.data.tokenizer_args.build()
|
||||
validate_train_args(
|
||||
args,
|
||||
|
|
Loading…
Add table
Reference in a new issue