disable reshard after forward

This commit is contained in:
Srini Iyer 2025-02-13 00:58:55 +00:00
parent 48e4ad0bd2
commit 67624845d0

View file

@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
group_plan.append(("output", True))
else:
for i in range(model_args.n_layers_local_encoder):
group_plan.append((f"local_encoder.layers.{i}", True))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
group_plan.append((f"local_encoder.layers.{i}", False))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_local_decoder):
group_plan.append((f"local_decoder.layers.{i}", True))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
group_plan.append((f"local_decoder.layers.{i}", False))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_global):
group_plan.append((f"global_transformer.layers.{i}", True))
group_plan.append((f"global_transformer.layers.{i}", False))
for i in range(len(model_args.encoder_hash_byte_group_size)):
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
return group_plan