Changes for training entropy model and correcting attention in local models (#25)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-17 14:23:01 -08:00 committed by GitHub
parent caec8d2621
commit 6ffeb66b53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 349 additions and 138 deletions

View file

@ -24,6 +24,7 @@ def test_entropy_model():
dataset_files=[ARROW_TEST_DATA],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs(
@ -38,7 +39,7 @@ def test_entropy_model():
BLT_DATA,
"entropy_model.pth",
),
)
).cuda()
preprocess_iter = PreprocessIterator(
arrow_file,
tokenizer_args=tokenizer_args,
@ -48,8 +49,10 @@ def test_entropy_model():
for example in preprocess_iter.create_iter():
tokens = torch.tensor(example.tokens).unsqueeze(0)
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
preds = entropy_model(tokens)
preds = entropy_model(tokens.cuda())
pred_entropies = entropy(preds)
assert pred_entropies.shape == expected_entropies.shape
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
assert torch.allclose(
pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5
)
break