blt/bytelatent/model/latent_transformer.py

205 lines
5.5 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from typing import List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
2024-12-12 23:32:30 +00:00
RMSNorm,
flex_attention_comp,
repeat_kv,
)
from bytelatent.model.utils import create_causal_mask
logger = logging.getLogger()
class CrossAttention(nn.Module):
"""
CrossAttention block to attend to the encoder states from the decoder.
Rope is not supported.
"""
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
kv: torch.Tensor,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
) -> torch.Tensor:
# B S D
bsz, seq_len, _ = x.shape
_, slen_kv, _ = kv.shape
x = self.cross_attn_norm_q(x)
kv = self.cross_attn_norm_kv(kv)
xq = self.wq(x)
xk = self.wk(kv)
xv = self.wv(kv)
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
output = self.wo(output.reshape(output_shape))
return x + output
def init_weights(self, base_std: float, factor: float = 1.0):
std = base_std * factor
nn.init.trunc_normal_(
self.wq.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wk.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wv.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
output_std = std / (2**0.5)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=output_std,
a=-3 * output_std,
b=3 * output_std,
)
self.cross_attn_norm_q.reset_parameters()
self.cross_attn_norm_kv.reset_parameters()
class GlobalTransformer(BaseTransformer):
def __init__(self, args: BaseTransformerArgs):
2024-12-12 23:32:30 +00:00
super().__init__(args)
self.dropout = args.dropout
self.eos_id = args.eos_id
2024-12-12 23:32:30 +00:00
self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
self.token_embedding_projection = nn.Linear(
args.dim_token_emb,
args.dim,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
embeds: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
"""
Similar to BaseTransformer.forward, but with an additional embeds argument
and projection to the token space.
"""
bs, seqlen = tokens.shape
h = embeds
mask = (
mask
if mask is not None
else create_causal_mask(
seqlen,
self.attn_impl,
self.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
2024-12-12 23:32:30 +00:00
)
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
h = self.token_embedding_projection(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
2024-12-12 23:32:30 +00:00
return h, cache
def init_weights(self, init_base_std: float):
super().init_weights()
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
self.token_embedding_projection.weight,
mean=0.0,
std=init_base_std,
a=-3 * init_base_std,
b=3 * init_base_std,
)