diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py
index 44ff5e9..328b716 100644
--- a/bytelatent/data/patcher.py
+++ b/bytelatent/data/patcher.py
@@ -91,7 +91,7 @@ def calculate_entropies(
             split = split.reshape(-1, max_length)
             if device is not None:
                 split = split.to(device)
-            assert torch.all(split >= 0) and torch.all(split < 260)
+            # assert torch.all(split >= 0) and torch.all(split < 260)
             pred = entropy_model(split)
             pred = pred.reshape(-1, pred.shape[-1])[
                 : split.numel() - pad_size, :
@@ -103,7 +103,7 @@ def calculate_entropies(
         concat_entropies = torch.cat(entropies, dim=0)
         concat_entropies = concat_entropies.reshape(tokens.shape)
         concat_preds = torch.cat(preds, dim=0)
-        concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1)
+        concat_preds = concat_preds.reshape(tokens.shape[0], -1)
     return concat_entropies, concat_preds
 
 
diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py
index 30754ee..0e11a60 100644
--- a/bytelatent/entropy_model.py
+++ b/bytelatent/entropy_model.py
@@ -15,7 +15,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
         reloaded = json.loads(fr.read())
 
     torch.set_default_dtype(torch.bfloat16)
-    model_params = reloaded["model"]
+    model_params = reloaded["entropy_model"]
     logger.warning(
         "Update checkpoint to load attn and sliding window args from checkpoint"
     )
@@ -24,7 +24,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
             dim=model_params["dim"],
             n_layers=model_params["n_layers"],
             n_heads=model_params["n_heads"],
-            max_seqlen=model_params["max_length"],
+            max_seqlen=model_params["max_seqlen"],
             ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
             vocab_size=model_params["vocab_size"],
             attn_bias_type="local_block_causal",
@@ -34,7 +34,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
     )
 
     entropy_model.load_state_dict(
-        torch.load(state_dict_path, map_location=device), strict=False
+        torch.load(state_dict_path, map_location=device)["model"], strict=False
     )
     entropy_model.to(device)
     entropy_model = entropy_model.eval()