# Copyright (c) Meta Platforms, Inc. and affiliates. import logging from typing import List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, ) from bytelatent.model.utils import create_causal_mask logger = logging.getLogger() class CrossAttention(nn.Module): """ CrossAttention block to attend to the encoder states from the decoder. Rope is not supported. """ def __init__( self, dim: int, head_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, ): super().__init__() self.dim = dim self.head_dim = head_dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps) self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) self.wq = nn.Linear( dim, n_heads * head_dim, bias=False, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, ) def forward( self, x: torch.Tensor, kv: torch.Tensor, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, ) -> torch.Tensor: # B S D bsz, seq_len, _ = x.shape _, slen_kv, _ = kv.shape x = self.cross_attn_norm_q(x) kv = self.cross_attn_norm_kv(kv) xq = self.wq(x) xk = self.wk(kv) xv = self.wv(kv) output_shape = xq.shape # B S D -> B S H D xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) xk = repeat_kv(xk, self.heads_per_group, dim=2) xv = repeat_kv(xv, self.heads_per_group, dim=2) assert mask is None or isinstance(mask, BlockMask) xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D output = self.wo(output.reshape(output_shape)) return x + output def init_weights(self, base_std: float, factor: float = 1.0): std = base_std * factor nn.init.trunc_normal_( self.wq.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) nn.init.trunc_normal_( self.wk.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) nn.init.trunc_normal_( self.wv.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) output_std = std / (2**0.5) nn.init.trunc_normal_( self.wo.weight, mean=0.0, std=output_std, a=-3 * output_std, b=3 * output_std, ) self.cross_attn_norm_q.reset_parameters() self.cross_attn_norm_kv.reset_parameters() class GlobalTransformer(BaseTransformer): def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: self.token_embedding_projection = nn.Linear( args.dim_token_emb, args.dim, bias=False, ) def forward( self, tokens: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, embeds: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ Similar to BaseTransformer.forward, but with an additional embeds argument and projection to the token space. """ bs, seqlen = tokens.shape h = embeds mask = ( mask if mask is not None else create_causal_mask( seqlen, self.attn_impl, self.attn_bias_type, tokens=tokens, eos_id=self.eos_id, ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: h = self.token_embedding_projection(h) h = F.dropout(h, p=self.dropout, training=self.training) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): super().init_weights() if self.token_embedding_projection is not None: nn.init.trunc_normal_( self.token_embedding_projection.weight, mean=0.0, std=init_base_std, a=-3 * init_base_std, b=3 * init_base_std, )