diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 3c4a3f8..298d3d4 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -463,7 +463,13 @@ def parallelize_model( raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}") if distributed_args.selective_activation_checkpointing: - for module in [model.global_transformer, model.local_encoder, model.local_decoder]: + # only works for blt models + # assuming that entropy models will not use checkpointing + for module in [ + model.global_transformer, + model.local_encoder, + model.local_decoder, + ]: for i in range(len(module.layers)): module.layers[i] = checkpoint_wrapper( module.layers[i], diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index f088768..d6eab14 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -179,7 +179,7 @@ class LocalModelBase(nn.Module): ) if self.patch_embedding_projection is not None: - patch_emb_std = self.dim_patch_emb ** (-0.5) + patch_emb_std = self.dim_patch_emb ** (-0.5) nn.init.trunc_normal_( self.patch_embedding_projection.weight, mean=0.0,