From e668ac0280e57429571912467df62689ed68e62b Mon Sep 17 00:00:00 2001
From: Srini Iyer <sviyer@meta.com>
Date: Tue, 25 Feb 2025 20:36:20 +0000
Subject: [PATCH] Initialize rope embeddings properly for the entropy model

---
 bytelatent/base_transformer.py | 6 +-----
 bytelatent/transformer.py      | 9 ++++++---
 2 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py
index 25fbf71..19b1b33 100644
--- a/bytelatent/base_transformer.py
+++ b/bytelatent/base_transformer.py
@@ -617,12 +617,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
             h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
         return h
 
-    def reset_parameters(self):
-        # Either use fixed base std or sqrt model dim
-        self.rope_embeddings.reset_parameters()
-
     def init_weights(self):
-        self.reset_parameters()
+        self.rope_embeddings.reset_parameters()
         for depth, layer in enumerate(self.layers):
             factor = {
                 InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py
index da03761..906a54b 100644
--- a/bytelatent/transformer.py
+++ b/bytelatent/transformer.py
@@ -116,10 +116,11 @@ class LMTransformer(BaseTransformer):
             return logits
 
     def reset_parameters(self, init_std=None):
-        # Either use fixed base std or sqrt model dim
-        super().reset_parameters()
-        init_std = init_std or (self.dim ** (-0.5))
         self.norm.reset_parameters()
+
+    def init_weights(self):
+        self.reset_parameters()
+        init_std = self.dim ** (-0.5)
         nn.init.trunc_normal_(
             self.tok_embeddings.weight,
             mean=0.0,
@@ -127,6 +128,8 @@ class LMTransformer(BaseTransformer):
             a=-3 * init_std,
             b=3 * init_std,
         )
+        super().init_weights()
+
         if not self.weight_tying:
             nn.init.trunc_normal_(
                 self.output.weight,