diff --git a/bytelatent/args.py b/bytelatent/args.py index ebe43ba..cb2d582 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -156,7 +156,11 @@ class DataloaderArgs(BaseModel): self, rank: int, world_size: int ) -> dict[str, SequenceIterator]: sequence_packing_args = SequencePackingArgs( - output_seq_len=self.seq_len, + output_seq_len=( + self.seq_len + 1 + if self.patcher_args.patching_mode == PatchingModeEnum.byte + else self.seq_len + ), buffer_size=self.buffer_size, ) source_to_sequence_iterator: dict[str, SequenceIterator] = {} diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index d220342..3bba5f3 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -201,9 +201,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): m = np.zeros((batch_size, seq_len), dtype=np.bool) for i, tok_seq in enumerate(tokens): - x[i, : len(tok_seq)] = tok_seq + x[i, : len(tok_seq) - 1] = tok_seq[:-1] y[i, : len(tok_seq) - 1] = tok_seq[1:] - m[i, : len(tok_seq)] = masks[i] + m[i, : len(tok_seq)] = masks[i][1:] + batch = Batch(x=x, y=y, mask=m) assert ( batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()