Pass mask in packing_iterator, correctly handle last batch

This commit is contained in:
Pedro Rodriguez 2025-02-20 20:15:45 +00:00
parent 0ffe2ab685
commit 55ddb0f84b

View file

@ -165,33 +165,51 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
batch_size = self.packing_args.batch_size batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id pad_id = self.packing_args.pad_id
seq_len = self.packing_args.seq_len seq_len = self.packing_args.seq_len
stop_iteration = False
tokens: list[list[int]] = []
masks: list[list[bool]] = []
while True: while True:
tokens: list[list[int]] = [] try:
masks: list[list[bool]] = [] for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
for _ in range(self.packing_args.batch_size): _tokens = sequence.tokens
sequence = next(sequence_iter) _mask = sequence.mask
_tokens = sequence.tokens assert (
_mask = sequence.mask sequence.patch_lengths is None
assert ( ), "patch_lengths should not be used in byte packing"
sequence.patch_lengths is None tokens.append(_tokens)
), "patch_lengths should not be used in byte packing" masks.append(_mask)
tokens.append(_tokens) except StopIteration:
masks.append(_mask) # At this point, there will be no new sequences, so we need to stop
# after yielding the already accumulated data (one batch).
# In this case, either:
# 1. We have a complete batch, so things go as normal
# 2. We have an incomplete batch, but due to creating a right sized batch,
# then filling the values in, this case is automatically handled.
stop_iteration = True
x = np.full((batch_size, seq_len), fill_value=pad_id) x = np.full((batch_size, seq_len), fill_value=pad_id)
y = np.full((batch_size, seq_len), fill_value=pad_id) y = np.full((batch_size, seq_len), fill_value=pad_id)
m = np.zeros((batch_size, seq_len))
for i, tok_seq in enumerate(tokens): for i, tok_seq in enumerate(tokens):
x[i, : len(tok_seq)] = tok_seq x[i, : len(tok_seq)] = tok_seq
y[i, : len(tok_seq) - 1] = tok_seq[1:] y[i, : len(tok_seq) - 1] = tok_seq[1:]
batch = Batch(x=x, y=y) m[i : len(tok_seq)] = masks[i]
batch = Batch(x=x, y=y, mask=m)
tokens = []
masks = []
assert ( assert (
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
yield batch yield batch
if stop_iteration:
break
def _create_iter_from_patch_lengths(self): def _create_iter_from_patch_lengths(self):
"""
TODO: Make this work for evals, specifically the last partial batch
"""
sequence_iter = self.sequence_iterator.create_iter() sequence_iter = self.sequence_iterator.create_iter()
batch_size = self.packing_args.batch_size batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id pad_id = self.packing_args.pad_id
@ -199,29 +217,33 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
pad_to_max_length = self.packing_args.pad_to_max_length pad_to_max_length = self.packing_args.pad_to_max_length
enable_byte_ngrams = self.packing_args.enable_byte_ngrams enable_byte_ngrams = self.packing_args.enable_byte_ngrams
max_length = self.packing_args.max_length max_length = self.packing_args.max_length
assert max_length is not None
tokens: list[list[int]] = []
masks: list[list[bool]] = []
patch_lengths: list[list[int]] = []
stop_iteration = False
while True: while True:
tokens: list[list[int]] = [] try:
masks: list[list[bool]] = [] for _ in range(self.packing_args.batch_size):
patch_lengths: list[list[int]] = [] sequence = next(sequence_iter)
_tokens = sequence.tokens
for _ in range(self.packing_args.batch_size): _mask = sequence.mask
sequence = next(sequence_iter) _patch_lengths = sequence.patch_lengths
_tokens = sequence.tokens assert (
_mask = sequence.mask _patch_lengths is not None
_patch_lengths = sequence.patch_lengths ), "patch lengths are required for packing based on patches."
assert ( # Reminder: seq_len is in terms of patches
_patch_lengths is not None assert len(sequence.patch_lengths) == self.packing_args.seq_len
), "patch lengths are required for packing based on patches." last_patch_length = 0
# Reminder: seq_len is in terms of patches if _patch_lengths[0] > 1:
assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = _patch_lengths[-1]
last_patch_length = 0 _patch_lengths[0] -= 1
if _patch_lengths[0] > 1: _patch_lengths = [1] + _patch_lengths[:-1]
last_patch_length = _patch_lengths[-1] tokens.append(_tokens[: len(_tokens) - last_patch_length])
_patch_lengths[0] -= 1 masks.append(_mask[: len(_mask) - last_patch_length])
_patch_lengths = [1] + _patch_lengths[:-1] patch_lengths.append(_patch_lengths)
tokens.append(_tokens[: len(_tokens) - last_patch_length]) except StopIteration:
masks.append(_mask[: len(_mask) - last_patch_length]) stop_iteration = True
patch_lengths.append(_patch_lengths)
x_patch_lengths = np.array(patch_lengths) x_patch_lengths = np.array(patch_lengths)
# pad batch to same length # pad batch to same length
@ -249,6 +271,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
ngram_ids=ngram_ids, ngram_ids=ngram_ids,
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks), mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
) )
tokens = []
masks = []
patch_lengths = []
assert ( assert (
x_patch_lengths.sum() == x.size + batch_size x_patch_lengths.sum() == x.size + batch_size
), f"{x_patch_lengths.sum()} != {x.size + batch_size}" ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
@ -269,3 +295,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
enable_byte_ngrams=enable_byte_ngrams, enable_byte_ngrams=enable_byte_ngrams,
) )
yield batch yield batch
if stop_iteration:
break