diff --git a/bytelatent/train.py b/bytelatent/train.py index 9b35e58..f54b393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -131,6 +131,9 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + assert ( + args.model.max_encoder_seq_length == args.data.max_encoder_seq_length + ), "max_encoder_seq_length for model and data should match" if args.entropy_model is not None: logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") @@ -611,7 +614,7 @@ def train(args: TrainArgs): interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 ).item() - interval_total_n_bytes_per_gpu = n_bytes + interval_total_n_bytes_per_gpu = n_bytes.item() interval_total_n_bytes_across_gpus = dist_sum( n_bytes, reduce_dtype=torch.bfloat16 ).item() diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index ad8affa..b65e502 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -146,16 +146,16 @@ def build_fsdp_grouping_plan(model_args: LMTransformerArgs): group_plan.append(("output", True)) else: for i in range(model_args.n_layers_local_encoder): - group_plan.append((f"local_encoder.layers.{i}", True)) - group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_encoder.layers.{i}", False)) + group_plan.append((f"local_encoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_local_decoder): - group_plan.append((f"local_decoder.layers.{i}", True)) - group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) + group_plan.append((f"local_decoder.layers.{i}", False)) + group_plan.append((f"local_decoder.cross_attn_layers.{i}", False)) for i in range(model_args.n_layers_global): - group_plan.append((f"global_transformer.layers.{i}", True)) + group_plan.append((f"global_transformer.layers.{i}", False)) for i in range(len(model_args.encoder_hash_byte_group_size)): - group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) + group_plan.append((f"encoder_hash_tok_embedding.{i}", False)) return group_plan