mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-21 00:59:09 +00:00
Initialize rope embeddings properly for the entropy model
This commit is contained in:
parent
aeb95f12a1
commit
e668ac0280
2 changed files with 7 additions and 8 deletions
bytelatent
|
@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
||||||
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
# Either use fixed base std or sqrt model dim
|
|
||||||
self.rope_embeddings.reset_parameters()
|
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.reset_parameters()
|
self.rope_embeddings.reset_parameters()
|
||||||
for depth, layer in enumerate(self.layers):
|
for depth, layer in enumerate(self.layers):
|
||||||
factor = {
|
factor = {
|
||||||
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
||||||
|
|
|
@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def reset_parameters(self, init_std=None):
|
def reset_parameters(self, init_std=None):
|
||||||
# Either use fixed base std or sqrt model dim
|
|
||||||
super().reset_parameters()
|
|
||||||
init_std = init_std or (self.dim ** (-0.5))
|
|
||||||
self.norm.reset_parameters()
|
self.norm.reset_parameters()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.reset_parameters()
|
||||||
|
init_std = self.dim ** (-0.5)
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.tok_embeddings.weight,
|
self.tok_embeddings.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
|
@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
|
||||||
a=-3 * init_std,
|
a=-3 * init_std,
|
||||||
b=3 * init_std,
|
b=3 * init_std,
|
||||||
)
|
)
|
||||||
|
super().init_weights()
|
||||||
|
|
||||||
if not self.weight_tying:
|
if not self.weight_tying:
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.output.weight,
|
self.output.weight,
|
||||||
|
|
Loading…
Add table
Reference in a new issue