From 08b8c7cd0584d5a0226555a580d8a5411d1a45b8 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <par@meta.com> Date: Thu, 27 Feb 2025 11:41:47 -0800 Subject: [PATCH] Pass mask in packing_iterator, correctly handle last batch, fix masking (#65) This commit does/fixes the following: 1. Adds unit tests for byte and patch packing to ensure it works correctly 2. Fixes a bug where for batches that end up with <max_length number of bytes (e.g., short patches), the mask was including elements that had value pad_id. This fixes the mask by setting it to be !=pad_id, if its not specified. 3. Correctly handles the last batch, where previously it would crash. This didn't affect training since we had enough data and/or looped iterators, but for evaluation perplexity, it comes up if we validation on an entire file. 4. Correctly forward the mask if it exists for byte packing Test Plan: ``` pytest bytelatent/ ``` Testing these changes more thoroughly in a stacked PR that fixes evals --- bytelatent/data/iterators/packing_iterator.py | 87 +++-- .../data/iterators/test_packing_iterator.py | 312 ++++++++++++++++++ 2 files changed, 367 insertions(+), 32 deletions(-) create mode 100644 bytelatent/data/iterators/test_packing_iterator.py diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index dc34120..f407f9f 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -41,12 +41,12 @@ class PackingIteratorState(PydanticIteratorState): ) -def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]): +def _merge_patch_seq_masks(bs: int, slen: int, mask_seqs: list[list[bool]]): assert len(mask_seqs) == bs lens = [len(m) for m in mask_seqs] if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens): - return None - assert slen == max(lens) - 1 + return np.ones((bs, slen), dtype=bool) + assert slen == max(lens) - 1, f"slen={slen} != max(lens)-1={max(lens) - 1}" mask = np.zeros((bs, slen), dtype=bool) for i, m in enumerate(mask_seqs): if m is None: @@ -176,28 +176,41 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): while True: tokens: list[list[int]] = [] masks: list[list[bool]] = [] - - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - assert ( - sequence.patch_lengths is None - ), "patch_lengths should not be used in byte packing" - tokens.append(_tokens) - masks.append(_mask) + stop_iteration = False + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + except StopIteration: + # At this point, there will be no new sequences, so we need to stop + # after yielding the already accumulated data (one batch). + # In this case, either: + # 1. We have a complete batch, so things go as normal + # 2. We have an incomplete batch, but due to creating a right sized batch, + # then filling the values in, this case is automatically handled. + stop_iteration = True x = np.full((batch_size, seq_len), fill_value=pad_id) y = np.full((batch_size, seq_len), fill_value=pad_id) + m = np.zeros((batch_size, seq_len), dtype=np.bool) for i, tok_seq in enumerate(tokens): x[i, : len(tok_seq)] = tok_seq y[i, : len(tok_seq) - 1] = tok_seq[1:] - batch = Batch(x=x, y=y) + m[i, : len(tok_seq)] = masks[i] + batch = Batch(x=x, y=y, mask=m) assert ( batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" yield batch + if stop_iteration: + break def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() @@ -207,29 +220,36 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): pad_to_max_length = self.packing_args.pad_to_max_length enable_byte_ngrams = self.packing_args.enable_byte_ngrams max_length = self.packing_args.max_length + assert max_length is not None while True: tokens: list[list[int]] = [] masks: list[list[bool]] = [] patch_lengths: list[list[int]] = [] + stop_iteration = False + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches + assert len(sequence.patch_lengths) == self.packing_args.seq_len + last_patch_length = 0 + if _patch_lengths[0] > 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + except StopIteration: + stop_iteration = True - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - _patch_lengths = sequence.patch_lengths - assert ( - _patch_lengths is not None - ), "patch lengths are required for packing based on patches." - # Reminder: seq_len is in terms of patches - assert len(sequence.patch_lengths) == self.packing_args.seq_len - last_patch_length = 0 - if _patch_lengths[0] > 1: - last_patch_length = _patch_lengths[-1] - _patch_lengths[0] -= 1 - _patch_lengths = [1] + _patch_lengths[:-1] - tokens.append(_tokens[: len(_tokens) - last_patch_length]) - masks.append(_mask[: len(_mask) - last_patch_length]) - patch_lengths.append(_patch_lengths) + if len(tokens) == 0 and stop_iteration: + break x_patch_lengths = np.array(patch_lengths) # pad batch to same length @@ -257,6 +277,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ngram_ids=ngram_ids, mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), ) + assert ( x_patch_lengths.sum() == x.size + batch_size ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" @@ -277,3 +298,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): enable_byte_ngrams=enable_byte_ngrams, ) yield batch + if stop_iteration: + break diff --git a/bytelatent/data/iterators/test_packing_iterator.py b/bytelatent/data/iterators/test_packing_iterator.py new file mode 100644 index 0000000..83b03f4 --- /dev/null +++ b/bytelatent/data/iterators/test_packing_iterator.py @@ -0,0 +1,312 @@ +import numpy as np +import pytest + +from bytelatent.data.data_types import BltSequence +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.packing_iterator import ( + PackingArgs, + PackingIterator, + PackingMode, + _merge_patch_seq_masks, +) + + +class DummySequenceIterator(StatefulIterator): + def __init__( + self, + *, + seq_len: int, + n_seqs: int, + patch_lengths: list[int] | None = None, + pad_id: int = 0, + ): + self.seq_len = seq_len + self.n_seqs = n_seqs + self.patch_lengths = patch_lengths + self.pad_id = pad_id + + def get_state(self): + raise NotImplementedError() + + def create_iter(self): + for i in range(self.n_seqs): + if self.patch_lengths is None: + tokens = np.arange( + i * self.seq_len + 1, (i + 1) * self.seq_len + 1 + ).tolist() + mask = [True] * self.seq_len # type: ignore + assert len(tokens) == self.seq_len + else: + n = sum(self.patch_lengths) + tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist() + assert len(tokens) == n + mask = [True] * n + assert len(mask) == len(tokens) + yield BltSequence( + tokens=tokens, + mask=mask, + patch_lengths=self.patch_lengths, + ) + + +def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int): + sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + packing_mode=PackingMode.BYTES, + max_length=None, + pad_to_max_length=False, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def create_patches_iter( + *, + seq_len: int, + n_seqs: int, + batch_size: int, + pad_id: int, + patch_lengths: list[int] | None, + max_length: int, +): + sequence_iterator = DummySequenceIterator( + # seq_len=number of bytes, which for blt/patches, is max_length since seq_len is + # in terms of number of patches + seq_len=max_length, + n_seqs=n_seqs, + patch_lengths=patch_lengths, + ) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + packing_mode=PackingMode.PATCHING, + max_length=max_length, + pad_to_max_length=True, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def test_last_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 3 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[-1].mask[2:].sum() == 0 + + +def test_edgecase_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 12 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 1 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[0].mask[10:].sum() == 0 + + +def test_exact_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 4 + assert n_nonpad == n_nonmask == seq_len * n_seqs + + +def test_exact_batch_correctness_patches(): + # First patch length is forced to be 1 + patch_lengths = [1, 255, 256, 256, 256] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # max_length - 1 is due to chopping off the last byte for + # having a y target + assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs + + +def test_short_batch_correctness_patches(): + # First patch length is forced to be 1 + # Total=48 + patch_lengths = [1, 11, 12, 12, 12] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # We'll still always have one byte chopped off the end + assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs) + + +def test_long_batch_correctness_patches(): + # First patch length is forced to be 1 + # Total=48 + patch_lengths = [1, 255, 256, 256, 1024] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + + # No chop here since the next byte is available + assert n_nonpad == n_nonmask == max_length * n_seqs + + +def test_merge_patch_seq_masks(): + batch_size = 4 + seq_len = 1024 + masks = [] + masks.append([True] * 1025) + masks.append([True] * 512) + masks.append([True] * 256) + masks.append([True] * 10) + expected_mask = np.zeros((batch_size, seq_len), dtype=bool) + expected_mask[0, :] = True + expected_mask[1, :511] = True + expected_mask[2, :255] = True + expected_mask[3, :9] = True + merged_mask = _merge_patch_seq_masks(batch_size, seq_len, masks) + assert (merged_mask == expected_mask).all() + + with pytest.raises(AssertionError): + masks = [] + masks.append([True] * 1024) + masks.append([True] * 512) + masks.append([True] * 256) + masks.append([True] * 10) + _merge_patch_seq_masks(batch_size, seq_len, masks)