From b28ceb624d87a129d92552ecb4f74acb422c3437 Mon Sep 17 00:00:00 2001 From: Srini Iyer Date: Thu, 6 Feb 2025 00:40:51 +0000 Subject: [PATCH] Add flag for rope outer in fp32 --- bytelatent/base_transformer.py | 33 ++++++++++++++++++++++++++++---- bytelatent/model/blt.py | 8 +++----- bytelatent/model/local_models.py | 1 + 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index dd0cce6..87d7334 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel): norm_eps: float = 1e-5 rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED @@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: ) -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + freqs = torch.outer(t, freqs).float() cos, sin = freqs.cos(), freqs.sin() @@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module): RotaryEmbedding Module """ - def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024): + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): super().__init__() self.theta = theta self.head_dim = head_dim self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product self.register_buffer( "freqs_cis", - precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta), + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), persistent=False, ) def reset_parameters(self): self.freqs_cis[...] = precompute_freqs_cis( - dim=self.head_dim, end=self.max_seqlen, theta=self.theta + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, ) def forward( @@ -577,6 +601,7 @@ class BaseTransformer(nn.Module): theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.eos_id = args.eos_id diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index a62be23..53a3be6 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): patch_in_forward: bool = False # Architecture and dimensions - dim_token: int = 256 + dim_token: int | None = None dim_global: int = 512 dim_local_decoder: int = 512 dim_local_encoder: int = 512 @@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): use_fsdp: bool = True attn_to_keep: str = "all" - # RoPE parameters - rope_theta: float = 10000.0 - rope_use_fp32_in_outer_product: bool = False - # Parameter mixing pm_size: int = 0 @@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, @@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index c16f62e..d0e24c0 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -86,6 +86,7 @@ class LocalModelBase(nn.Module): theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.pos_embeddings = None