mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Pass mask in packing_iterator, correctly handle last batch, fix masking
This commit does/fixes the following: 1. Adds unit tests for byte and patch packing to ensure it works correctly 2. Fixes a bug where for batches that end up with <max_length number of bytes (e.g., short patches), the mask was including elements that had value pad_id. This fixes the mask by setting it to be !=pad_id, if its not specified. 3. Correctly handles the last batch, where previously it would crash. This didn't affect training since we had enough data and/or looped iterators, but for evaluation perplexity, it comes up if we validation on an entire file. 4. Correctly forward the mask if it exists for byte packing Test Plan: ``` pytest bytelatent/ ``` Testing these changes more thoroughly in a stacked PR that fixes evals
This commit is contained in:
parent
2655e4cf82
commit
203bff3696
|
@ -119,12 +119,18 @@ def truncate_batch(
|
||||||
y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
|
y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
|
||||||
y[:, : batch.y.shape[1]] = batch.y
|
y[:, : batch.y.shape[1]] = batch.y
|
||||||
batch.y = y
|
batch.y = y
|
||||||
if batch.mask is not None and batch.mask.shape[1] < max_length:
|
if batch.mask is None:
|
||||||
mask = np.full(
|
mask = batch.x != pad_id
|
||||||
(batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
|
# Only set the mask if its actually doing anything
|
||||||
)
|
if mask.sum() != batch.x.size:
|
||||||
mask[:, : batch.mask.shape[1]] = batch.mask
|
batch.mask = mask
|
||||||
batch.mask = mask
|
else:
|
||||||
|
if batch.mask.shape[1] < max_length:
|
||||||
|
mask = np.full(
|
||||||
|
(batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
|
||||||
|
)
|
||||||
|
mask[:, : batch.mask.shape[1]] = batch.mask
|
||||||
|
batch.mask = mask
|
||||||
|
|
||||||
assert batch.x.shape[1] <= max_length
|
assert batch.x.shape[1] <= max_length
|
||||||
assert batch.y.shape[1] <= max_length
|
assert batch.y.shape[1] <= max_length
|
||||||
|
@ -173,31 +179,46 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
batch_size = self.packing_args.batch_size
|
batch_size = self.packing_args.batch_size
|
||||||
pad_id = self.packing_args.pad_id
|
pad_id = self.packing_args.pad_id
|
||||||
seq_len = self.packing_args.seq_len
|
seq_len = self.packing_args.seq_len
|
||||||
|
stop_iteration = False
|
||||||
|
tokens: list[list[int]] = []
|
||||||
|
masks: list[list[bool]] = []
|
||||||
while True:
|
while True:
|
||||||
tokens: list[list[int]] = []
|
try:
|
||||||
masks: list[list[bool]] = []
|
for _ in range(self.packing_args.batch_size):
|
||||||
|
sequence = next(sequence_iter)
|
||||||
for _ in range(self.packing_args.batch_size):
|
_tokens = sequence.tokens
|
||||||
sequence = next(sequence_iter)
|
_mask = sequence.mask
|
||||||
_tokens = sequence.tokens
|
assert (
|
||||||
_mask = sequence.mask
|
sequence.patch_lengths is None
|
||||||
assert (
|
), "patch_lengths should not be used in byte packing"
|
||||||
sequence.patch_lengths is None
|
tokens.append(_tokens)
|
||||||
), "patch_lengths should not be used in byte packing"
|
masks.append(_mask)
|
||||||
tokens.append(_tokens)
|
except StopIteration:
|
||||||
masks.append(_mask)
|
# At this point, there will be no new sequences, so we need to stop
|
||||||
|
# after yielding the already accumulated data (one batch).
|
||||||
|
# In this case, either:
|
||||||
|
# 1. We have a complete batch, so things go as normal
|
||||||
|
# 2. We have an incomplete batch, but due to creating a right sized batch,
|
||||||
|
# then filling the values in, this case is automatically handled.
|
||||||
|
stop_iteration = True
|
||||||
|
|
||||||
x = np.full((batch_size, seq_len), fill_value=pad_id)
|
x = np.full((batch_size, seq_len), fill_value=pad_id)
|
||||||
y = np.full((batch_size, seq_len), fill_value=pad_id)
|
y = np.full((batch_size, seq_len), fill_value=pad_id)
|
||||||
|
m = np.zeros((batch_size, seq_len), dtype=np.bool)
|
||||||
|
|
||||||
for i, tok_seq in enumerate(tokens):
|
for i, tok_seq in enumerate(tokens):
|
||||||
x[i, : len(tok_seq)] = tok_seq
|
x[i, : len(tok_seq)] = tok_seq
|
||||||
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
||||||
batch = Batch(x=x, y=y)
|
m[i, : len(tok_seq)] = masks[i]
|
||||||
|
batch = Batch(x=x, y=y, mask=m)
|
||||||
assert (
|
assert (
|
||||||
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
||||||
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
||||||
|
tokens = []
|
||||||
|
masks = []
|
||||||
yield batch
|
yield batch
|
||||||
|
if stop_iteration:
|
||||||
|
break
|
||||||
|
|
||||||
def _create_iter_from_patch_lengths(self):
|
def _create_iter_from_patch_lengths(self):
|
||||||
sequence_iter = self.sequence_iterator.create_iter()
|
sequence_iter = self.sequence_iterator.create_iter()
|
||||||
|
@ -207,29 +228,36 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
pad_to_max_length = self.packing_args.pad_to_max_length
|
pad_to_max_length = self.packing_args.pad_to_max_length
|
||||||
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
||||||
max_length = self.packing_args.max_length
|
max_length = self.packing_args.max_length
|
||||||
|
assert max_length is not None
|
||||||
|
tokens: list[list[int]] = []
|
||||||
|
masks: list[list[bool]] = []
|
||||||
|
patch_lengths: list[list[int]] = []
|
||||||
|
stop_iteration = False
|
||||||
while True:
|
while True:
|
||||||
tokens: list[list[int]] = []
|
try:
|
||||||
masks: list[list[bool]] = []
|
for _ in range(self.packing_args.batch_size):
|
||||||
patch_lengths: list[list[int]] = []
|
sequence = next(sequence_iter)
|
||||||
|
_tokens = sequence.tokens
|
||||||
|
_mask = sequence.mask
|
||||||
|
_patch_lengths = sequence.patch_lengths
|
||||||
|
assert (
|
||||||
|
_patch_lengths is not None
|
||||||
|
), "patch lengths are required for packing based on patches."
|
||||||
|
# Reminder: seq_len is in terms of patches
|
||||||
|
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
||||||
|
last_patch_length = 0
|
||||||
|
if _patch_lengths[0] > 1:
|
||||||
|
last_patch_length = _patch_lengths[-1]
|
||||||
|
_patch_lengths[0] -= 1
|
||||||
|
_patch_lengths = [1] + _patch_lengths[:-1]
|
||||||
|
tokens.append(_tokens[: len(_tokens) - last_patch_length])
|
||||||
|
masks.append(_mask[: len(_mask) - last_patch_length])
|
||||||
|
patch_lengths.append(_patch_lengths)
|
||||||
|
except StopIteration:
|
||||||
|
stop_iteration = True
|
||||||
|
|
||||||
for _ in range(self.packing_args.batch_size):
|
if len(tokens) == 0 and stop_iteration:
|
||||||
sequence = next(sequence_iter)
|
break
|
||||||
_tokens = sequence.tokens
|
|
||||||
_mask = sequence.mask
|
|
||||||
_patch_lengths = sequence.patch_lengths
|
|
||||||
assert (
|
|
||||||
_patch_lengths is not None
|
|
||||||
), "patch lengths are required for packing based on patches."
|
|
||||||
# Reminder: seq_len is in terms of patches
|
|
||||||
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
|
||||||
last_patch_length = 0
|
|
||||||
if _patch_lengths[0] > 1:
|
|
||||||
last_patch_length = _patch_lengths[-1]
|
|
||||||
_patch_lengths[0] -= 1
|
|
||||||
_patch_lengths = [1] + _patch_lengths[:-1]
|
|
||||||
tokens.append(_tokens[: len(_tokens) - last_patch_length])
|
|
||||||
masks.append(_mask[: len(_mask) - last_patch_length])
|
|
||||||
patch_lengths.append(_patch_lengths)
|
|
||||||
|
|
||||||
x_patch_lengths = np.array(patch_lengths)
|
x_patch_lengths = np.array(patch_lengths)
|
||||||
# pad batch to same length
|
# pad batch to same length
|
||||||
|
@ -257,6 +285,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
ngram_ids=ngram_ids,
|
ngram_ids=ngram_ids,
|
||||||
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
|
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
|
||||||
)
|
)
|
||||||
|
tokens = []
|
||||||
|
masks = []
|
||||||
|
patch_lengths = []
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
x_patch_lengths.sum() == x.size + batch_size
|
x_patch_lengths.sum() == x.size + batch_size
|
||||||
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
|
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
|
||||||
|
@ -277,3 +309,5 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
enable_byte_ngrams=enable_byte_ngrams,
|
enable_byte_ngrams=enable_byte_ngrams,
|
||||||
)
|
)
|
||||||
yield batch
|
yield batch
|
||||||
|
if stop_iteration:
|
||||||
|
break
|
||||||
|
|
285
bytelatent/data/iterators/test_packing_iterator.py
Normal file
285
bytelatent/data/iterators/test_packing_iterator.py
Normal file
|
@ -0,0 +1,285 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from bytelatent.data.data_types import BltSequence
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
|
from bytelatent.data.iterators.packing_iterator import (
|
||||||
|
PackingArgs,
|
||||||
|
PackingIterator,
|
||||||
|
PackingMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummySequenceIterator(StatefulIterator):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seq_len: int,
|
||||||
|
n_seqs: int,
|
||||||
|
patch_lengths: list[int] | None = None,
|
||||||
|
pad_id: int = 0,
|
||||||
|
):
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.n_seqs = n_seqs
|
||||||
|
self.patch_lengths = patch_lengths
|
||||||
|
self.pad_id = pad_id
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
for i in range(self.n_seqs):
|
||||||
|
if self.patch_lengths is None:
|
||||||
|
tokens = np.arange(
|
||||||
|
i * self.seq_len + 1, (i + 1) * self.seq_len + 1
|
||||||
|
).tolist()
|
||||||
|
mask = [True] * self.seq_len # type: ignore
|
||||||
|
assert len(tokens) == self.seq_len
|
||||||
|
else:
|
||||||
|
n = sum(self.patch_lengths)
|
||||||
|
tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist()
|
||||||
|
assert len(tokens) == n
|
||||||
|
mask = [True] * n
|
||||||
|
assert len(mask) == len(tokens)
|
||||||
|
yield BltSequence(
|
||||||
|
tokens=tokens,
|
||||||
|
mask=mask,
|
||||||
|
patch_lengths=self.patch_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int):
|
||||||
|
sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs)
|
||||||
|
packing_iterator = PackingIterator(
|
||||||
|
sequence_iterator,
|
||||||
|
packing_args=PackingArgs(
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
pad_id=pad_id,
|
||||||
|
packing_mode=PackingMode.BYTES,
|
||||||
|
max_length=None,
|
||||||
|
pad_to_max_length=False,
|
||||||
|
enable_byte_ngrams=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return packing_iterator.create_iter()
|
||||||
|
|
||||||
|
|
||||||
|
def create_patches_iter(
|
||||||
|
*,
|
||||||
|
seq_len: int,
|
||||||
|
n_seqs: int,
|
||||||
|
batch_size: int,
|
||||||
|
pad_id: int,
|
||||||
|
patch_lengths: list[int] | None,
|
||||||
|
max_length: int,
|
||||||
|
):
|
||||||
|
sequence_iterator = DummySequenceIterator(
|
||||||
|
# seq_len=number of bytes, which for blt/patches, is max_length since seq_len is
|
||||||
|
# in terms of number of patches
|
||||||
|
seq_len=max_length,
|
||||||
|
n_seqs=n_seqs,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
)
|
||||||
|
packing_iterator = PackingIterator(
|
||||||
|
sequence_iterator,
|
||||||
|
packing_args=PackingArgs(
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
pad_id=pad_id,
|
||||||
|
packing_mode=PackingMode.PATCHING,
|
||||||
|
max_length=max_length,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
enable_byte_ngrams=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return packing_iterator.create_iter()
|
||||||
|
|
||||||
|
|
||||||
|
def test_last_batch_correctness_bytes():
|
||||||
|
seq_len = 1024
|
||||||
|
n_seqs = 10
|
||||||
|
batch_size = 4
|
||||||
|
pad_id = 0
|
||||||
|
iterator = create_bytes_iter(
|
||||||
|
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
|
||||||
|
)
|
||||||
|
batches = []
|
||||||
|
n_nonpad = 0
|
||||||
|
n_nonmask = 0
|
||||||
|
for b in iterator:
|
||||||
|
assert b.x.shape[0] == batch_size
|
||||||
|
assert b.x.shape[1] == seq_len
|
||||||
|
n_nonpad += (b.x != pad_id).sum()
|
||||||
|
if b.mask is None:
|
||||||
|
n_nonmask += b.x.size
|
||||||
|
else:
|
||||||
|
n_nonmask += b.mask.sum()
|
||||||
|
batches.append(b)
|
||||||
|
assert len(batches) == 3
|
||||||
|
assert n_nonpad == n_nonmask == seq_len * n_seqs
|
||||||
|
# The second half of the last batch should be all pads
|
||||||
|
assert batches[-1].mask[2:].sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_edgecase_batch_correctness_bytes():
|
||||||
|
seq_len = 1024
|
||||||
|
n_seqs = 10
|
||||||
|
batch_size = 12
|
||||||
|
pad_id = 0
|
||||||
|
iterator = create_bytes_iter(
|
||||||
|
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
|
||||||
|
)
|
||||||
|
batches = []
|
||||||
|
n_nonpad = 0
|
||||||
|
n_nonmask = 0
|
||||||
|
for b in iterator:
|
||||||
|
assert b.x.shape[0] == batch_size
|
||||||
|
assert b.x.shape[1] == seq_len
|
||||||
|
n_nonpad += (b.x != pad_id).sum()
|
||||||
|
if b.mask is None:
|
||||||
|
n_nonmask += b.x.size
|
||||||
|
else:
|
||||||
|
n_nonmask += b.mask.sum()
|
||||||
|
batches.append(b)
|
||||||
|
assert len(batches) == 1
|
||||||
|
assert n_nonpad == n_nonmask == seq_len * n_seqs
|
||||||
|
# The second half of the last batch should be all pads
|
||||||
|
assert batches[0].mask[10:].sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_exact_batch_correctness_bytes():
|
||||||
|
seq_len = 1024
|
||||||
|
n_seqs = 12
|
||||||
|
batch_size = 4
|
||||||
|
pad_id = 0
|
||||||
|
iterator = create_bytes_iter(
|
||||||
|
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
|
||||||
|
)
|
||||||
|
batches = []
|
||||||
|
n_nonpad = 0
|
||||||
|
n_nonmask = 0
|
||||||
|
for b in iterator:
|
||||||
|
assert b.x.shape[0] == batch_size
|
||||||
|
assert b.x.shape[1] == seq_len
|
||||||
|
n_nonpad += (b.x != pad_id).sum()
|
||||||
|
if b.mask is None:
|
||||||
|
n_nonmask += b.x.size
|
||||||
|
else:
|
||||||
|
n_nonmask += b.mask.sum()
|
||||||
|
batches.append(b)
|
||||||
|
assert len(batches) == 4
|
||||||
|
assert n_nonpad == n_nonmask == seq_len * n_seqs
|
||||||
|
|
||||||
|
|
||||||
|
def test_exact_batch_correctness_patches():
|
||||||
|
# First patch length is forced to be 1
|
||||||
|
patch_lengths = [1, 255, 256, 256, 256]
|
||||||
|
# Recall: This is in terms of bytes
|
||||||
|
max_length = 1024
|
||||||
|
# Recall: This is in terms of patches
|
||||||
|
seq_len = 5
|
||||||
|
n_seqs = 12
|
||||||
|
batch_size = 4
|
||||||
|
pad_id = 0
|
||||||
|
iterator = create_patches_iter(
|
||||||
|
seq_len=seq_len,
|
||||||
|
n_seqs=n_seqs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pad_id=pad_id,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
batches = []
|
||||||
|
n_nonpad = 0
|
||||||
|
n_nonmask = 0
|
||||||
|
for batch in iterator:
|
||||||
|
assert batch.x.shape[0] == batch_size
|
||||||
|
assert batch.x.shape[1] == max_length
|
||||||
|
n_nonpad += (batch.x != pad_id).sum()
|
||||||
|
if batch.mask is None:
|
||||||
|
n_nonmask += batch.x.size
|
||||||
|
else:
|
||||||
|
n_nonmask += batch.mask.sum()
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
assert len(batches) == 3
|
||||||
|
|
||||||
|
# max_length - 1 is due to chopping off the last byte for
|
||||||
|
# having a y target
|
||||||
|
assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs
|
||||||
|
|
||||||
|
|
||||||
|
def test_short_batch_correctness_patches():
|
||||||
|
# First patch length is forced to be 1
|
||||||
|
# Total=48
|
||||||
|
patch_lengths = [1, 11, 12, 12, 12]
|
||||||
|
# Recall: This is in terms of bytes
|
||||||
|
max_length = 1024
|
||||||
|
# Recall: This is in terms of patches
|
||||||
|
seq_len = 5
|
||||||
|
n_seqs = 12
|
||||||
|
batch_size = 4
|
||||||
|
pad_id = 0
|
||||||
|
iterator = create_patches_iter(
|
||||||
|
seq_len=seq_len,
|
||||||
|
n_seqs=n_seqs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pad_id=pad_id,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
batches = []
|
||||||
|
n_nonpad = 0
|
||||||
|
n_nonmask = 0
|
||||||
|
for batch in iterator:
|
||||||
|
assert batch.x.shape[0] == batch_size
|
||||||
|
assert batch.x.shape[1] == max_length
|
||||||
|
n_nonpad += (batch.x != pad_id).sum()
|
||||||
|
if batch.mask is None:
|
||||||
|
n_nonmask += batch.x.size
|
||||||
|
else:
|
||||||
|
n_nonmask += batch.mask.sum()
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
assert len(batches) == 3
|
||||||
|
|
||||||
|
# We'll still always have one byte chopped off the end
|
||||||
|
assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_long_batch_correctness_patches():
|
||||||
|
# First patch length is forced to be 1
|
||||||
|
# Total=48
|
||||||
|
patch_lengths = [1, 255, 256, 256, 1024]
|
||||||
|
# Recall: This is in terms of bytes
|
||||||
|
max_length = 1024
|
||||||
|
# Recall: This is in terms of patches
|
||||||
|
seq_len = 5
|
||||||
|
n_seqs = 12
|
||||||
|
batch_size = 4
|
||||||
|
pad_id = 0
|
||||||
|
iterator = create_patches_iter(
|
||||||
|
seq_len=seq_len,
|
||||||
|
n_seqs=n_seqs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pad_id=pad_id,
|
||||||
|
patch_lengths=patch_lengths,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
batches = []
|
||||||
|
n_nonpad = 0
|
||||||
|
n_nonmask = 0
|
||||||
|
for batch in iterator:
|
||||||
|
assert batch.x.shape[0] == batch_size
|
||||||
|
assert batch.x.shape[1] == max_length
|
||||||
|
n_nonpad += (batch.x != pad_id).sum()
|
||||||
|
if batch.mask is None:
|
||||||
|
n_nonmask += batch.x.size
|
||||||
|
else:
|
||||||
|
n_nonmask += batch.mask.sum()
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
assert len(batches) == 3
|
||||||
|
|
||||||
|
# No chop here since the next byte is available
|
||||||
|
assert n_nonpad == n_nonmask == max_length * n_seqs
|
Loading…
Reference in a new issue