mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
6ffeb66b53
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
205 lines
5.5 KiB
Python
205 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import logging
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from torch.nn.attention.flex_attention import BlockMask
|
|
from xformers.ops import AttentionBias
|
|
|
|
from bytelatent.base_transformer import (
|
|
BaseTransformer,
|
|
BaseTransformerArgs,
|
|
RMSNorm,
|
|
flex_attention_comp,
|
|
repeat_kv,
|
|
)
|
|
from bytelatent.model.utils import create_causal_mask
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
"""
|
|
CrossAttention block to attend to the encoder states from the decoder.
|
|
Rope is not supported.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
head_dim: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
norm_eps: float,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.head_dim = head_dim
|
|
|
|
self.n_heads = n_heads
|
|
self.n_kv_heads = n_kv_heads
|
|
self.heads_per_group = self.n_heads // self.n_kv_heads
|
|
|
|
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
|
|
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
|
|
|
self.wq = nn.Linear(
|
|
dim,
|
|
n_heads * head_dim,
|
|
bias=False,
|
|
)
|
|
self.wk = nn.Linear(
|
|
dim,
|
|
n_kv_heads * head_dim,
|
|
bias=False,
|
|
)
|
|
self.wv = nn.Linear(
|
|
dim,
|
|
n_kv_heads * head_dim,
|
|
bias=False,
|
|
)
|
|
|
|
self.wo = nn.Linear(
|
|
n_heads * head_dim,
|
|
dim,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
kv: torch.Tensor,
|
|
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
|
) -> torch.Tensor:
|
|
# B S D
|
|
bsz, seq_len, _ = x.shape
|
|
_, slen_kv, _ = kv.shape
|
|
x = self.cross_attn_norm_q(x)
|
|
kv = self.cross_attn_norm_kv(kv)
|
|
|
|
xq = self.wq(x)
|
|
xk = self.wk(kv)
|
|
xv = self.wv(kv)
|
|
|
|
output_shape = xq.shape
|
|
# B S D -> B S H D
|
|
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
|
|
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
|
|
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
|
|
|
|
xk = repeat_kv(xk, self.heads_per_group, dim=2)
|
|
xv = repeat_kv(xv, self.heads_per_group, dim=2)
|
|
|
|
assert mask is None or isinstance(mask, BlockMask)
|
|
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
|
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
|
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
|
|
|
output = self.wo(output.reshape(output_shape))
|
|
|
|
return x + output
|
|
|
|
def init_weights(self, base_std: float, factor: float = 1.0):
|
|
std = base_std * factor
|
|
|
|
nn.init.trunc_normal_(
|
|
self.wq.weight,
|
|
mean=0.0,
|
|
std=std,
|
|
a=-3 * std,
|
|
b=3 * std,
|
|
)
|
|
|
|
nn.init.trunc_normal_(
|
|
self.wk.weight,
|
|
mean=0.0,
|
|
std=std,
|
|
a=-3 * std,
|
|
b=3 * std,
|
|
)
|
|
|
|
nn.init.trunc_normal_(
|
|
self.wv.weight,
|
|
mean=0.0,
|
|
std=std,
|
|
a=-3 * std,
|
|
b=3 * std,
|
|
)
|
|
|
|
output_std = std / (2**0.5)
|
|
nn.init.trunc_normal_(
|
|
self.wo.weight,
|
|
mean=0.0,
|
|
std=output_std,
|
|
a=-3 * output_std,
|
|
b=3 * output_std,
|
|
)
|
|
self.cross_attn_norm_q.reset_parameters()
|
|
self.cross_attn_norm_kv.reset_parameters()
|
|
|
|
|
|
class GlobalTransformer(BaseTransformer):
|
|
def __init__(self, args: BaseTransformerArgs):
|
|
super().__init__(args)
|
|
self.dropout = args.dropout
|
|
self.eos_id = args.eos_id
|
|
|
|
self.token_embedding_projection = None
|
|
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
|
self.token_embedding_projection = nn.Linear(
|
|
args.dim_token_emb,
|
|
args.dim,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
tokens: torch.Tensor,
|
|
tok_idx: Optional[torch.Tensor] = None,
|
|
embeds: Optional[torch.Tensor] = None,
|
|
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
|
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
|
):
|
|
"""
|
|
Similar to BaseTransformer.forward, but with an additional embeds argument
|
|
and projection to the token space.
|
|
"""
|
|
bs, seqlen = tokens.shape
|
|
|
|
h = embeds
|
|
|
|
mask = (
|
|
mask
|
|
if mask is not None
|
|
else create_causal_mask(
|
|
seqlen,
|
|
self.attn_impl,
|
|
self.attn_bias_type,
|
|
tokens=tokens,
|
|
eos_id=self.eos_id,
|
|
)
|
|
)
|
|
|
|
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
|
|
h = self.token_embedding_projection(h)
|
|
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
|
|
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
|
return h, cache
|
|
|
|
def init_weights(self, init_base_std: float):
|
|
super().init_weights()
|
|
if self.token_embedding_projection is not None:
|
|
nn.init.trunc_normal_(
|
|
self.token_embedding_projection.weight,
|
|
mean=0.0,
|
|
std=init_base_std,
|
|
a=-3 * init_base_std,
|
|
b=3 * init_base_std,
|
|
)
|