Add flag for rope outer in fp32

This commit is contained in:
Srini Iyer 2025-02-06 00:40:51 +00:00
parent 162b99b4a3
commit b28ceb624d
3 changed files with 33 additions and 9 deletions

View file

@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel):
norm_eps: float = 1e-5
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
rope_use_fp32_in_outer_product: bool = False,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
if rope_use_fp32_in_outer_product:
t = t.to(torch.float32)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module):
RotaryEmbedding Module
"""
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
def __init__(
self,
theta: float,
head_dim: int,
max_seqlen: int = 1024,
rope_use_fp32_in_outer_product: bool = False,
):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
precompute_freqs_cis(
dim=head_dim,
end=max_seqlen,
theta=theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
),
persistent=False,
)
def reset_parameters(self):
self.freqs_cis[...] = precompute_freqs_cis(
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
dim=self.head_dim,
end=self.max_seqlen,
theta=self.theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
)
def forward(
@ -577,6 +601,7 @@ class BaseTransformer(nn.Module):
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
)
self.eos_id = args.eos_id

View file

@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
patch_in_forward: bool = False
# Architecture and dimensions
dim_token: int = 256
dim_token: int | None = None
dim_global: int = 512
dim_local_decoder: int = 512
dim_local_encoder: int = 512
@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
use_fsdp: bool = True
attn_to_keep: str = "all"
# RoPE parameters
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
# Parameter mixing
pm_size: int = 0
@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,

View file

@ -86,6 +86,7 @@ class LocalModelBase(nn.Module):
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
)
self.pos_embeddings = None