mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
disable reshard after forward (#56)
Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
48e4ad0bd2
commit
9d907fed1c
|
@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
|||
group_plan.append(("output", True))
|
||||
else:
|
||||
for i in range(model_args.n_layers_local_encoder):
|
||||
group_plan.append((f"local_encoder.layers.{i}", True))
|
||||
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
|
||||
group_plan.append((f"local_encoder.layers.{i}", False))
|
||||
group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
|
||||
for i in range(model_args.n_layers_local_decoder):
|
||||
group_plan.append((f"local_decoder.layers.{i}", True))
|
||||
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
|
||||
group_plan.append((f"local_decoder.layers.{i}", False))
|
||||
group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
|
||||
for i in range(model_args.n_layers_global):
|
||||
group_plan.append((f"global_transformer.layers.{i}", True))
|
||||
group_plan.append((f"global_transformer.layers.{i}", False))
|
||||
|
||||
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
||||
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
|
||||
group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
|
||||
|
||||
return group_plan
|
||||
|
||||
|
|
Loading…
Reference in a new issue