From 9d907fed1c94a42f04f262b18ba38e36780c9ddc Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 12 Feb 2025 18:33:53 -0800 Subject: [PATCH] disable reshard after forward (#56) Co-authored-by: Srini Iyer --- bytelatent/transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index ad8affa..b65e502 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs): group_plan.append(("output", True)) else: for i in range(model_args.n_layers_local_encoder): - group_plan.append((f"local_encoder.layers.{i}", True)) - group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_encoder.layers.{i}", False)) + group_plan.append((f"local_encoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_local_decoder): - group_plan.append((f"local_decoder.layers.{i}", True)) - group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_decoder.layers.{i}", False)) + group_plan.append((f"local_decoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_global): - group_plan.append((f"global_transformer.layers.{i}", True)) + group_plan.append((f"global_transformer.layers.{i}", False)) for i in range(len(model_args.encoder_hash_byte_group_size)): - group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) + group_plan.append((f"encoder_hash_tok_embedding.{i}", False)) return group_plan