From 3e9de6276360ab8fa52dc8ce9b551445f387eb38 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 20 Feb 2025 20:15:45 +0000 Subject: [PATCH] Pass mask in packing_iterator, correctly handle last batch --- bytelatent/data/iterators/packing_iterator.py | 96 ++++++--- .../data/iterators/test_packing_iterator.py | 193 ++++++++++++++++++ 2 files changed, 255 insertions(+), 34 deletions(-) create mode 100644 bytelatent/data/iterators/test_packing_iterator.py diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 5ed280d..cdb9465 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -165,31 +165,46 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id seq_len = self.packing_args.seq_len + stop_iteration = False + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] while True: - tokens: list[list[int]] = [] - masks: list[list[bool]] = [] - - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - assert ( - sequence.patch_lengths is None - ), "patch_lengths should not be used in byte packing" - tokens.append(_tokens) - masks.append(_mask) + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + except StopIteration: + # At this point, there will be no new sequences, so we need to stop + # after yielding the already accumulated data (one batch). + # In this case, either: + # 1. We have a complete batch, so things go as normal + # 2. We have an incomplete batch, but due to creating a right sized batch, + # then filling the values in, this case is automatically handled. + stop_iteration = True x = np.full((batch_size, seq_len), fill_value=pad_id) y = np.full((batch_size, seq_len), fill_value=pad_id) + m = np.zeros((batch_size, seq_len)) for i, tok_seq in enumerate(tokens): x[i, : len(tok_seq)] = tok_seq y[i, : len(tok_seq) - 1] = tok_seq[1:] - batch = Batch(x=x, y=y) + m[i, : len(tok_seq)] = masks[i] + batch = Batch(x=x, y=y, mask=m) assert ( batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + tokens = [] + masks = [] yield batch + if stop_iteration: + break def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() @@ -199,29 +214,36 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): pad_to_max_length = self.packing_args.pad_to_max_length enable_byte_ngrams = self.packing_args.enable_byte_ngrams max_length = self.packing_args.max_length + assert max_length is not None + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + patch_lengths: list[list[int]] = [] + stop_iteration = False while True: - tokens: list[list[int]] = [] - masks: list[list[bool]] = [] - patch_lengths: list[list[int]] = [] + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches + assert len(sequence.patch_lengths) == self.packing_args.seq_len + last_patch_length = 0 + if _patch_lengths[0] > 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + except StopIteration: + stop_iteration = True - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - _patch_lengths = sequence.patch_lengths - assert ( - _patch_lengths is not None - ), "patch lengths are required for packing based on patches." - # Reminder: seq_len is in terms of patches - assert len(sequence.patch_lengths) == self.packing_args.seq_len - last_patch_length = 0 - if _patch_lengths[0] > 1: - last_patch_length = _patch_lengths[-1] - _patch_lengths[0] -= 1 - _patch_lengths = [1] + _patch_lengths[:-1] - tokens.append(_tokens[: len(_tokens) - last_patch_length]) - masks.append(_mask[: len(_mask) - last_patch_length]) - patch_lengths.append(_patch_lengths) + if len(tokens) == 0 and stop_iteration: + break x_patch_lengths = np.array(patch_lengths) # pad batch to same length @@ -249,6 +271,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ngram_ids=ngram_ids, mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), ) + tokens = [] + masks = [] + patch_lengths = [] + assert ( x_patch_lengths.sum() == x.size + batch_size ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" @@ -269,3 +295,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): enable_byte_ngrams=enable_byte_ngrams, ) yield batch + if stop_iteration: + break diff --git a/bytelatent/data/iterators/test_packing_iterator.py b/bytelatent/data/iterators/test_packing_iterator.py new file mode 100644 index 0000000..8cf1ca2 --- /dev/null +++ b/bytelatent/data/iterators/test_packing_iterator.py @@ -0,0 +1,193 @@ +import numpy as np + +from bytelatent.data.data_types import BltSequence +from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator + + +class DummySequenceIterator(StatefulIterator): + def __init__( + self, *, seq_len: int, n_seqs: int, patch_lengths: list[int] | None = None + ): + self.seq_len = seq_len + self.n_seqs = n_seqs + self.patch_lengths = patch_lengths + + def get_state(self): + raise NotImplementedError() + + def create_iter(self): + for i in range(self.n_seqs): + tokens = np.arange( + i * self.seq_len + 1, (i + 1) * self.seq_len + 1 + ).tolist() + assert len(tokens) == self.seq_len + if self.patch_lengths is not None: + assert sum(self.patch_lengths) == len(tokens) + yield BltSequence( + tokens=tokens, + mask=[1] * self.seq_len, # type: ignore + patch_lengths=self.patch_lengths, + ) + + +def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int): + sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + tokenizer_name="bytes", + max_length=None, + pad_to_max_length=False, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def create_patches_iter( + *, + seq_len: int, + n_seqs: int, + batch_size: int, + pad_id: int, + patch_lengths: list[int] | None, + max_length: int, +): + sequence_iterator = DummySequenceIterator( + # seq_len=number of bytes, which for blt/patches, is max_length since seq_len is + # in terms of number of patches + seq_len=max_length, + n_seqs=n_seqs, + patch_lengths=patch_lengths, + ) + packing_iterator = PackingIterator( + sequence_iterator, + packing_args=PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + pad_id=pad_id, + tokenizer_name="blt", + max_length=max_length, + pad_to_max_length=True, + enable_byte_ngrams=False, + ), + ) + return packing_iterator.create_iter() + + +def test_last_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 3 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[-1].mask[2:].sum() == 0 + + +def test_edgecase_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 10 + batch_size = 12 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 1 + assert n_nonpad == n_nonmask == seq_len * n_seqs + # The second half of the last batch should be all pads + assert batches[0].mask[10:].sum() == 0 + + +def test_exact_batch_correctness_bytes(): + seq_len = 1024 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_bytes_iter( + seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for b in iterator: + assert b.x.shape[0] == batch_size + assert b.x.shape[1] == seq_len + n_nonpad += (b.x != pad_id).sum() + if b.mask is None: + n_nonmask += b.x.size + else: + n_nonmask += b.mask.sum() + batches.append(b) + assert len(batches) == 4 + assert n_nonpad == n_nonmask == seq_len * n_seqs + + +def test_exact_batch_correctness_patches(): + # First patch length is forced to be 1 + patch_lengths = [1, 255, 256, 256, 256] + # Recall: This is in terms of bytes + max_length = 1024 + # Recall: This is in terms of patches + seq_len = 5 + n_seqs = 12 + batch_size = 4 + pad_id = 0 + iterator = create_patches_iter( + seq_len=seq_len, + n_seqs=n_seqs, + batch_size=batch_size, + pad_id=pad_id, + patch_lengths=patch_lengths, + max_length=max_length, + ) + batches = [] + n_nonpad = 0 + n_nonmask = 0 + for batch in iterator: + assert batch.x.shape[0] == batch_size + assert batch.x.shape[1] == max_length + n_nonpad += (batch.x != pad_id).sum() + # TODO: Discuss with artidoro if the code or expected behavior is wrong + # assert (batch.patch_lengths == patch_lengths).all() + if batch.mask is None: + n_nonmask += batch.x.size + else: + n_nonmask += batch.mask.sum() + batches.append(batch) + + assert len(batches) == 3 + # TODO: Ditto above, this is due to how things get padded in the end + # assert n_nonpad == n_nonmask == max_length * n_seqs