mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
Merge ab8f8a4412
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
45d52b7ae3
|
@ -131,6 +131,9 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|||
if args.model is not None:
|
||||
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
||||
args.model.vocab_size = output_size
|
||||
assert (
|
||||
args.model.max_encoder_seq_length == args.data.max_encoder_seq_length
|
||||
), "max_encoder_seq_length for model and data should match"
|
||||
|
||||
if args.entropy_model is not None:
|
||||
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
||||
|
@ -611,7 +614,7 @@ def train(args: TrainArgs):
|
|||
interval_total_tok_loss_across_gpus = dist_sum(
|
||||
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
||||
).item()
|
||||
interval_total_n_bytes_per_gpu = n_bytes
|
||||
interval_total_n_bytes_per_gpu = n_bytes.item()
|
||||
interval_total_n_bytes_across_gpus = dist_sum(
|
||||
n_bytes, reduce_dtype=torch.bfloat16
|
||||
).item()
|
||||
|
|
|
@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
|||
group_plan.append(("output", True))
|
||||
else:
|
||||
for i in range(model_args.n_layers_local_encoder):
|
||||
group_plan.append((f"local_encoder.layers.{i}", True))
|
||||
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
|
||||
group_plan.append((f"local_encoder.layers.{i}", False))
|
||||
group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
|
||||
for i in range(model_args.n_layers_local_decoder):
|
||||
group_plan.append((f"local_decoder.layers.{i}", True))
|
||||
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
|
||||
group_plan.append((f"local_decoder.layers.{i}", False))
|
||||
group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
|
||||
for i in range(model_args.n_layers_global):
|
||||
group_plan.append((f"global_transformer.layers.{i}", True))
|
||||
group_plan.append((f"global_transformer.layers.{i}", False))
|
||||
|
||||
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
||||
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
|
||||
group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
|
||||
|
||||
return group_plan
|
||||
|
||||
|
|
Loading…
Reference in a new issue