diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index ad8affa..b65e502 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs): group_plan.append(("output", True)) else: for i in range(model_args.n_layers_local_encoder): - group_plan.append((f"local_encoder.layers.{i}", True)) - group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_encoder.layers.{i}", False)) + group_plan.append((f"local_encoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_local_decoder): - group_plan.append((f"local_decoder.layers.{i}", True)) - group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_decoder.layers.{i}", False)) + group_plan.append((f"local_decoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_global): - group_plan.append((f"global_transformer.layers.{i}", True)) + group_plan.append((f"global_transformer.layers.{i}", False)) for i in range(len(model_args.encoder_hash_byte_group_size)): - group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) + group_plan.append((f"encoder_hash_tok_embedding.{i}", False)) return group_plan