From 203bff36960be4f47b499ff0a099846e8bf7d9d4 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Sat, 22 Feb 2025 01:23:16 +0000 Subject: [PATCH] Pass mask in packing_iterator, correctly handle last batch, fix masking This commit does/fixes the following: 1. Adds unit tests for byte and patch packing to ensure it works correctly 2. Fixes a bug where for batches that end up with 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + except StopIteration: + stop_iteration = True - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - _patch_lengths = sequence.patch_lengths - assert ( - _patch_lengths is not None - ), "patch lengths are required for packing based on patches." - # Reminder: seq_len is in terms of patches - assert len(sequence.patch_lengths) == self.packing_args.seq_len - last_patch_length = 0 - if _patch_lengths[0] > 1: - last_patch_length = _patch_lengths[-1] - _patch_lengths[0] -= 1 - _patch_lengths = [1] + _patch_lengths[:-1] - tokens.append(_tokens[: len(_tokens) - last_patch_length]) - masks.append(_mask[: len(_mask) - last_patch_length]) - patch_lengths.append(_patch_lengths) + if len(tokens) == 0 and stop_iteration: + break x_patch_lengths = np.array(patch_lengths) # pad batch to same length @@ -257,6 +285,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ngram_ids=ngram_ids, mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), ) + tokens = [] + masks = [] + patch_lengths = [] + assert ( x_patch_lengths.sum() == x.size + batch_size ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" @@ -277,3 +309,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): enable_byte_ngrams=enable_byte_ngrams, ) yield batch + if stop_iteration: + break diff --git a/bytelatent/data/iterators/test_packing_iterator.py b/bytelatent/data/iterators/test_packing_iterator.py new file mode 100644 index 0000000..ef296e4 --- /dev/null +++ b/bytelatent/data/iterators/test_packing_iterator.py @@ -0,0 +1,285 @@ +import numpy as np + +from bytelatent.data.data_types import BltSequence +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.packing_iterator import ( + PackingArgs, + PackingIterator, + PackingMode, +) + + +class DummySequenceIterator(StatefulIterator): + def __init__( + self, + *, + seq_len: int, + n_seqs: int, + patch_lengths: list[int] | None = None, + pad_id: int = 0, + ): + self.seq_len = seq_len + self.n_seqs = n_seqs + self.patch_lengths = patch_lengths + self.pad_id = pad_id + + def get_state(self): + raise NotImplementedError() + + def create_iter(self): + for i in range(self.n_seqs): + if self.patch_lengths is None: + tokens = np.arange( + i * self.seq_len + 1, (i + 1) * self.seq_len + 1 + ).tolist() + mask = [True] * self.seq_len # type: ignore + assert len(tokens) == self.seq_len + else: + n = sum(self.patch_lengths) + tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist() + assert len(tokens) == n + mask = [True] * n + assert len(mask) == len(tokens) + yield BltSequence( + tokens=tokens, + mask=mask, + patch_lengths=self.patch_lengths, + ) + + +def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int): + sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + packing_mode=PackingMode.BYTES, + max_length=None, + pad_to_max_length=False, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def create_patches_iter( + *, + seq_len: int, + n_seqs: int, + batch_size: int, + pad_id: int, + patch_lengths: list[int] | None, + max_length: int, +): + sequence_iterator = DummySequenceIterator( + # seq_len=number of bytes, which for blt/patches, is max_length since seq_len is + # in terms of number of patches + seq_len=max_length, + n_seqs=n_seqs, + patch_lengths=patch_lengths, + ) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + packing_mode=PackingMode.PATCHING, + max_length=max_length, + pad_to_max_length=True, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def test_last_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 3 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[-1].mask[2:].sum() == 0 + + +def test_edgecase_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 12 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 1 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[0].mask[10:].sum() == 0 + + +def test_exact_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 4 + assert n_nonpad == n_nonmask == seq_len * n_seqs + + +def test_exact_batch_correctness_patches(): + # First patch length is forced to be 1 + patch_lengths = [1, 255, 256, 256, 256] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # max_length - 1 is due to chopping off the last byte for + # having a y target + assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs + + +def test_short_batch_correctness_patches(): + # First patch length is forced to be 1 + # Total=48 + patch_lengths = [1, 11, 12, 12, 12] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # We'll still always have one byte chopped off the end + assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs) + + +def test_long_batch_correctness_patches(): + # First patch length is forced to be 1 + # Total=48 + patch_lengths = [1, 255, 256, 256, 1024] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # No chop here since the next byte is available + assert n_nonpad == n_nonmask == max_length * n_seqs