mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
6ffeb66b53
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
1083 lines
39 KiB
Python
1083 lines
39 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
from enum import Enum, auto
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from pydantic import ConfigDict, model_validator
|
|
from torch import nn
|
|
from torch.nn.attention.flex_attention import create_block_mask
|
|
from typing_extensions import Self
|
|
|
|
from bytelatent.base_transformer import (
|
|
BaseTransformerArgs,
|
|
InitStdFactor,
|
|
TransformerBlock,
|
|
)
|
|
from bytelatent.data.patcher import Patcher, PatcherArgs
|
|
from bytelatent.model.latent_transformer import GlobalTransformer
|
|
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
|
|
from bytelatent.model.utils import downsample
|
|
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
|
|
|
|
|
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
|
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
|
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
|
|
|
|
|
|
def get_num_flop_per_token(
|
|
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
|
|
) -> int:
|
|
return 6 * num_non_embed_params + attention_flops_per_token(
|
|
n_layers, seq_len, dim, True
|
|
)
|
|
|
|
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
|
|
def setattrs(_self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
setattr(_self, k, v)
|
|
|
|
|
|
def get_encoder_dim_token_emb(args):
|
|
if args.dim_token is not None:
|
|
dim_token_emb = args.dim_token
|
|
elif args.use_local_encoder_transformer:
|
|
dim_token_emb = args.dim_local_encoder
|
|
else:
|
|
dim_token_emb = args.dim_global // args.patch_size
|
|
return dim_token_emb
|
|
|
|
|
|
def get_encoder_dim_patch_emb(args):
|
|
dim_patch_emb = None
|
|
if args.cross_attn_encoder:
|
|
if args.cross_attn_init_by_pooling:
|
|
dim_patch_emb = args.dim_local_encoder
|
|
else:
|
|
dim_patch_emb = args.dim_global
|
|
return dim_patch_emb
|
|
|
|
|
|
def get_global_dim_patch_emb(args):
|
|
dim_token_emb = get_encoder_dim_token_emb(args)
|
|
if args.cross_attn_encoder:
|
|
dim_patch_emb = dim_token_emb * args.cross_attn_k
|
|
elif (
|
|
args.downsampling_by_pooling is None
|
|
or not args.downsampling_by_pooling
|
|
or len(args.downsampling_by_pooling) == 0
|
|
):
|
|
dim_patch_emb = dim_token_emb * args.patch_size
|
|
else:
|
|
dim_patch_emb = dim_token_emb * sum(
|
|
[
|
|
pooling in args.downsampling_by_pooling
|
|
for pooling in ["avg", "min", "max"]
|
|
]
|
|
)
|
|
return dim_patch_emb
|
|
|
|
|
|
def get_decoder_dim_token_emb(args):
|
|
if args.share_encoder_decoder_emb:
|
|
dim_token_emb = get_encoder_dim_token_emb(args)
|
|
elif args.dim_token is not None:
|
|
dim_token_emb = args.dim_token
|
|
else:
|
|
dim_token_emb = args.dim_local_decoder
|
|
return dim_token_emb
|
|
|
|
|
|
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
|
|
if ngram_to_size_str is None:
|
|
return None
|
|
ngram_to_size = {}
|
|
for entry in ngram_to_size_str.split(","):
|
|
ngram, size = entry.split(":")
|
|
ngram = int(ngram)
|
|
size = int(size)
|
|
ngram_to_size[ngram] = size
|
|
return ngram_to_size
|
|
|
|
|
|
def fill_tokens(tokens, patch_size, fill_id):
|
|
batch_size, seq_len = tokens.shape
|
|
if seq_len % patch_size == 0:
|
|
return tokens
|
|
else:
|
|
remaining = patch_size - seq_len % patch_size
|
|
final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
|
|
return torch.cat((tokens, final_padding), dim=1)
|
|
|
|
|
|
def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len):
|
|
first_patch_length = patch_lengths[0, 0]
|
|
assert torch.all(
|
|
first_patch_length == patch_lengths[:, 0]
|
|
), "first patch should always be the same size (1 for dynamic, patch_size for static)."
|
|
assert (
|
|
first_patch_length - nb_boe == 1
|
|
), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})"
|
|
# Remove first patch from patch_ids for local decoder inputs and shift the last patch.
|
|
# decoder_patch_lengths = patch_lengths[:, 1:].clone()
|
|
# decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
|
|
decoder_patch_lengths = patch_lengths[:, 1:]
|
|
assert (
|
|
decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]
|
|
== patch_lengths.sum()
|
|
), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}"
|
|
assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}"
|
|
decoder_patch_ids = patch_ids_from_lengths(
|
|
patch_lengths=decoder_patch_lengths, seq_len=seq_len
|
|
)
|
|
return decoder_patch_ids
|
|
|
|
|
|
primes = [
|
|
1000000007,
|
|
5915587277,
|
|
1500450271,
|
|
3267000013,
|
|
5754853343,
|
|
4093082899,
|
|
9576890767,
|
|
3628273133,
|
|
2860486313,
|
|
5463458053,
|
|
3367900313,
|
|
]
|
|
|
|
|
|
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
|
|
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
|
|
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
|
|
return torch.sum(t * prime_powers, dim=-1)
|
|
|
|
|
|
def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2):
|
|
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64)
|
|
prime_powers = torch.stack([prime**i for i in range(group_size)])
|
|
|
|
def rolling_polynomial_hash_fn(t):
|
|
return torch.sum(t * prime_powers, dim=-1)
|
|
|
|
return rolling_polynomial_hash_fn
|
|
|
|
|
|
def byte_group_hash_function(
|
|
x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
|
|
):
|
|
"""
|
|
Returns a hash of the input x and maps it to a value in the range [0, max_hash].
|
|
|
|
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
|
|
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
|
|
|
|
Note: max hash can make a big difference on the number of collisions.
|
|
"""
|
|
with torch.no_grad():
|
|
bs, seq_len = x.shape
|
|
# x_numpy = x.numpy()
|
|
# hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
|
|
# for i in range(bs):
|
|
# for j in range(seq_len):
|
|
# start = max(j, j-group_size+1)
|
|
# end = j+1
|
|
# hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
|
|
|
|
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
|
|
x = torch.cat([prefix, x], dim=1)
|
|
windows = x.unfold(1, group_size, 1)
|
|
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
|
|
hashes = rolling_polynomial_hash(windows, hash_func_nb)
|
|
hash_values_range = hashes % max_hash
|
|
hash_values_range.requires_grad = False
|
|
return hash_values_range
|
|
|
|
|
|
def create_patch_mask_from_ids(
|
|
patch_ids, num_patches, window=None, patches_as_queries=False
|
|
):
|
|
"""
|
|
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
|
|
is True if the patch id at position (i, j) is less than or equal to k.
|
|
Args:
|
|
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
|
|
num_patches (int): Total number of patches.
|
|
window (int): If not None, only considers patches within a window of size window.
|
|
patches_as_queries (bool): If True, the patches are used as queries
|
|
Returns:
|
|
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
|
|
"""
|
|
bs, seq_len = patch_ids.shape
|
|
if not patches_as_queries:
|
|
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
|
|
kv_ids = (
|
|
torch.arange(num_patches, device=patch_ids.device)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.expand(bs, seq_len, num_patches)
|
|
)
|
|
else:
|
|
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
|
|
q_ids = (
|
|
torch.arange(num_patches, device=patch_ids.device)
|
|
.unsqueeze(0)
|
|
.unsqueeze(-1)
|
|
.expand(bs, num_patches, seq_len)
|
|
)
|
|
if window is None:
|
|
mask = q_ids == kv_ids
|
|
else:
|
|
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
|
|
return mask
|
|
|
|
|
|
def cross_attn_mask(
|
|
patch_ids,
|
|
patch_lengths,
|
|
N,
|
|
patches_as_queries=False,
|
|
cross_attn_k=1,
|
|
window=None,
|
|
block_mask=True,
|
|
):
|
|
bs = patch_ids.shape[0]
|
|
with torch.no_grad():
|
|
# Create the patch mask
|
|
cross_mask = create_patch_mask_from_ids(
|
|
patch_ids,
|
|
patch_lengths.shape[1],
|
|
window=window,
|
|
patches_as_queries=patches_as_queries,
|
|
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
|
|
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
|
|
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
|
|
assert cross_mask.shape == (
|
|
bs,
|
|
q_len,
|
|
kv_len,
|
|
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
|
|
if block_mask:
|
|
|
|
def patch_mask(b, h, q_idx, kv_idx):
|
|
return cross_mask[b, q_idx, kv_idx]
|
|
|
|
block_mask = create_block_mask(
|
|
patch_mask,
|
|
B=bs,
|
|
H=None,
|
|
Q_LEN=q_len,
|
|
KV_LEN=kv_len,
|
|
_compile=True,
|
|
)
|
|
return block_mask
|
|
else:
|
|
return torch.where(
|
|
cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
|
|
).unsqueeze(
|
|
1
|
|
) # [bs, 1, q_len, kv_len]
|
|
|
|
|
|
def get_blt_input(
|
|
tokens: torch.Tensor,
|
|
enforce_patch_size_multiple: bool,
|
|
nb_boe: torch.Tensor,
|
|
patch_size: int,
|
|
boe_id: int,
|
|
):
|
|
"""
|
|
This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
|
|
tokens respectively.
|
|
|
|
Consider the input and target sequences:
|
|
X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
|
|
Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
|
|
with patch_size=4
|
|
|
|
Note 1: that there will be no special tokens introduced at the patch level.
|
|
Note 2: X_e needs to be trimmed to be passed to Global
|
|
|
|
Current without boe:
|
|
X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
|
|
X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
|
|
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
|
|
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
|
|
|
|
--> lag fix:
|
|
X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
|
|
X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
|
|
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
|
|
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
|
|
|
|
Dynamic (current):
|
|
X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
|
|
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
|
|
|
|
entropy patching:
|
|
input: 7, bos, 9, 10
|
|
pred (high entropy): eos, 8, 10, eos
|
|
|
|
X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
|
|
X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
|
|
X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
|
|
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
|
|
|
|
--> lag fix no boe (force single byte first patch):
|
|
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
|
|
X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
|
|
X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
|
|
Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
|
|
|
|
input: 4, 7, bos, 9, 10
|
|
pred (high entropy): 5, eos, 8, 10, eos
|
|
|
|
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
|
|
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
|
|
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
|
|
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
|
|
|
|
Handle the last byte properly.
|
|
patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
|
|
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
|
|
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
|
|
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
|
|
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
|
|
|
|
|
|
bpe delim
|
|
X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
|
|
X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
|
|
X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
|
|
Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
|
|
|
|
|
|
Note 1: that there will be no special tokens introduced at the patch level.
|
|
Note 2: X_e needs to be trimmed to be passed to Global
|
|
"""
|
|
batch_size, seq_len = tokens.shape
|
|
local_encoder_tokens = tokens
|
|
local_decoder_tokens = tokens
|
|
|
|
if nb_boe > 0:
|
|
padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
|
|
local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
|
|
# global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
|
|
|
|
# create global tokens, contains boe tokens and eos
|
|
# padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
|
|
# patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
|
|
# global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
|
|
# global_tokens += global_tokens.eq(0).int() * boe_id
|
|
# TODO: fix this when we want to use block causal in the global.
|
|
|
|
if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
|
|
local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
|
|
|
|
return local_encoder_tokens, None, local_decoder_tokens
|
|
|
|
|
|
def patch_ids_from_lengths(patch_lengths, seq_len):
|
|
bs, num_patches = patch_lengths.shape
|
|
# Create a tensor of cumulative sums of the patch lengths
|
|
cum_d = torch.cat(
|
|
[
|
|
torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
|
|
patch_lengths.cumsum(dim=-1),
|
|
],
|
|
dim=-1,
|
|
)
|
|
patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum(
|
|
dim=-2
|
|
) - 1
|
|
assert not (
|
|
torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0
|
|
), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0"
|
|
return patch_ids
|
|
|
|
|
|
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
# Basic model configuration
|
|
seed: int = 42
|
|
vocab_size: int = -1
|
|
dim: int = 512
|
|
n_layers: int = 8
|
|
n_heads: int = 8
|
|
# TODO: What is the purpose of this parameter?
|
|
weight_tying: bool = False
|
|
|
|
# Architecture and dimensions
|
|
dim_token: int = 256
|
|
dim_global: int = 512
|
|
dim_local_decoder: int = 512
|
|
dim_local_encoder: int = 512
|
|
n_layers_global: int = 8
|
|
n_layers_local_decoder: int = 8
|
|
n_layers_local_encoder: int = 8
|
|
|
|
# Tokenization and patching
|
|
tokenization_mode: str = "bpe"
|
|
patch_size: float | None = None
|
|
patching_mode: str | None = None
|
|
patching_threshold: float | None = None
|
|
patching_threshold_add: float | None = None
|
|
monotonicity: bool = False
|
|
patching_batch_size: int = 1
|
|
patching_device: str = "cuda"
|
|
data_loader_patching: bool = False
|
|
max_patch_length: int | None = None
|
|
|
|
# Encoder/Decoder configuration
|
|
tie_local_encoder_decoder_logits: bool = False
|
|
use_local_encoder_transformer: bool = False
|
|
encoder_lm_loss: bool = False
|
|
max_encoder_seq_length: int | None = None
|
|
pad_to_max_length: bool = False
|
|
encoder_enable_byte_ngrams: bool = False
|
|
encoder_enable_byte_group_hash: bool = False
|
|
ngram_vocab_sizes: int | None = None
|
|
|
|
# Cross attention configurations
|
|
cross_attn_encoder: bool = False
|
|
cross_attn_decoder: bool = False
|
|
cross_attn_window_encoder: int | None = None
|
|
cross_attn_window_decoder: int | None = None
|
|
cross_attn_k: int | None = None
|
|
cross_attn_nheads: int | None = None
|
|
cross_attn_all_layers_decoder: bool = False
|
|
cross_attn_all_layers_encoder: bool = False
|
|
cross_attn_use_flex_attention: bool = True
|
|
cross_attn_init_by_pooling: bool = False
|
|
|
|
# Encoder hash configurations
|
|
encoder_hash_byte_group_size: Any | None = None
|
|
encoder_hash_byte_group_vocab: int = 30000
|
|
encoder_hash_byte_group_nb_functions: int = 3
|
|
|
|
# Model behavior and optimization
|
|
log_patch_lengths: bool = False
|
|
non_linearity: str = "swiglu"
|
|
use_rope: bool = True
|
|
recompute_fc1_out: bool = False
|
|
recompute_fc3_out: bool = False
|
|
recompute_attn: bool = True
|
|
custom_bwd: bool = False
|
|
layer_ckpt: str = "all"
|
|
|
|
# Initialization and attention
|
|
init_use_gaussian: bool = True
|
|
init_use_depth: str = "current"
|
|
attn_bias_type: str = "causal"
|
|
alpha_depth: str = "disabled"
|
|
max_length: int = 2048
|
|
|
|
# Norm configuration
|
|
norm_eps: float = 1e-5
|
|
norm_affine: bool = True
|
|
pre_norm: bool = True
|
|
norm_type: str = "rmsnorm"
|
|
|
|
# Additional configurations
|
|
multiple_of: int = 256
|
|
ffn_dim_multiplier: float = 1.0
|
|
dropout: float = 0
|
|
output_size: int = -1
|
|
|
|
# Additional parameters from ModelArgs
|
|
architecture: str = "vanilla"
|
|
share_encoder_decoder_emb: bool = True
|
|
global_local_decoder_residual_layer: str | None = None
|
|
|
|
tokenize_with_bpe_delimiter: bool = False
|
|
patching_thresholds_str: str | None = None
|
|
tie_local_encoder_decoder: bool = False
|
|
encoder_preds_low_entropy_toks: float | None = None
|
|
encoder_preds_random_toks: float | None = None
|
|
dim_token_emb: int | None = None
|
|
dim_patch_emb: int | None = None
|
|
|
|
encoder_ngram_table_dir: str | None = None
|
|
encoder_ngram_to_size_str: str | None = None
|
|
|
|
# Model architecture params
|
|
entropy_model_checkpoint_dir: str | None = None
|
|
entropy_model_is_ngram_model: bool = False
|
|
downsampling_by_pooling: str | None = None
|
|
n_heads_global: int = 8
|
|
n_heads_local_decoder: int = 8
|
|
n_heads_local_encoder: int = 8
|
|
n_kv_heads: int | None = None
|
|
n_kv_heads_global: int | None = None
|
|
conv_kernel_size: int | None = None
|
|
local_attention_window_len: int | None = None
|
|
|
|
# Performance optimization
|
|
sequence_parallel: bool = False
|
|
loss_parallel: bool = False
|
|
fuse_sequence_parallel: bool = False
|
|
use_fsdp: bool = True
|
|
attn_to_keep: str = "all"
|
|
|
|
# RoPE parameters
|
|
rope_theta: float = 10000.0
|
|
rope_use_fp32_in_outer_product: bool = False
|
|
|
|
# Parameter mixing
|
|
pm_size: int = 0
|
|
|
|
# Logging
|
|
full_logging_n_layers: int = 4
|
|
|
|
@model_validator(mode="after")
|
|
def check_hash_byte_sizes(self) -> Self:
|
|
if (
|
|
self.encoder_hash_byte_group_size is not None
|
|
and type(self.encoder_hash_byte_group_size) == str
|
|
):
|
|
self.encoder_hash_byte_group_size = [
|
|
int(x)
|
|
for x in self.encoder_hash_byte_group_size.split(",")
|
|
if len(x) > 0
|
|
]
|
|
return self
|
|
|
|
|
|
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
|
# Global encoder specific dimensions
|
|
dim_token_emb: int | None = None
|
|
dim_patch_emb: int | None = None
|
|
|
|
def __post_init__(self):
|
|
# Override base args with global encoder specific values
|
|
self.dim = self.dim_global
|
|
self.n_layers = self.n_layers_global
|
|
self.n_heads = self.n_heads_global
|
|
self.n_kv_heads = self.n_kv_heads_global
|
|
self.local_attention_window_len = None
|
|
self.cross_attn_encoder = False
|
|
self.cross_attn_decoder = False
|
|
|
|
|
|
class LocalDecoderArgs(ByteLatentTransformerArgs):
|
|
# Local decoder specific dimensions
|
|
dim_token_emb: int | None = None
|
|
dim_patch_emb: int | None = None
|
|
|
|
def __post_init__(self):
|
|
# Override base args with local decoder specific values
|
|
self.dim = self.dim_local_decoder
|
|
self.n_layers = self.n_layers_local_decoder
|
|
self.n_heads = self.n_heads_local_decoder
|
|
self.cross_attn_encoder = False
|
|
self.cross_attn_init_by_pooling = False
|
|
self.attn_bias_type = "local_block_causal"
|
|
|
|
|
|
def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer:
|
|
global_args = args.model_copy(
|
|
deep=True,
|
|
update=dict(
|
|
dim=args.dim_global,
|
|
n_layers=args.n_layers_global,
|
|
n_heads=args.n_heads_global,
|
|
n_kv_heads=args.n_kv_heads_global,
|
|
local_attention_window_len=None,
|
|
dim_token_emb=get_global_dim_patch_emb(args),
|
|
dim_patch_emb=None,
|
|
cross_attn_encoder=False,
|
|
cross_attn_decoder=False,
|
|
),
|
|
)
|
|
|
|
return GlobalTransformer(global_args)
|
|
|
|
|
|
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
|
local_encoder_args = LocalModelArgs(
|
|
# Updated args
|
|
dim=args.dim_local_encoder,
|
|
n_layers=args.n_layers_local_encoder,
|
|
n_heads=args.n_heads_local_encoder,
|
|
dim_token_emb=get_encoder_dim_token_emb(args),
|
|
dim_patch_emb=get_encoder_dim_patch_emb(args),
|
|
cross_attn_encoder=args.cross_attn_encoder,
|
|
cross_attn_decoder=False,
|
|
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
|
|
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
|
|
# Defaults
|
|
head_dim=args.head_dim,
|
|
max_seqlen=args.max_encoder_seq_length,
|
|
dropout=args.dropout,
|
|
vocab_size=args.vocab_size + args.pm_size,
|
|
norm_eps=args.norm_eps,
|
|
patch_size=args.patch_size,
|
|
sliding_window=args.local_attention_window_len,
|
|
use_rope=args.use_rope,
|
|
rope_theta=args.rope_theta,
|
|
init_base_std=args.init_base_std,
|
|
init_std_factor=args.init_std_factor,
|
|
n_kv_heads=args.n_kv_heads,
|
|
attn_impl=args.attn_impl,
|
|
attn_bias_type="local_block_causal",
|
|
multiple_of=args.multiple_of,
|
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
patching_mode=args.patching_mode,
|
|
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
|
downsampling_by_pooling=args.downsampling_by_pooling,
|
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
|
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
|
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
|
cross_attn_nheads=args.cross_attn_nheads,
|
|
eos_id=args.eos_id,
|
|
)
|
|
|
|
return LocalEncoder(local_encoder_args)
|
|
|
|
|
|
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
|
# First deep copy the original args
|
|
local_decoder_args = LocalModelArgs(
|
|
dim=args.dim_local_decoder,
|
|
n_layers=args.n_layers_local_decoder,
|
|
n_heads=args.n_heads_local_decoder,
|
|
dim_token_emb=get_decoder_dim_token_emb(args),
|
|
dim_patch_emb=args.dim_global,
|
|
cross_attn_encoder=False,
|
|
cross_attn_decoder=args.cross_attn_decoder,
|
|
cross_attn_init_by_pooling=False, # states are already defined
|
|
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
|
|
# Defaults
|
|
head_dim=args.head_dim,
|
|
max_seqlen=args.max_encoder_seq_length,
|
|
dropout=args.dropout,
|
|
vocab_size=args.vocab_size + args.pm_size,
|
|
norm_eps=args.norm_eps,
|
|
patch_size=args.patch_size,
|
|
sliding_window=args.local_attention_window_len,
|
|
use_rope=args.use_rope,
|
|
rope_theta=args.rope_theta,
|
|
init_base_std=args.init_base_std,
|
|
init_std_factor=args.init_std_factor,
|
|
n_kv_heads=args.n_kv_heads,
|
|
attn_impl=args.attn_impl,
|
|
attn_bias_type="local_block_causal",
|
|
multiple_of=args.multiple_of,
|
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
patching_mode=args.patching_mode,
|
|
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
|
downsampling_by_pooling=args.downsampling_by_pooling,
|
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
|
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
|
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
|
cross_attn_nheads=args.cross_attn_nheads,
|
|
eos_id=args.eos_id,
|
|
)
|
|
|
|
return LocalDecoder(local_decoder_args)
|
|
|
|
|
|
class EmbeddingType(Enum):
|
|
HASH_TOK = auto()
|
|
NGRAM = auto()
|
|
|
|
|
|
def init_embeddings(
|
|
args,
|
|
embedding_type: EmbeddingType,
|
|
local_encoder_dim: int,
|
|
encoder_hash_byte_group_size: list = None,
|
|
):
|
|
if (
|
|
embedding_type == EmbeddingType.HASH_TOK
|
|
and args.encoder_hash_byte_group_size is None
|
|
):
|
|
return None
|
|
if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None:
|
|
return None
|
|
|
|
embeddings = []
|
|
|
|
if embedding_type == EmbeddingType.HASH_TOK:
|
|
emb_dim = local_encoder_dim
|
|
encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
|
|
for _ in range(args.encoder_hash_byte_group_nb_functions):
|
|
for _ in encoder_hash_byte_group_size:
|
|
embeddings.append(
|
|
nn.Embedding(
|
|
encoder_hash_byte_group_vocab,
|
|
emb_dim,
|
|
)
|
|
)
|
|
|
|
elif embedding_type == EmbeddingType.NGRAM:
|
|
encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str)
|
|
emb_dim = local_encoder_dim
|
|
OFFSET = 4 # This should be passed as parameter if it's variable
|
|
for ngram_vocab_size in encoder_ngram_to_size.values():
|
|
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
|
|
|
|
return nn.ModuleList(embeddings)
|
|
|
|
|
|
def compute_hash_embeddings(
|
|
local_encoder_tokens: torch.Tensor,
|
|
local_encoder,
|
|
encoder_hash_tok_embedding: nn.ModuleList,
|
|
encoder_hash_byte_group_nb_functions: int,
|
|
encoder_hash_byte_group_size: list,
|
|
encoder_hash_byte_group_vocab: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute embeddings using hash token embeddings.
|
|
|
|
Args:
|
|
local_encoder_tokens: Input tokens tensor
|
|
local_encoder: Encoder object with tok_embeddings method
|
|
encoder_hash_tok_embedding: ModuleList of hash token embeddings
|
|
encoder_hash_byte_group_nb_functions: Number of hash functions
|
|
encoder_hash_byte_group_size: List of byte group sizes
|
|
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
|
|
|
|
Returns:
|
|
torch.Tensor: Combined embeddings
|
|
"""
|
|
if encoder_hash_tok_embedding is None:
|
|
return None
|
|
|
|
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
|
|
|
|
i = 0
|
|
for func_nb in range(encoder_hash_byte_group_nb_functions):
|
|
for byte_group_size in encoder_hash_byte_group_size:
|
|
hash_ids = byte_group_hash_function(
|
|
local_encoder_tokens,
|
|
byte_group_size,
|
|
hash_func_nb=func_nb,
|
|
max_hash=encoder_hash_byte_group_vocab,
|
|
)
|
|
hash_tok_embedding = encoder_hash_tok_embedding[i]
|
|
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
|
|
i += 1
|
|
|
|
assert i == len(encoder_hash_tok_embedding)
|
|
return local_encoder_embeds
|
|
|
|
|
|
class ByteLatentTransformer(nn.Module):
|
|
"""
|
|
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
|
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
|
and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
|
|
improved performance and inference efficiency.
|
|
"""
|
|
|
|
def __init__(self, args: ByteLatentTransformerArgs):
|
|
super().__init__()
|
|
|
|
# General configuration
|
|
self.weight_tying = args.weight_tying
|
|
self.patch_size = args.patch_size
|
|
self.patching_mode = args.patching_mode
|
|
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
|
|
BOE_ID,
|
|
BOS_ID,
|
|
PAD_ID,
|
|
EOS_ID,
|
|
)
|
|
self.downsampling_by_pooling = args.downsampling_by_pooling
|
|
self.patching_threshold = args.patching_threshold
|
|
self.dim = args.dim
|
|
self.init_base_std = args.init_base_std
|
|
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
|
self.max_seqlen = args.max_seqlen
|
|
|
|
# Cross attention configuration
|
|
self.cross_attn_encoder = args.cross_attn_encoder
|
|
self.cross_attn_decoder = args.cross_attn_decoder
|
|
self.cross_attn_k = args.cross_attn_k
|
|
self.cross_attn_window_encoder = args.cross_attn_window_encoder
|
|
self.cross_attn_window_decoder = args.cross_attn_window_decoder
|
|
self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention
|
|
|
|
# Encoder hash configuration
|
|
self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size
|
|
self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
|
|
self.encoder_hash_byte_group_nb_functions = (
|
|
args.encoder_hash_byte_group_nb_functions
|
|
)
|
|
|
|
# ByteLatent modules
|
|
self.local_encoder = create_local_encoder(args)
|
|
self.global_transformer = create_global_transformer(args)
|
|
self.local_decoder = create_local_decoder(args)
|
|
self.encoder_hash_tok_embedding = init_embeddings(
|
|
args,
|
|
EmbeddingType.HASH_TOK,
|
|
local_encoder_dim=self.local_encoder.dim,
|
|
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
|
|
)
|
|
self.encoder_ngram_embedding = init_embeddings(
|
|
args,
|
|
EmbeddingType.NGRAM,
|
|
local_encoder_dim=self.local_encoder.dim,
|
|
encoder_hash_byte_group_size=None,
|
|
)
|
|
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
|
|
|
# Transformer layers
|
|
self.layers = nn.ModuleList(
|
|
[TransformerBlock(args) for _ in range(args.n_layers)]
|
|
)
|
|
|
|
# Encoder ngram embedding tables
|
|
self.encoder_ngram_embedding = None
|
|
if args.encoder_enable_byte_ngrams:
|
|
self.encoder_ngram_embedding = nn.ModuleList()
|
|
assert args.ngram_vocab_sizes is not None
|
|
self.encoder_ngram_to_size = parse_ngram_to_size(
|
|
args.encoder_ngram_to_size_str
|
|
)
|
|
ngram_emb_dim = self.local_encoder.dim
|
|
for ngram_vocab_size in self.encoder_ngram_to_size.values():
|
|
self.encoder_ngram_embedding.append(
|
|
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
|
|
)
|
|
|
|
# Output layer
|
|
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
|
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
|
if args.weight_tying:
|
|
self.output.weight = self.tok_embeddings.weight
|
|
|
|
# Patcher module
|
|
if not args.data_loader_patching:
|
|
self.patcher = Patcher(
|
|
PatcherArgs(
|
|
patch_size=args.patch_size,
|
|
patching_mode=args.patching_mode,
|
|
patching_threshold=args.patching_threshold,
|
|
patching_threshold_add=args.patching_threshold_add,
|
|
monotonicity=args.monotonicity,
|
|
max_patch_length=args.max_patch_length,
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
tokens: torch.Tensor,
|
|
patch_lengths: Optional[torch.Tensor] = None,
|
|
ngram_ids: Optional[torch.Tensor] = None,
|
|
):
|
|
# Ensure ngram_ids is either a tensor or None
|
|
assert (
|
|
isinstance(ngram_ids, torch.Tensor) or ngram_ids is None
|
|
), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}"
|
|
|
|
bs, N = tokens.shape # Batch size and sequence length
|
|
|
|
# Get megabyte inputs
|
|
nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1)
|
|
local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
|
|
tokens=tokens,
|
|
enforce_patch_size_multiple=False,
|
|
nb_boe=nb_boe,
|
|
patch_size=self.patch_size,
|
|
boe_id=self.boe_id,
|
|
)
|
|
|
|
# Patching
|
|
if patch_lengths is None:
|
|
assert (
|
|
getattr(self, "patcher", None) is not None
|
|
), "Patcher not defined and no patch_lengths passed."
|
|
patch_lengths, tok_scores = self.patcher.patch(
|
|
local_encoder_tokens,
|
|
include_next_token=True,
|
|
threshold=self.patcher.threshold,
|
|
)
|
|
else:
|
|
if nb_boe > 0:
|
|
patch_lengths[:, 0] += nb_boe
|
|
|
|
assert torch.min(patch_lengths) >= 0
|
|
|
|
# Generate patch IDs from patch_lengths
|
|
patch_ids = patch_ids_from_lengths(
|
|
patch_lengths, local_encoder_tokens.shape[-1]
|
|
)
|
|
assert torch.max(patch_ids) + 1 <= torch.max(
|
|
(patch_lengths != 0).sum(dim=-1)
|
|
), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
|
|
|
|
cross_attn_mask_enc = None
|
|
# Cross-attention encoder
|
|
if self.cross_attn_encoder:
|
|
cross_attn_mask_enc = cross_attn_mask(
|
|
patch_ids,
|
|
patch_lengths,
|
|
N,
|
|
patches_as_queries=True,
|
|
cross_attn_k=self.cross_attn_k,
|
|
window=self.cross_attn_window_encoder,
|
|
block_mask=self.cross_attn_use_flex_attention,
|
|
)
|
|
|
|
# Hashing and embedding
|
|
local_encoder_embeds = compute_hash_embeddings(
|
|
local_encoder_tokens=local_encoder_tokens,
|
|
local_encoder=self.local_encoder,
|
|
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
|
|
encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
|
|
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
|
|
encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
|
|
)
|
|
|
|
# N-gram table embeddings
|
|
if self.encoder_ngram_embedding is not None:
|
|
assert ngram_ids is not None, "ngram_ids must be provided"
|
|
if local_encoder_embeds is None:
|
|
local_encoder_embeds = self.local_encoder.tok_embeddings(
|
|
local_encoder_tokens
|
|
)
|
|
assert len(ngram_ids) == len(
|
|
self.encoder_ngram_embedding
|
|
), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
|
|
for i in range(ngram_ids.shape[0]):
|
|
ngram_embedding = self.encoder_ngram_embedding[i]
|
|
ngram_embeds = ngram_embedding(ngram_ids[i])
|
|
assert (
|
|
local_encoder_embeds.shape == ngram_embeds.shape
|
|
), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
|
|
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
|
|
|
# Local encoder
|
|
h_cross = None
|
|
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
|
tokens=local_encoder_tokens,
|
|
embeds=local_encoder_embeds,
|
|
patch_embeds=h_cross if self.cross_attn_encoder else None,
|
|
cross_mask=cross_attn_mask_enc,
|
|
num_patches=patch_lengths.shape[1],
|
|
patch_ids=patch_ids,
|
|
)
|
|
|
|
# Downsampling
|
|
if not self.cross_attn_encoder:
|
|
assert (
|
|
patch_ids.shape[1] == h_encoder.shape[1]
|
|
), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
|
|
h = downsample(
|
|
h_encoder,
|
|
patch_lengths.shape[1],
|
|
patch_lengths,
|
|
patch_ids,
|
|
downsampling_by_pooling=self.downsampling_by_pooling,
|
|
patch_size=self.patch_size,
|
|
)
|
|
else:
|
|
# Reshape h_cross
|
|
h = h_cross.view(bs, patch_lengths.shape[1], -1)
|
|
|
|
# Global transformer
|
|
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
|
|
rows, cols = torch.where(local_encoder_tokens == self.eos_id)
|
|
eos_patch_ids = patch_ids[rows, cols]
|
|
global_tokens[rows, eos_patch_ids] = self.eos_id
|
|
|
|
h, _ = self.global_transformer(
|
|
embeds=h,
|
|
tokens=global_tokens,
|
|
)
|
|
|
|
# Unpatching
|
|
dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
|
|
|
|
# Generate decoder patch IDs
|
|
decoder_patch_ids = decoder_patch_ids_from_lengths(
|
|
patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
|
|
)
|
|
assert (
|
|
torch.max(decoder_patch_ids) + 1 <= h.shape[1]
|
|
), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
|
|
assert (
|
|
decoder_patch_ids.shape[1] == dec_embeds.shape[1]
|
|
), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
|
|
|
|
# Cross-attention decoder
|
|
if not self.cross_attn_decoder:
|
|
h = torch.gather(
|
|
h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
|
|
)
|
|
cross_attn_mask_dec = None
|
|
assert local_decoder_tokens.shape == h.shape[:-1]
|
|
else:
|
|
cross_attn_mask_dec = cross_attn_mask(
|
|
decoder_patch_ids,
|
|
patch_lengths,
|
|
N,
|
|
patches_as_queries=False,
|
|
cross_attn_k=self.cross_attn_k,
|
|
window=self.cross_attn_window_decoder,
|
|
block_mask=self.cross_attn_use_flex_attention,
|
|
)
|
|
|
|
# Local decoder
|
|
output, _ = self.local_decoder(
|
|
embeds=dec_embeds,
|
|
patch_embeds=h,
|
|
tokens=local_decoder_tokens,
|
|
cross_mask=cross_attn_mask_dec,
|
|
)
|
|
return output
|
|
|
|
def reset_parameters(self, init_std=None):
|
|
# Either use fixed base std or sqrt model dim
|
|
init_std = init_std or (self.dim ** (-0.5))
|
|
nn.init.trunc_normal_(
|
|
self.tok_embeddings.weight,
|
|
mean=0.0,
|
|
std=init_std,
|
|
a=-3 * init_std,
|
|
b=3 * init_std,
|
|
)
|
|
if not self.weight_tying:
|
|
nn.init.trunc_normal_(
|
|
self.output.weight,
|
|
mean=0.0,
|
|
std=init_std,
|
|
a=-3 * init_std,
|
|
b=3 * init_std,
|
|
)
|
|
|
|
def init_weights(self):
|
|
self.reset_parameters()
|
|
self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
|
|
for depth, layer in enumerate(self.layers):
|
|
factor = {
|
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
|
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
|
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
|
InitStdFactor.DISABLED: 1.0,
|
|
}[self.init_std_factor]
|
|
|
|
layer.init_weights(self.init_base_std, factor)
|
|
|
|
self.local_decoder.init_weights(self.init_base_std)
|
|
self.global_transformer.init_weights(self.init_base_std)
|
|
self.local_encoder.init_weights(self.init_base_std)
|
|
|
|
for emb in self.encoder_hash_tok_embedding:
|
|
nn.init.trunc_normal_(
|
|
emb.weight,
|
|
mean=0.0,
|
|
std=self.init_base_std,
|
|
a=-3 * self.init_base_std,
|
|
b=3 * self.init_base_std,
|
|
)
|