mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-07 13:09:09 +00:00
248 lines
7.7 KiB
Python
248 lines
7.7 KiB
Python
from enum import Enum, auto
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
from pydantic import BaseModel, ConfigDict, model_validator
|
|
from typing_extensions import Self
|
|
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
|
|
|
class InitStdFactor(str, Enum):
|
|
DISABLED = "disabled" # Init std is divided by 1.0
|
|
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
|
|
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
|
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
|
|
|
|
|
class BaseTransformerArgs(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
dim: int = 512
|
|
n_layers: int = 8
|
|
head_dim: int | None = None
|
|
n_heads: int | None = None
|
|
n_kv_heads: int | None = None
|
|
|
|
ffn_dim_multiplier: float | None = None
|
|
|
|
multiple_of: int = 256
|
|
|
|
norm_eps: float = 1e-5
|
|
|
|
rope_theta: float = 10000.0
|
|
rope_use_fp32_in_outer_product: bool = False
|
|
|
|
init_base_std: float | None = None
|
|
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
|
|
|
max_seqlen: int = 1024
|
|
|
|
attn_impl: str | None = "sdpa"
|
|
attn_bias_type: str | None = None
|
|
# Special token config
|
|
eos_id: int | None = EOS_ID
|
|
|
|
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
# Basic model configuration
|
|
seed: int = 42
|
|
vocab_size: int = -1
|
|
dim: int = 512
|
|
n_layers: int = 8
|
|
n_heads: int = 8
|
|
# TODO: What is the purpose of this parameter?
|
|
weight_tying: bool = False
|
|
patch_in_forward: bool = False
|
|
|
|
# Architecture and dimensions
|
|
dim_token: int | None = None
|
|
dim_global: int = 512
|
|
dim_local_decoder: int = 512
|
|
dim_local_encoder: int = 512
|
|
n_layers_global: int = 8
|
|
n_layers_local_decoder: int = 8
|
|
n_layers_local_encoder: int = 8
|
|
|
|
# Tokenization and patching
|
|
patch_size: float | None = None
|
|
patching_mode: str | None = None
|
|
patching_threshold: float | None = None
|
|
patching_threshold_add: float | None = None
|
|
monotonicity: bool = False
|
|
patching_batch_size: int = 1
|
|
patching_device: str = "cuda"
|
|
max_patch_length: int | None = None
|
|
|
|
# Encoder/Decoder configuration
|
|
tie_local_encoder_decoder_logits: bool = False
|
|
use_local_encoder_transformer: bool = False
|
|
encoder_lm_loss: bool = False
|
|
max_encoder_seq_length: int | None = None
|
|
pad_to_max_length: bool = False
|
|
encoder_enable_byte_ngrams: bool = False
|
|
encoder_enable_byte_group_hash: bool = False
|
|
ngram_vocab_sizes: int | None = None
|
|
|
|
# Cross attention configurations
|
|
cross_attn_encoder: bool = False
|
|
cross_attn_decoder: bool = False
|
|
cross_attn_window_encoder: int | None = None
|
|
cross_attn_window_decoder: int | None = None
|
|
cross_attn_k: int | None = None
|
|
cross_attn_nheads: int | None = None
|
|
cross_attn_all_layers_decoder: bool = False
|
|
cross_attn_all_layers_encoder: bool = False
|
|
cross_attn_use_flex_attention: bool = True
|
|
cross_attn_init_by_pooling: bool = False
|
|
|
|
# Encoder hash configurations
|
|
encoder_hash_byte_group_size: Any | None = None
|
|
encoder_hash_byte_group_vocab: int = 30000
|
|
encoder_hash_byte_group_nb_functions: int = 3
|
|
|
|
# Model behavior and optimization
|
|
log_patch_lengths: bool = False
|
|
non_linearity: str = "swiglu"
|
|
use_rope: bool = True
|
|
recompute_fc1_out: bool = False
|
|
recompute_fc3_out: bool = False
|
|
recompute_attn: bool = True
|
|
custom_bwd: bool = False
|
|
layer_ckpt: str = "all"
|
|
|
|
# Initialization and attention
|
|
init_use_gaussian: bool = True
|
|
init_use_depth: str = "current"
|
|
attn_bias_type: str = "causal"
|
|
alpha_depth: str = "disabled"
|
|
max_length: int = 2048
|
|
|
|
# Norm configuration
|
|
norm_eps: float = 1e-5
|
|
norm_affine: bool = True
|
|
pre_norm: bool = True
|
|
norm_type: str = "rmsnorm"
|
|
|
|
# Additional configurations
|
|
multiple_of: int = 256
|
|
ffn_dim_multiplier: float = 1.0
|
|
dropout: float = 0
|
|
output_size: int = -1
|
|
|
|
# Additional parameters from ModelArgs
|
|
architecture: str = "vanilla"
|
|
share_encoder_decoder_emb: bool = True
|
|
global_local_decoder_residual_layer: str | None = None
|
|
|
|
tokenize_with_bpe_delimiter: bool = False
|
|
patching_thresholds_str: str | None = None
|
|
tie_local_encoder_decoder: bool = False
|
|
encoder_preds_low_entropy_toks: float | None = None
|
|
encoder_preds_random_toks: float | None = None
|
|
dim_token_emb: int | None = None
|
|
dim_patch_emb: int | None = None
|
|
|
|
encoder_ngram_table_dir: str | None = None
|
|
encoder_ngram_to_size_str: str | None = None
|
|
|
|
# Model architecture params
|
|
entropy_model_checkpoint_dir: str | None = None
|
|
entropy_model_is_ngram_model: bool = False
|
|
downsampling_by_pooling: str | None = None
|
|
n_heads_global: int = 8
|
|
n_heads_local_decoder: int = 8
|
|
n_heads_local_encoder: int = 8
|
|
n_kv_heads: int | None = None
|
|
n_kv_heads_global: int | None = None
|
|
conv_kernel_size: int | None = None
|
|
local_attention_window_len: int | None = None
|
|
|
|
# Performance optimization
|
|
sequence_parallel: bool = False
|
|
loss_parallel: bool = False
|
|
fuse_sequence_parallel: bool = False
|
|
use_fsdp: bool = True
|
|
attn_to_keep: str = "all"
|
|
|
|
# Parameter mixing
|
|
pm_size: int = 0
|
|
|
|
# Logging
|
|
full_logging_n_layers: int = 4
|
|
|
|
@model_validator(mode="after")
|
|
def check_hash_byte_sizes(self) -> Self:
|
|
if (
|
|
self.encoder_hash_byte_group_size is not None
|
|
and type(self.encoder_hash_byte_group_size) == str
|
|
):
|
|
self.encoder_hash_byte_group_size = [
|
|
int(x)
|
|
for x in self.encoder_hash_byte_group_size.split(",")
|
|
if len(x) > 0
|
|
]
|
|
return self
|
|
|
|
|
|
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
|
# Global encoder specific dimensions
|
|
dim_token_emb: int | None = None
|
|
dim_patch_emb: int | None = None
|
|
|
|
def __post_init__(self):
|
|
# Override base args with global encoder specific values
|
|
self.dim = self.dim_global
|
|
self.n_layers = self.n_layers_global
|
|
self.n_heads = self.n_heads_global
|
|
self.n_kv_heads = self.n_kv_heads_global
|
|
self.local_attention_window_len = None
|
|
self.cross_attn_encoder = False
|
|
self.cross_attn_decoder = False
|
|
|
|
|
|
class LocalDecoderArgs(ByteLatentTransformerArgs):
|
|
# Local decoder specific dimensions
|
|
dim_token_emb: int | None = None
|
|
dim_patch_emb: int | None = None
|
|
|
|
def __post_init__(self):
|
|
# Override base args with local decoder specific values
|
|
self.dim = self.dim_local_decoder
|
|
self.n_layers = self.n_layers_local_decoder
|
|
self.n_heads = self.n_heads_local_decoder
|
|
self.cross_attn_encoder = False
|
|
self.cross_attn_init_by_pooling = False
|
|
self.attn_bias_type = "local_block_causal"
|
|
|
|
|
|
class LocalModelArgs(BaseTransformerArgs):
|
|
model_config = ConfigDict(extra="forbid")
|
|
# Override defaults
|
|
attn_impl: str | None = "xformers"
|
|
attn_bias_type: str | None = "local_block_causal"
|
|
|
|
# Local encoder specific dimensions
|
|
dropout: float
|
|
vocab_size: int
|
|
patch_size: float
|
|
sliding_window: int | None
|
|
use_rope: bool
|
|
cross_attn_encoder: bool | None
|
|
cross_attn_decoder: bool | None
|
|
cross_attn_k: int | None
|
|
cross_attn_init_by_pooling: bool
|
|
patching_mode: str
|
|
use_local_encoder_transformer: bool
|
|
downsampling_by_pooling: str | None
|
|
encoder_hash_byte_group_size: Any | None = None
|
|
cross_attn_all_layers_encoder: bool = False
|
|
cross_attn_all_layers_decoder: bool = False
|
|
cross_attn_nheads: int | None
|
|
|
|
dim_token_emb: int
|
|
dim_patch_emb: int | None
|
|
|
|
|
|
class LMTransformerArgs(BaseTransformerArgs):
|
|
seed: int = 42
|
|
|
|
vocab_size: int = -1
|
|
weight_tying: bool = False
|
|
|
|
sliding_window: int | None = None
|
|
|