From 06df8e865152dd26dea7eeeddaa05e7f06b0524a Mon Sep 17 00:00:00 2001 From: Srini Iyer Date: Wed, 9 Apr 2025 00:20:02 +0000 Subject: [PATCH] +1 to seq len for entropy model --- bytelatent/args.py | 6 +++++- bytelatent/data/iterators/packing_iterator.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index ebe43ba..cb2d582 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -156,7 +156,11 @@ class DataloaderArgs(BaseModel): self, rank: int, world_size: int ) -> dict[str, SequenceIterator]: sequence_packing_args = SequencePackingArgs( - output_seq_len=self.seq_len, + output_seq_len=( + self.seq_len + 1 + if self.patcher_args.patching_mode == PatchingModeEnum.byte + else self.seq_len + ), buffer_size=self.buffer_size, ) source_to_sequence_iterator: dict[str, SequenceIterator] = {} diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index d220342..3bba5f3 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -201,9 +201,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): m = np.zeros((batch_size, seq_len), dtype=np.bool) for i, tok_seq in enumerate(tokens): - x[i, : len(tok_seq)] = tok_seq + x[i, : len(tok_seq) - 1] = tok_seq[:-1] y[i, : len(tok_seq) - 1] = tok_seq[1:] - m[i, : len(tok_seq)] = masks[i] + m[i, : len(tok_seq)] = masks[i][1:] + batch = Batch(x=x, y=y, mask=m) assert ( batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()