mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-23 10:09:09 +00:00
Merge 06df8e8651
into e299427ae4
This commit is contained in:
commit
2749d0e435
2 changed files with 8 additions and 3 deletions
bytelatent
|
@ -156,7 +156,11 @@ class DataloaderArgs(BaseModel):
|
|||
self, rank: int, world_size: int
|
||||
) -> dict[str, SequenceIterator]:
|
||||
sequence_packing_args = SequencePackingArgs(
|
||||
output_seq_len=self.seq_len,
|
||||
output_seq_len=(
|
||||
self.seq_len + 1
|
||||
if self.patcher_args.patching_mode == PatchingModeEnum.byte
|
||||
else self.seq_len
|
||||
),
|
||||
buffer_size=self.buffer_size,
|
||||
)
|
||||
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
|
||||
|
|
|
@ -201,9 +201,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|||
m = np.zeros((batch_size, seq_len), dtype=np.bool)
|
||||
|
||||
for i, tok_seq in enumerate(tokens):
|
||||
x[i, : len(tok_seq)] = tok_seq
|
||||
x[i, : len(tok_seq) - 1] = tok_seq[:-1]
|
||||
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
||||
m[i, : len(tok_seq)] = masks[i]
|
||||
m[i, : len(tok_seq)] = masks[i][1:]
|
||||
|
||||
batch = Batch(x=x, y=y, mask=m)
|
||||
assert (
|
||||
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
||||
|
|
Loading…
Add table
Reference in a new issue