diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index fd448eb..064217e 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -28,6 +28,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -57,6 +58,7 @@ def test_basic_arrow_file(): row_num=251, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -77,6 +79,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/data/test_data.py b/bytelatent/data/test_data.py new file mode 100644 index 0000000..15f2996 --- /dev/null +++ b/bytelatent/data/test_data.py @@ -0,0 +1,46 @@ +import os +import pickle +import pytest +from omegaconf import OmegaConf +from bytelatent.args import TrainArgs +from bytelatent.constants import BLT_DATA + + +def get_test_config(): + if "BLT_INTERNAL" in os.environ: + internal_dir = os.environ["BLT_INTERNAL"] + else: + internal_dir = "../internal-blt/configs" + test_config = os.path.join(internal_dir, "tests.yaml") + return test_config + + +@pytest.mark.skipif( + not os.path.exists(get_test_config()), + reason="Skipping since internal config is missing", +) +def test_first_batch_matches(): + test_config_path = get_test_config() + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + file_cfg = OmegaConf.load(test_config_path) + merged_cfg = OmegaConf.merge(default_cfg, file_cfg) + merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(merged_cfg) + # MP doesn't work with async very well, but it doesn't change logic + train_args.data.load_async = False + + # Test data created by pickling first batch in train loop then exiting + with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: + first_batch = pickle.load(f) + + # Emulate 1 node, 8 gpu training + data_loader = train_args.data.build_from_rank(0, 8) + batch_iterator = data_loader.create_iter() + print("Getting first batch") + batch = next(batch_iterator) + assert (batch.x == first_batch.x).all() + assert (batch.y == first_batch.y).all() + assert (batch.mask == first_batch.mask).all() + assert (batch.patch_lengths == first_batch.patch_lengths).all() + assert batch.ngram_ids is None and first_batch.ngram_ids is None + assert batch.is_final == False and batch.is_final == False diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index d6eab14..d92a1fb 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -74,12 +74,10 @@ class LocalModelBase(nn.Module): self.boe_id = BOE_ID - self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.layers = nn.ModuleList( [TransformerBlock(args) for _ in range(args.n_layers)] ) - self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) if not self.use_rope: self.pos_embeddings = nn.Embedding(args.max_length, args.dim) else: @@ -131,16 +129,18 @@ class LocalModelBase(nn.Module): def init_weights(self, init_std=None): self.rope.reset_parameters() - self.norm.reset_parameters() + if hasattr(self, "norm"): + self.norm.reset_parameters() init_std = init_std or (self.dim ** (-0.5)) - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + if hasattr(self, "tok_embeddings"): + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) if self.pos_embeddings is not None: nn.init.trunc_normal_( self.pos_embeddings.weight, @@ -212,6 +212,8 @@ class LocalEncoder(LocalModelBase): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads + self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 @@ -314,6 +316,8 @@ class LocalDecoder(LocalModelBase): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 9db7ff6..8623eb1 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -25,6 +25,7 @@ def test_entropy_model(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs(