disable reshard after forward (#56)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-02-12 18:33:53 -08:00 committed by GitHub
parent 48e4ad0bd2
commit 9d907fed1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
group_plan.append(("output", True)) group_plan.append(("output", True))
else: else:
for i in range(model_args.n_layers_local_encoder): for i in range(model_args.n_layers_local_encoder):
group_plan.append((f"local_encoder.layers.{i}", True)) group_plan.append((f"local_encoder.layers.{i}", False))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) group_plan.append((f"local_encoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_local_decoder): for i in range(model_args.n_layers_local_decoder):
group_plan.append((f"local_decoder.layers.{i}", True)) group_plan.append((f"local_decoder.layers.{i}", False))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) group_plan.append((f"local_decoder.cross_attn_layers.{i}", False))
for i in range(model_args.n_layers_global): for i in range(model_args.n_layers_global):
group_plan.append((f"global_transformer.layers.{i}", True)) group_plan.append((f"global_transformer.layers.{i}", False))
for i in range(len(model_args.encoder_hash_byte_group_size)): for i in range(len(model_args.encoder_hash_byte_group_size)):
group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) group_plan.append((f"encoder_hash_tok_embedding.{i}", False))
return group_plan return group_plan