Pass mask in packing_iterator, correctly handle last batch

This commit is contained in:
Pedro Rodriguez 2025-02-20 20:15:45 +00:00
parent 0ffe2ab685
commit 55ddb0f84b

View file

@ -165,10 +165,11 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
seq_len = self.packing_args.seq_len
while True:
stop_iteration = False
tokens: list[list[int]] = []
masks: list[list[bool]] = []
while True:
try:
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
@ -178,20 +179,37 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
), "patch_lengths should not be used in byte packing"
tokens.append(_tokens)
masks.append(_mask)
except StopIteration:
# At this point, there will be no new sequences, so we need to stop
# after yielding the already accumulated data (one batch).
# In this case, either:
# 1. We have a complete batch, so things go as normal
# 2. We have an incomplete batch, but due to creating a right sized batch,
# then filling the values in, this case is automatically handled.
stop_iteration = True
x = np.full((batch_size, seq_len), fill_value=pad_id)
y = np.full((batch_size, seq_len), fill_value=pad_id)
m = np.zeros((batch_size, seq_len))
for i, tok_seq in enumerate(tokens):
x[i, : len(tok_seq)] = tok_seq
y[i, : len(tok_seq) - 1] = tok_seq[1:]
batch = Batch(x=x, y=y)
m[i : len(tok_seq)] = masks[i]
batch = Batch(x=x, y=y, mask=m)
tokens = []
masks = []
assert (
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
yield batch
if stop_iteration:
break
def _create_iter_from_patch_lengths(self):
"""
TODO: Make this work for evals, specifically the last partial batch
"""
sequence_iter = self.sequence_iterator.create_iter()
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
@ -199,11 +217,13 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
pad_to_max_length = self.packing_args.pad_to_max_length
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
max_length = self.packing_args.max_length
while True:
assert max_length is not None
tokens: list[list[int]] = []
masks: list[list[bool]] = []
patch_lengths: list[list[int]] = []
stop_iteration = False
while True:
try:
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
@ -222,6 +242,8 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
tokens.append(_tokens[: len(_tokens) - last_patch_length])
masks.append(_mask[: len(_mask) - last_patch_length])
patch_lengths.append(_patch_lengths)
except StopIteration:
stop_iteration = True
x_patch_lengths = np.array(patch_lengths)
# pad batch to same length
@ -249,6 +271,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
ngram_ids=ngram_ids,
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
)
tokens = []
masks = []
patch_lengths = []
assert (
x_patch_lengths.sum() == x.size + batch_size
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
@ -269,3 +295,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
enable_byte_ngrams=enable_byte_ngrams,
)
yield batch
if stop_iteration:
break