This commit is contained in:
Srinivasan Iyer 2025-04-09 02:30:23 +00:00 committed by GitHub
commit 2749d0e435
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions
bytelatent

View file

@ -156,7 +156,11 @@ class DataloaderArgs(BaseModel):
self, rank: int, world_size: int
) -> dict[str, SequenceIterator]:
sequence_packing_args = SequencePackingArgs(
output_seq_len=self.seq_len,
output_seq_len=(
self.seq_len + 1
if self.patcher_args.patching_mode == PatchingModeEnum.byte
else self.seq_len
),
buffer_size=self.buffer_size,
)
source_to_sequence_iterator: dict[str, SequenceIterator] = {}

View file

@ -201,9 +201,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
m = np.zeros((batch_size, seq_len), dtype=np.bool)
for i, tok_seq in enumerate(tokens):
x[i, : len(tok_seq)] = tok_seq
x[i, : len(tok_seq) - 1] = tok_seq[:-1]
y[i, : len(tok_seq) - 1] = tok_seq[1:]
m[i, : len(tok_seq)] = masks[i]
m[i, : len(tok_seq)] = masks[i][1:]
batch = Batch(x=x, y=y, mask=m)
assert (
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()