using apex rmsnorm (#57)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

* using apex rmsnorm

* added message for missing apex

* black

* missed a print

---------

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-02-14 11:22:03 -08:00 committed by GitHub
parent c49e25171e
commit f3e8125f74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 35 deletions

View file

@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
else:
@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module):
return self.freqs_cis[0:seqlen]
class RMSNorm(nn.Module):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
x = probe.log_stats(x, "resid")
output = self._norm(x.float())
return (output * self.weight.float()).type_as(x)
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,

View file

@ -12,12 +12,19 @@ from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
RMSNorm,
flex_attention_comp,
repeat_kv,
)
from bytelatent.model.utils import create_causal_mask
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
logger = logging.getLogger()
@ -44,7 +51,7 @@ class CrossAttention(nn.Module):
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
self.wq = nn.Linear(

View file

@ -14,7 +14,6 @@ from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformerArgs,
InitStdFactor,
RMSNorm,
RotaryEmbedding,
TransformerBlock,
)
@ -22,6 +21,14 @@ from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
logger = logging.getLogger()

View file

@ -19,11 +19,18 @@ from xformers.ops import AttentionBias, fmha
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
RMSNorm,
cross_entropy,
)
from bytelatent.model.utils import create_causal_mask
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30