make sure max_encoder_seq_length matches

This commit is contained in:
Srini Iyer 2025-02-13 00:52:08 +00:00
parent 22c7fe1d1c
commit 0ce2cd45ef

View file

@ -130,6 +130,7 @@ def validate_train_args(args: TrainArgs, output_size: int):
if args.model is not None: if args.model is not None:
logger.info(f"Setting model output size to {args.model.vocab_size}") logger.info(f"Setting model output size to {args.model.vocab_size}")
args.model.vocab_size = output_size args.model.vocab_size = output_size
assert(args.model.max_encoder_seq_length == args.data.max_encoder_seq_length)
if args.entropy_model is not None: if args.entropy_model is not None:
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
@ -610,7 +611,7 @@ def train(args: TrainArgs):
interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_across_gpus = dist_sum(
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
).item() ).item()
interval_total_n_bytes_per_gpu = n_bytes interval_total_n_bytes_per_gpu = n_bytes.item()
interval_total_n_bytes_across_gpus = dist_sum( interval_total_n_bytes_across_gpus = dist_sum(
n_bytes, reduce_dtype=torch.bfloat16 n_bytes, reduce_dtype=torch.bfloat16
).item() ).item()