mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 13:02:14 +00:00
fix save and reload model state (#49)
Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
fe45f69fbf
commit
22c7fe1d1c
|
@ -74,12 +74,10 @@ class LocalModelBase(nn.Module):
|
|||
|
||||
self.boe_id = BOE_ID
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.layers = nn.ModuleList(
|
||||
[TransformerBlock(args) for _ in range(args.n_layers)]
|
||||
)
|
||||
|
||||
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
||||
if not self.use_rope:
|
||||
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
|
||||
else:
|
||||
|
@ -131,16 +129,18 @@ class LocalModelBase(nn.Module):
|
|||
|
||||
def init_weights(self, init_std=None):
|
||||
self.rope.reset_parameters()
|
||||
self.norm.reset_parameters()
|
||||
if hasattr(self, "norm"):
|
||||
self.norm.reset_parameters()
|
||||
|
||||
init_std = init_std or (self.dim ** (-0.5))
|
||||
nn.init.trunc_normal_(
|
||||
self.tok_embeddings.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
if hasattr(self, "tok_embeddings"):
|
||||
nn.init.trunc_normal_(
|
||||
self.tok_embeddings.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
if self.pos_embeddings is not None:
|
||||
nn.init.trunc_normal_(
|
||||
self.pos_embeddings.weight,
|
||||
|
@ -212,6 +212,8 @@ class LocalEncoder(LocalModelBase):
|
|||
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||
self.cross_attn_nheads = args.cross_attn_nheads
|
||||
|
||||
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
||||
|
||||
if self.cross_attn_encoder:
|
||||
self.cross_attn_layers = torch.nn.ModuleList()
|
||||
layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
|
||||
|
@ -314,6 +316,8 @@ class LocalDecoder(LocalModelBase):
|
|||
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||
self.cross_attn_nheads = args.cross_attn_nheads
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
if self.cross_attn_decoder:
|
||||
self.cross_attn_layers = torch.nn.ModuleList()
|
||||
layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
|
||||
|
|
Loading…
Reference in a new issue