mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-15 00:29:43 +00:00
Changes for training entropy model and correcting attention in local models (#25)
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
This commit is contained in:
parent
caec8d2621
commit
6ffeb66b53
15 changed files with 349 additions and 138 deletions
|
@ -24,6 +24,7 @@ def test_entropy_model():
|
|||
dataset_files=[ARROW_TEST_DATA],
|
||||
row_num=0,
|
||||
arrow_batch_size=100,
|
||||
s3_profile=None,
|
||||
)
|
||||
arrow_file = initial_state.build()
|
||||
tokenizer_args = TokenizerArgs(
|
||||
|
@ -38,7 +39,7 @@ def test_entropy_model():
|
|||
BLT_DATA,
|
||||
"entropy_model.pth",
|
||||
),
|
||||
)
|
||||
).cuda()
|
||||
preprocess_iter = PreprocessIterator(
|
||||
arrow_file,
|
||||
tokenizer_args=tokenizer_args,
|
||||
|
@ -48,8 +49,10 @@ def test_entropy_model():
|
|||
for example in preprocess_iter.create_iter():
|
||||
tokens = torch.tensor(example.tokens).unsqueeze(0)
|
||||
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
|
||||
preds = entropy_model(tokens)
|
||||
preds = entropy_model(tokens.cuda())
|
||||
pred_entropies = entropy(preds)
|
||||
assert pred_entropies.shape == expected_entropies.shape
|
||||
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
|
||||
assert torch.allclose(
|
||||
pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5
|
||||
)
|
||||
break
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue