Initialize rope embeddings properly for the entropy model ()

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-02-25 15:35:25 -08:00 committed by GitHub
parent aeb95f12a1
commit 0da051f4f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 8 deletions

View file

@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
return h
def reset_parameters(self):
# Either use fixed base std or sqrt model dim
self.rope_embeddings.reset_parameters()
def init_weights(self):
self.reset_parameters()
self.rope_embeddings.reset_parameters()
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,

View file

@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
return logits
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
super().reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
self.norm.reset_parameters()
def init_weights(self):
self.reset_parameters()
init_std = self.dim ** (-0.5)
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
a=-3 * init_std,
b=3 * init_std,
)
super().init_weights()
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,