Merge 55ddb0f84b into sapling-pr-archive-EntilZha

This commit is contained in:
Pedro Rodriguez 2025-02-20 12:16:14 -08:00 committed by GitHub
commit 86abff94d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -165,33 +165,51 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
seq_len = self.packing_args.seq_len
stop_iteration = False
tokens: list[list[int]] = []
masks: list[list[bool]] = []
while True:
tokens: list[list[int]] = []
masks: list[list[bool]] = []
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
_mask = sequence.mask
assert (
sequence.patch_lengths is None
), "patch_lengths should not be used in byte packing"
tokens.append(_tokens)
masks.append(_mask)
try:
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
_mask = sequence.mask
assert (
sequence.patch_lengths is None
), "patch_lengths should not be used in byte packing"
tokens.append(_tokens)
masks.append(_mask)
except StopIteration:
# At this point, there will be no new sequences, so we need to stop
# after yielding the already accumulated data (one batch).
# In this case, either:
# 1. We have a complete batch, so things go as normal
# 2. We have an incomplete batch, but due to creating a right sized batch,
# then filling the values in, this case is automatically handled.
stop_iteration = True
x = np.full((batch_size, seq_len), fill_value=pad_id)
y = np.full((batch_size, seq_len), fill_value=pad_id)
m = np.zeros((batch_size, seq_len))
for i, tok_seq in enumerate(tokens):
x[i, : len(tok_seq)] = tok_seq
y[i, : len(tok_seq) - 1] = tok_seq[1:]
batch = Batch(x=x, y=y)
m[i : len(tok_seq)] = masks[i]
batch = Batch(x=x, y=y, mask=m)
tokens = []
masks = []
assert (
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
yield batch
if stop_iteration:
break
def _create_iter_from_patch_lengths(self):
"""
TODO: Make this work for evals, specifically the last partial batch
"""
sequence_iter = self.sequence_iterator.create_iter()
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
@ -199,29 +217,33 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
pad_to_max_length = self.packing_args.pad_to_max_length
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
max_length = self.packing_args.max_length
assert max_length is not None
tokens: list[list[int]] = []
masks: list[list[bool]] = []
patch_lengths: list[list[int]] = []
stop_iteration = False
while True:
tokens: list[list[int]] = []
masks: list[list[bool]] = []
patch_lengths: list[list[int]] = []
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
_mask = sequence.mask
_patch_lengths = sequence.patch_lengths
assert (
_patch_lengths is not None
), "patch lengths are required for packing based on patches."
# Reminder: seq_len is in terms of patches
assert len(sequence.patch_lengths) == self.packing_args.seq_len
last_patch_length = 0
if _patch_lengths[0] > 1:
last_patch_length = _patch_lengths[-1]
_patch_lengths[0] -= 1
_patch_lengths = [1] + _patch_lengths[:-1]
tokens.append(_tokens[: len(_tokens) - last_patch_length])
masks.append(_mask[: len(_mask) - last_patch_length])
patch_lengths.append(_patch_lengths)
try:
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
_mask = sequence.mask
_patch_lengths = sequence.patch_lengths
assert (
_patch_lengths is not None
), "patch lengths are required for packing based on patches."
# Reminder: seq_len is in terms of patches
assert len(sequence.patch_lengths) == self.packing_args.seq_len
last_patch_length = 0
if _patch_lengths[0] > 1:
last_patch_length = _patch_lengths[-1]
_patch_lengths[0] -= 1
_patch_lengths = [1] + _patch_lengths[:-1]
tokens.append(_tokens[: len(_tokens) - last_patch_length])
masks.append(_mask[: len(_mask) - last_patch_length])
patch_lengths.append(_patch_lengths)
except StopIteration:
stop_iteration = True
x_patch_lengths = np.array(patch_lengths)
# pad batch to same length
@ -249,6 +271,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
ngram_ids=ngram_ids,
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
)
tokens = []
masks = []
patch_lengths = []
assert (
x_patch_lengths.sum() == x.size + batch_size
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
@ -269,3 +295,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
enable_byte_ngrams=enable_byte_ngrams,
)
yield batch
if stop_iteration:
break