import numpy as np from bytelatent.data.data_types import BltSequence from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.packing_iterator import ( PackingArgs, PackingIterator, PackingMode, ) class DummySequenceIterator(StatefulIterator): def __init__( self, *, seq_len: int, n_seqs: int, patch_lengths: list[int] | None = None, pad_id: int = 0, ): self.seq_len = seq_len self.n_seqs = n_seqs self.patch_lengths = patch_lengths self.pad_id = pad_id def get_state(self): raise NotImplementedError() def create_iter(self): for i in range(self.n_seqs): if self.patch_lengths is None: tokens = np.arange( i * self.seq_len + 1, (i + 1) * self.seq_len + 1 ).tolist() mask = [True] * self.seq_len # type: ignore assert len(tokens) == self.seq_len else: n = sum(self.patch_lengths) tokens = np.arange(i * n + 1, (i + 1) * n + 1).tolist() assert len(tokens) == n mask = [True] * n assert len(mask) == len(tokens) yield BltSequence( tokens=tokens, mask=mask, patch_lengths=self.patch_lengths, ) def create_bytes_iter(*, seq_len: int, n_seqs: int, batch_size: int, pad_id: int): sequence_iterator = DummySequenceIterator(seq_len=seq_len, n_seqs=n_seqs) packing_iterator = PackingIterator( sequence_iterator, packing_args=PackingArgs( batch_size=batch_size, seq_len=seq_len, pad_id=pad_id, packing_mode=PackingMode.BYTES, max_length=None, pad_to_max_length=False, enable_byte_ngrams=False, ), ) return packing_iterator.create_iter() def create_patches_iter( *, seq_len: int, n_seqs: int, batch_size: int, pad_id: int, patch_lengths: list[int] | None, max_length: int, ): sequence_iterator = DummySequenceIterator( # seq_len=number of bytes, which for blt/patches, is max_length since seq_len is # in terms of number of patches seq_len=max_length, n_seqs=n_seqs, patch_lengths=patch_lengths, ) packing_iterator = PackingIterator( sequence_iterator, packing_args=PackingArgs( batch_size=batch_size, seq_len=seq_len, pad_id=pad_id, packing_mode=PackingMode.PATCHING, max_length=max_length, pad_to_max_length=True, enable_byte_ngrams=False, ), ) return packing_iterator.create_iter() def test_last_batch_correctness_bytes(): seq_len = 1024 n_seqs = 10 batch_size = 4 pad_id = 0 iterator = create_bytes_iter( seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id ) batches = [] n_nonpad = 0 n_nonmask = 0 for b in iterator: assert b.x.shape[0] == batch_size assert b.x.shape[1] == seq_len n_nonpad += (b.x != pad_id).sum() if b.mask is None: n_nonmask += b.x.size else: n_nonmask += b.mask.sum() batches.append(b) assert len(batches) == 3 assert n_nonpad == n_nonmask == seq_len * n_seqs # The second half of the last batch should be all pads assert batches[-1].mask[2:].sum() == 0 def test_edgecase_batch_correctness_bytes(): seq_len = 1024 n_seqs = 10 batch_size = 12 pad_id = 0 iterator = create_bytes_iter( seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id ) batches = [] n_nonpad = 0 n_nonmask = 0 for b in iterator: assert b.x.shape[0] == batch_size assert b.x.shape[1] == seq_len n_nonpad += (b.x != pad_id).sum() if b.mask is None: n_nonmask += b.x.size else: n_nonmask += b.mask.sum() batches.append(b) assert len(batches) == 1 assert n_nonpad == n_nonmask == seq_len * n_seqs # The second half of the last batch should be all pads assert batches[0].mask[10:].sum() == 0 def test_exact_batch_correctness_bytes(): seq_len = 1024 n_seqs = 12 batch_size = 4 pad_id = 0 iterator = create_bytes_iter( seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id ) batches = [] n_nonpad = 0 n_nonmask = 0 for b in iterator: assert b.x.shape[0] == batch_size assert b.x.shape[1] == seq_len n_nonpad += (b.x != pad_id).sum() if b.mask is None: n_nonmask += b.x.size else: n_nonmask += b.mask.sum() batches.append(b) assert len(batches) == 4 assert n_nonpad == n_nonmask == seq_len * n_seqs def test_exact_batch_correctness_patches(): # First patch length is forced to be 1 patch_lengths = [1, 255, 256, 256, 256] # Recall: This is in terms of bytes max_length = 1024 # Recall: This is in terms of patches seq_len = 5 n_seqs = 12 batch_size = 4 pad_id = 0 iterator = create_patches_iter( seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id, patch_lengths=patch_lengths, max_length=max_length, ) batches = [] n_nonpad = 0 n_nonmask = 0 for batch in iterator: assert batch.x.shape[0] == batch_size assert batch.x.shape[1] == max_length n_nonpad += (batch.x != pad_id).sum() if batch.mask is None: n_nonmask += batch.x.size else: n_nonmask += batch.mask.sum() batches.append(batch) assert len(batches) == 3 # max_length - 1 is due to chopping off the last byte for # having a y target assert n_nonpad == n_nonmask == (max_length - 1) * n_seqs def test_short_batch_correctness_patches(): # First patch length is forced to be 1 # Total=48 patch_lengths = [1, 11, 12, 12, 12] # Recall: This is in terms of bytes max_length = 1024 # Recall: This is in terms of patches seq_len = 5 n_seqs = 12 batch_size = 4 pad_id = 0 iterator = create_patches_iter( seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id, patch_lengths=patch_lengths, max_length=max_length, ) batches = [] n_nonpad = 0 n_nonmask = 0 for batch in iterator: assert batch.x.shape[0] == batch_size assert batch.x.shape[1] == max_length n_nonpad += (batch.x != pad_id).sum() if batch.mask is None: n_nonmask += batch.x.size else: n_nonmask += batch.mask.sum() batches.append(batch) assert len(batches) == 3 # We'll still always have one byte chopped off the end assert n_nonpad == n_nonmask == ((sum(patch_lengths) - 1) * n_seqs) def test_long_batch_correctness_patches(): # First patch length is forced to be 1 # Total=48 patch_lengths = [1, 255, 256, 256, 1024] # Recall: This is in terms of bytes max_length = 1024 # Recall: This is in terms of patches seq_len = 5 n_seqs = 12 batch_size = 4 pad_id = 0 iterator = create_patches_iter( seq_len=seq_len, n_seqs=n_seqs, batch_size=batch_size, pad_id=pad_id, patch_lengths=patch_lengths, max_length=max_length, ) batches = [] n_nonpad = 0 n_nonmask = 0 for batch in iterator: assert batch.x.shape[0] == batch_size assert batch.x.shape[1] == max_length n_nonpad += (batch.x != pad_id).sum() if batch.mask is None: n_nonmask += batch.x.size else: n_nonmask += batch.mask.sum() batches.append(batch) assert len(batches) == 3 # No chop here since the next byte is available assert n_nonpad == n_nonmask == max_length * n_seqs