mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-09 03:09:10 +00:00
Initialize rope embeddings properly for the entropy model
This commit is contained in:
parent
aeb95f12a1
commit
e668ac0280
2 changed files with 7 additions and 8 deletions
bytelatent
|
@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
|||
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
||||
return h
|
||||
|
||||
def reset_parameters(self):
|
||||
# Either use fixed base std or sqrt model dim
|
||||
self.rope_embeddings.reset_parameters()
|
||||
|
||||
def init_weights(self):
|
||||
self.reset_parameters()
|
||||
self.rope_embeddings.reset_parameters()
|
||||
for depth, layer in enumerate(self.layers):
|
||||
factor = {
|
||||
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
||||
|
|
|
@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
|
|||
return logits
|
||||
|
||||
def reset_parameters(self, init_std=None):
|
||||
# Either use fixed base std or sqrt model dim
|
||||
super().reset_parameters()
|
||||
init_std = init_std or (self.dim ** (-0.5))
|
||||
self.norm.reset_parameters()
|
||||
|
||||
def init_weights(self):
|
||||
self.reset_parameters()
|
||||
init_std = self.dim ** (-0.5)
|
||||
nn.init.trunc_normal_(
|
||||
self.tok_embeddings.weight,
|
||||
mean=0.0,
|
||||
|
@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
|
|||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
super().init_weights()
|
||||
|
||||
if not self.weight_tying:
|
||||
nn.init.trunc_normal_(
|
||||
self.output.weight,
|
||||
|
|
Loading…
Add table
Reference in a new issue