mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
comment + black
This commit is contained in:
parent
30f82211c4
commit
ba922695b3
|
@ -463,7 +463,13 @@ def parallelize_model(
|
||||||
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
||||||
|
|
||||||
if distributed_args.selective_activation_checkpointing:
|
if distributed_args.selective_activation_checkpointing:
|
||||||
for module in [model.global_transformer, model.local_encoder, model.local_decoder]:
|
# only works for blt models
|
||||||
|
# assuming that entropy models will not use checkpointing
|
||||||
|
for module in [
|
||||||
|
model.global_transformer,
|
||||||
|
model.local_encoder,
|
||||||
|
model.local_decoder,
|
||||||
|
]:
|
||||||
for i in range(len(module.layers)):
|
for i in range(len(module.layers)):
|
||||||
module.layers[i] = checkpoint_wrapper(
|
module.layers[i] = checkpoint_wrapper(
|
||||||
module.layers[i],
|
module.layers[i],
|
||||||
|
|
|
@ -179,7 +179,7 @@ class LocalModelBase(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.patch_embedding_projection is not None:
|
if self.patch_embedding_projection is not None:
|
||||||
patch_emb_std = self.dim_patch_emb ** (-0.5)
|
patch_emb_std = self.dim_patch_emb ** (-0.5)
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.patch_embedding_projection.weight,
|
self.patch_embedding_projection.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
|
|
Loading…
Reference in a new issue