blt/bytelatent/base_transformer.py

616 lines
19 KiB
Python
Raw Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
2024-12-12 23:32:30 +00:00
from enum import Enum
from typing import Optional, Tuple, Union
import torch
from pydantic import BaseModel, ConfigDict
2024-12-12 23:32:30 +00:00
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
_mask_mod_signature,
flex_attention,
)
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
2024-12-12 23:32:30 +00:00
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
else:
flex_attention_comp = None
2024-12-12 23:32:30 +00:00
class InitStdFactor(Enum):
DISABLED = "disabled" # Init std is divided by 1.0
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
2024-12-12 23:32:30 +00:00
dim: int = 512
n_layers: int = 8
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
2024-12-12 23:32:30 +00:00
ffn_dim_multiplier: float | None = None
2024-12-12 23:32:30 +00:00
multiple_of: int = 256
norm_eps: float = 1e-5
rope_theta: float = 10000.0
init_base_std: float | None = None
2024-12-12 23:32:30 +00:00
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
2024-12-12 23:32:30 +00:00
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
target.flatten(end_dim=-1),
**kwargs,
)
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
] + [2, 2]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
seq_dim: int,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
freqs_cis, xq_, seq_dim
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def lengths_to_start_ids(lengths):
doc_start = lengths.cumsum(0)
doc_start = doc_start.roll(1)
doc_start[0] = 0
return doc_start
def lengths_to_local_ids(lengths):
assert lengths.ndim == 1
nb_seqs = lengths.size(0)
total_seqlen = lengths.sum()
# This gives the document id of each token
doc_id = torch.repeat_interleave(lengths)
# Compute document start for each document
doc_start = lengths_to_start_ids(lengths)
# Compute document start for each token
doc_start = doc_start[doc_id]
# Compute the position of each token within each document
tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
return doc_id, tok_id
def generate_doc_mask_mod(
mask_mod: _mask_mod_signature,
lengths: torch.Tensor,
kv_lengths: Optional[torch.Tensor] = None,
) -> _mask_mod_signature:
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
format.
Args:
mask_mod: The mask mod to apply to the documents
lengths: Lengths of each document
Note:
What is the sequence stacked format? When assembling batches of inputs, we
take multiple sequences and stack them together to form 1 large sequence. We then
use masking to ensure that the attention scores are only applied to tokens within
the same document.
Example:
- Square mask
doc_mask lengths
a a b b b c c 2 3 2
a 1 0 0 0 0 0 0
a 1 1 0 0 0 0 0
b 0 0 1 0 0 0 0
b 0 0 1 1 0 0 0
b 0 0 1 1 1 0 0
c 0 0 0 0 0 1 0
c 0 0 0 0 0 1 1
"""
kv_lengths = kv_lengths if kv_lengths is not None else lengths
q_document_id, q_token_id = lengths_to_local_ids(lengths)
kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
q_max_idx = lengths.sum() - 1
kv_max_idx = kv_lengths.sum() - 1
def doc_mask_mod(b, h, q_idx, kv_idx):
q_idx_cap = torch.minimum(q_max_idx, q_idx)
kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
q_logical = q_token_id[q_idx_cap]
kv_logical = kv_token_id[kv_idx_cap]
inner_mask = mask_mod(b, h, q_logical, kv_logical)
return same_doc & inner_mask & valid_idx
return doc_mask_mod
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
class RotaryEmbedding(torch.nn.Module):
"""
RotaryEmbedding Module
"""
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
persistent=False,
)
def reset_parameters(self):
self.freqs_cis[...] = precompute_freqs_cis(
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
)
def forward(
self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
):
"""
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
Args:
seqlen (int): Contiguous sequence length
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
Returns:
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
"""
test = (seqlen is not None) or (tok_idx is not None)
assert test, "Should provide atleast seqlen or tok_idx"
if tok_idx is not None:
return self.freqs_cis[tok_idx]
elif seqlen is not None:
return self.freqs_cis[0:seqlen]
class RMSNorm(nn.Module):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
x = probe.log_stats(x, "resid")
output = self._norm(x.float())
return (output * self.weight.float()).type_as(x)
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
2024-12-12 23:32:30 +00:00
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
# B S D
bsz, seq_len, dim = x.shape
xq = self.wq(x.view_as(x))
xk = self.wk(x.view_as(x))
xv = self.wv(x.view_as(x))
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
# This condition helps us be easily compatible
# with inference by adding a pluggable KVCache
if hasattr(self, "kv_cache"):
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
if attn_impl == "flex_attention":
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "xformers":
2024-12-12 23:32:30 +00:00
assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
2024-12-12 23:32:30 +00:00
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
output = output.view(query_shape)
2024-12-12 23:32:30 +00:00
# This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa":
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
assert mask is None or isinstance(mask, (str, torch.Tensor))
is_causal = (mask == "causal") if isinstance(mask, str) else False
mask = mask if isinstance(mask, torch.Tensor) else None
output = F.scaled_dot_product_attention(
xq,
xk,
xv,
is_causal=is_causal,
attn_mask=mask,
)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
else:
raise NotImplementedError(
f"Attention implementation {attn_impl} not supported"
)
output = self.wo(output.reshape(output_shape))
return output
def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.dim ** (-0.5))
for w in [self.wq, self.wk, self.wv]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std / factor,
a=-3 * init_std,
b=3 * init_std,
)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
mp_size: int = 1,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
assert hidden_dim % mp_size == 0
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# B S D
x1 = self.w1(x.view_as(x))
x3 = self.w3(x.view_as(x))
output = self.w2(F.silu(x1) * x3)
return output
def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.dim ** (-0.5))
out_init_std = init_std or (self.hidden_dim ** (-0.5))
in_init_std = in_init_std
out_init_std = out_init_std / factor
for w in [self.w1, self.w3]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
b=3 * out_init_std,
)
class TransformerBlock(nn.Module):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
assert (args.head_dim is not None) or (
args.n_heads is not None
), "Should specify at least head_dim or n_heads"
self.head_dim = args.head_dim or args.dim // args.n_heads
self.n_heads = args.n_heads or args.dim // args.head_dim
self.n_kv_heads = args.n_kv_heads or self.n_heads
assert args.n_heads % self.n_kv_heads == 0
assert args.dim % args.n_heads == 0
self.attention = Attention(
dim=args.dim,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
attn_out = self.attention(
2024-12-12 23:32:30 +00:00
self.attention_norm(x),
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
2024-12-12 23:32:30 +00:00
return out
def init_weights(self, init_std=None, factor=1.0):
self.attention.reset_parameters(init_std, factor)
self.attention_norm.reset_parameters()
self.feed_forward.reset_parameters(init_std, factor)
self.ffn_norm.reset_parameters()
class BaseTransformer(nn.Module):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
2024-12-12 23:32:30 +00:00
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
)
self.eos_id = args.eos_id
2024-12-12 23:32:30 +00:00
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
def forward(
self,
h,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
):
freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
for i, layer in enumerate(self.layers):
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
return h
def reset_parameters(self):
# Either use fixed base std or sqrt model dim
self.rope_embeddings.reset_parameters()
def init_weights(self):
self.reset_parameters()
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(self.init_base_std, factor)