blt/bytelatent/data/test_data.py
Pedro Rodriguez 85c2f28f26
Test first batch matches (#53)
Summary:

Test Plan:
2025-02-13 10:05:08 -08:00

49 lines
1.7 KiB
Python

import os
import pickle
import pytest
from omegaconf import OmegaConf
from bytelatent.args import TrainArgs
from bytelatent.constants import BLT_DATA
def get_test_config():
if "BLT_INTERNAL" in os.environ:
internal_dir = os.environ["BLT_INTERNAL"]
else:
internal_dir = "../internal-blt/configs"
test_config = os.path.join(internal_dir, "tests.yaml")
return test_config
@pytest.mark.skipif(
not os.path.exists(get_test_config()),
reason="Skipping since internal config is missing",
)
def test_first_batch_matches():
test_config_path = get_test_config()
default_cfg = OmegaConf.create(TrainArgs().model_dump())
file_cfg = OmegaConf.load(test_config_path)
merged_cfg = OmegaConf.merge(default_cfg, file_cfg)
merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(merged_cfg)
# MP doesn't work with async very well, but it doesn't change logic
train_args.data.load_async = False
# Test data created by pickling first batch in train loop then exiting
with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f:
first_batch = pickle.load(f)
# Emulate 1 node, 8 gpu training
data_loader = train_args.data.build_from_rank(0, 8)
batch_iterator = data_loader.create_iter()
print("Getting first batch")
batch = next(batch_iterator)
assert (batch.x == first_batch.x).all()
assert (batch.y == first_batch.y).all()
assert (batch.mask == first_batch.mask).all()
assert (batch.patch_lengths == first_batch.patch_lengths).all()
assert batch.ngram_ids is None and first_batch.ngram_ids is None
assert batch.is_final == False and batch.is_final == False