mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 18:19:06 +00:00
+1 to seq len for entropy model
This commit is contained in:
parent
138c2f3494
commit
06df8e8651
2 changed files with 8 additions and 3 deletions
|
@ -156,7 +156,11 @@ class DataloaderArgs(BaseModel):
|
||||||
self, rank: int, world_size: int
|
self, rank: int, world_size: int
|
||||||
) -> dict[str, SequenceIterator]:
|
) -> dict[str, SequenceIterator]:
|
||||||
sequence_packing_args = SequencePackingArgs(
|
sequence_packing_args = SequencePackingArgs(
|
||||||
output_seq_len=self.seq_len,
|
output_seq_len=(
|
||||||
|
self.seq_len + 1
|
||||||
|
if self.patcher_args.patching_mode == PatchingModeEnum.byte
|
||||||
|
else self.seq_len
|
||||||
|
),
|
||||||
buffer_size=self.buffer_size,
|
buffer_size=self.buffer_size,
|
||||||
)
|
)
|
||||||
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
|
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
|
||||||
|
|
|
@ -201,9 +201,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
m = np.zeros((batch_size, seq_len), dtype=np.bool)
|
m = np.zeros((batch_size, seq_len), dtype=np.bool)
|
||||||
|
|
||||||
for i, tok_seq in enumerate(tokens):
|
for i, tok_seq in enumerate(tokens):
|
||||||
x[i, : len(tok_seq)] = tok_seq
|
x[i, : len(tok_seq) - 1] = tok_seq[:-1]
|
||||||
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
||||||
m[i, : len(tok_seq)] = masks[i]
|
m[i, : len(tok_seq)] = masks[i][1:]
|
||||||
|
|
||||||
batch = Batch(x=x, y=y, mask=m)
|
batch = Batch(x=x, y=y, mask=m)
|
||||||
assert (
|
assert (
|
||||||
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
||||||
|
|
Loading…
Add table
Reference in a new issue