mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
123 lines
4.7 KiB
Python
123 lines
4.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
from logging import getLogger
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from bytelatent.data.data_types import BltSequence
|
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
|
from bytelatent.data.iterators.preprocess_iterator import (
|
|
PreprocessIterator,
|
|
PreprocessIteratorState,
|
|
)
|
|
|
|
logger = getLogger()
|
|
|
|
|
|
class SequencePackingArgs(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
output_seq_len: int
|
|
buffer_size: int
|
|
|
|
|
|
class SequenceIteratorState(BaseModel, IteratorState):
|
|
model_config = ConfigDict(extra="forbid")
|
|
sequence_packing_args: SequencePackingArgs
|
|
preprocess_iterator_state: PreprocessIteratorState
|
|
rng_state: dict[str, Any]
|
|
|
|
def build(self):
|
|
preprocess_iterator = self.preprocess_iterator_state.build()
|
|
return SequenceIterator(
|
|
preprocess_iterator,
|
|
sequence_packing_args=self.sequence_packing_args,
|
|
rng_state=self.rng_state,
|
|
)
|
|
|
|
|
|
class SequenceIterator(StatefulIterator):
|
|
def __init__(
|
|
self,
|
|
preprocess_iterator: PreprocessIterator,
|
|
*,
|
|
rng_state: dict[str, Any],
|
|
sequence_packing_args: SequencePackingArgs,
|
|
):
|
|
self.preprocess_iterator = preprocess_iterator
|
|
self.sequence_packing_args = sequence_packing_args
|
|
self.output_seq_len = sequence_packing_args.output_seq_len
|
|
self.buffer_size = sequence_packing_args.buffer_size
|
|
self.rng = np.random.default_rng()
|
|
self.rng.bit_generator.state = rng_state
|
|
|
|
def get_state(self):
|
|
# TODO: need to also perist the current shuffle buffer
|
|
return SequenceIteratorState(
|
|
sequence_packing_args=self.sequence_packing_args,
|
|
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
|
rng_state=self.rng.bit_generator.state,
|
|
)
|
|
|
|
def create_iter(self):
|
|
example_iter = self.preprocess_iterator.create_iter()
|
|
n_buffer_patches = self.buffer_size * self.output_seq_len
|
|
|
|
patch_lengths: list[int] = []
|
|
tokens: list[int] = []
|
|
mask: list[bool] = []
|
|
first = True
|
|
for example in example_iter:
|
|
assert example.tokens is not None
|
|
assert example.mask is not None
|
|
assert example.patch_lengths is not None
|
|
assert len(example.tokens) != 0
|
|
assert len(example.mask) != 0
|
|
assert len(example.tokens) == len(example.mask)
|
|
assert len(example.tokens) == sum(example.patch_lengths)
|
|
|
|
tokens.extend(example.tokens)
|
|
mask.extend(example.mask)
|
|
patch_lengths.extend(example.patch_lengths)
|
|
|
|
while len(patch_lengths) >= n_buffer_patches:
|
|
if first:
|
|
first = False
|
|
logger.info("First buffer complete")
|
|
|
|
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
|
self.buffer_size, self.output_seq_len
|
|
)
|
|
seq_tokens = []
|
|
seq_mask = []
|
|
start_id = 0
|
|
# We fix the number of patches and therefore global steps per batch
|
|
# so we have a variable number of tokens we need to account for
|
|
for num_tokens in x_patches.sum(axis=-1):
|
|
seq_tokens.append(tokens[start_id : start_id + num_tokens])
|
|
seq_mask.append(mask[start_id : start_id + num_tokens])
|
|
start_id += num_tokens
|
|
|
|
assert start_id == x_patches.sum()
|
|
|
|
# Remove what we just added from the buffer
|
|
patch_lengths = patch_lengths[n_buffer_patches:]
|
|
tokens = tokens[x_patches.sum() :]
|
|
mask = mask[x_patches.sum() :]
|
|
|
|
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
|
assert len(seq_patch_lengths) == self.buffer_size
|
|
for idx in self.rng.permutation(len(seq_patch_lengths)):
|
|
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
|
assert (
|
|
sum(seq_patch_lengths[idx])
|
|
== len(seq_tokens[idx])
|
|
== len(seq_mask[idx])
|
|
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
|
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
|
yield BltSequence(
|
|
tokens=seq_tokens[idx],
|
|
mask=seq_mask[idx],
|
|
patch_lengths=seq_patch_lengths[idx],
|
|
)
|