mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-16 00:59:43 +00:00
using apex rmsnorm (#57)
* using apex rmsnorm * added message for missing apex * black * missed a print --------- Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
c49e25171e
commit
f3e8125f74
4 changed files with 33 additions and 35 deletions
|
@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha
|
|||
from bytelatent import probe
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
RMSNorm = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("Apex not found. Using nn.RMSNorm")
|
||||
RMSNorm = nn.RMSNorm
|
||||
|
||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||
flex_attention_comp = torch.compile(flex_attention)
|
||||
else:
|
||||
|
@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
return self.freqs_cis[0:seqlen]
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x: torch.Tensor):
|
||||
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = probe.log_stats(x, "resid")
|
||||
output = self._norm(x.float())
|
||||
return (output * self.weight.float()).type_as(x)
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight) # type: ignore
|
||||
|
||||
|
||||
def _reshape_for_attn_bias(
|
||||
attn_bias: AttentionBias | None,
|
||||
*tensors: torch.Tensor,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue