mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
make sure max_encoder_seq_length matches
This commit is contained in:
parent
22c7fe1d1c
commit
0ce2cd45ef
|
@ -130,6 +130,7 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
if args.model is not None:
|
if args.model is not None:
|
||||||
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
||||||
args.model.vocab_size = output_size
|
args.model.vocab_size = output_size
|
||||||
|
assert(args.model.max_encoder_seq_length == args.data.max_encoder_seq_length)
|
||||||
|
|
||||||
if args.entropy_model is not None:
|
if args.entropy_model is not None:
|
||||||
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
||||||
|
@ -610,7 +611,7 @@ def train(args: TrainArgs):
|
||||||
interval_total_tok_loss_across_gpus = dist_sum(
|
interval_total_tok_loss_across_gpus = dist_sum(
|
||||||
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
||||||
).item()
|
).item()
|
||||||
interval_total_n_bytes_per_gpu = n_bytes
|
interval_total_n_bytes_per_gpu = n_bytes.item()
|
||||||
interval_total_n_bytes_across_gpus = dist_sum(
|
interval_total_n_bytes_across_gpus = dist_sum(
|
||||||
n_bytes, reduce_dtype=torch.bfloat16
|
n_bytes, reduce_dtype=torch.bfloat16
|
||||||
).item()
|
).item()
|
||||||
|
|
Loading…
Reference in a new issue