blt/bytelatent/model/blt.py

1083 lines
39 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum, auto
from typing import Any, Optional
import torch
from pydantic import ConfigDict, model_validator
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from typing_extensions import Self
from bytelatent.base_transformer import (
BaseTransformerArgs,
InitStdFactor,
TransformerBlock,
)
from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.latent_transformer import GlobalTransformer
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
2024-12-12 23:32:30 +00:00
from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
def get_num_flop_per_token(
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
) -> int:
return 6 * num_non_embed_params + attention_flops_per_token(
n_layers, seq_len, dim, True
)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def setattrs(_self, **kwargs):
for k, v in kwargs.items():
setattr(_self, k, v)
def get_encoder_dim_token_emb(args):
if args.dim_token is not None:
dim_token_emb = args.dim_token
elif args.use_local_encoder_transformer:
dim_token_emb = args.dim_local_encoder
else:
dim_token_emb = args.dim_global // args.patch_size
return dim_token_emb
def get_encoder_dim_patch_emb(args):
dim_patch_emb = None
if args.cross_attn_encoder:
if args.cross_attn_init_by_pooling:
dim_patch_emb = args.dim_local_encoder
else:
dim_patch_emb = args.dim_global
return dim_patch_emb
def get_global_dim_patch_emb(args):
dim_token_emb = get_encoder_dim_token_emb(args)
if args.cross_attn_encoder:
dim_patch_emb = dim_token_emb * args.cross_attn_k
elif (
args.downsampling_by_pooling is None
or not args.downsampling_by_pooling
or len(args.downsampling_by_pooling) == 0
):
dim_patch_emb = dim_token_emb * args.patch_size
else:
dim_patch_emb = dim_token_emb * sum(
[
pooling in args.downsampling_by_pooling
for pooling in ["avg", "min", "max"]
]
)
return dim_patch_emb
def get_decoder_dim_token_emb(args):
if args.share_encoder_decoder_emb:
dim_token_emb = get_encoder_dim_token_emb(args)
elif args.dim_token is not None:
dim_token_emb = args.dim_token
else:
dim_token_emb = args.dim_local_decoder
return dim_token_emb
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
if ngram_to_size_str is None:
return None
ngram_to_size = {}
for entry in ngram_to_size_str.split(","):
ngram, size = entry.split(":")
ngram = int(ngram)
size = int(size)
ngram_to_size[ngram] = size
return ngram_to_size
def fill_tokens(tokens, patch_size, fill_id):
batch_size, seq_len = tokens.shape
if seq_len % patch_size == 0:
return tokens
else:
remaining = patch_size - seq_len % patch_size
final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
return torch.cat((tokens, final_padding), dim=1)
def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len):
first_patch_length = patch_lengths[0, 0]
assert torch.all(
first_patch_length == patch_lengths[:, 0]
), "first patch should always be the same size (1 for dynamic, patch_size for static)."
assert (
first_patch_length - nb_boe == 1
), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})"
# Remove first patch from patch_ids for local decoder inputs and shift the last patch.
# decoder_patch_lengths = patch_lengths[:, 1:].clone()
# decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
decoder_patch_lengths = patch_lengths[:, 1:]
assert (
decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]
== patch_lengths.sum()
), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}"
assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}"
decoder_patch_ids = patch_ids_from_lengths(
patch_lengths=decoder_patch_lengths, seq_len=seq_len
)
return decoder_patch_ids
primes = [
1000000007,
5915587277,
1500450271,
3267000013,
5754853343,
4093082899,
9576890767,
3628273133,
2860486313,
5463458053,
3367900313,
]
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
return torch.sum(t * prime_powers, dim=-1)
def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2):
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64)
prime_powers = torch.stack([prime**i for i in range(group_size)])
def rolling_polynomial_hash_fn(t):
return torch.sum(t * prime_powers, dim=-1)
return rolling_polynomial_hash_fn
def byte_group_hash_function(
x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
):
"""
Returns a hash of the input x and maps it to a value in the range [0, max_hash].
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
Note: max hash can make a big difference on the number of collisions.
"""
with torch.no_grad():
bs, seq_len = x.shape
# x_numpy = x.numpy()
# hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
# for i in range(bs):
# for j in range(seq_len):
# start = max(j, j-group_size+1)
# end = j+1
# hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
x = torch.cat([prefix, x], dim=1)
windows = x.unfold(1, group_size, 1)
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
hashes = rolling_polynomial_hash(windows, hash_func_nb)
hash_values_range = hashes % max_hash
hash_values_range.requires_grad = False
return hash_values_range
def create_patch_mask_from_ids(
patch_ids, num_patches, window=None, patches_as_queries=False
):
"""
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
is True if the patch id at position (i, j) is less than or equal to k.
Args:
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
num_patches (int): Total number of patches.
window (int): If not None, only considers patches within a window of size window.
patches_as_queries (bool): If True, the patches are used as queries
Returns:
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
"""
bs, seq_len = patch_ids.shape
if not patches_as_queries:
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
kv_ids = (
torch.arange(num_patches, device=patch_ids.device)
.unsqueeze(0)
.unsqueeze(0)
.expand(bs, seq_len, num_patches)
)
else:
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
q_ids = (
torch.arange(num_patches, device=patch_ids.device)
.unsqueeze(0)
.unsqueeze(-1)
.expand(bs, num_patches, seq_len)
)
if window is None:
mask = q_ids == kv_ids
else:
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
return mask
def cross_attn_mask(
patch_ids,
patch_lengths,
N,
patches_as_queries=False,
cross_attn_k=1,
window=None,
block_mask=True,
):
bs = patch_ids.shape[0]
with torch.no_grad():
# Create the patch mask
cross_mask = create_patch_mask_from_ids(
patch_ids,
patch_lengths.shape[1],
window=window,
patches_as_queries=patches_as_queries,
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
assert cross_mask.shape == (
bs,
q_len,
kv_len,
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
if block_mask:
def patch_mask(b, h, q_idx, kv_idx):
return cross_mask[b, q_idx, kv_idx]
block_mask = create_block_mask(
patch_mask,
B=bs,
H=None,
Q_LEN=q_len,
KV_LEN=kv_len,
_compile=True,
)
return block_mask
else:
return torch.where(
cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
).unsqueeze(
1
) # [bs, 1, q_len, kv_len]
def get_blt_input(
tokens: torch.Tensor,
enforce_patch_size_multiple: bool,
nb_boe: torch.Tensor,
patch_size: int,
boe_id: int,
):
"""
This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
tokens respectively.
Consider the input and target sequences:
X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
with patch_size=4
Note 1: that there will be no special tokens introduced at the patch level.
Note 2: X_e needs to be trimmed to be passed to Global
Current without boe:
X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
--> lag fix:
X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
Dynamic (current):
X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
entropy patching:
input: 7, bos, 9, 10
pred (high entropy): eos, 8, 10, eos
X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
--> lag fix no boe (force single byte first patch):
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
input: 4, 7, bos, 9, 10
pred (high entropy): 5, eos, 8, 10, eos
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
Handle the last byte properly.
patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
bpe delim
X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
Note 1: that there will be no special tokens introduced at the patch level.
Note 2: X_e needs to be trimmed to be passed to Global
"""
batch_size, seq_len = tokens.shape
local_encoder_tokens = tokens
local_decoder_tokens = tokens
if nb_boe > 0:
padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
# global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
# create global tokens, contains boe tokens and eos
# padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
# patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
# global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
# global_tokens += global_tokens.eq(0).int() * boe_id
# TODO: fix this when we want to use block causal in the global.
if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
return local_encoder_tokens, None, local_decoder_tokens
def patch_ids_from_lengths(patch_lengths, seq_len):
bs, num_patches = patch_lengths.shape
# Create a tensor of cumulative sums of the patch lengths
cum_d = torch.cat(
[
torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
patch_lengths.cumsum(dim=-1),
],
dim=-1,
)
patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum(
dim=-2
) - 1
assert not (
torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0
), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0"
return patch_ids
class ByteLatentTransformerArgs(BaseTransformerArgs):
# Basic model configuration
seed: int = 42
vocab_size: int = -1
dim: int = 512
n_layers: int = 8
n_heads: int = 8
# TODO: What is the purpose of this parameter?
weight_tying: bool = False
# Architecture and dimensions
dim_token: int = 256
dim_global: int = 512
dim_local_decoder: int = 512
dim_local_encoder: int = 512
n_layers_global: int = 8
n_layers_local_decoder: int = 8
n_layers_local_encoder: int = 8
# Tokenization and patching
tokenization_mode: str = "bpe"
patch_size: float | None = None
patching_mode: str | None = None
patching_threshold: float | None = None
patching_threshold_add: float | None = None
monotonicity: bool = False
patching_batch_size: int = 1
patching_device: str = "cuda"
data_loader_patching: bool = False
max_patch_length: int | None = None
# Encoder/Decoder configuration
tie_local_encoder_decoder_logits: bool = False
use_local_encoder_transformer: bool = False
encoder_lm_loss: bool = False
max_encoder_seq_length: int | None = None
pad_to_max_length: bool = False
encoder_enable_byte_ngrams: bool = False
encoder_enable_byte_group_hash: bool = False
ngram_vocab_sizes: int | None = None
# Cross attention configurations
cross_attn_encoder: bool = False
cross_attn_decoder: bool = False
cross_attn_window_encoder: int | None = None
cross_attn_window_decoder: int | None = None
cross_attn_k: int | None = None
cross_attn_nheads: int | None = None
cross_attn_all_layers_decoder: bool = False
cross_attn_all_layers_encoder: bool = False
cross_attn_use_flex_attention: bool = True
cross_attn_init_by_pooling: bool = False
# Encoder hash configurations
encoder_hash_byte_group_size: Any | None = None
encoder_hash_byte_group_vocab: int = 30000
encoder_hash_byte_group_nb_functions: int = 3
# Model behavior and optimization
log_patch_lengths: bool = False
non_linearity: str = "swiglu"
use_rope: bool = True
recompute_fc1_out: bool = False
recompute_fc3_out: bool = False
recompute_attn: bool = True
custom_bwd: bool = False
layer_ckpt: str = "all"
# Initialization and attention
init_use_gaussian: bool = True
init_use_depth: str = "current"
attn_bias_type: str = "causal"
alpha_depth: str = "disabled"
max_length: int = 2048
# Norm configuration
norm_eps: float = 1e-5
norm_affine: bool = True
pre_norm: bool = True
norm_type: str = "rmsnorm"
# Additional configurations
multiple_of: int = 256
ffn_dim_multiplier: float = 1.0
dropout: float = 0
output_size: int = -1
# Additional parameters from ModelArgs
architecture: str = "vanilla"
share_encoder_decoder_emb: bool = True
global_local_decoder_residual_layer: str | None = None
tokenize_with_bpe_delimiter: bool = False
patching_thresholds_str: str | None = None
tie_local_encoder_decoder: bool = False
encoder_preds_low_entropy_toks: float | None = None
encoder_preds_random_toks: float | None = None
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
encoder_ngram_table_dir: str | None = None
encoder_ngram_to_size_str: str | None = None
# Model architecture params
entropy_model_checkpoint_dir: str | None = None
entropy_model_is_ngram_model: bool = False
downsampling_by_pooling: str | None = None
n_heads_global: int = 8
n_heads_local_decoder: int = 8
n_heads_local_encoder: int = 8
n_kv_heads: int | None = None
n_kv_heads_global: int | None = None
conv_kernel_size: int | None = None
local_attention_window_len: int | None = None
# Performance optimization
sequence_parallel: bool = False
loss_parallel: bool = False
fuse_sequence_parallel: bool = False
use_fsdp: bool = True
attn_to_keep: str = "all"
# RoPE parameters
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
# Parameter mixing
pm_size: int = 0
# Logging
full_logging_n_layers: int = 4
@model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self:
if (
self.encoder_hash_byte_group_size is not None
and type(self.encoder_hash_byte_group_size) == str
):
self.encoder_hash_byte_group_size = [
int(x)
for x in self.encoder_hash_byte_group_size.split(",")
if len(x) > 0
]
return self
class GlobalTransformerArgs(ByteLatentTransformerArgs):
# Global encoder specific dimensions
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with global encoder specific values
self.dim = self.dim_global
self.n_layers = self.n_layers_global
self.n_heads = self.n_heads_global
self.n_kv_heads = self.n_kv_heads_global
self.local_attention_window_len = None
self.cross_attn_encoder = False
self.cross_attn_decoder = False
class LocalDecoderArgs(ByteLatentTransformerArgs):
# Local decoder specific dimensions
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with local decoder specific values
self.dim = self.dim_local_decoder
self.n_layers = self.n_layers_local_decoder
self.n_heads = self.n_heads_local_decoder
self.cross_attn_encoder = False
self.cross_attn_init_by_pooling = False
self.attn_bias_type = "local_block_causal"
def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer:
global_args = args.model_copy(
deep=True,
update=dict(
dim=args.dim_global,
n_layers=args.n_layers_global,
n_heads=args.n_heads_global,
n_kv_heads=args.n_kv_heads_global,
local_attention_window_len=None,
dim_token_emb=get_global_dim_patch_emb(args),
dim_patch_emb=None,
cross_attn_encoder=False,
cross_attn_decoder=False,
),
)
return GlobalTransformer(global_args)
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
local_encoder_args = LocalModelArgs(
# Updated args
dim=args.dim_local_encoder,
n_layers=args.n_layers_local_encoder,
n_heads=args.n_heads_local_encoder,
dim_token_emb=get_encoder_dim_token_emb(args),
dim_patch_emb=get_encoder_dim_patch_emb(args),
cross_attn_encoder=args.cross_attn_encoder,
cross_attn_decoder=False,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
2024-12-12 23:32:30 +00:00
)
return LocalEncoder(local_encoder_args)
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
# First deep copy the original args
local_decoder_args = LocalModelArgs(
dim=args.dim_local_decoder,
n_layers=args.n_layers_local_decoder,
n_heads=args.n_heads_local_decoder,
dim_token_emb=get_decoder_dim_token_emb(args),
dim_patch_emb=args.dim_global,
cross_attn_encoder=False,
cross_attn_decoder=args.cross_attn_decoder,
cross_attn_init_by_pooling=False, # states are already defined
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
2024-12-12 23:32:30 +00:00
)
return LocalDecoder(local_decoder_args)
class EmbeddingType(Enum):
HASH_TOK = auto()
NGRAM = auto()
def init_embeddings(
args,
embedding_type: EmbeddingType,
local_encoder_dim: int,
encoder_hash_byte_group_size: list = None,
):
if (
embedding_type == EmbeddingType.HASH_TOK
and args.encoder_hash_byte_group_size is None
):
return None
if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None:
return None
embeddings = []
if embedding_type == EmbeddingType.HASH_TOK:
emb_dim = local_encoder_dim
encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
for _ in range(args.encoder_hash_byte_group_nb_functions):
for _ in encoder_hash_byte_group_size:
embeddings.append(
nn.Embedding(
encoder_hash_byte_group_vocab,
emb_dim,
)
)
elif embedding_type == EmbeddingType.NGRAM:
encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str)
emb_dim = local_encoder_dim
OFFSET = 4 # This should be passed as parameter if it's variable
for ngram_vocab_size in encoder_ngram_to_size.values():
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
return nn.ModuleList(embeddings)
def compute_hash_embeddings(
local_encoder_tokens: torch.Tensor,
local_encoder,
encoder_hash_tok_embedding: nn.ModuleList,
encoder_hash_byte_group_nb_functions: int,
encoder_hash_byte_group_size: list,
encoder_hash_byte_group_vocab: int,
) -> torch.Tensor:
"""
Compute embeddings using hash token embeddings.
Args:
local_encoder_tokens: Input tokens tensor
local_encoder: Encoder object with tok_embeddings method
encoder_hash_tok_embedding: ModuleList of hash token embeddings
encoder_hash_byte_group_nb_functions: Number of hash functions
encoder_hash_byte_group_size: List of byte group sizes
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
Returns:
torch.Tensor: Combined embeddings
"""
if encoder_hash_tok_embedding is None:
return None
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
i = 0
for func_nb in range(encoder_hash_byte_group_nb_functions):
for byte_group_size in encoder_hash_byte_group_size:
hash_ids = byte_group_hash_function(
local_encoder_tokens,
byte_group_size,
hash_func_nb=func_nb,
max_hash=encoder_hash_byte_group_vocab,
)
hash_tok_embedding = encoder_hash_tok_embedding[i]
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
i += 1
assert i == len(encoder_hash_tok_embedding)
return local_encoder_embeds
class ByteLatentTransformer(nn.Module):
"""
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
improved performance and inference efficiency.
"""
def __init__(self, args: ByteLatentTransformerArgs):
super().__init__()
# General configuration
self.weight_tying = args.weight_tying
self.patch_size = args.patch_size
self.patching_mode = args.patching_mode
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
BOE_ID,
BOS_ID,
PAD_ID,
EOS_ID,
)
self.downsampling_by_pooling = args.downsampling_by_pooling
self.patching_threshold = args.patching_threshold
self.dim = args.dim
self.init_base_std = args.init_base_std
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
# Cross attention configuration
self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_k = args.cross_attn_k
self.cross_attn_window_encoder = args.cross_attn_window_encoder
self.cross_attn_window_decoder = args.cross_attn_window_decoder
self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention
# Encoder hash configuration
self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size
self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
self.encoder_hash_byte_group_nb_functions = (
args.encoder_hash_byte_group_nb_functions
)
# ByteLatent modules
self.local_encoder = create_local_encoder(args)
self.global_transformer = create_global_transformer(args)
self.local_decoder = create_local_decoder(args)
self.encoder_hash_tok_embedding = init_embeddings(
args,
EmbeddingType.HASH_TOK,
local_encoder_dim=self.local_encoder.dim,
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
)
self.encoder_ngram_embedding = init_embeddings(
args,
EmbeddingType.NGRAM,
local_encoder_dim=self.local_encoder.dim,
encoder_hash_byte_group_size=None,
)
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
# Transformer layers
self.layers = nn.ModuleList(
[TransformerBlock(args) for _ in range(args.n_layers)]
)
# Encoder ngram embedding tables
self.encoder_ngram_embedding = None
if args.encoder_enable_byte_ngrams:
self.encoder_ngram_embedding = nn.ModuleList()
assert args.ngram_vocab_sizes is not None
self.encoder_ngram_to_size = parse_ngram_to_size(
args.encoder_ngram_to_size_str
)
ngram_emb_dim = self.local_encoder.dim
for ngram_vocab_size in self.encoder_ngram_to_size.values():
self.encoder_ngram_embedding.append(
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
)
# Output layer
assert args.vocab_size > 0, "vocab_size must be greater than 0"
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
if args.weight_tying:
self.output.weight = self.tok_embeddings.weight
# Patcher module
if not args.data_loader_patching:
self.patcher = Patcher(
PatcherArgs(
patch_size=args.patch_size,
patching_mode=args.patching_mode,
patching_threshold=args.patching_threshold,
patching_threshold_add=args.patching_threshold_add,
monotonicity=args.monotonicity,
max_patch_length=args.max_patch_length,
)
)
def forward(
self,
tokens: torch.Tensor,
patch_lengths: Optional[torch.Tensor] = None,
ngram_ids: Optional[torch.Tensor] = None,
):
# Ensure ngram_ids is either a tensor or None
assert (
isinstance(ngram_ids, torch.Tensor) or ngram_ids is None
), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}"
bs, N = tokens.shape # Batch size and sequence length
# Get megabyte inputs
nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1)
local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
tokens=tokens,
enforce_patch_size_multiple=False,
nb_boe=nb_boe,
patch_size=self.patch_size,
boe_id=self.boe_id,
)
# Patching
if patch_lengths is None:
assert (
getattr(self, "patcher", None) is not None
), "Patcher not defined and no patch_lengths passed."
patch_lengths, tok_scores = self.patcher.patch(
local_encoder_tokens,
include_next_token=True,
threshold=self.patcher.threshold,
)
else:
if nb_boe > 0:
patch_lengths[:, 0] += nb_boe
assert torch.min(patch_lengths) >= 0
# Generate patch IDs from patch_lengths
patch_ids = patch_ids_from_lengths(
patch_lengths, local_encoder_tokens.shape[-1]
)
assert torch.max(patch_ids) + 1 <= torch.max(
(patch_lengths != 0).sum(dim=-1)
), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
cross_attn_mask_enc = None
# Cross-attention encoder
if self.cross_attn_encoder:
cross_attn_mask_enc = cross_attn_mask(
patch_ids,
patch_lengths,
N,
patches_as_queries=True,
cross_attn_k=self.cross_attn_k,
window=self.cross_attn_window_encoder,
block_mask=self.cross_attn_use_flex_attention,
)
# Hashing and embedding
local_encoder_embeds = compute_hash_embeddings(
local_encoder_tokens=local_encoder_tokens,
local_encoder=self.local_encoder,
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
)
# N-gram table embeddings
if self.encoder_ngram_embedding is not None:
assert ngram_ids is not None, "ngram_ids must be provided"
if local_encoder_embeds is None:
local_encoder_embeds = self.local_encoder.tok_embeddings(
local_encoder_tokens
)
assert len(ngram_ids) == len(
self.encoder_ngram_embedding
), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
for i in range(ngram_ids.shape[0]):
ngram_embedding = self.encoder_ngram_embedding[i]
ngram_embeds = ngram_embedding(ngram_ids[i])
assert (
local_encoder_embeds.shape == ngram_embeds.shape
), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
local_encoder_embeds = local_encoder_embeds + ngram_embeds
# Local encoder
h_cross = None
(h_encoder, h_cross), cache_encoder = self.local_encoder(
tokens=local_encoder_tokens,
embeds=local_encoder_embeds,
patch_embeds=h_cross if self.cross_attn_encoder else None,
cross_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
)
# Downsampling
if not self.cross_attn_encoder:
assert (
patch_ids.shape[1] == h_encoder.shape[1]
), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
h = downsample(
h_encoder,
patch_lengths.shape[1],
patch_lengths,
patch_ids,
downsampling_by_pooling=self.downsampling_by_pooling,
patch_size=self.patch_size,
)
else:
# Reshape h_cross
h = h_cross.view(bs, patch_lengths.shape[1], -1)
# Global transformer
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
rows, cols = torch.where(local_encoder_tokens == self.eos_id)
eos_patch_ids = patch_ids[rows, cols]
global_tokens[rows, eos_patch_ids] = self.eos_id
h, _ = self.global_transformer(
embeds=h,
tokens=global_tokens,
)
# Unpatching
dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
# Generate decoder patch IDs
decoder_patch_ids = decoder_patch_ids_from_lengths(
patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
)
assert (
torch.max(decoder_patch_ids) + 1 <= h.shape[1]
), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
assert (
decoder_patch_ids.shape[1] == dec_embeds.shape[1]
), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
# Cross-attention decoder
if not self.cross_attn_decoder:
h = torch.gather(
h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
)
cross_attn_mask_dec = None
assert local_decoder_tokens.shape == h.shape[:-1]
else:
cross_attn_mask_dec = cross_attn_mask(
decoder_patch_ids,
patch_lengths,
N,
patches_as_queries=False,
cross_attn_k=self.cross_attn_k,
window=self.cross_attn_window_decoder,
block_mask=self.cross_attn_use_flex_attention,
)
# Local decoder
output, _ = self.local_decoder(
embeds=dec_embeds,
patch_embeds=h,
tokens=local_decoder_tokens,
cross_mask=cross_attn_mask_dec,
)
return output
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
init_std = init_std or (self.dim ** (-0.5))
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
def init_weights(self):
self.reset_parameters()
self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(self.init_base_std, factor)
self.local_decoder.init_weights(self.init_base_std)
self.global_transformer.init_weights(self.init_base_std)
self.local_encoder.init_weights(self.init_base_std)
for emb in self.encoder_hash_tok_embedding:
nn.init.trunc_normal_(
emb.weight,
mean=0.0,
std=self.init_base_std,
a=-3 * self.init_base_std,
b=3 * self.init_base_std,
)