blt/blt_one_file.py
ita.zaporozhets@huggingface.co 4f86b6e7ab consolidated model file
2025-06-03 13:30:02 +00:00

2618 lines
89 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum, auto
from typing import Any, List, Optional, Tuple, Union
import torch
from huggingface_hub import PyTorchModelHubMixin
from pydantic import model_validator
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
from typing_extensions import Self
import json
import logging
import torch
import torch.nn
import torch.nn as nn
from pydantic import ConfigDict
from torch.nn import functional as F
from xformers.ops import AttentionBias, fmha
import abc
import os
import time
from collections import defaultdict
from pydantic import BaseModel
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
RMSNorm = nn.RMSNorm
from bytelatent.distributed import get_local_rank
logger = logging.getLogger()
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
else:
flex_attention_comp = None
def patch_reduce(h, max_num_patches, reduction, patch_ids):
"""
Reduce variable length patches to single embedding per patch
Note: this works with variable number of patches for different sequences in the batch
It handles variable length patches by assuming that patch_lengths will be 0 for any
extra patches on the *right*. Since there can be a variable number of patches
this function also return the number of patches for each sequence in the batch.
Any embeddings on the right that are not allocated to a patch
(i.e. if the sum(patch_lengths[i]) < seq_len for any i)
will be sent to a dummy patch, which is trimmed before returning.
"""
bs, seq_len, emb_dim = h.shape
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
reduced_embs = torch.zeros(
(bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
)
reduced_embs = reduced_embs.scatter_reduce(
src=h,
dim=1,
index=patch_ids,
reduce=reduction,
include_self=False,
)
reduced_embs = reduced_embs[:, :max_num_patches, :]
return reduced_embs
def concat_downsample(h, patch_lengths, patch_size):
# The assumption in this function is that seq_len = patch_size * num_patches.
bs, seq_len, emb_dim = h.shape
patch_end_ids = torch.cumsum(patch_lengths, dim=1)
patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
patch_end_ids.device
)
# Is clamp ok here?
patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
patch_ids = patch_ids.view(bs, -1, emb_dim)
# after gather h.shape = [batch_size, seq_len, dim]
h = torch.gather(h, 1, patch_ids)
h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
return h
def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
cat = []
if "avg" in pooling_mode or "mean" in pooling_mode:
cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
if "min" in pooling_mode:
cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
if "max" in pooling_mode:
cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
assert len(cat) > 0
h = torch.cat(cat, dim=-1)
return h
def downsample(
h,
num_patches,
patch_lengths=None,
patch_ids=None,
downsampling_by_pooling=None,
patch_size=4,
):
"""
Downsampling:
a. concatenating embeddings in the patch
Note: with dynamic patching, patch the last patch_size tokens.
b. pooling embeddings in the patch
"""
# input: h.shape = [batch_size, seq_len, dim]
# input: pool h.shape = [batch_size, seq_len / patch_size, dim]
# if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
# By pooling
max_num_patches = num_patches
assert patch_ids is not None
h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
else:
# TODO: remove this condition
# By concatenating (fixed lengths patching)
assert patch_lengths is not None
h = concat_downsample(h, patch_lengths, patch_size)
return h
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
"""
0 0 0 1 0 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 0
-> 4 4 3 2 4 5
"""
mask = batch == eos_id
mask[:, -1] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row, col = torch.where(mask)
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [int(col[0].item() + 1)] + seqlens.tolist()
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "xformers":
if attn_bias_type is None:
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "causal":
assert sliding_window is None
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "block_causal":
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
)
elif attn_bias_type == "local_block_causal":
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
).make_local_attention(sliding_window)
else:
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "sdpa":
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
if attn_bias_type == "causal":
return "causal"
if BLT_SUPPRESS_ATTN_ERROR == 1:
return "causal"
else:
raise ValueError(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
)
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
elif attn_impl == "fmha":
return None
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
class InitStdFactor(str, Enum):
DISABLED = "disabled" # Init std is divided by 1.0
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512
n_layers: int = 8
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
ffn_dim_multiplier: float | None = None
multiple_of: int = 256
norm_eps: float = 1e-5
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
target.flatten(end_dim=-1),
**kwargs,
)
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
rope_use_fp32_in_outer_product: bool = False,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
if rope_use_fp32_in_outer_product:
t = t.to(torch.float32)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
] + [2, 2]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
seq_dim: int,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
freqs_cis, xq_, seq_dim
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
class RotaryEmbedding(torch.nn.Module):
"""
RotaryEmbedding Module
"""
def __init__(
self,
theta: float,
head_dim: int,
max_seqlen: int = 1024,
rope_use_fp32_in_outer_product: bool = False,
):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
dim=head_dim,
end=max_seqlen,
theta=theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
),
persistent=False,
)
def reset_parameters(self):
self.freqs_cis[...] = precompute_freqs_cis(
dim=self.head_dim,
end=self.max_seqlen,
theta=self.theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
)
def forward(
self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
):
"""
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
Args:
seqlen (int): Contiguous sequence length
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
Returns:
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
"""
test = (seqlen is not None) or (tok_idx is not None)
assert test, "Should provide atleast seqlen or tok_idx"
if tok_idx is not None:
return self.freqs_cis[tok_idx]
elif seqlen is not None:
return self.freqs_cis[0:seqlen]
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
# B S D
bsz, seq_len, dim = x.shape
xq = self.wq(x.view_as(x))
xk = self.wk(x.view_as(x))
xv = self.wv(x.view_as(x))
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
# This condition helps us be easily compatible
# with inference by adding a pluggable KVCache
if hasattr(self, "kv_cache"):
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
if attn_impl == "flex_attention":
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
output = output.view(query_shape)
# This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa":
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
assert mask is None or isinstance(mask, (str, torch.Tensor))
is_causal = (mask == "causal") if isinstance(mask, str) else False
mask = mask if isinstance(mask, torch.Tensor) else None
output = F.scaled_dot_product_attention(
xq,
xk,
xv,
is_causal=is_causal,
attn_mask=mask,
)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
else:
raise NotImplementedError(
f"Attention implementation {attn_impl} not supported"
)
output_reshaped = output.reshape(output_shape)
output = self.wo(output_reshaped)
return output
def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.dim ** (-0.5)) / factor
for w in [self.wq, self.wk, self.wv]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
mp_size: int = 1,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
assert hidden_dim % mp_size == 0
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# B S D
x1 = self.w1(x.view_as(x))
x3 = self.w3(x.view_as(x))
output = self.w2(F.silu(x1) * x3)
return output
def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.dim ** (-0.5)) / factor
out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
b=3 * out_init_std,
)
nn.init.trunc_normal_(
self.w3.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
class TransformerBlock(nn.Module):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
assert (args.head_dim is not None) or (
args.n_heads is not None
), "Should specify at least head_dim or n_heads"
self.head_dim = args.head_dim or args.dim // args.n_heads
self.n_heads = args.n_heads or args.dim // args.head_dim
self.n_kv_heads = args.n_kv_heads or self.n_heads
assert args.n_heads % self.n_kv_heads == 0
assert args.dim % args.n_heads == 0
self.attention = Attention(
dim=args.dim,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
norm_x = self.attention_norm(x)
attn_out = self.attention(
norm_x,
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
return out
def init_weights(self, init_std=None, factor=1.0):
self.attention.reset_parameters(init_std, factor)
self.attention_norm.reset_parameters()
self.feed_forward.reset_parameters(init_std, factor)
self.ffn_norm.reset_parameters()
class SequenceModelWithOutput(abc.ABC):
@abc.abstractmethod
def get_output_seq_len(self) -> int:
pass
class BaseTransformer(nn.Module, SequenceModelWithOutput):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
)
self.eos_id = args.eos_id
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
def get_output_seq_len(self):
return self.max_seqlen
def forward(
self,
h,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
):
freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
for i, layer in enumerate(self.layers):
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
return h
def init_weights(self):
self.rope_embeddings.reset_parameters()
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(self.init_base_std, factor)
class LMTransformerArgs(BaseTransformerArgs):
seed: int = 42
vocab_size: int = -1
weight_tying: bool = False
sliding_window: int | None = None
class LMTransformer(
BaseTransformer,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
# paper_url="https://arxiv.org/abs/2412.09871",
pipeline_tag="text-generation",
license="other",
license_name="fair-noncommercial-research-license",
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
coders={
LMTransformerArgs: (
lambda x: {"args": x.model_dump()},
lambda data: LMTransformerArgs(**data),
)
},
):
def __init__(self, args: LMTransformerArgs):
super().__init__(args)
self.weight_tying = args.weight_tying
self.sliding_window = args.sliding_window
assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
if args.weight_tying:
self.output.weight = self.embeddings.tok_embeddings.weight
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def forward(
self,
token_values: torch.Tensor,
target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str | None = None,
):
if attn_impl is None:
attn_impl = self.attn_impl
bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values)
mask = (
mask
if mask is not None
else create_causal_mask(
seqlen,
attn_impl,
self.attn_bias_type,
sliding_window=self.sliding_window,
tokens=token_values,
eos_id=self.eos_id,
)
)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
logits = self.output(self.norm(h))
if target is not None:
return cross_entropy(logits, target)
else:
return logits
def reset_parameters(self, init_std=None):
self.norm.reset_parameters()
def init_weights(self):
self.reset_parameters()
init_std = self.dim ** (-0.5)
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
super().init_weights()
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
class PatchingModeEnum(str, Enum):
entropy = "entropy"
bpe = "bpe"
bpe_patcher = "bpe_patcher"
space = "space"
static = "static"
byte = "byte"
class PatcherArgs(BaseModel):
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
patching_device: str = "cuda"
entropy_model_checkpoint_dir: str | None = None
realtime_patching: bool = False
threshold: float = 1.335442066192627
threshold_add: float | None = None
max_patch_length: int | None = None
patch_size: float = 4.5
patching_batch_size: int = 1
device: str = "cuda"
monotonicity: bool = False
log_time: bool = False
def build(self) -> "Patcher":
return Patcher(self)
def rightpad(seq, pad_id, max_len):
return seq + [pad_id] * (max_len - len(seq))
def check_non_zero_after_zero(tensor):
zero_mask = tensor == 0
shifted_mask = torch.cat(
[
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
zero_mask[:, :-1],
],
dim=1,
)
non_zero_after_zero = (tensor != 0) & shifted_mask
return non_zero_after_zero.any()
def to_device(entropy_model, device=None):
if device == "cuda":
rank = get_local_rank()
device = f"cuda:{rank}"
entropy_model = entropy_model.to(device)
return entropy_model, device
def split_large_numbers(lst, m):
new_lst = []
for i in lst:
if i > m:
while i > m:
new_lst.append(m)
i -= m
new_lst.append(i)
else:
new_lst.append(i)
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
return new_lst
class Patcher:
def __init__(self, patcher_args: PatcherArgs):
self.patcher_args = patcher_args
self.patching_mode = patcher_args.patching_mode
self.realtime_patching = patcher_args.realtime_patching
if self.realtime_patching:
assert (
patcher_args.entropy_model_checkpoint_dir is not None
), "Cannot require realtime patching without an entropy model checkpoint"
maybe_consolidated = os.path.join(
patcher_args.entropy_model_checkpoint_dir,
"consolidated/consolidated.pth",
)
if os.path.exists(maybe_consolidated):
state_path = maybe_consolidated
else:
state_path = os.path.join(
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
)
entropy_model, _ = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir,
state_path,
)
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
self.entropy_model = entropy_model
else:
self.entropy_model = None
self.threshold = patcher_args.threshold
self.threshold_add = patcher_args.threshold_add
self.max_patch_length = patcher_args.max_patch_length
self.patch_size = patcher_args.patch_size
self.patching_batch_size = patcher_args.patching_batch_size
self.device = patcher_args.device
self.monotonicity = patcher_args.monotonicity
self.log_time = patcher_args.log_time
if self.log_time:
self.log = defaultdict(float)
def patch(
self,
tokens: torch.Tensor,
include_next_token: bool = False,
preds: torch.Tensor | None = None,
entropies: torch.Tensor | None = None,
threshold: float = None,
) -> torch.Tensor:
"""
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
-> output tensor: [batch_size, max_num_patches]
each tensor is processed independently and gets right padded with zeros.
Patching with the following modes:
1. patching_mode = None: static patch size
2. patching_mode = "entropy":
calculate entropy of each token, allocate patches so that the total
number of patches is the same as static patching but choose to begin
patches on tokens where the model is most uncertain (highest entropy).
When threshold is provided, it uses the threshold to decide when to
start a new patch.
3. patching_mode = "space":
use space like tokens to define the patches.
4. patching_mode = "bpe":
use bpe delim tokens to define the patches.
To correctly patch the last token, it may be necessary to include the next token in the patch
lengths calculations. This is controlled by the include_next_token argument.
"""
bs, seq_len = tokens.shape
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
scores = None
# STATIC
if self.patching_mode == PatchingModeEnum.byte:
patch_lengths = torch.ones(
(bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
)
else:
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
# Apply any processing to patch lengths
if self.max_patch_length is not None:
# TODO: avoid going back to a list here.
patch_lengths = [
split_large_numbers(pl, self.max_patch_length)
for pl in patch_lengths.tolist()
]
max_len = max([len(pl) for pl in patch_lengths])
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
patch_lengths = torch.tensor(
patch_lengths, dtype=tokens.dtype, device=tokens.device
)
assert not check_non_zero_after_zero(patch_lengths)
# Find the last non-zero column index using argmax on a reversed version of the tensor
last_non_zero_col_reversed = (
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
)
# Slice the tensor up to the last non-zero column
patch_lengths = patch_lengths[
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
]
assert (
torch.sum(patch_lengths)
== tokens.numel() + include_next_token * tokens.shape[0]
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
if self.log_time:
self.log["postprocessing_patch_lengths"] += time.time() - s
self.log["tokens"] += patch_lengths.sum().item()
return patch_lengths, scores
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())
torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["entropy_model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
)
entropy_model_args = LMTransformerArgs(
dim=model_params["dim"],
n_layers=model_params["n_layers"],
n_heads=model_params["n_heads"],
max_seqlen=model_params["max_seqlen"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
)
entropy_model = LMTransformer(entropy_model_args)
entropy_model.load_state_dict(
torch.load(state_dict_path, map_location=device)["model"], strict=False
)
entropy_model.to(device)
entropy_model = entropy_model.eval()
# no grads for the model:
for param in entropy_model.parameters():
param.requires_grad = False
return entropy_model, entropy_model_args
def get_encoder_dim_token_emb(args):
if args.dim_token is not None:
dim_token_emb = args.dim_token
elif args.use_local_encoder_transformer:
dim_token_emb = args.dim_local_encoder
else:
dim_token_emb = args.dim_global // args.patch_size
return dim_token_emb
def get_encoder_dim_patch_emb(args):
dim_patch_emb = None
if args.cross_attn_encoder:
if args.cross_attn_init_by_pooling:
dim_patch_emb = args.dim_local_encoder
else:
dim_patch_emb = args.dim_global
return dim_patch_emb
def get_global_dim_patch_emb(args):
dim_token_emb = get_encoder_dim_token_emb(args)
if args.cross_attn_encoder:
dim_patch_emb = dim_token_emb * args.cross_attn_k
elif (
args.downsampling_by_pooling is None
or not args.downsampling_by_pooling
or len(args.downsampling_by_pooling) == 0
):
dim_patch_emb = dim_token_emb * args.patch_size
else:
dim_patch_emb = dim_token_emb * sum(
[
pooling in args.downsampling_by_pooling
for pooling in ["avg", "min", "max"]
]
)
return dim_patch_emb
def get_decoder_dim_token_emb(args):
if args.share_encoder_decoder_emb:
dim_token_emb = get_encoder_dim_token_emb(args)
elif args.dim_token is not None:
dim_token_emb = args.dim_token
else:
dim_token_emb = args.dim_local_decoder
return dim_token_emb
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
if ngram_to_size_str is None:
return None
ngram_to_size = {}
for entry in ngram_to_size_str.split(","):
ngram, size = entry.split(":")
ngram = int(ngram)
size = int(size)
ngram_to_size[ngram] = size
return ngram_to_size
def fill_tokens(tokens, patch_size, fill_id):
batch_size, seq_len = tokens.shape
if seq_len % patch_size == 0:
return tokens
else:
remaining = patch_size - seq_len % patch_size
final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
return torch.cat((tokens, final_padding), dim=1)
def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len):
first_patch_length = patch_lengths[0, 0]
assert torch.all(
first_patch_length == patch_lengths[:, 0]
), "first patch should always be the same size (1 for dynamic, patch_size for static)."
assert (
first_patch_length - nb_boe == 1
), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})"
# Remove first patch from patch_ids for local decoder inputs and shift the last patch.
# decoder_patch_lengths = patch_lengths[:, 1:].clone()
# decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
decoder_patch_lengths = patch_lengths[:, 1:]
assert (
decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]
== patch_lengths.sum()
), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}"
assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}"
decoder_patch_ids = patch_ids_from_lengths(
patch_lengths=decoder_patch_lengths, seq_len=seq_len
)
return decoder_patch_ids
primes = [
1000000007,
5915587277,
1500450271,
3267000013,
5754853343,
4093082899,
9576890767,
3628273133,
2860486313,
5463458053,
3367900313,
]
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
return torch.sum(t * prime_powers, dim=-1)
def byte_group_hash_function(
x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
):
"""
Returns a hash of the input x and maps it to a value in the range [0, max_hash].
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
Note: max hash can make a big difference on the number of collisions.
"""
with torch.no_grad():
bs, seq_len = x.shape
# x_numpy = x.numpy()
# hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
# for i in range(bs):
# for j in range(seq_len):
# start = max(j, j-group_size+1)
# end = j+1
# hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
x = torch.cat([prefix, x], dim=1)
windows = x.unfold(1, group_size, 1)
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
hashes = rolling_polynomial_hash(windows, hash_func_nb)
hash_values_range = hashes % max_hash
hash_values_range.requires_grad = False
return hash_values_range
def create_patch_mask_from_ids(
patch_ids, num_patches, window=None, patches_as_queries=False
):
"""
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
is True if the patch id at position (i, j) is less than or equal to k.
Args:
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
num_patches (int): Total number of patches.
window (int): If not None, only considers patches within a window of size window.
patches_as_queries (bool): If True, the patches are used as queries
Returns:
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
"""
bs, seq_len = patch_ids.shape
if not patches_as_queries:
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
kv_ids = (
torch.arange(num_patches, device=patch_ids.device)
.unsqueeze(0)
.unsqueeze(0)
.expand(bs, seq_len, num_patches)
)
else:
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
q_ids = (
torch.arange(num_patches, device=patch_ids.device)
.unsqueeze(0)
.unsqueeze(-1)
.expand(bs, num_patches, seq_len)
)
if window is None:
mask = q_ids == kv_ids
else:
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
return mask
def cross_attn_mask(
patch_ids,
patch_lengths,
N,
patches_as_queries=False,
cross_attn_k=1,
window=None,
block_mask=True,
):
bs = patch_ids.shape[0]
with torch.no_grad():
# Create the patch mask
cross_mask = create_patch_mask_from_ids(
patch_ids,
patch_lengths.shape[1],
window=window,
patches_as_queries=patches_as_queries,
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
assert cross_mask.shape == (
bs,
q_len,
kv_len,
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
if block_mask:
def patch_mask(b, h, q_idx, kv_idx):
return cross_mask[b, q_idx, kv_idx]
block_mask = create_block_mask(
patch_mask,
B=bs,
H=None,
Q_LEN=q_len,
KV_LEN=kv_len,
_compile=True,
)
return block_mask
else:
return torch.where(
cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
).unsqueeze(
1
) # [bs, 1, q_len, kv_len]
def get_blt_input(
tokens: torch.Tensor,
enforce_patch_size_multiple: bool,
nb_boe: torch.Tensor,
patch_size: int,
boe_id: int,
):
"""
This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
tokens respectively.
Consider the input and target sequences:
X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
with patch_size=4
Note 1: that there will be no special tokens introduced at the patch level.
Note 2: X_e needs to be trimmed to be passed to Global
Current without boe:
X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
--> lag fix:
X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
Dynamic (current):
X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
entropy patching:
input: 7, bos, 9, 10
pred (high entropy): eos, 8, 10, eos
X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
--> lag fix no boe (force single byte first patch):
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
input: 4, 7, bos, 9, 10
pred (high entropy): 5, eos, 8, 10, eos
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
Handle the last byte properly.
patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
bpe delim
X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
Note 1: that there will be no special tokens introduced at the patch level.
Note 2: X_e needs to be trimmed to be passed to Global
"""
batch_size, seq_len = tokens.shape
local_encoder_tokens = tokens
local_decoder_tokens = tokens
if nb_boe > 0:
padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
# global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
# create global tokens, contains boe tokens and eos
# padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
# patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
# global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
# global_tokens += global_tokens.eq(0).int() * boe_id
# TODO: fix this when we want to use block causal in the global.
if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
return local_encoder_tokens, None, local_decoder_tokens
def patch_ids_from_lengths(patch_lengths, seq_len):
bs, num_patches = patch_lengths.shape
# Create a tensor of cumulative sums of the patch lengths
cum_d = torch.cat(
[
torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
patch_lengths.cumsum(dim=-1),
],
dim=-1,
)
patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum(
dim=-2
) - 1
assert not (
torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0
), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0"
return patch_ids
class ByteLatentTransformerArgs(BaseTransformerArgs):
# Basic model configuration
seed: int = 42
vocab_size: int = -1
dim: int = 512
n_layers: int = 8
n_heads: int = 8
# TODO: What is the purpose of this parameter?
weight_tying: bool = False
patch_in_forward: bool = False
# Architecture and dimensions
dim_token: int | None = None
dim_global: int = 512
dim_local_decoder: int = 512
dim_local_encoder: int = 512
n_layers_global: int = 8
n_layers_local_decoder: int = 8
n_layers_local_encoder: int = 8
# Tokenization and patching
patch_size: float | None = None
patching_mode: str | None = None
patching_threshold: float | None = None
patching_threshold_add: float | None = None
monotonicity: bool = False
patching_batch_size: int = 1
patching_device: str = "cuda"
max_patch_length: int | None = None
# Encoder/Decoder configuration
tie_local_encoder_decoder_logits: bool = False
use_local_encoder_transformer: bool = False
encoder_lm_loss: bool = False
max_encoder_seq_length: int | None = None
pad_to_max_length: bool = False
encoder_enable_byte_ngrams: bool = False
encoder_enable_byte_group_hash: bool = False
ngram_vocab_sizes: int | None = None
# Cross attention configurations
cross_attn_encoder: bool = False
cross_attn_decoder: bool = False
cross_attn_window_encoder: int | None = None
cross_attn_window_decoder: int | None = None
cross_attn_k: int | None = None
cross_attn_nheads: int | None = None
cross_attn_all_layers_decoder: bool = False
cross_attn_all_layers_encoder: bool = False
cross_attn_use_flex_attention: bool = True
cross_attn_init_by_pooling: bool = False
# Encoder hash configurations
encoder_hash_byte_group_size: Any | None = None
encoder_hash_byte_group_vocab: int = 30000
encoder_hash_byte_group_nb_functions: int = 3
# Model behavior and optimization
log_patch_lengths: bool = False
non_linearity: str = "swiglu"
use_rope: bool = True
recompute_fc1_out: bool = False
recompute_fc3_out: bool = False
recompute_attn: bool = True
custom_bwd: bool = False
layer_ckpt: str = "all"
# Initialization and attention
init_use_gaussian: bool = True
init_use_depth: str = "current"
attn_bias_type: str = "causal"
alpha_depth: str = "disabled"
max_length: int = 2048
# Norm configuration
norm_eps: float = 1e-5
norm_affine: bool = True
pre_norm: bool = True
norm_type: str = "rmsnorm"
# Additional configurations
multiple_of: int = 256
ffn_dim_multiplier: float = 1.0
dropout: float = 0
output_size: int = -1
# Additional parameters from ModelArgs
architecture: str = "vanilla"
share_encoder_decoder_emb: bool = True
global_local_decoder_residual_layer: str | None = None
tokenize_with_bpe_delimiter: bool = False
patching_thresholds_str: str | None = None
tie_local_encoder_decoder: bool = False
encoder_preds_low_entropy_toks: float | None = None
encoder_preds_random_toks: float | None = None
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
encoder_ngram_table_dir: str | None = None
encoder_ngram_to_size_str: str | None = None
# Model architecture params
entropy_model_checkpoint_dir: str | None = None
entropy_model_is_ngram_model: bool = False
downsampling_by_pooling: str | None = None
n_heads_global: int = 8
n_heads_local_decoder: int = 8
n_heads_local_encoder: int = 8
n_kv_heads: int | None = None
n_kv_heads_global: int | None = None
conv_kernel_size: int | None = None
local_attention_window_len: int | None = None
# Performance optimization
sequence_parallel: bool = False
loss_parallel: bool = False
fuse_sequence_parallel: bool = False
use_fsdp: bool = True
attn_to_keep: str = "all"
# Parameter mixing
pm_size: int = 0
# Logging
full_logging_n_layers: int = 4
@model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self:
if (
self.encoder_hash_byte_group_size is not None
and type(self.encoder_hash_byte_group_size) == str
):
self.encoder_hash_byte_group_size = [
int(x)
for x in self.encoder_hash_byte_group_size.split(",")
if len(x) > 0
]
return self
class GlobalTransformerArgs(ByteLatentTransformerArgs):
# Global encoder specific dimensions
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with global encoder specific values
self.dim = self.dim_global
self.n_layers = self.n_layers_global
self.n_heads = self.n_heads_global
self.n_kv_heads = self.n_kv_heads_global
self.local_attention_window_len = None
self.cross_attn_encoder = False
self.cross_attn_decoder = False
class LocalDecoderArgs(ByteLatentTransformerArgs):
# Local decoder specific dimensions
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with local decoder specific values
self.dim = self.dim_local_decoder
self.n_layers = self.n_layers_local_decoder
self.n_heads = self.n_heads_local_decoder
self.cross_attn_encoder = False
self.cross_attn_init_by_pooling = False
self.attn_bias_type = "local_block_causal"
def create_global_transformer(args: ByteLatentTransformerArgs):
global_args = args.model_copy(
deep=True,
update=dict(
dim=args.dim_global,
n_layers=args.n_layers_global,
n_heads=args.n_heads_global,
n_kv_heads=args.n_kv_heads_global,
local_attention_window_len=None,
dim_token_emb=get_global_dim_patch_emb(args),
dim_patch_emb=None,
cross_attn_encoder=False,
cross_attn_decoder=False,
),
)
return GlobalTransformer(global_args)
class LocalModelArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Override defaults
attn_impl: str | None = "xformers"
attn_bias_type: str | None = "local_block_causal"
# Local encoder specific dimensions
dropout: float
vocab_size: int
patch_size: float
sliding_window: int | None
use_rope: bool
cross_attn_encoder: bool | None
cross_attn_decoder: bool | None
cross_attn_k: int | None
cross_attn_init_by_pooling: bool
patching_mode: str
use_local_encoder_transformer: bool
downsampling_by_pooling: str | None
encoder_hash_byte_group_size: Any | None = None
cross_attn_all_layers_encoder: bool = False
cross_attn_all_layers_decoder: bool = False
cross_attn_nheads: int | None
dim_token_emb: int
dim_patch_emb: int | None
class LocalModelBase(nn.Module):
def __init__(self, args: LocalModelArgs):
super().__init__()
self.dim = args.dim
self.dropout = args.dropout
self.vocab_size = args.vocab_size
self.patch_size = args.patch_size
self.dim_patch_emb = args.dim_patch_emb
self.attn_impl = args.attn_impl
self.sliding_window = args.sliding_window
self.use_rope = args.use_rope
self.init_std_factor = args.init_std_factor
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
self.cross_attn_k = getattr(args, "cross_attn_k", None)
self.eos_id = args.eos_id
self.boe_id = BOE_ID
self.layers = nn.ModuleList(
[TransformerBlock(args) for _ in range(args.n_layers)]
)
if not self.use_rope:
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
else:
self.rope = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
)
self.pos_embeddings = None
self.token_embedding_projection = (
nn.Linear(args.dim_token_emb, args.dim, bias=False)
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
else None
)
self.patch_embedding_projection = self._create_patch_projection(args)
def _should_create_patch_projection(self, args: LocalModelArgs):
dimension_mismatch = (
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
)
# Check cross attention conditions
cross_attn_conditions = (
args.cross_attn_encoder and args.cross_attn_init_by_pooling
) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
return dimension_mismatch or cross_attn_conditions
def _create_patch_projection(self, args):
if not self._should_create_patch_projection(args):
return None
output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
return nn.Linear(
in_features=args.dim_patch_emb,
out_features=output_dim,
bias=False,
)
def apply_embedding(self, tokens, embeds):
if embeds is not None:
return embeds
else:
return self.tok_embeddings(tokens)
def init_weights(self, init_std=None):
self.rope.reset_parameters()
if hasattr(self, "norm"):
self.norm.reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
if hasattr(self, "tok_embeddings"):
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.pos_embeddings is not None:
nn.init.trunc_normal_(
self.pos_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(None, factor)
if hasattr(self, "output"):
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
self.token_embedding_projection.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.patch_embedding_projection is not None:
patch_emb_std = self.dim_patch_emb ** (-0.5)
nn.init.trunc_normal_(
self.patch_embedding_projection.weight,
mean=0.0,
std=patch_emb_std,
a=-3 * patch_emb_std,
b=3 * patch_emb_std,
)
if self.cross_attn_layers is not None:
for depth, layer in enumerate(self.cross_attn_layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(None, factor)
class LocalEncoder(LocalModelBase):
def __init__(self, args: LocalModelArgs):
super().__init__(args)
self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
self.cross_attn_nheads = args.cross_attn_nheads
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
if self.cross_attn_encoder:
self.cross_attn_layers = torch.nn.ModuleList()
layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
for _ in range(layers_to_add):
self.cross_attn_layers.append(
CrossAttention(
dim=self.dim,
head_dim=self.dim // self.cross_attn_nheads,
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
)
)
def apply_embedding(self, tokens, embeds):
if embeds is not None:
assert (
self.expects_hash_embeddings
), "Not expecting embeddings to be passed."
return embeds
else:
return self.tok_embeddings(tokens)
def forward(
self,
tokens: torch.Tensor,
embeds: Optional[torch.Tensor] = None,
patch_embeds: Optional[torch.Tensor] = None,
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
cross_mask: Optional[torch.Tensor] = None,
num_patches: Optional[int] = None,
patch_ids: Optional[torch.Tensor] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
""" """
bs, seqlen = tokens.shape
if mask is None:
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
# check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and (
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
):
patch_embeds = self.apply_cross_attention(
h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
)
h_residual = patch_embeds if self.cross_attn_encoder else None
return (h, h_residual), cache
def apply_cross_attention(
self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
):
# apply pooling and project
if self.cross_attn_init_by_pooling and patch_embeds is None:
patch_embeds = downsample(
h,
num_patches,
patch_ids=patch_ids,
downsampling_by_pooling=self.downsampling_by_pooling,
patch_size=self.patch_size,
)
if self.patch_embedding_projection is not None:
patch_embeds = self.patch_embedding_projection(patch_embeds)
patch_embeds = patch_embeds.reshape(
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
)
layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
patch_embeds_cross = self.cross_attn_layers[layer_idx](
x=patch_embeds,
kv=h,
mask=cross_mask,
)
return patch_embeds + patch_embeds_cross
class LocalDecoder(LocalModelBase):
def __init__(self, args: LocalModelArgs):
super().__init__(args)
# Model configuration flags
self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
self.cross_attn_nheads = args.cross_attn_nheads
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
if self.cross_attn_decoder:
self.cross_attn_layers = torch.nn.ModuleList()
layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
for _ in range(layers_to_add):
self.cross_attn_layers.append(
CrossAttention(
dim=self.dim,
head_dim=self.dim // self.cross_attn_nheads,
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
)
)
self.output = nn.Linear(
self.dim,
args.vocab_size,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
embeds: Optional[torch.Tensor],
patch_embeds: Optional[torch.Tensor] = None,
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
cross_mask: Optional[torch.Tensor] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
bs, seqlen = tokens.shape
assert embeds is not None, "Embeddings must be provided"
if mask is None:
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = embeds
if self.patch_embedding_projection is not None:
assert patch_embeds is not None, "Patch embeddings must be passed."
patch_embeds = self.patch_embedding_projection(patch_embeds)
if self.cross_attn_k is not None:
patch_embeds = patch_embeds.reshape(
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
)
if patch_embeds is not None and not self.cross_attn_decoder:
h = h + patch_embeds
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
if self.cross_attn_decoder and (
i == 0 or self.cross_attn_all_layers_decoder
):
# Use cross attention to extract info from patch_embeds into h
h_cross = self.cross_attn_layers[i](
x=h,
kv=patch_embeds,
mask=cross_mask,
)
h = h + h_cross
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
h_preds = self.output(h_preds)
h_preds = h_preds.float()
return h_preds, cache
class CrossAttention(nn.Module):
"""
CrossAttention block to attend to the encoder states from the decoder.
Rope is not supported.
"""
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
kv: torch.Tensor,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
) -> torch.Tensor:
# B S D
bsz, seq_len, _ = x.shape
_, slen_kv, _ = kv.shape
x_norm = self.cross_attn_norm_q(x)
kv = self.cross_attn_norm_kv(kv)
xq = self.wq(x_norm)
xk = self.wk(kv)
xv = self.wv(kv)
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
output = self.wo(output.reshape(output_shape))
return x + output
def init_weights(self, base_std: float, factor: float = 1.0):
std = base_std or (self.dim ** (-0.5)) / factor
nn.init.trunc_normal_(
self.wq.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wk.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wv.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
self.cross_attn_norm_q.reset_parameters()
self.cross_attn_norm_kv.reset_parameters()
class GlobalTransformer(BaseTransformer):
def __init__(self, args: BaseTransformerArgs):
super().__init__(args)
self.dropout = args.dropout
self.eos_id = args.eos_id
self.dim_token_emb = args.dim_token_emb
self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
self.token_embedding_projection = nn.Linear(
args.dim_token_emb,
args.dim,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
embeds: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
"""
Similar to BaseTransformer.forward, but with an additional embeds argument
and projection to the token space.
"""
bs, seqlen = tokens.shape
h = embeds
mask = (
mask
if mask is not None
else create_causal_mask(
seqlen,
self.attn_impl,
self.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
)
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
h = self.token_embedding_projection(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
return h, cache
def init_weights(self):
super().init_weights()
std = self.dim_token_emb ** (-0.5)
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
self.token_embedding_projection.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
local_encoder_args = LocalModelArgs(
# Updated args
dim=args.dim_local_encoder,
n_layers=args.n_layers_local_encoder,
n_heads=args.n_heads_local_encoder,
dim_token_emb=get_encoder_dim_token_emb(args),
dim_patch_emb=get_encoder_dim_patch_emb(args),
cross_attn_encoder=args.cross_attn_encoder,
cross_attn_decoder=False,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
)
return LocalEncoder(local_encoder_args)
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
# First deep copy the original args
local_decoder_args = LocalModelArgs(
dim=args.dim_local_decoder,
n_layers=args.n_layers_local_decoder,
n_heads=args.n_heads_local_decoder,
dim_token_emb=get_decoder_dim_token_emb(args),
dim_patch_emb=args.dim_global,
cross_attn_encoder=False,
cross_attn_decoder=args.cross_attn_decoder,
cross_attn_init_by_pooling=False, # states are already defined
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
)
return LocalDecoder(local_decoder_args)
class EmbeddingType(Enum):
HASH_TOK = auto()
NGRAM = auto()
def init_embeddings(
args,
embedding_type: EmbeddingType,
local_encoder_dim: int,
encoder_hash_byte_group_size: list = None,
):
if (
embedding_type == EmbeddingType.HASH_TOK
and args.encoder_hash_byte_group_size is None
):
return None
if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None:
return None
embeddings = []
if embedding_type == EmbeddingType.HASH_TOK:
emb_dim = local_encoder_dim
encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
for _ in range(args.encoder_hash_byte_group_nb_functions):
for _ in encoder_hash_byte_group_size:
embeddings.append(
nn.Embedding(
encoder_hash_byte_group_vocab,
emb_dim,
)
)
elif embedding_type == EmbeddingType.NGRAM:
encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str)
emb_dim = local_encoder_dim
OFFSET = 4 # This should be passed as parameter if it's variable
for ngram_vocab_size in encoder_ngram_to_size.values():
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
return nn.ModuleList(embeddings)
def compute_hash_embeddings(
local_encoder_tokens: torch.Tensor,
local_encoder,
encoder_hash_tok_embedding: nn.ModuleList,
encoder_hash_byte_group_nb_functions: int,
encoder_hash_byte_group_size: list,
encoder_hash_byte_group_vocab: int,
) -> torch.Tensor:
"""
Compute embeddings using hash token embeddings.
Args:
local_encoder_tokens: Input tokens tensor
local_encoder: Encoder object with tok_embeddings method
encoder_hash_tok_embedding: ModuleList of hash token embeddings
encoder_hash_byte_group_nb_functions: Number of hash functions
encoder_hash_byte_group_size: List of byte group sizes
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
Returns:
torch.Tensor: Combined embeddings
"""
if encoder_hash_tok_embedding is None:
return None
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
i = 0
for func_nb in range(encoder_hash_byte_group_nb_functions):
for byte_group_size in encoder_hash_byte_group_size:
hash_ids = byte_group_hash_function(
local_encoder_tokens,
byte_group_size,
hash_func_nb=func_nb,
max_hash=encoder_hash_byte_group_vocab,
)
hash_tok_embedding = encoder_hash_tok_embedding[i]
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
i += 1
assert i == len(encoder_hash_tok_embedding)
return local_encoder_embeds
class ByteLatentTransformer(
nn.Module,
SequenceModelWithOutput,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
# paper_url="https://arxiv.org/abs/2412.09871",
pipeline_tag="text-generation",
license="other",
license_name="fair-noncommercial-research-license",
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
coders={
ByteLatentTransformerArgs: (
lambda x: {"args": x.model_dump()},
lambda data: ByteLatentTransformerArgs(**data),
)
},
):
"""
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
improved performance and inference efficiency.
"""
def __init__(self, args: ByteLatentTransformerArgs):
super().__init__()
# General configuration
self.weight_tying = args.weight_tying
self.patch_size = args.patch_size
self.patching_mode = args.patching_mode
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
BOE_ID,
BOS_ID,
PAD_ID,
EOS_ID,
)
self.downsampling_by_pooling = args.downsampling_by_pooling
self.patching_threshold = args.patching_threshold
self.dim = args.dim
self.init_base_std = args.init_base_std
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
# Cross attention configuration
self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_k = args.cross_attn_k
self.cross_attn_window_encoder = args.cross_attn_window_encoder
self.cross_attn_window_decoder = args.cross_attn_window_decoder
self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention
# Encoder hash configuration
self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size
self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
self.encoder_hash_byte_group_nb_functions = (
args.encoder_hash_byte_group_nb_functions
)
# ByteLatent modules
self.local_encoder = create_local_encoder(args)
self.global_transformer = create_global_transformer(args)
self.local_decoder = create_local_decoder(args)
self.encoder_hash_tok_embedding = init_embeddings(
args,
EmbeddingType.HASH_TOK,
local_encoder_dim=self.local_encoder.dim,
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
)
self.encoder_ngram_embedding = init_embeddings(
args,
EmbeddingType.NGRAM,
local_encoder_dim=self.local_encoder.dim,
encoder_hash_byte_group_size=None,
)
# Encoder ngram embedding tables
self.encoder_ngram_embedding = None
if args.encoder_enable_byte_ngrams:
self.encoder_ngram_embedding = nn.ModuleList()
assert args.ngram_vocab_sizes is not None
self.encoder_ngram_to_size = parse_ngram_to_size(
args.encoder_ngram_to_size_str
)
ngram_emb_dim = self.local_encoder.dim
for ngram_vocab_size in self.encoder_ngram_to_size.values():
self.encoder_ngram_embedding.append(
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
)
# Output layer
assert args.vocab_size > 0, "vocab_size must be greater than 0"
# Patcher module
if args.patch_in_forward:
self.patcher = Patcher(
PatcherArgs(
patch_size=args.patch_size,
patching_mode=args.patching_mode,
patching_threshold=args.patching_threshold,
patching_threshold_add=args.patching_threshold_add,
monotonicity=args.monotonicity,
max_patch_length=args.max_patch_length,
)
)
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def get_output_seq_len(self):
return self.max_seqlen
def forward(
self,
tokens: torch.Tensor,
patch_lengths: Optional[torch.Tensor] = None,
ngram_ids: Optional[torch.Tensor] = None,
):
# Ensure ngram_ids is either a tensor or None
assert (
isinstance(ngram_ids, torch.Tensor) or ngram_ids is None
), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}"
bs, N = tokens.shape # Batch size and sequence length
# Get megabyte inputs
nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1)
local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
tokens=tokens,
enforce_patch_size_multiple=False,
nb_boe=nb_boe,
patch_size=self.patch_size,
boe_id=self.boe_id,
)
# Patching
if patch_lengths is None:
assert (
getattr(self, "patcher", None) is not None
), "Patcher not defined and no patch_lengths passed."
patch_lengths, tok_scores = self.patcher.patch(
local_encoder_tokens,
include_next_token=True,
threshold=self.patcher.threshold,
)
else:
if nb_boe > 0:
patch_lengths[:, 0] += nb_boe
assert torch.min(patch_lengths) >= 0
# Generate patch IDs from patch_lengths
patch_ids = patch_ids_from_lengths(
patch_lengths, local_encoder_tokens.shape[-1]
)
assert torch.max(patch_ids) + 1 <= torch.max(
(patch_lengths != 0).sum(dim=-1)
), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
cross_attn_mask_enc = None
# Cross-attention encoder
if self.cross_attn_encoder:
cross_attn_mask_enc = cross_attn_mask(
patch_ids,
patch_lengths,
N,
patches_as_queries=True,
cross_attn_k=self.cross_attn_k,
window=self.cross_attn_window_encoder,
block_mask=self.cross_attn_use_flex_attention,
)
# Hashing and embedding
local_encoder_embeds = compute_hash_embeddings(
local_encoder_tokens=local_encoder_tokens,
local_encoder=self.local_encoder,
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
)
# N-gram table embeddings
if self.encoder_ngram_embedding is not None:
assert ngram_ids is not None, "ngram_ids must be provided"
if local_encoder_embeds is None:
local_encoder_embeds = self.local_encoder.tok_embeddings(
local_encoder_tokens
)
assert len(ngram_ids) == len(
self.encoder_ngram_embedding
), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
for i in range(ngram_ids.shape[0]):
ngram_embedding = self.encoder_ngram_embedding[i]
ngram_embeds = ngram_embedding(ngram_ids[i])
assert (
local_encoder_embeds.shape == ngram_embeds.shape
), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
local_encoder_embeds = local_encoder_embeds + ngram_embeds
# Local encoder
(h_encoder, h_cross), cache_encoder = self.local_encoder(
tokens=local_encoder_tokens,
embeds=local_encoder_embeds,
patch_embeds=None,
cross_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
)
# Downsampling
if not self.cross_attn_encoder:
assert (
patch_ids.shape[1] == h_encoder.shape[1]
), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
h = downsample(
h_encoder,
patch_lengths.shape[1],
patch_lengths,
patch_ids,
downsampling_by_pooling=self.downsampling_by_pooling,
patch_size=self.patch_size,
)
else:
# Reshape h_cross
h = h_cross.view(bs, patch_lengths.shape[1], -1)
# Global transformer
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
rows, cols = torch.where(local_encoder_tokens == self.eos_id)
eos_patch_ids = patch_ids[rows, cols]
global_tokens[rows, eos_patch_ids] = self.eos_id
h, _ = self.global_transformer(
embeds=h,
tokens=global_tokens,
)
# Unpatching
dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
# Generate decoder patch IDs
decoder_patch_ids = decoder_patch_ids_from_lengths(
patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
)
assert (
torch.max(decoder_patch_ids) + 1 <= h.shape[1]
), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
assert (
decoder_patch_ids.shape[1] == dec_embeds.shape[1]
), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
# Cross-attention decoder
if not self.cross_attn_decoder:
h = torch.gather(
h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
)
cross_attn_mask_dec = None
assert local_decoder_tokens.shape == h.shape[:-1]
else:
cross_attn_mask_dec = cross_attn_mask(
decoder_patch_ids,
patch_lengths,
N,
patches_as_queries=False,
cross_attn_k=self.cross_attn_k,
window=self.cross_attn_window_decoder,
block_mask=self.cross_attn_use_flex_attention,
)
# Local decoder
output, _ = self.local_decoder(
embeds=dec_embeds,
patch_embeds=h,
tokens=local_decoder_tokens,
cross_mask=cross_attn_mask_dec,
)
return output
def init_weights(self):
self.local_encoder.init_weights()
self.global_transformer.init_weights()
self.local_decoder.init_weights()
emb_std = self.local_encoder.dim ** (-0.5)
for emb in self.encoder_hash_tok_embedding:
nn.init.trunc_normal_(
emb.weight,
mean=0.0,
std=emb_std,
a=-3 * emb_std,
b=3 * emb_std,
)