From f3e8125f7407581e841dd7bec9ab4c12138f8505 Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Fri, 14 Feb 2025 11:22:03 -0800 Subject: [PATCH] using apex rmsnorm (#57) * using apex rmsnorm * added message for missing apex * black * missed a print --------- Co-authored-by: Srini Iyer --- bytelatent/base_transformer.py | 39 ++++++-------------------- bytelatent/model/latent_transformer.py | 11 ++++++-- bytelatent/model/local_models.py | 9 +++++- bytelatent/transformer.py | 9 +++++- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 217224f..7b76b9e 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha from bytelatent import probe from bytelatent.tokenizers.constants import EOS_ID +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) else: @@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module): return self.freqs_cis[0:seqlen] -class RMSNorm(nn.Module): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor): - x = probe.log_stats(x, "resid") - output = self._norm(x.float()) - return (output * self.weight.float()).type_as(x) - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore - - def _reshape_for_attn_bias( attn_bias: AttentionBias | None, *tensors: torch.Tensor, diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index d91f49f..95b6d8b 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -12,12 +12,19 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, - RMSNorm, flex_attention_comp, repeat_kv, ) from bytelatent.model.utils import create_causal_mask +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + logger = logging.getLogger() @@ -44,7 +51,7 @@ class CrossAttention(nn.Module): self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads - self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) self.wq = nn.Linear( diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index d92a1fb..353c878 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -14,7 +14,6 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformerArgs, InitStdFactor, - RMSNorm, RotaryEmbedding, TransformerBlock, ) @@ -22,6 +21,14 @@ from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + logger = logging.getLogger() diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index b65e502..2e45ea5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -19,11 +19,18 @@ from xformers.ops import AttentionBias, fmha from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, - RMSNorm, cross_entropy, ) from bytelatent.model.utils import create_causal_mask +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") + RMSNorm = nn.RMSNorm + def attention_flops_per_token(n_layers, seq_len, dim, causal): # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30