From d3bf3a13832b5fcf8b708f29f997128df1643c65 Mon Sep 17 00:00:00 2001 From: Srini Iyer Date: Thu, 13 Feb 2025 19:46:06 +0000 Subject: [PATCH] using apex rmsnorm --- bytelatent/base_transformer.py | 38 +++++--------------------- bytelatent/model/latent_transformer.py | 10 +++++-- bytelatent/model/local_models.py | 8 +++++- bytelatent/transformer.py | 8 +++++- 4 files changed, 29 insertions(+), 35 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 217224f..1888658 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -17,6 +17,13 @@ from xformers.ops import AttentionBias, fmha from bytelatent import probe from bytelatent.tokenizers.constants import EOS_ID +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + RMSNorm = nn.RMSNorm + if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) else: @@ -294,37 +301,6 @@ class RotaryEmbedding(torch.nn.Module): return self.freqs_cis[0:seqlen] -class RMSNorm(nn.Module): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor): - x = probe.log_stats(x, "resid") - output = self._norm(x.float()) - return (output * self.weight.float()).type_as(x) - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore - - def _reshape_for_attn_bias( attn_bias: AttentionBias | None, *tensors: torch.Tensor, diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index d91f49f..b1f72a7 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -12,12 +12,18 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, - RMSNorm, flex_attention_comp, repeat_kv, ) from bytelatent.model.utils import create_causal_mask +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + RMSNorm = nn.RMSNorm + logger = logging.getLogger() @@ -44,7 +50,7 @@ class CrossAttention(nn.Module): self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads - self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) self.wq = nn.Linear( diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index d92a1fb..1d8db07 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -14,7 +14,6 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformerArgs, InitStdFactor, - RMSNorm, RotaryEmbedding, TransformerBlock, ) @@ -22,6 +21,13 @@ from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + RMSNorm = nn.RMSNorm + logger = logging.getLogger() diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index b65e502..cc9e1f3 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -19,11 +19,17 @@ from xformers.ops import AttentionBias, fmha from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, - RMSNorm, cross_entropy, ) from bytelatent.model.utils import create_causal_mask +try: + from apex.normalization.fused_layer_norm import FusedRMSNorm + + RMSNorm = FusedRMSNorm +except (ImportError, ModuleNotFoundError): + RMSNorm = nn.RMSNorm + def attention_flops_per_token(n_layers, seq_len, dim, causal): # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30