mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
import os
|
|
import pickle
|
|
|
|
import pytest
|
|
from omegaconf import OmegaConf
|
|
|
|
from bytelatent.args import TrainArgs
|
|
from bytelatent.constants import BLT_DATA
|
|
|
|
|
|
def get_test_config():
|
|
if "BLT_INTERNAL" in os.environ:
|
|
internal_dir = os.environ["BLT_INTERNAL"]
|
|
else:
|
|
internal_dir = "../internal-blt/configs"
|
|
test_config = os.path.join(internal_dir, "tests.yaml")
|
|
return test_config
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not os.path.exists(get_test_config()),
|
|
reason="Skipping since internal config is missing",
|
|
)
|
|
def test_first_batch_matches():
|
|
test_config_path = get_test_config()
|
|
default_cfg = OmegaConf.create(TrainArgs().model_dump())
|
|
file_cfg = OmegaConf.load(test_config_path)
|
|
merged_cfg = OmegaConf.merge(default_cfg, file_cfg)
|
|
merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True)
|
|
train_args = TrainArgs.model_validate(merged_cfg)
|
|
# MP doesn't work with async very well, but it doesn't change logic
|
|
train_args.data.load_async = False
|
|
|
|
# Test data created by pickling first batch in train loop then exiting
|
|
with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f:
|
|
first_batch = pickle.load(f)
|
|
|
|
# Emulate 1 node, 8 gpu training
|
|
data_loader = train_args.data.build_from_rank(0, 8)
|
|
batch_iterator = data_loader.create_iter()
|
|
print("Getting first batch")
|
|
batch = next(batch_iterator)
|
|
assert (batch.x == first_batch.x).all()
|
|
assert (batch.y == first_batch.y).all()
|
|
assert (batch.mask == first_batch.mask).all()
|
|
assert (batch.patch_lengths == first_batch.patch_lengths).all()
|
|
assert batch.ngram_ids is None and first_batch.ngram_ids is None
|
|
assert batch.is_final == False and batch.is_final == False
|