mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-10 19:59:09 +00:00
parent
8f2cf8899d
commit
63913e4dba
5 changed files with 68 additions and 3 deletions
bytelatent
|
@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
|
||||||
preprocess_dir: str | None = None
|
preprocess_dir: str | None = None
|
||||||
dataset_files: list[str] | None = None
|
dataset_files: list[str] | None = None
|
||||||
entropy_model_name: str | None = "transformer_100m"
|
entropy_model_name: str | None = "transformer_100m"
|
||||||
arrow_batch_size: int = 100
|
# Be very careful with increasing, increases memory usage by that factor per rank, per data source
|
||||||
|
arrow_batch_size: int = 20
|
||||||
buffer_size: int = 64
|
buffer_size: int = 64
|
||||||
file_format: str = "arrow"
|
file_format: str = "arrow"
|
||||||
|
|
||||||
|
|
|
@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
|
batch_size=self.arrow_batch_size,
|
||||||
|
# We have large files in GBs, no need to readahead
|
||||||
|
fragment_readahead=1,
|
||||||
|
# Don't readahead in case batches are huge (e.g., books)
|
||||||
|
batch_readahead=1,
|
||||||
|
)
|
||||||
for batch in self.batch_iterator:
|
for batch in self.batch_iterator:
|
||||||
batch_columns = batch.to_pydict()
|
batch_columns = batch.to_pydict()
|
||||||
if self.file_format == "arrow":
|
if self.file_format == "arrow":
|
||||||
|
|
|
@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
|
||||||
PydanticIteratorState,
|
PydanticIteratorState,
|
||||||
StatefulIterator,
|
StatefulIterator,
|
||||||
)
|
)
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
|
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
||||||
|
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||||
from bytelatent.data.iterators.preprocess_iterator import (
|
from bytelatent.data.iterators.preprocess_iterator import (
|
||||||
PreprocessIterator,
|
PreprocessIterator,
|
||||||
PreprocessIteratorState,
|
PreprocessIteratorState,
|
||||||
|
@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_datafile(
|
||||||
|
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
|
||||||
|
):
|
||||||
|
if isinstance(iterator, ArrowFileIterator):
|
||||||
|
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
|
||||||
|
elif isinstance(iterator, PreprocessIterator):
|
||||||
|
return get_datafile(iterator.arrow_iterator)
|
||||||
|
elif isinstance(iterator, LoopingIterator):
|
||||||
|
return get_datafile(iterator.file_iterator)
|
||||||
|
elif isinstance(iterator, LimitIterator):
|
||||||
|
return get_datafile(iterator.base_iterator)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class SequenceIterator(StatefulIterator):
|
class SequenceIterator(StatefulIterator):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
|
||||||
tokens: list[int] = []
|
tokens: list[int] = []
|
||||||
mask: list[bool] = []
|
mask: list[bool] = []
|
||||||
first = True
|
first = True
|
||||||
|
logger.info(
|
||||||
|
"Starting first buffer for: %s",
|
||||||
|
get_datafile(self.preprocess_iterator),
|
||||||
|
)
|
||||||
for example in example_iter:
|
for example in example_iter:
|
||||||
assert example.tokens is not None
|
assert example.tokens is not None
|
||||||
assert example.mask is not None
|
assert example.mask is not None
|
||||||
|
@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
|
||||||
while len(patch_lengths) >= n_buffer_patches:
|
while len(patch_lengths) >= n_buffer_patches:
|
||||||
if first:
|
if first:
|
||||||
first = False
|
first = False
|
||||||
logger.info("First buffer complete")
|
logger.info(
|
||||||
|
"First buffer complete for: %s",
|
||||||
|
get_datafile(self.preprocess_iterator),
|
||||||
|
)
|
||||||
|
|
||||||
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
||||||
self.buffer_size, self.output_seq_len
|
self.buffer_size, self.output_seq_len
|
||||||
|
|
30
bytelatent/iterate_data.py
Normal file
30
bytelatent/iterate_data.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pyarrow
|
||||||
|
import typer
|
||||||
|
from rich.progress import track
|
||||||
|
|
||||||
|
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
|
||||||
|
from bytelatent.logger import init_logger
|
||||||
|
|
||||||
|
|
||||||
|
def main(state_file: str):
|
||||||
|
init_logger()
|
||||||
|
pyarrow.set_io_thread_count(4)
|
||||||
|
pyarrow.set_cpu_count(4)
|
||||||
|
with open(state_file) as f:
|
||||||
|
train_state = json.load(f)
|
||||||
|
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
||||||
|
packing_iterator_state = dl_state.base_iterator_state
|
||||||
|
print("building")
|
||||||
|
packing_iterator = packing_iterator_state.build()
|
||||||
|
print("iter")
|
||||||
|
batch_iter = packing_iterator.create_iter()
|
||||||
|
batch = None
|
||||||
|
print("looping")
|
||||||
|
for i in track(range(1_000)):
|
||||||
|
batch = next(batch_iter)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
|
@ -13,6 +13,7 @@ from timeit import default_timer as timer
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyarrow
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn.functional
|
import torch.nn.functional
|
||||||
|
@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
|
||||||
|
|
||||||
def train(args: TrainArgs):
|
def train(args: TrainArgs):
|
||||||
with ExitStack() as context_stack:
|
with ExitStack() as context_stack:
|
||||||
|
pyarrow.set_io_thread_count(4)
|
||||||
|
pyarrow.set_cpu_count(4)
|
||||||
tokenizer = args.data.tokenizer_args.build()
|
tokenizer = args.data.tokenizer_args.build()
|
||||||
validate_train_args(
|
validate_train_args(
|
||||||
args,
|
args,
|
||||||
|
|
Loading…
Add table
Reference in a new issue