From 67624845d066efa5d9951d55d46446a8ceb70f69 Mon Sep 17 00:00:00 2001 From: Srini Iyer Date: Thu, 13 Feb 2025 00:58:55 +0000 Subject: [PATCH] disable reshard after forward --- bytelatent/transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index ad8affa..b65e502 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs): group_plan.append(("output", True)) else: for i in range(model_args.n_layers_local_encoder): - group_plan.append((f"local_encoder.layers.{i}", True)) - group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_encoder.layers.{i}", False)) + group_plan.append((f"local_encoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_local_decoder): - group_plan.append((f"local_decoder.layers.{i}", True)) - group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_decoder.layers.{i}", False)) + group_plan.append((f"local_decoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_global): - group_plan.append((f"global_transformer.layers.{i}", True)) + group_plan.append((f"global_transformer.layers.{i}", False)) for i in range(len(model_args.encoder_hash_byte_group_size)): - group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) + group_plan.append((f"encoder_hash_tok_embedding.{i}", False)) return group_plan