From 48e4ad0bd2c21e676fb17be0e768356f68cfe68a Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 12 Feb 2025 18:27:22 -0800 Subject: [PATCH] make sure max_encoder_seq_length matches (#55) * make sure max_encoder_seq_length matches * black and assert comment --------- Co-authored-by: Srini Iyer --- bytelatent/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..0ee87df 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -130,6 +130,9 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + assert ( + args.model.max_encoder_seq_length == args.data.max_encoder_seq_length + ), "max_encoder_seq_length for model and data should match" if args.entropy_model is not None: logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") @@ -610,7 +613,7 @@ def train(args: TrainArgs): interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 ).item() - interval_total_n_bytes_per_gpu = n_bytes + interval_total_n_bytes_per_gpu = n_bytes.item() interval_total_n_bytes_across_gpus = dist_sum( n_bytes, reduce_dtype=torch.bfloat16 ).item()