diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index dc34120..87b0e84 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -119,12 +119,18 @@ def truncate_batch( y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype) y[:, : batch.y.shape[1]] = batch.y batch.y = y - if batch.mask is not None and batch.mask.shape[1] < max_length: - mask = np.full( - (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype - ) - mask[:, : batch.mask.shape[1]] = batch.mask - batch.mask = mask + if batch.mask is None: + mask = batch.x != pad_id + # Only set the mask if its actually doing anything + if mask.sum() != batch.x.size: + batch.mask = mask + else: + if batch.mask.shape[1] < max_length: + mask = np.full( + (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype + ) + mask[:, : batch.mask.shape[1]] = batch.mask + batch.mask = mask assert batch.x.shape[1] <= max_length assert batch.y.shape[1] <= max_length @@ -173,31 +179,46 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id seq_len = self.packing_args.seq_len + stop_iteration = False + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] while True: - tokens: list[list[int]] = [] - masks: list[list[bool]] = [] - - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - assert ( - sequence.patch_lengths is None - ), "patch_lengths should not be used in byte packing" - tokens.append(_tokens) - masks.append(_mask) + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + except StopIteration: + # At this point, there will be no new sequences, so we need to stop + # after yielding the already accumulated data (one batch). + # In this case, either: + # 1. We have a complete batch, so things go as normal + # 2. We have an incomplete batch, but due to creating a right sized batch, + # then filling the values in, this case is automatically handled. + stop_iteration = True x = np.full((batch_size, seq_len), fill_value=pad_id) y = np.full((batch_size, seq_len), fill_value=pad_id) + m = np.zeros((batch_size, seq_len), dtype=np.bool) for i, tok_seq in enumerate(tokens): x[i, : len(tok_seq)] = tok_seq y[i, : len(tok_seq) - 1] = tok_seq[1:] - batch = Batch(x=x, y=y) + m[i, : len(tok_seq)] = masks[i] + batch = Batch(x=x, y=y, mask=m) assert ( batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + tokens = [] + masks = [] yield batch + if stop_iteration: + break def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() @@ -207,29 +228,36 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): pad_to_max_length = self.packing_args.pad_to_max_length enable_byte_ngrams = self.packing_args.enable_byte_ngrams max_length = self.packing_args.max_length + assert max_length is not None + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + patch_lengths: list[list[int]] = [] + stop_iteration = False while True: - tokens: list[list[int]] = [] - masks: list[list[bool]] = [] - patch_lengths: list[list[int]] = [] + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches + assert len(sequence.patch_lengths) == self.packing_args.seq_len + last_patch_length = 0 + if _patch_lengths[0] > 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + except StopIteration: + stop_iteration = True - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - _patch_lengths = sequence.patch_lengths - assert ( - _patch_lengths is not None - ), "patch lengths are required for packing based on patches." - # Reminder: seq_len is in terms of patches - assert len(sequence.patch_lengths) == self.packing_args.seq_len - last_patch_length = 0 - if _patch_lengths[0] > 1: - last_patch_length = _patch_lengths[-1] - _patch_lengths[0] -= 1 - _patch_lengths = [1] + _patch_lengths[:-1] - tokens.append(_tokens[: len(_tokens) - last_patch_length]) - masks.append(_mask[: len(_mask) - last_patch_length]) - patch_lengths.append(_patch_lengths) + if len(tokens) == 0 and stop_iteration: + break x_patch_lengths = np.array(patch_lengths) # pad batch to same length @@ -257,6 +285,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ngram_ids=ngram_ids, mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), ) + tokens = [] + masks = [] + patch_lengths = [] + assert ( x_patch_lengths.sum() == x.size + batch_size ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" @@ -277,3 +309,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): enable_byte_ngrams=enable_byte_ngrams, ) yield batch + if stop_iteration: + break diff --git a/bytelatent/data/iterators/test_packing_iterator.py b/bytelatent/data/iterators/test_packing_iterator.py new file mode 100644 index 0000000..ef296e4 --- /dev/null +++ b/bytelatent/data/iterators/test_packing_iterator.py @@ -0,0 +1,285 @@ +import numpy as np + +from bytelatent.data.data_types import BltSequence +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.packing_iterator import ( + PackingArgs, + PackingIterator, + PackingMode, +) + + +class DummySequenceIterator(StatefulIterator): + def __init__( + self, + *, + seq_len: int, + n_seqs: int, + patch_lengths: list[int] | None = None, + pad_id: int = 0, + ): + self.seq_len = seq_len + self.n_seqs = n_seqs + self.patch_lengths = patch_lengths + self.pad_id = pad_id + + def get_state(self): + raise NotImplementedError() + + def create_iter(self): + for i in range(self.n_seqs): + if self.patch_lengths is None: + tokens = np.arange( + i * self.seq_len + 1, (i + 1) * self.seq_len + 1 + ).tolist() + mask = [True] * self.seq_len # type: ignore + assert len(tokens) == self.seq_len + else: + n = sum(self.patch_lengths) + tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist() + assert len(tokens) == n + mask = [True] * n + assert len(mask) == len(tokens) + yield BltSequence( + tokens=tokens, + mask=mask, + patch_lengths=self.patch_lengths, + ) + + +def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int): + sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + packing_mode=PackingMode.BYTES, + max_length=None, + pad_to_max_length=False, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def create_patches_iter( + *, + seq_len: int, + n_seqs: int, + batch_size: int, + pad_id: int, + patch_lengths: list[int] | None, + max_length: int, +): + sequence_iterator = DummySequenceIterator( + # seq_len=number of bytes, which for blt/patches, is max_length since seq_len is + # in terms of number of patches + seq_len=max_length, + n_seqs=n_seqs, + patch_lengths=patch_lengths, + ) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + packing_mode=PackingMode.PATCHING, + max_length=max_length, + pad_to_max_length=True, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def test_last_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 3 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[-1].mask[2:].sum() == 0 + + +def test_edgecase_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 12 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 1 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[0].mask[10:].sum() == 0 + + +def test_exact_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 4 + assert n_nonpad == n_nonmask == seq_len * n_seqs + + +def test_exact_batch_correctness_patches(): + # First patch length is forced to be 1 + patch_lengths = [1, 255, 256, 256, 256] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # max_length - 1 is due to chopping off the last byte for + # having a y target + assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs + + +def test_short_batch_correctness_patches(): + # First patch length is forced to be 1 + # Total=48 + patch_lengths = [1, 11, 12, 12, 12] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # We'll still always have one byte chopped off the end + assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs) + + +def test_long_batch_correctness_patches(): + # First patch length is forced to be 1 + # Total=48 + patch_lengths = [1, 255, 256, 256, 1024] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # No chop here since the next byte is available + assert n_nonpad == n_nonmask == max_length * n_seqs