Merge ab8f8a4412 into sapling-pr-archive-EntilZha

This commit is contained in:
Pedro Rodriguez 2025-02-13 10:04:43 -08:00 committed by GitHub
commit 45d52b7ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 10 additions and 7 deletions

View file

@ -131,6 +131,9 @@ def validate_train_args(args: TrainArgs, output_size: int):
if args.model is not None: if args.model is not None:
logger.info(f"Setting model output size to {args.model.vocab_size}") logger.info(f"Setting model output size to {args.model.vocab_size}")
args.model.vocab_size = output_size args.model.vocab_size = output_size
assert (
args.model.max_encoder_seq_length == args.data.max_encoder_seq_length
), "max_encoder_seq_length for model and data should match"
if args.entropy_model is not None: if args.entropy_model is not None:
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
@ -611,7 +614,7 @@ def train(args: TrainArgs):
interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_across_gpus = dist_sum(
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
).item() ).item()
interval_total_n_bytes_per_gpu = n_bytes interval_total_n_bytes_per_gpu = n_bytes.item()
interval_total_n_bytes_across_gpus = dist_sum( interval_total_n_bytes_across_gpus = dist_sum(
n_bytes, reduce_dtype=torch.bfloat16 n_bytes, reduce_dtype=torch.bfloat16
).item() ).item()

View file

@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
group_plan.append(("output", True)) group_plan.append(("output", True))
else: else:
for i in range(model_args.n_layers_local_encoder): for i in range(model_args.n_layers_local_encoder):
group_plan.append((f"local_encoder.layers.{i}", True)) group_plan.append((f"local_encoder.layers.{i}", False))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_local_decoder): for i in range(model_args.n_layers_local_decoder):
group_plan.append((f"local_decoder.layers.{i}", True)) group_plan.append((f"local_decoder.layers.{i}", False))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_global): for i in range(model_args.n_layers_global):
group_plan.append((f"global_transformer.layers.{i}", True)) group_plan.append((f"global_transformer.layers.{i}", False))
for i in range(len(model_args.encoder_hash_byte_group_size)): for i in range(len(model_args.encoder_hash_byte_group_size)):
group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
return group_plan return group_plan