blt/bytelatent/data/iterators/sequence_iterator.py
2024-12-12 15:32:30 -08:00

123 lines
4.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from logging import getLogger
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.preprocess_iterator import (
PreprocessIterator,
PreprocessIteratorState,
)
logger = getLogger()
class SequencePackingArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
output_seq_len: int
buffer_size: int
class SequenceIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
sequence_packing_args: SequencePackingArgs
preprocess_iterator_state: PreprocessIteratorState
rng_state: dict[str, Any]
def build(self):
preprocess_iterator = self.preprocess_iterator_state.build()
return SequenceIterator(
preprocess_iterator,
sequence_packing_args=self.sequence_packing_args,
rng_state=self.rng_state,
)
class SequenceIterator(StatefulIterator):
def __init__(
self,
preprocess_iterator: PreprocessIterator,
*,
rng_state: dict[str, Any],
sequence_packing_args: SequencePackingArgs,
):
self.preprocess_iterator = preprocess_iterator
self.sequence_packing_args = sequence_packing_args
self.output_seq_len = sequence_packing_args.output_seq_len
self.buffer_size = sequence_packing_args.buffer_size
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
def get_state(self):
# TODO: need to also perist the current shuffle buffer
return SequenceIteratorState(
sequence_packing_args=self.sequence_packing_args,
preprocess_iterator_state=self.preprocess_iterator.get_state(),
rng_state=self.rng.bit_generator.state,
)
def create_iter(self):
example_iter = self.preprocess_iterator.create_iter()
n_buffer_patches = self.buffer_size * self.output_seq_len
patch_lengths: list[int] = []
tokens: list[int] = []
mask: list[bool] = []
first = True
for example in example_iter:
assert example.tokens is not None
assert example.mask is not None
assert example.patch_lengths is not None
assert len(example.tokens) != 0
assert len(example.mask) != 0
assert len(example.tokens) == len(example.mask)
assert len(example.tokens) == sum(example.patch_lengths)
tokens.extend(example.tokens)
mask.extend(example.mask)
patch_lengths.extend(example.patch_lengths)
while len(patch_lengths) >= n_buffer_patches:
if first:
first = False
logger.info("First buffer complete")
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
self.buffer_size, self.output_seq_len
)
seq_tokens = []
seq_mask = []
start_id = 0
# We fix the number of patches and therefore global steps per batch
# so we have a variable number of tokens we need to account for
for num_tokens in x_patches.sum(axis=-1):
seq_tokens.append(tokens[start_id : start_id + num_tokens])
seq_mask.append(mask[start_id : start_id + num_tokens])
start_id += num_tokens
assert start_id == x_patches.sum()
# Remove what we just added from the buffer
patch_lengths = patch_lengths[n_buffer_patches:]
tokens = tokens[x_patches.sum() :]
mask = mask[x_patches.sum() :]
seq_patch_lengths: list[list[int]] = x_patches.tolist()
assert len(seq_patch_lengths) == self.buffer_size
for idx in self.rng.permutation(len(seq_patch_lengths)):
assert len(seq_patch_lengths[idx]) == self.output_seq_len
assert (
sum(seq_patch_lengths[idx])
== len(seq_tokens[idx])
== len(seq_mask[idx])
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
yield BltSequence(
tokens=seq_tokens[idx],
mask=seq_mask[idx],
patch_lengths=seq_patch_lengths[idx],
)