2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum , auto
from typing import Any , Optional
import torch
from pydantic import ConfigDict , model_validator
from torch import nn
from torch . nn . attention . flex_attention import create_block_mask
from typing_extensions import Self
from bytelatent . base_transformer import (
BaseTransformerArgs ,
InitStdFactor ,
TransformerBlock ,
)
from bytelatent . data . patcher import Patcher , PatcherArgs
2025-01-17 22:23:01 +00:00
from bytelatent . model . latent_transformer import GlobalTransformer
from bytelatent . model . local_models import LocalDecoder , LocalEncoder , LocalModelArgs
2024-12-12 23:32:30 +00:00
from bytelatent . model . utils import downsample
from bytelatent . tokenizers . constants import BOE_ID , BOS_ID , EOS_ID , OFFSET , PAD_ID
def attention_flops_per_token ( n_layers , seq_len , dim , causal ) :
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
return 3.5 * ( 4 * n_layers * seq_len * dim / / ( 2 if causal else 1 ) )
def get_num_flop_per_token (
num_non_embed_params : int , n_layers : int , dim : int , seq_len : int
) - > int :
return 6 * num_non_embed_params + attention_flops_per_token (
n_layers , seq_len , dim , True
)
def causal_mask ( b , h , q_idx , kv_idx ) :
return q_idx > = kv_idx
def setattrs ( _self , * * kwargs ) :
for k , v in kwargs . items ( ) :
setattr ( _self , k , v )
def get_encoder_dim_token_emb ( args ) :
if args . dim_token is not None :
dim_token_emb = args . dim_token
elif args . use_local_encoder_transformer :
dim_token_emb = args . dim_local_encoder
else :
dim_token_emb = args . dim_global / / args . patch_size
return dim_token_emb
def get_encoder_dim_patch_emb ( args ) :
dim_patch_emb = None
if args . cross_attn_encoder :
if args . cross_attn_init_by_pooling :
dim_patch_emb = args . dim_local_encoder
else :
dim_patch_emb = args . dim_global
return dim_patch_emb
def get_global_dim_patch_emb ( args ) :
dim_token_emb = get_encoder_dim_token_emb ( args )
if args . cross_attn_encoder :
dim_patch_emb = dim_token_emb * args . cross_attn_k
elif (
args . downsampling_by_pooling is None
or not args . downsampling_by_pooling
or len ( args . downsampling_by_pooling ) == 0
) :
dim_patch_emb = dim_token_emb * args . patch_size
else :
dim_patch_emb = dim_token_emb * sum (
[
pooling in args . downsampling_by_pooling
for pooling in [ " avg " , " min " , " max " ]
]
)
return dim_patch_emb
def get_decoder_dim_token_emb ( args ) :
if args . share_encoder_decoder_emb :
dim_token_emb = get_encoder_dim_token_emb ( args )
elif args . dim_token is not None :
dim_token_emb = args . dim_token
else :
dim_token_emb = args . dim_local_decoder
return dim_token_emb
def parse_ngram_to_size ( ngram_to_size_str : str | None ) - > dict [ int , int ] :
if ngram_to_size_str is None :
return None
ngram_to_size = { }
for entry in ngram_to_size_str . split ( " , " ) :
ngram , size = entry . split ( " : " )
ngram = int ( ngram )
size = int ( size )
ngram_to_size [ ngram ] = size
return ngram_to_size
def fill_tokens ( tokens , patch_size , fill_id ) :
batch_size , seq_len = tokens . shape
if seq_len % patch_size == 0 :
return tokens
else :
remaining = patch_size - seq_len % patch_size
final_padding = tokens . new ( batch_size , remaining ) . fill_ ( fill_id )
return torch . cat ( ( tokens , final_padding ) , dim = 1 )
def decoder_patch_ids_from_lengths ( patch_lengths , nb_boe , seq_len ) :
first_patch_length = patch_lengths [ 0 , 0 ]
assert torch . all (
first_patch_length == patch_lengths [ : , 0 ]
) , " first patch should always be the same size (1 for dynamic, patch_size for static). "
assert (
first_patch_length - nb_boe == 1
) , f " First patch (patch length: { first_patch_length } ) should have one non-boe token (boe toks: { nb_boe } ) "
# Remove first patch from patch_ids for local decoder inputs and shift the last patch.
# decoder_patch_lengths = patch_lengths[:, 1:].clone()
# decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
decoder_patch_lengths = patch_lengths [ : , 1 : ]
assert (
decoder_patch_lengths . sum ( ) + ( nb_boe + 1 ) * patch_lengths . shape [ 0 ]
== patch_lengths . sum ( )
) , f " { decoder_patch_lengths . sum ( ) + ( nb_boe + 1 ) * patch_lengths . shape [ 0 ] } != { patch_lengths . sum ( ) } "
assert torch . all ( decoder_patch_lengths > = 0 ) , f " { decoder_patch_lengths } "
decoder_patch_ids = patch_ids_from_lengths (
patch_lengths = decoder_patch_lengths , seq_len = seq_len
)
return decoder_patch_ids
primes = [
1000000007 ,
5915587277 ,
1500450271 ,
3267000013 ,
5754853343 ,
4093082899 ,
9576890767 ,
3628273133 ,
2860486313 ,
5463458053 ,
3367900313 ,
]
def rolling_polynomial_hash ( t , hash_func_nb : int = 0 ) :
prime = torch . tensor ( primes [ hash_func_nb ] , dtype = torch . int64 , device = t . device )
prime_powers = torch . stack ( [ prime * * i for i in range ( t . shape [ - 1 ] ) ] )
return torch . sum ( t * prime_powers , dim = - 1 )
def get_rolling_polynomial_hash_fn ( hash_func_nb : int = 0 , group_size : int = 2 ) :
prime = torch . tensor ( primes [ hash_func_nb ] , dtype = torch . int64 )
prime_powers = torch . stack ( [ prime * * i for i in range ( group_size ) ] )
def rolling_polynomial_hash_fn ( t ) :
return torch . sum ( t * prime_powers , dim = - 1 )
return rolling_polynomial_hash_fn
def byte_group_hash_function (
x : torch . Tensor , group_size : int = 2 , hash_func_nb : int = 0 , max_hash : int = 30000
) :
"""
Returns a hash of the input x and maps it to a value in the range [ 0 , max_hash ] .
expects : x of shape ( batch_size , seq_len ) with values as ids in the token vocab .
returns a tensor of shape ( batch_size , seq_len ) with values in the range [ 0 , max_hash ] .
Note : max hash can make a big difference on the number of collisions .
"""
with torch . no_grad ( ) :
bs , seq_len = x . shape
# x_numpy = x.numpy()
# hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
# for i in range(bs):
# for j in range(seq_len):
# start = max(j, j-group_size+1)
# end = j+1
# hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
prefix = torch . zeros ( bs , group_size - 1 , dtype = torch . int64 , device = x . device )
x = torch . cat ( [ prefix , x ] , dim = 1 )
windows = x . unfold ( 1 , group_size , 1 )
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
hashes = rolling_polynomial_hash ( windows , hash_func_nb )
hash_values_range = hashes % max_hash
hash_values_range . requires_grad = False
return hash_values_range
def create_patch_mask_from_ids (
patch_ids , num_patches , window = None , patches_as_queries = False
) :
"""
Creates a tensor of shape [ bs , seq_len , num_patches ] where each element at position ( i , j , k )
is True if the patch id at position ( i , j ) is less than or equal to k .
Args :
patch_ids ( torch . Tensor ) : Tensor of shape [ bs , seq_len ] containing patch ids .
num_patches ( int ) : Total number of patches .
window ( int ) : If not None , only considers patches within a window of size window .
patches_as_queries ( bool ) : If True , the patches are used as queries
Returns :
torch . Tensor : Tensor of shape [ bs , q_len , kv_len ] with the desired mask .
"""
bs , seq_len = patch_ids . shape
if not patches_as_queries :
q_ids = patch_ids . unsqueeze ( - 1 ) . expand ( bs , seq_len , num_patches )
kv_ids = (
torch . arange ( num_patches , device = patch_ids . device )
. unsqueeze ( 0 )
. unsqueeze ( 0 )
. expand ( bs , seq_len , num_patches )
)
else :
kv_ids = patch_ids . unsqueeze ( 1 ) . expand ( bs , num_patches , seq_len )
q_ids = (
torch . arange ( num_patches , device = patch_ids . device )
. unsqueeze ( 0 )
. unsqueeze ( - 1 )
. expand ( bs , num_patches , seq_len )
)
if window is None :
mask = q_ids == kv_ids
else :
mask = ( kv_ids < = q_ids ) & ( q_ids < kv_ids + window )
return mask
def cross_attn_mask (
patch_ids ,
patch_lengths ,
N ,
patches_as_queries = False ,
cross_attn_k = 1 ,
window = None ,
block_mask = True ,
) :
bs = patch_ids . shape [ 0 ]
with torch . no_grad ( ) :
# Create the patch mask
cross_mask = create_patch_mask_from_ids (
patch_ids ,
patch_lengths . shape [ 1 ] ,
window = window ,
patches_as_queries = patches_as_queries ,
) . repeat_interleave ( cross_attn_k , dim = 1 if patches_as_queries else - 1 )
q_len = patch_lengths . shape [ 1 ] * cross_attn_k if patches_as_queries else N
kv_len = N if patches_as_queries else patch_lengths . shape [ 1 ] * cross_attn_k
assert cross_mask . shape == (
bs ,
q_len ,
kv_len ,
) , f " { cross_mask . shape } != { ( bs , q_len , kv_len ) } "
if block_mask :
def patch_mask ( b , h , q_idx , kv_idx ) :
return cross_mask [ b , q_idx , kv_idx ]
block_mask = create_block_mask (
patch_mask ,
B = bs ,
H = None ,
Q_LEN = q_len ,
KV_LEN = kv_len ,
_compile = True ,
)
return block_mask
else :
return torch . where (
cross_mask , torch . tensor ( 0.0 ) , torch . tensor ( float ( " -inf " ) )
) . unsqueeze (
1
) # [bs, 1, q_len, kv_len]
def get_blt_input (
tokens : torch . Tensor ,
enforce_patch_size_multiple : bool ,
nb_boe : torch . Tensor ,
patch_size : int ,
boe_id : int ,
) :
"""
This function returns X_et , X_gt and X_dt , the encoder , global , and decoder
tokens respectively .
Consider the input and target sequences :
X = [ 3 , 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 , 13 ]
Y = [ 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 , 13 , 14 ]
with patch_size = 4
Note 1 : that there will be no special tokens introduced at the patch level .
Note 2 : X_e needs to be trimmed to be passed to Global
Current without boe :
X_et = [ [ boe , boe , boe , boe ] [ 3 , 4 , 5 , 6 ] , [ 7 , eos , bos , 8 ] , [ 9 , 10 , eos , bos ] [ 11 , 12 , 13 , pad ] ]
X_g = [ [ boe , boe , boe , boe ] [ 3 , 4 , 5 , 6 ] , [ 7 , eos , bos , 8 ] , [ 9 , 10 , eos , bos ] [ 11 , 12 , 13 , pad ] ] # remove last glob patch
X_dt = [ [ 3 , 4 , 5 , 6 ] [ 7 , eos , bos , 8 ] , [ 9 , 10 , eos , bos ] , [ 11 , 12 , 13 ] ]
Y = [ [ 4 , 5 , 6 , 7 ] [ eos , bos , 8 , 9 ] , [ 10 , eos , bos , 11 ] , [ 12 , 13 , 14 ] ]
- - > lag fix :
X_et = [ [ boe , boe , boe , 3 ] [ 4 , 5 , 6 , 7 ] , [ eos , bos , 8 , 9 ] , [ 10 , eos , bos , 11 ] [ 12 , 13 , pad , pad ] ]
X_g = [ [ boe , boe , boe , 3 ] [ 4 , 5 , 6 , 7 ] , [ eos , bos , 8 , 9 ] , [ 10 , eos , bos , 11 ] ]
X_dt = [ [ 3 , 4 , 5 , 6 ] [ 7 , eos , bos , 8 ] , [ 9 , 10 , eos , bos ] , [ 11 , 12 , 13 ] ]
Y = [ [ 4 , 5 , 6 , 7 ] [ eos , bos , 8 , 9 ] , [ 10 , eos , bos , 11 ] , [ 12 , 13 , 14 ] ]
Dynamic ( current ) :
X = [ 3 , 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos ]
Y = [ 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 ]
entropy patching :
input : 7 , bos , 9 , 10
pred ( high entropy ) : eos , 8 , 10 , eos
X_et = [ [ boe , 3 , 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos ]
X_g = [ [ boe ] , [ 3 , 4 , 5 , 6 ] , [ 7 , eos ] , [ bos , 8 ] , [ 9 ] , [ 10 , eos ] ]
X_dt = [ [ 3 , 4 , 5 , 6 ] , [ 7 , eos ] , [ bos , 8 ] , [ 9 ] , [ 10 , eos ] , [ bos ] ]
Y = [ 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 ]
- - > lag fix no boe ( force single byte first patch ) :
X_et = [ [ 3 , 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 ]
X_g = [ [ 3 ] , [ 4 , 5 , 6 , 7 ] , [ eos , bos ] , [ 8 , 9 ] , [ 10 ] , [ eos , bos ] , [ 11 , 12 ] ] # remove last global patch
X_dt = [ [ 3 , 4 , 5 , 6 ] , [ 7 , eos ] , [ bos , 8 ] , [ 9 ] , [ 10 , eos ] , [ bos , 11 , 12 ] ]
Y = [ 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 , 13 ]
input : 4 , 7 , bos , 9 , 10
pred ( high entropy ) : 5 , eos , 8 , 10 , eos
X_et = [ [ 3 , 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 ]
X_g = [ [ 3 ] , [ 4 ] , [ 5 , 6 , 7 ] , [ eos , bos ] , [ 8 , 9 ] , [ 10 ] , [ eos , bos ] , [ 11 , 12 ] ] # remove last global patch
X_dt = [ [ 3 ] [ 4 , 5 , 6 ] , [ 7 , eos ] , [ bos , 8 ] , [ 9 ] , [ 10 , eos ] , [ bos , 11 , 12 ] ]
Y = [ 4 , ] [ 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 , 13 ]
Handle the last byte properly .
patch_lengths = [ 1 , 1 , 3 , 2 , 2 1 2 2 1 ]
X_et = [ [ 3 , 4 , 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 ]
X_g = [ [ 3 ] , [ 4 ] , [ 5 , 6 , 7 ] , [ eos , bos ] , [ 8 , 9 ] , [ 10 ] , [ eos , bos ] , [ 11 , 12 ] ] # do not remove last global patch
X_dt = [ [ 3 ] [ 4 , 5 , 6 ] , [ 7 , eos ] , [ bos , 8 ] , [ 9 ] , [ 10 , eos ] , [ bos , 11 ] [ 12 ] ]
Y = [ 4 , ] [ 5 , 6 , 7 , eos , bos , 8 , 9 , 10 , eos , bos , 11 , 12 , 13 ] ]
bpe delim
X_et = [ [ 3 , 4 , 5 , 6 , 7 , < d > , eos , bos , < d > , 8 , 9 , < d > , 10 , < d > , eos , bos , 11 , 12 ]
X_g = [ [ 3 ] , [ 4 , 5 , 6 , 7 , < d > ] , [ eos , bos , < d > ] , . .
X_dt = [ [ 3 , 4 , 5 , 6 , 7 ] , [ < d > , eos , bos ] , [ < d > , bos , 8 ] , . .
Y = [ 4 , 5 , 6 , 7 , < d > , eos , bos , < d > 8 , 9 , < d > , . .
Note 1 : that there will be no special tokens introduced at the patch level .
Note 2 : X_e needs to be trimmed to be passed to Global
"""
batch_size , seq_len = tokens . shape
local_encoder_tokens = tokens
local_decoder_tokens = tokens
if nb_boe > 0 :
padded_patch = tokens . new ( batch_size , nb_boe ) . fill_ ( boe_id )
local_encoder_tokens = torch . cat ( ( padded_patch , local_encoder_tokens ) , dim = 1 )
# global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
# create global tokens, contains boe tokens and eos
# padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
# patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
# global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
# global_tokens += global_tokens.eq(0).int() * boe_id
# TODO: fix this when we want to use block causal in the global.
if enforce_patch_size_multiple and local_encoder_tokens . shape [ - 1 ] % patch_size != 0 :
local_encoder_tokens = fill_tokens ( local_encoder_tokens , patch_size , boe_id )
return local_encoder_tokens , None , local_decoder_tokens
def patch_ids_from_lengths ( patch_lengths , seq_len ) :
bs , num_patches = patch_lengths . shape
# Create a tensor of cumulative sums of the patch lengths
cum_d = torch . cat (
[
torch . zeros ( bs , 1 , dtype = patch_lengths . dtype , device = patch_lengths . device ) ,
patch_lengths . cumsum ( dim = - 1 ) ,
] ,
dim = - 1 ,
)
patch_ids = ( cum_d . unsqueeze ( - 1 ) < = torch . arange ( seq_len , device = cum_d . device ) ) . sum (
dim = - 2
) - 1
assert not (
torch . max ( patch_ids ) > patch_lengths . shape [ - 1 ] or torch . min ( patch_ids ) < 0
) , f " { torch . max ( patch_ids ) } > { patch_lengths . shape [ - 1 ] } or { torch . min ( patch_ids ) } < 0 "
return patch_ids
class ByteLatentTransformerArgs ( BaseTransformerArgs ) :
# Basic model configuration
seed : int = 42
vocab_size : int = - 1
dim : int = 512
n_layers : int = 8
n_heads : int = 8
# TODO: What is the purpose of this parameter?
weight_tying : bool = False
# Architecture and dimensions
dim_token : int = 256
dim_global : int = 512
dim_local_decoder : int = 512
dim_local_encoder : int = 512
n_layers_global : int = 8
n_layers_local_decoder : int = 8
n_layers_local_encoder : int = 8
# Tokenization and patching
tokenization_mode : str = " bpe "
patch_size : float | None = None
patching_mode : str | None = None
patching_threshold : float | None = None
patching_threshold_add : float | None = None
monotonicity : bool = False
patching_batch_size : int = 1
patching_device : str = " cuda "
data_loader_patching : bool = False
max_patch_length : int | None = None
# Encoder/Decoder configuration
tie_local_encoder_decoder_logits : bool = False
use_local_encoder_transformer : bool = False
encoder_lm_loss : bool = False
max_encoder_seq_length : int | None = None
pad_to_max_length : bool = False
encoder_enable_byte_ngrams : bool = False
encoder_enable_byte_group_hash : bool = False
ngram_vocab_sizes : int | None = None
# Cross attention configurations
cross_attn_encoder : bool = False
cross_attn_decoder : bool = False
cross_attn_window_encoder : int | None = None
cross_attn_window_decoder : int | None = None
cross_attn_k : int | None = None
cross_attn_nheads : int | None = None
cross_attn_all_layers_decoder : bool = False
cross_attn_all_layers_encoder : bool = False
cross_attn_use_flex_attention : bool = True
cross_attn_init_by_pooling : bool = False
# Encoder hash configurations
encoder_hash_byte_group_size : Any | None = None
encoder_hash_byte_group_vocab : int = 30000
encoder_hash_byte_group_nb_functions : int = 3
# Model behavior and optimization
log_patch_lengths : bool = False
non_linearity : str = " swiglu "
use_rope : bool = True
recompute_fc1_out : bool = False
recompute_fc3_out : bool = False
recompute_attn : bool = True
custom_bwd : bool = False
layer_ckpt : str = " all "
# Initialization and attention
init_use_gaussian : bool = True
init_use_depth : str = " current "
attn_bias_type : str = " causal "
alpha_depth : str = " disabled "
max_length : int = 2048
# Norm configuration
norm_eps : float = 1e-5
norm_affine : bool = True
pre_norm : bool = True
norm_type : str = " rmsnorm "
# Additional configurations
multiple_of : int = 256
ffn_dim_multiplier : float = 1.0
dropout : float = 0
output_size : int = - 1
# Additional parameters from ModelArgs
architecture : str = " vanilla "
share_encoder_decoder_emb : bool = True
global_local_decoder_residual_layer : str | None = None
tokenize_with_bpe_delimiter : bool = False
patching_thresholds_str : str | None = None
tie_local_encoder_decoder : bool = False
encoder_preds_low_entropy_toks : float | None = None
encoder_preds_random_toks : float | None = None
dim_token_emb : int | None = None
dim_patch_emb : int | None = None
encoder_ngram_table_dir : str | None = None
encoder_ngram_to_size_str : str | None = None
# Model architecture params
entropy_model_checkpoint_dir : str | None = None
entropy_model_is_ngram_model : bool = False
downsampling_by_pooling : str | None = None
n_heads_global : int = 8
n_heads_local_decoder : int = 8
n_heads_local_encoder : int = 8
n_kv_heads : int | None = None
n_kv_heads_global : int | None = None
conv_kernel_size : int | None = None
local_attention_window_len : int | None = None
# Performance optimization
sequence_parallel : bool = False
loss_parallel : bool = False
fuse_sequence_parallel : bool = False
use_fsdp : bool = True
attn_to_keep : str = " all "
# RoPE parameters
rope_theta : float = 10000.0
rope_use_fp32_in_outer_product : bool = False
# Parameter mixing
pm_size : int = 0
# Logging
full_logging_n_layers : int = 4
@model_validator ( mode = " after " )
def check_hash_byte_sizes ( self ) - > Self :
if (
self . encoder_hash_byte_group_size is not None
and type ( self . encoder_hash_byte_group_size ) == str
) :
self . encoder_hash_byte_group_size = [
int ( x )
for x in self . encoder_hash_byte_group_size . split ( " , " )
if len ( x ) > 0
]
return self
class GlobalTransformerArgs ( ByteLatentTransformerArgs ) :
# Global encoder specific dimensions
dim_token_emb : int | None = None
dim_patch_emb : int | None = None
def __post_init__ ( self ) :
# Override base args with global encoder specific values
self . dim = self . dim_global
self . n_layers = self . n_layers_global
self . n_heads = self . n_heads_global
self . n_kv_heads = self . n_kv_heads_global
self . local_attention_window_len = None
self . cross_attn_encoder = False
self . cross_attn_decoder = False
class LocalDecoderArgs ( ByteLatentTransformerArgs ) :
# Local decoder specific dimensions
dim_token_emb : int | None = None
dim_patch_emb : int | None = None
def __post_init__ ( self ) :
# Override base args with local decoder specific values
self . dim = self . dim_local_decoder
self . n_layers = self . n_layers_local_decoder
self . n_heads = self . n_heads_local_decoder
self . cross_attn_encoder = False
self . cross_attn_init_by_pooling = False
self . attn_bias_type = " local_block_causal "
def create_global_transformer ( args : ByteLatentTransformerArgs ) - > GlobalTransformer :
global_args = args . model_copy (
deep = True ,
update = dict (
dim = args . dim_global ,
n_layers = args . n_layers_global ,
n_heads = args . n_heads_global ,
n_kv_heads = args . n_kv_heads_global ,
local_attention_window_len = None ,
dim_token_emb = get_global_dim_patch_emb ( args ) ,
dim_patch_emb = None ,
cross_attn_encoder = False ,
cross_attn_decoder = False ,
) ,
)
return GlobalTransformer ( global_args )
def create_local_encoder ( args : ByteLatentTransformerArgs ) - > LocalEncoder :
2025-01-17 22:23:01 +00:00
local_encoder_args = LocalModelArgs (
# Updated args
dim = args . dim_local_encoder ,
n_layers = args . n_layers_local_encoder ,
n_heads = args . n_heads_local_encoder ,
dim_token_emb = get_encoder_dim_token_emb ( args ) ,
dim_patch_emb = get_encoder_dim_patch_emb ( args ) ,
cross_attn_encoder = args . cross_attn_encoder ,
cross_attn_decoder = False ,
cross_attn_k = args . cross_attn_k if args . cross_attn_encoder else None ,
cross_attn_init_by_pooling = args . cross_attn_init_by_pooling ,
# Defaults
head_dim = args . head_dim ,
max_seqlen = args . max_encoder_seq_length ,
dropout = args . dropout ,
vocab_size = args . vocab_size + args . pm_size ,
norm_eps = args . norm_eps ,
patch_size = args . patch_size ,
sliding_window = args . local_attention_window_len ,
use_rope = args . use_rope ,
rope_theta = args . rope_theta ,
init_base_std = args . init_base_std ,
init_std_factor = args . init_std_factor ,
n_kv_heads = args . n_kv_heads ,
attn_impl = args . attn_impl ,
attn_bias_type = " local_block_causal " ,
multiple_of = args . multiple_of ,
ffn_dim_multiplier = args . ffn_dim_multiplier ,
patching_mode = args . patching_mode ,
use_local_encoder_transformer = args . use_local_encoder_transformer ,
downsampling_by_pooling = args . downsampling_by_pooling ,
encoder_hash_byte_group_size = args . encoder_hash_byte_group_size ,
cross_attn_all_layers_encoder = args . cross_attn_all_layers_encoder ,
cross_attn_all_layers_decoder = args . cross_attn_all_layers_decoder ,
cross_attn_nheads = args . cross_attn_nheads ,
eos_id = args . eos_id ,
2024-12-12 23:32:30 +00:00
)
return LocalEncoder ( local_encoder_args )
def create_local_decoder ( args : ByteLatentTransformerArgs ) - > LocalDecoder :
# First deep copy the original args
2025-01-17 22:23:01 +00:00
local_decoder_args = LocalModelArgs (
dim = args . dim_local_decoder ,
n_layers = args . n_layers_local_decoder ,
n_heads = args . n_heads_local_decoder ,
dim_token_emb = get_decoder_dim_token_emb ( args ) ,
dim_patch_emb = args . dim_global ,
cross_attn_encoder = False ,
cross_attn_decoder = args . cross_attn_decoder ,
cross_attn_init_by_pooling = False , # states are already defined
cross_attn_k = args . cross_attn_k if args . cross_attn_decoder else None ,
# Defaults
head_dim = args . head_dim ,
max_seqlen = args . max_encoder_seq_length ,
dropout = args . dropout ,
vocab_size = args . vocab_size + args . pm_size ,
norm_eps = args . norm_eps ,
patch_size = args . patch_size ,
sliding_window = args . local_attention_window_len ,
use_rope = args . use_rope ,
rope_theta = args . rope_theta ,
init_base_std = args . init_base_std ,
init_std_factor = args . init_std_factor ,
n_kv_heads = args . n_kv_heads ,
attn_impl = args . attn_impl ,
attn_bias_type = " local_block_causal " ,
multiple_of = args . multiple_of ,
ffn_dim_multiplier = args . ffn_dim_multiplier ,
patching_mode = args . patching_mode ,
use_local_encoder_transformer = args . use_local_encoder_transformer ,
downsampling_by_pooling = args . downsampling_by_pooling ,
encoder_hash_byte_group_size = args . encoder_hash_byte_group_size ,
cross_attn_all_layers_encoder = args . cross_attn_all_layers_encoder ,
cross_attn_all_layers_decoder = args . cross_attn_all_layers_decoder ,
cross_attn_nheads = args . cross_attn_nheads ,
eos_id = args . eos_id ,
2024-12-12 23:32:30 +00:00
)
return LocalDecoder ( local_decoder_args )
class EmbeddingType ( Enum ) :
HASH_TOK = auto ( )
NGRAM = auto ( )
def init_embeddings (
args ,
embedding_type : EmbeddingType ,
local_encoder_dim : int ,
encoder_hash_byte_group_size : list = None ,
) :
if (
embedding_type == EmbeddingType . HASH_TOK
and args . encoder_hash_byte_group_size is None
) :
return None
if embedding_type == EmbeddingType . NGRAM and args . encoder_ngram_to_size_str is None :
return None
embeddings = [ ]
if embedding_type == EmbeddingType . HASH_TOK :
emb_dim = local_encoder_dim
encoder_hash_byte_group_vocab = args . encoder_hash_byte_group_vocab
for _ in range ( args . encoder_hash_byte_group_nb_functions ) :
for _ in encoder_hash_byte_group_size :
embeddings . append (
nn . Embedding (
encoder_hash_byte_group_vocab ,
emb_dim ,
)
)
elif embedding_type == EmbeddingType . NGRAM :
encoder_ngram_to_size = parse_ngram_to_size ( args . encoder_ngram_to_size_str )
emb_dim = local_encoder_dim
OFFSET = 4 # This should be passed as parameter if it's variable
for ngram_vocab_size in encoder_ngram_to_size . values ( ) :
embeddings . append ( nn . Embedding ( ngram_vocab_size + OFFSET , emb_dim ) )
return nn . ModuleList ( embeddings )
def compute_hash_embeddings (
local_encoder_tokens : torch . Tensor ,
local_encoder ,
encoder_hash_tok_embedding : nn . ModuleList ,
encoder_hash_byte_group_nb_functions : int ,
encoder_hash_byte_group_size : list ,
encoder_hash_byte_group_vocab : int ,
) - > torch . Tensor :
"""
Compute embeddings using hash token embeddings .
Args :
local_encoder_tokens : Input tokens tensor
local_encoder : Encoder object with tok_embeddings method
encoder_hash_tok_embedding : ModuleList of hash token embeddings
encoder_hash_byte_group_nb_functions : Number of hash functions
encoder_hash_byte_group_size : List of byte group sizes
encoder_hash_byte_group_vocab : Vocabulary size for hash embeddings
Returns :
torch . Tensor : Combined embeddings
"""
if encoder_hash_tok_embedding is None :
return None
local_encoder_embeds = local_encoder . tok_embeddings ( local_encoder_tokens )
i = 0
for func_nb in range ( encoder_hash_byte_group_nb_functions ) :
for byte_group_size in encoder_hash_byte_group_size :
hash_ids = byte_group_hash_function (
local_encoder_tokens ,
byte_group_size ,
hash_func_nb = func_nb ,
max_hash = encoder_hash_byte_group_vocab ,
)
hash_tok_embedding = encoder_hash_tok_embedding [ i ]
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding ( hash_ids )
i + = 1
assert i == len ( encoder_hash_tok_embedding )
return local_encoder_embeds
class ByteLatentTransformer ( nn . Module ) :
"""
The ByteLatentTransformer ( BLT ) is a byte - level language model architecture that processes byte sequences
by dynamically segmenting them into patches . It uses a combination of local encoders , global transformers ,
and local decoders to efficiently encode and decode byte sequences , leveraging patch - based processing for
improved performance and inference efficiency .
"""
def __init__ ( self , args : ByteLatentTransformerArgs ) :
super ( ) . __init__ ( )
# General configuration
self . weight_tying = args . weight_tying
self . patch_size = args . patch_size
self . patching_mode = args . patching_mode
self . boe_id , self . bos_id , self . pad_id , self . eos_id = (
BOE_ID ,
BOS_ID ,
PAD_ID ,
EOS_ID ,
)
self . downsampling_by_pooling = args . downsampling_by_pooling
self . patching_threshold = args . patching_threshold
self . dim = args . dim
self . init_base_std = args . init_base_std
self . init_std_factor = InitStdFactor ( args . init_std_factor )
self . max_seqlen = args . max_seqlen
# Cross attention configuration
self . cross_attn_encoder = args . cross_attn_encoder
self . cross_attn_decoder = args . cross_attn_decoder
self . cross_attn_k = args . cross_attn_k
self . cross_attn_window_encoder = args . cross_attn_window_encoder
self . cross_attn_window_decoder = args . cross_attn_window_decoder
self . cross_attn_use_flex_attention = args . cross_attn_use_flex_attention
# Encoder hash configuration
self . encoder_hash_byte_group_size = args . encoder_hash_byte_group_size
self . encoder_hash_byte_group_vocab = args . encoder_hash_byte_group_vocab
self . encoder_hash_byte_group_nb_functions = (
args . encoder_hash_byte_group_nb_functions
)
# ByteLatent modules
self . local_encoder = create_local_encoder ( args )
self . global_transformer = create_global_transformer ( args )
self . local_decoder = create_local_decoder ( args )
self . encoder_hash_tok_embedding = init_embeddings (
args ,
EmbeddingType . HASH_TOK ,
local_encoder_dim = self . local_encoder . dim ,
encoder_hash_byte_group_size = self . encoder_hash_byte_group_size ,
)
self . encoder_ngram_embedding = init_embeddings (
args ,
EmbeddingType . NGRAM ,
local_encoder_dim = self . local_encoder . dim ,
encoder_hash_byte_group_size = None ,
)
self . tok_embeddings = torch . nn . Embedding ( args . vocab_size , args . dim )
# Transformer layers
self . layers = nn . ModuleList (
[ TransformerBlock ( args ) for _ in range ( args . n_layers ) ]
)
# Encoder ngram embedding tables
self . encoder_ngram_embedding = None
if args . encoder_enable_byte_ngrams :
self . encoder_ngram_embedding = nn . ModuleList ( )
assert args . ngram_vocab_sizes is not None
self . encoder_ngram_to_size = parse_ngram_to_size (
args . encoder_ngram_to_size_str
)
ngram_emb_dim = self . local_encoder . dim
for ngram_vocab_size in self . encoder_ngram_to_size . values ( ) :
self . encoder_ngram_embedding . append (
nn . Embedding ( ngram_vocab_size + OFFSET , ngram_emb_dim )
)
# Output layer
assert args . vocab_size > 0 , " vocab_size must be greater than 0 "
self . output = nn . Linear ( args . dim , args . vocab_size , bias = False )
if args . weight_tying :
self . output . weight = self . tok_embeddings . weight
# Patcher module
if not args . data_loader_patching :
self . patcher = Patcher (
PatcherArgs (
patch_size = args . patch_size ,
patching_mode = args . patching_mode ,
patching_threshold = args . patching_threshold ,
patching_threshold_add = args . patching_threshold_add ,
monotonicity = args . monotonicity ,
max_patch_length = args . max_patch_length ,
)
)
def forward (
self ,
tokens : torch . Tensor ,
patch_lengths : Optional [ torch . Tensor ] = None ,
ngram_ids : Optional [ torch . Tensor ] = None ,
) :
# Ensure ngram_ids is either a tensor or None
assert (
isinstance ( ngram_ids , torch . Tensor ) or ngram_ids is None
) , f " ngram_ids must be a tensor or None, but was: { type ( ngram_ids ) } "
bs , N = tokens . shape # Batch size and sequence length
# Get megabyte inputs
nb_boe = int ( 0 if self . patching_mode != " " else self . patch_size - 1 )
local_encoder_tokens , _ , local_decoder_tokens = get_blt_input (
tokens = tokens ,
enforce_patch_size_multiple = False ,
nb_boe = nb_boe ,
patch_size = self . patch_size ,
boe_id = self . boe_id ,
)
# Patching
if patch_lengths is None :
assert (
getattr ( self , " patcher " , None ) is not None
) , " Patcher not defined and no patch_lengths passed. "
patch_lengths , tok_scores = self . patcher . patch (
local_encoder_tokens ,
include_next_token = True ,
threshold = self . patcher . threshold ,
)
else :
if nb_boe > 0 :
patch_lengths [ : , 0 ] + = nb_boe
assert torch . min ( patch_lengths ) > = 0
# Generate patch IDs from patch_lengths
patch_ids = patch_ids_from_lengths (
patch_lengths , local_encoder_tokens . shape [ - 1 ]
)
assert torch . max ( patch_ids ) + 1 < = torch . max (
( patch_lengths != 0 ) . sum ( dim = - 1 )
) , f " { torch . max ( patch_ids ) + 1 } > { torch . max ( ( patch_lengths != 0 ) . sum ( dim = - 1 ) ) } "
cross_attn_mask_enc = None
# Cross-attention encoder
if self . cross_attn_encoder :
cross_attn_mask_enc = cross_attn_mask (
patch_ids ,
patch_lengths ,
N ,
patches_as_queries = True ,
cross_attn_k = self . cross_attn_k ,
window = self . cross_attn_window_encoder ,
block_mask = self . cross_attn_use_flex_attention ,
)
# Hashing and embedding
local_encoder_embeds = compute_hash_embeddings (
local_encoder_tokens = local_encoder_tokens ,
local_encoder = self . local_encoder ,
encoder_hash_tok_embedding = self . encoder_hash_tok_embedding ,
encoder_hash_byte_group_nb_functions = self . encoder_hash_byte_group_nb_functions ,
encoder_hash_byte_group_size = self . encoder_hash_byte_group_size ,
encoder_hash_byte_group_vocab = self . encoder_hash_byte_group_vocab ,
)
# N-gram table embeddings
if self . encoder_ngram_embedding is not None :
assert ngram_ids is not None , " ngram_ids must be provided "
if local_encoder_embeds is None :
local_encoder_embeds = self . local_encoder . tok_embeddings (
local_encoder_tokens
)
assert len ( ngram_ids ) == len (
self . encoder_ngram_embedding
) , f " ngram_ids.shape[0]= { ngram_ids . shape [ 0 ] } versus len(encoder_ngram_embedding)= { len ( self . encoder_ngram_embedding ) } , ngram_ids.shape= { ngram_ids . shape } "
for i in range ( ngram_ids . shape [ 0 ] ) :
ngram_embedding = self . encoder_ngram_embedding [ i ]
ngram_embeds = ngram_embedding ( ngram_ids [ i ] )
assert (
local_encoder_embeds . shape == ngram_embeds . shape
) , f " Shape mismatch: { local_encoder_embeds . shape } vs { ngram_embeds . shape } , ngram_ids.shape= { ngram_ids . shape } "
local_encoder_embeds = local_encoder_embeds + ngram_embeds
# Local encoder
h_cross = None
( h_encoder , h_cross ) , cache_encoder = self . local_encoder (
tokens = local_encoder_tokens ,
embeds = local_encoder_embeds ,
patch_embeds = h_cross if self . cross_attn_encoder else None ,
cross_mask = cross_attn_mask_enc ,
num_patches = patch_lengths . shape [ 1 ] ,
patch_ids = patch_ids ,
)
# Downsampling
if not self . cross_attn_encoder :
assert (
patch_ids . shape [ 1 ] == h_encoder . shape [ 1 ]
) , f " { patch_ids . shape [ 1 ] } != { h_encoder . shape [ 1 ] } "
h = downsample (
h_encoder ,
patch_lengths . shape [ 1 ] ,
patch_lengths ,
patch_ids ,
downsampling_by_pooling = self . downsampling_by_pooling ,
patch_size = self . patch_size ,
)
else :
# Reshape h_cross
h = h_cross . view ( bs , patch_lengths . shape [ 1 ] , - 1 )
# Global transformer
global_tokens = tokens . new ( h . shape [ 0 ] , h . shape [ 1 ] ) . fill_ ( self . boe_id )
rows , cols = torch . where ( local_encoder_tokens == self . eos_id )
eos_patch_ids = patch_ids [ rows , cols ]
global_tokens [ rows , eos_patch_ids ] = self . eos_id
h , _ = self . global_transformer (
embeds = h ,
tokens = global_tokens ,
)
# Unpatching
dec_embeds = h_encoder [ : , nb_boe : nb_boe + N , : ]
# Generate decoder patch IDs
decoder_patch_ids = decoder_patch_ids_from_lengths (
patch_lengths , nb_boe , local_decoder_tokens . shape [ - 1 ]
)
assert (
torch . max ( decoder_patch_ids ) + 1 < = h . shape [ 1 ]
) , f " { torch . max ( decoder_patch_ids ) + 1 } > { h . shape [ 1 ] } "
assert (
decoder_patch_ids . shape [ 1 ] == dec_embeds . shape [ 1 ]
) , f " { decoder_patch_ids . shape [ 1 ] } != { dec_embeds . shape [ 1 ] } "
# Cross-attention decoder
if not self . cross_attn_decoder :
h = torch . gather (
h , 1 , decoder_patch_ids . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , h . shape [ - 1 ] )
)
cross_attn_mask_dec = None
assert local_decoder_tokens . shape == h . shape [ : - 1 ]
else :
cross_attn_mask_dec = cross_attn_mask (
decoder_patch_ids ,
patch_lengths ,
N ,
patches_as_queries = False ,
cross_attn_k = self . cross_attn_k ,
window = self . cross_attn_window_decoder ,
block_mask = self . cross_attn_use_flex_attention ,
)
# Local decoder
output , _ = self . local_decoder (
embeds = dec_embeds ,
patch_embeds = h ,
tokens = local_decoder_tokens ,
cross_mask = cross_attn_mask_dec ,
)
return output
def reset_parameters ( self , init_std = None ) :
# Either use fixed base std or sqrt model dim
init_std = init_std or ( self . dim * * ( - 0.5 ) )
nn . init . trunc_normal_ (
self . tok_embeddings . weight ,
mean = 0.0 ,
std = init_std ,
a = - 3 * init_std ,
b = 3 * init_std ,
)
if not self . weight_tying :
nn . init . trunc_normal_ (
self . output . weight ,
mean = 0.0 ,
std = init_std ,
a = - 3 * init_std ,
b = 3 * init_std ,
)
def init_weights ( self ) :
self . reset_parameters ( )
self . init_base_std = self . init_base_std or ( self . dim * * ( - 0.5 ) )
for depth , layer in enumerate ( self . layers ) :
factor = {
InitStdFactor . CURRENT_DEPTH : ( 2 * ( depth + 1 ) ) * * 0.5 ,
InitStdFactor . GLOBAL_DEPTH : ( 2 * ( len ( self . layers ) + 1 ) ) * * 0.5 ,
InitStdFactor . DIM_RATIO : self . dim / 4096 ,
InitStdFactor . DISABLED : 1.0 ,
} [ self . init_std_factor ]
layer . init_weights ( self . init_base_std , factor )
self . local_decoder . init_weights ( self . init_base_std )
self . global_transformer . init_weights ( self . init_base_std )
self . local_encoder . init_weights ( self . init_base_std )
for emb in self . encoder_hash_tok_embedding :
nn . init . trunc_normal_ (
emb . weight ,
mean = 0.0 ,
std = self . init_base_std ,
a = - 3 * self . init_base_std ,
b = 3 * self . init_base_std ,
)