mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 06:14:35 +00:00
Add rope fp32 (#43)
* Log model * Add flag for rope outer in fp32 --------- Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
6fbaf7266f
commit
739dc71a0a
4 changed files with 34 additions and 9 deletions
|
@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel):
|
|||
norm_eps: float = 1e-5
|
||||
|
||||
rope_theta: float = 10000.0
|
||||
rope_use_fp32_in_outer_product: bool = False
|
||||
|
||||
init_base_std: float | None = None
|
||||
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
||||
|
@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
|
|||
)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
def precompute_freqs_cis(
|
||||
dim: int,
|
||||
end: int,
|
||||
theta: float = 10000.0,
|
||||
rope_use_fp32_in_outer_product: bool = False,
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
|
@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
|||
"""
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
if rope_use_fp32_in_outer_product:
|
||||
t = t.to(torch.float32)
|
||||
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
|
@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
RotaryEmbedding Module
|
||||
"""
|
||||
|
||||
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
|
||||
def __init__(
|
||||
self,
|
||||
theta: float,
|
||||
head_dim: int,
|
||||
max_seqlen: int = 1024,
|
||||
rope_use_fp32_in_outer_product: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.theta = theta
|
||||
self.head_dim = head_dim
|
||||
self.max_seqlen = max_seqlen
|
||||
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
|
||||
|
||||
self.register_buffer(
|
||||
"freqs_cis",
|
||||
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
|
||||
precompute_freqs_cis(
|
||||
dim=head_dim,
|
||||
end=max_seqlen,
|
||||
theta=theta,
|
||||
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.freqs_cis[...] = precompute_freqs_cis(
|
||||
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
||||
dim=self.head_dim,
|
||||
end=self.max_seqlen,
|
||||
theta=self.theta,
|
||||
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -577,6 +601,7 @@ class BaseTransformer(nn.Module):
|
|||
theta=args.rope_theta,
|
||||
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||
max_seqlen=args.max_seqlen,
|
||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||
)
|
||||
self.eos_id = args.eos_id
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue