blt/bytelatent/data/iterators/test_packing_iterator.py

286 lines
8 KiB
Python
Raw Permalink Normal View History

import numpy as np
from bytelatent.data.data_types import BltSequence
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.packing_iterator import (
PackingArgs,
PackingIterator,
PackingMode,
)
class DummySequenceIterator(StatefulIterator):
def __init__(
self,
*,
seq_len: int,
n_seqs: int,
patch_lengths: list[int] | None = None,
pad_id: int = 0,
):
self.seq_len = seq_len
self.n_seqs = n_seqs
self.patch_lengths = patch_lengths
self.pad_id = pad_id
def get_state(self):
raise NotImplementedError()
def create_iter(self):
for i in range(self.n_seqs):
if self.patch_lengths is None:
tokens = np.arange(
i * self.seq_len + 1, (i + 1) * self.seq_len + 1
).tolist()
mask = [True] * self.seq_len # type: ignore
assert len(tokens) == self.seq_len
else:
n = sum(self.patch_lengths)
tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist()
assert len(tokens) == n
mask = [True] * n
assert len(mask) == len(tokens)
yield BltSequence(
tokens=tokens,
mask=mask,
patch_lengths=self.patch_lengths,
)
def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int):
sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs)
packing_iterator = PackingIterator(
sequence_iterator,
packing_args=PackingArgs(
batch_size=batch_size,
seq_len=seq_len,
pad_id=pad_id,
packing_mode=PackingMode.BYTES,
max_length=None,
pad_to_max_length=False,
enable_byte_ngrams=False,
),
)
return packing_iterator.create_iter()
def create_patches_iter(
*,
seq_len: int,
n_seqs: int,
batch_size: int,
pad_id: int,
patch_lengths: list[int] | None,
max_length: int,
):
sequence_iterator = DummySequenceIterator(
# seq_len=number of bytes, which for blt/patches, is max_length since seq_len is
# in terms of number of patches
seq_len=max_length,
n_seqs=n_seqs,
patch_lengths=patch_lengths,
)
packing_iterator = PackingIterator(
sequence_iterator,
packing_args=PackingArgs(
batch_size=batch_size,
seq_len=seq_len,
pad_id=pad_id,
packing_mode=PackingMode.PATCHING,
max_length=max_length,
pad_to_max_length=True,
enable_byte_ngrams=False,
),
)
return packing_iterator.create_iter()
def test_last_batch_correctness_bytes():
seq_len = 1024
n_seqs = 10
batch_size = 4
pad_id = 0
iterator = create_bytes_iter(
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
)
batches = []
n_nonpad = 0
n_nonmask = 0
for b in iterator:
assert b.x.shape[0] == batch_size
assert b.x.shape[1] == seq_len
n_nonpad += (b.x != pad_id).sum()
if b.mask is None:
n_nonmask += b.x.size
else:
n_nonmask += b.mask.sum()
batches.append(b)
assert len(batches) == 3
assert n_nonpad == n_nonmask == seq_len * n_seqs
# The second half of the last batch should be all pads
assert batches[-1].mask[2:].sum() == 0
def test_edgecase_batch_correctness_bytes():
seq_len = 1024
n_seqs = 10
batch_size = 12
pad_id = 0
iterator = create_bytes_iter(
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
)
batches = []
n_nonpad = 0
n_nonmask = 0
for b in iterator:
assert b.x.shape[0] == batch_size
assert b.x.shape[1] == seq_len
n_nonpad += (b.x != pad_id).sum()
if b.mask is None:
n_nonmask += b.x.size
else:
n_nonmask += b.mask.sum()
batches.append(b)
assert len(batches) == 1
assert n_nonpad == n_nonmask == seq_len * n_seqs
# The second half of the last batch should be all pads
assert batches[0].mask[10:].sum() == 0
def test_exact_batch_correctness_bytes():
seq_len = 1024
n_seqs = 12
batch_size = 4
pad_id = 0
iterator = create_bytes_iter(
seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id
)
batches = []
n_nonpad = 0
n_nonmask = 0
for b in iterator:
assert b.x.shape[0] == batch_size
assert b.x.shape[1] == seq_len
n_nonpad += (b.x != pad_id).sum()
if b.mask is None:
n_nonmask += b.x.size
else:
n_nonmask += b.mask.sum()
batches.append(b)
assert len(batches) == 4
assert n_nonpad == n_nonmask == seq_len * n_seqs
def test_exact_batch_correctness_patches():
# First patch length is forced to be 1
patch_lengths = [1, 255, 256, 256, 256]
# Recall: This is in terms of bytes
max_length = 1024
# Recall: This is in terms of patches
seq_len = 5
n_seqs = 12
batch_size = 4
pad_id = 0
iterator = create_patches_iter(
seq_len=seq_len,
n_seqs=n_seqs,
batch_size=batch_size,
pad_id=pad_id,
patch_lengths=patch_lengths,
max_length=max_length,
)
batches = []
n_nonpad = 0
n_nonmask = 0
for batch in iterator:
assert batch.x.shape[0] == batch_size
assert batch.x.shape[1] == max_length
n_nonpad += (batch.x != pad_id).sum()
if batch.mask is None:
n_nonmask += batch.x.size
else:
n_nonmask += batch.mask.sum()
batches.append(batch)
assert len(batches) == 3
# max_length - 1 is due to chopping off the last byte for
# having a y target
assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs
def test_short_batch_correctness_patches():
# First patch length is forced to be 1
# Total=48
patch_lengths = [1, 11, 12, 12, 12]
# Recall: This is in terms of bytes
max_length = 1024
# Recall: This is in terms of patches
seq_len = 5
n_seqs = 12
batch_size = 4
pad_id = 0
iterator = create_patches_iter(
seq_len=seq_len,
n_seqs=n_seqs,
batch_size=batch_size,
pad_id=pad_id,
patch_lengths=patch_lengths,
max_length=max_length,
)
batches = []
n_nonpad = 0
n_nonmask = 0
for batch in iterator:
assert batch.x.shape[0] == batch_size
assert batch.x.shape[1] == max_length
n_nonpad += (batch.x != pad_id).sum()
if batch.mask is None:
n_nonmask += batch.x.size
else:
n_nonmask += batch.mask.sum()
batches.append(batch)
assert len(batches) == 3
# We'll still always have one byte chopped off the end
assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs)
def test_long_batch_correctness_patches():
# First patch length is forced to be 1
# Total=48
patch_lengths = [1, 255, 256, 256, 1024]
# Recall: This is in terms of bytes
max_length = 1024
# Recall: This is in terms of patches
seq_len = 5
n_seqs = 12
batch_size = 4
pad_id = 0
iterator = create_patches_iter(
seq_len=seq_len,
n_seqs=n_seqs,
batch_size=batch_size,
pad_id=pad_id,
patch_lengths=patch_lengths,
max_length=max_length,
)
batches = []
n_nonpad = 0
n_nonmask = 0
for batch in iterator:
assert batch.x.shape[0] == batch_size
assert batch.x.shape[1] == max_length
n_nonpad += (batch.x != pad_id).sum()
if batch.mask is None:
n_nonmask += batch.x.size
else:
n_nonmask += batch.mask.sum()
batches.append(batch)
assert len(batches) == 3
# No chop here since the next byte is available
assert n_nonpad == n_nonmask == max_length * n_seqs