From e668ac0280e57429571912467df62689ed68e62b Mon Sep 17 00:00:00 2001 From: Srini Iyer <sviyer@meta.com> Date: Tue, 25 Feb 2025 20:36:20 +0000 Subject: [PATCH] Initialize rope embeddings properly for the entropy model --- bytelatent/base_transformer.py | 6 +----- bytelatent/transformer.py | 9 ++++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 25fbf71..19b1b33 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput): h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) return h - def reset_parameters(self): - # Either use fixed base std or sqrt model dim - self.rope_embeddings.reset_parameters() - def init_weights(self): - self.reset_parameters() + self.rope_embeddings.reset_parameters() for depth, layer in enumerate(self.layers): factor = { InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index da03761..906a54b 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer): return logits def reset_parameters(self, init_std=None): - # Either use fixed base std or sqrt model dim - super().reset_parameters() - init_std = init_std or (self.dim ** (-0.5)) self.norm.reset_parameters() + + def init_weights(self): + self.reset_parameters() + init_std = self.dim ** (-0.5) nn.init.trunc_normal_( self.tok_embeddings.weight, mean=0.0, @@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer): a=-3 * init_std, b=3 * init_std, ) + super().init_weights() + if not self.weight_tying: nn.init.trunc_normal_( self.output.weight,