make sure max_encoder_seq_length matches (#55)

* make sure max_encoder_seq_length matches

* black and assert comment

---------

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-02-12 18:27:22 -08:00 committed by GitHub
parent 22c7fe1d1c
commit 48e4ad0bd2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -130,6 +130,9 @@ def validate_train_args(args: TrainArgs, output_size: int):
if args.model is not None:
logger.info(f"Setting model output size to {args.model.vocab_size}")
args.model.vocab_size = output_size
assert (
args.model.max_encoder_seq_length == args.data.max_encoder_seq_length
), "max_encoder_seq_length for model and data should match"
if args.entropy_model is not None:
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
@ -610,7 +613,7 @@ def train(args: TrainArgs):
interval_total_tok_loss_across_gpus = dist_sum(
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
).item()
interval_total_n_bytes_per_gpu = n_bytes
interval_total_n_bytes_per_gpu = n_bytes.item()
interval_total_n_bytes_across_gpus = dist_sum(
n_bytes, reduce_dtype=torch.bfloat16
).item()