diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py
index 7b76b9e..d44676d 100644
--- a/bytelatent/base_transformer.py
+++ b/bytelatent/base_transformer.py
@@ -1,4 +1,5 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
 import os
 from enum import Enum
 from typing import Optional, Tuple, Union
@@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
 )
 from xformers.ops import AttentionBias, fmha
 
-from bytelatent import probe
 from bytelatent.tokenizers.constants import EOS_ID
 
+logger = logging.getLogger()
+
 try:
     from apex.normalization.fused_layer_norm import FusedRMSNorm
 
     RMSNorm = FusedRMSNorm
 except (ImportError, ModuleNotFoundError):
-    print("Apex not found. Using nn.RMSNorm")
+    logging.debug("Apex not found. Using nn.RMSNorm")
     RMSNorm = nn.RMSNorm
 
 if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py
index 95b6d8b..a6cabdc 100644
--- a/bytelatent/model/latent_transformer.py
+++ b/bytelatent/model/latent_transformer.py
@@ -17,16 +17,15 @@ from bytelatent.base_transformer import (
 )
 from bytelatent.model.utils import create_causal_mask
 
+logger = logging.getLogger()
 try:
     from apex.normalization.fused_layer_norm import FusedRMSNorm
 
     RMSNorm = FusedRMSNorm
 except (ImportError, ModuleNotFoundError):
-    print("Apex not found. Using nn.RMSNorm")
+    logging.debug("Apex not found. Using nn.RMSNorm")
     RMSNorm = nn.RMSNorm
 
-logger = logging.getLogger()
-
 
 class CrossAttention(nn.Module):
     """
diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py
index 353c878..09a5a19 100644
--- a/bytelatent/model/local_models.py
+++ b/bytelatent/model/local_models.py
@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
 import torch
 import torch.nn
 import torch.nn as nn
-from pydantic import BaseModel, ConfigDict
+from pydantic import ConfigDict
 from torch.nn import functional as F
 from torch.nn.attention.flex_attention import BlockMask
 from xformers.ops import AttentionBias
@@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention
 from bytelatent.model.utils import create_causal_mask, downsample
 from bytelatent.tokenizers.blt_tokenizer import BOE_ID
 
+logger = logging.getLogger()
 try:
     from apex.normalization.fused_layer_norm import FusedRMSNorm
 
     RMSNorm = FusedRMSNorm
 except (ImportError, ModuleNotFoundError):
-    print("Apex not found. Using nn.RMSNorm")
+    logging.debug("Apex not found. Using nn.RMSNorm")
     RMSNorm = nn.RMSNorm
 
-logger = logging.getLogger()
-
 
 class LocalModelArgs(BaseTransformerArgs):
     model_config = ConfigDict(extra="forbid")
diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py
index 2e45ea5..da03761 100644
--- a/bytelatent/transformer.py
+++ b/bytelatent/transformer.py
@@ -1,6 +1,6 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 
-from dataclasses import dataclass
+import logging
 from typing import Optional, Tuple, Union
 
 import torch
@@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import (
     parallelize_module,
 )
 from torch.nn.attention.flex_attention import BlockMask, create_block_mask
-from xformers.ops import AttentionBias, fmha
+from xformers.ops import AttentionBias
 
 from bytelatent.base_transformer import (
     BaseTransformer,
@@ -23,12 +23,14 @@ from bytelatent.base_transformer import (
 )
 from bytelatent.model.utils import create_causal_mask
 
+logger = logging.getLogger()
+
 try:
     from apex.normalization.fused_layer_norm import FusedRMSNorm
 
     RMSNorm = FusedRMSNorm
 except (ImportError, ModuleNotFoundError):
-    print("Apex not found. Using nn.RMSNorm")
+    logging.debug("Apex not found. Using nn.RMSNorm")
     RMSNorm = nn.RMSNorm