Initialize rope embeddings properly for the entropy model

This commit is contained in:
Srini Iyer 2025-02-25 20:36:20 +00:00
parent aeb95f12a1
commit e668ac0280
2 changed files with 7 additions and 8 deletions

View file

@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
return h
def reset_parameters(self):
# Either use fixed base std or sqrt model dim
self.rope_embeddings.reset_parameters()
def init_weights(self):
self.reset_parameters()
self.rope_embeddings.reset_parameters()
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,

View file

@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
return logits
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
super().reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
self.norm.reset_parameters()
def init_weights(self):
self.reset_parameters()
init_std = self.dim ** (-0.5)
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
a=-3 * init_std,
b=3 * init_std,
)
super().init_weights()
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,