From 55ddb0f84b0c533e623dab12f09258c619a0aaa0 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 20 Feb 2025 20:15:45 +0000 Subject: [PATCH] Pass mask in packing_iterator, correctly handle last batch --- bytelatent/data/iterators/packing_iterator.py | 98 ++++++++++++------- 1 file changed, 63 insertions(+), 35 deletions(-) diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 5ed280d..7b19594 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -165,33 +165,51 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id seq_len = self.packing_args.seq_len + stop_iteration = False + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] while True: - tokens: list[list[int]] = [] - masks: list[list[bool]] = [] - - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - assert ( - sequence.patch_lengths is None - ), "patch_lengths should not be used in byte packing" - tokens.append(_tokens) - masks.append(_mask) + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + except StopIteration: + # At this point, there will be no new sequences, so we need to stop + # after yielding the already accumulated data (one batch). + # In this case, either: + # 1. We have a complete batch, so things go as normal + # 2. We have an incomplete batch, but due to creating a right sized batch, + # then filling the values in, this case is automatically handled. + stop_iteration = True x = np.full((batch_size, seq_len), fill_value=pad_id) y = np.full((batch_size, seq_len), fill_value=pad_id) + m = np.zeros((batch_size, seq_len)) for i, tok_seq in enumerate(tokens): x[i, : len(tok_seq)] = tok_seq y[i, : len(tok_seq) - 1] = tok_seq[1:] - batch = Batch(x=x, y=y) + m[i : len(tok_seq)] = masks[i] + batch = Batch(x=x, y=y, mask=m) + tokens = [] + masks = [] assert ( batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" yield batch + if stop_iteration: + break def _create_iter_from_patch_lengths(self): + """ + TODO: Make this work for evals, specifically the last partial batch + """ sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -199,29 +217,33 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): pad_to_max_length = self.packing_args.pad_to_max_length enable_byte_ngrams = self.packing_args.enable_byte_ngrams max_length = self.packing_args.max_length + assert max_length is not None + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + patch_lengths: list[list[int]] = [] + stop_iteration = False while True: - tokens: list[list[int]] = [] - masks: list[list[bool]] = [] - patch_lengths: list[list[int]] = [] - - for _ in range(self.packing_args.batch_size): - sequence = next(sequence_iter) - _tokens = sequence.tokens - _mask = sequence.mask - _patch_lengths = sequence.patch_lengths - assert ( - _patch_lengths is not None - ), "patch lengths are required for packing based on patches." - # Reminder: seq_len is in terms of patches - assert len(sequence.patch_lengths) == self.packing_args.seq_len - last_patch_length = 0 - if _patch_lengths[0] > 1: - last_patch_length = _patch_lengths[-1] - _patch_lengths[0] -= 1 - _patch_lengths = [1] + _patch_lengths[:-1] - tokens.append(_tokens[: len(_tokens) - last_patch_length]) - masks.append(_mask[: len(_mask) - last_patch_length]) - patch_lengths.append(_patch_lengths) + try: + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches + assert len(sequence.patch_lengths) == self.packing_args.seq_len + last_patch_length = 0 + if _patch_lengths[0] > 1: + last_patch_length = _patch_lengths[-1] + _patch_lengths[0] -= 1 + _patch_lengths = [1] + _patch_lengths[:-1] + tokens.append(_tokens[: len(_tokens) - last_patch_length]) + masks.append(_mask[: len(_mask) - last_patch_length]) + patch_lengths.append(_patch_lengths) + except StopIteration: + stop_iteration = True x_patch_lengths = np.array(patch_lengths) # pad batch to same length @@ -249,6 +271,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ngram_ids=ngram_ids, mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), ) + tokens = [] + masks = [] + patch_lengths = [] + assert ( x_patch_lengths.sum() == x.size + batch_size ), f"{x_patch_lengths.sum()} != {x.size + batch_size}" @@ -269,3 +295,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): enable_byte_ngrams=enable_byte_ngrams, ) yield batch + if stop_iteration: + break