From 48e4ad0bd2c21e676fb17be0e768356f68cfe68a Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 12 Feb 2025 18:27:22 -0800 Subject: [PATCH 1/3] make sure max_encoder_seq_length matches (#55) * make sure max_encoder_seq_length matches * black and assert comment --------- Co-authored-by: Srini Iyer --- bytelatent/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..0ee87df 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -130,6 +130,9 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + assert ( + args.model.max_encoder_seq_length == args.data.max_encoder_seq_length + ), "max_encoder_seq_length for model and data should match" if args.entropy_model is not None: logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") @@ -610,7 +613,7 @@ def train(args: TrainArgs): interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 ).item() - interval_total_n_bytes_per_gpu = n_bytes + interval_total_n_bytes_per_gpu = n_bytes.item() interval_total_n_bytes_across_gpus = dist_sum( n_bytes, reduce_dtype=torch.bfloat16 ).item() From 9d907fed1c94a42f04f262b18ba38e36780c9ddc Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 12 Feb 2025 18:33:53 -0800 Subject: [PATCH 2/3] disable reshard after forward (#56) Co-authored-by: Srini Iyer --- bytelatent/transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index ad8affa..b65e502 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs): group_plan.append(("output", True)) else: for i in range(model_args.n_layers_local_encoder): - group_plan.append((f"local_encoder.layers.{i}", True)) - group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_encoder.layers.{i}", False)) + group_plan.append((f"local_encoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_local_decoder): - group_plan.append((f"local_decoder.layers.{i}", True)) - group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_decoder.layers.{i}", False)) + group_plan.append((f"local_decoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_global): - group_plan.append((f"global_transformer.layers.{i}", True)) + group_plan.append((f"global_transformer.layers.{i}", False)) for i in range(len(model_args.encoder_hash_byte_group_size)): - group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) + group_plan.append((f"encoder_hash_tok_embedding.{i}", False)) return group_plan From ab8f8a4412dd3b203d0d6bbc5fb73cfc1abbda97 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 13 Feb 2025 18:04:30 +0000 Subject: [PATCH 3/3] Test first batch matches Summary: Test Plan: --- .../data/iterators/test_arrow_iterator.py | 3 ++ bytelatent/data/test_data.py | 48 +++++++++++++++++++ bytelatent/test_entropy_model.py | 1 + pyproject.toml | 1 + 4 files changed, 53 insertions(+) create mode 100644 bytelatent/data/test_data.py diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index fd448eb..064217e 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -28,6 +28,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -57,6 +58,7 @@ def test_basic_arrow_file(): row_num=251, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -77,6 +79,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/data/test_data.py b/bytelatent/data/test_data.py new file mode 100644 index 0000000..efb8bcf --- /dev/null +++ b/bytelatent/data/test_data.py @@ -0,0 +1,48 @@ +import os +import pickle + +import pytest +from omegaconf import OmegaConf + +from bytelatent.args import TrainArgs +from bytelatent.constants import BLT_DATA + + +def get_test_config(): + if "BLT_INTERNAL" in os.environ: + internal_dir = os.environ["BLT_INTERNAL"] + else: + internal_dir = "../internal-blt/configs" + test_config = os.path.join(internal_dir, "tests.yaml") + return test_config + + +@pytest.mark.skipif( + not os.path.exists(get_test_config()), + reason="Skipping since internal config is missing", +) +def test_first_batch_matches(): + test_config_path = get_test_config() + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + file_cfg = OmegaConf.load(test_config_path) + merged_cfg = OmegaConf.merge(default_cfg, file_cfg) + merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(merged_cfg) + # MP doesn't work with async very well, but it doesn't change logic + train_args.data.load_async = False + + # Test data created by pickling first batch in train loop then exiting + with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: + first_batch = pickle.load(f) + + # Emulate 1 node, 8 gpu training + data_loader = train_args.data.build_from_rank(0, 8) + batch_iterator = data_loader.create_iter() + print("Getting first batch") + batch = next(batch_iterator) + assert (batch.x == first_batch.x).all() + assert (batch.y == first_batch.y).all() + assert (batch.mask == first_batch.mask).all() + assert (batch.patch_lengths == first_batch.patch_lengths).all() + assert batch.ngram_ids is None and first_batch.ngram_ids is None + assert batch.is_final == False and batch.is_final == False diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 9db7ff6..8623eb1 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -25,6 +25,7 @@ def test_entropy_model(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/pyproject.toml b/pyproject.toml index e2ecd0d..814d8d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,5 @@ profile = "black" known_bytelatent = "bytelatent" known_apps = "apps" +known_third_party = "wandb" sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER"