mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-07 13:09:09 +00:00
move args out
This commit is contained in:
parent
4f86b6e7ab
commit
c2108e7256
3 changed files with 351 additions and 414 deletions
248
blt_args.py
Normal file
248
blt_args.py
Normal file
|
@ -0,0 +1,248 @@
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||||
|
|
||||||
|
class InitStdFactor(str, Enum):
|
||||||
|
DISABLED = "disabled" # Init std is divided by 1.0
|
||||||
|
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
|
||||||
|
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
||||||
|
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransformerArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
dim: int = 512
|
||||||
|
n_layers: int = 8
|
||||||
|
head_dim: int | None = None
|
||||||
|
n_heads: int | None = None
|
||||||
|
n_kv_heads: int | None = None
|
||||||
|
|
||||||
|
ffn_dim_multiplier: float | None = None
|
||||||
|
|
||||||
|
multiple_of: int = 256
|
||||||
|
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
rope_use_fp32_in_outer_product: bool = False
|
||||||
|
|
||||||
|
init_base_std: float | None = None
|
||||||
|
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
||||||
|
|
||||||
|
max_seqlen: int = 1024
|
||||||
|
|
||||||
|
attn_impl: str | None = "sdpa"
|
||||||
|
attn_bias_type: str | None = None
|
||||||
|
# Special token config
|
||||||
|
eos_id: int | None = EOS_ID
|
||||||
|
|
||||||
|
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
|
# Basic model configuration
|
||||||
|
seed: int = 42
|
||||||
|
vocab_size: int = -1
|
||||||
|
dim: int = 512
|
||||||
|
n_layers: int = 8
|
||||||
|
n_heads: int = 8
|
||||||
|
# TODO: What is the purpose of this parameter?
|
||||||
|
weight_tying: bool = False
|
||||||
|
patch_in_forward: bool = False
|
||||||
|
|
||||||
|
# Architecture and dimensions
|
||||||
|
dim_token: int | None = None
|
||||||
|
dim_global: int = 512
|
||||||
|
dim_local_decoder: int = 512
|
||||||
|
dim_local_encoder: int = 512
|
||||||
|
n_layers_global: int = 8
|
||||||
|
n_layers_local_decoder: int = 8
|
||||||
|
n_layers_local_encoder: int = 8
|
||||||
|
|
||||||
|
# Tokenization and patching
|
||||||
|
patch_size: float | None = None
|
||||||
|
patching_mode: str | None = None
|
||||||
|
patching_threshold: float | None = None
|
||||||
|
patching_threshold_add: float | None = None
|
||||||
|
monotonicity: bool = False
|
||||||
|
patching_batch_size: int = 1
|
||||||
|
patching_device: str = "cuda"
|
||||||
|
max_patch_length: int | None = None
|
||||||
|
|
||||||
|
# Encoder/Decoder configuration
|
||||||
|
tie_local_encoder_decoder_logits: bool = False
|
||||||
|
use_local_encoder_transformer: bool = False
|
||||||
|
encoder_lm_loss: bool = False
|
||||||
|
max_encoder_seq_length: int | None = None
|
||||||
|
pad_to_max_length: bool = False
|
||||||
|
encoder_enable_byte_ngrams: bool = False
|
||||||
|
encoder_enable_byte_group_hash: bool = False
|
||||||
|
ngram_vocab_sizes: int | None = None
|
||||||
|
|
||||||
|
# Cross attention configurations
|
||||||
|
cross_attn_encoder: bool = False
|
||||||
|
cross_attn_decoder: bool = False
|
||||||
|
cross_attn_window_encoder: int | None = None
|
||||||
|
cross_attn_window_decoder: int | None = None
|
||||||
|
cross_attn_k: int | None = None
|
||||||
|
cross_attn_nheads: int | None = None
|
||||||
|
cross_attn_all_layers_decoder: bool = False
|
||||||
|
cross_attn_all_layers_encoder: bool = False
|
||||||
|
cross_attn_use_flex_attention: bool = True
|
||||||
|
cross_attn_init_by_pooling: bool = False
|
||||||
|
|
||||||
|
# Encoder hash configurations
|
||||||
|
encoder_hash_byte_group_size: Any | None = None
|
||||||
|
encoder_hash_byte_group_vocab: int = 30000
|
||||||
|
encoder_hash_byte_group_nb_functions: int = 3
|
||||||
|
|
||||||
|
# Model behavior and optimization
|
||||||
|
log_patch_lengths: bool = False
|
||||||
|
non_linearity: str = "swiglu"
|
||||||
|
use_rope: bool = True
|
||||||
|
recompute_fc1_out: bool = False
|
||||||
|
recompute_fc3_out: bool = False
|
||||||
|
recompute_attn: bool = True
|
||||||
|
custom_bwd: bool = False
|
||||||
|
layer_ckpt: str = "all"
|
||||||
|
|
||||||
|
# Initialization and attention
|
||||||
|
init_use_gaussian: bool = True
|
||||||
|
init_use_depth: str = "current"
|
||||||
|
attn_bias_type: str = "causal"
|
||||||
|
alpha_depth: str = "disabled"
|
||||||
|
max_length: int = 2048
|
||||||
|
|
||||||
|
# Norm configuration
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
norm_affine: bool = True
|
||||||
|
pre_norm: bool = True
|
||||||
|
norm_type: str = "rmsnorm"
|
||||||
|
|
||||||
|
# Additional configurations
|
||||||
|
multiple_of: int = 256
|
||||||
|
ffn_dim_multiplier: float = 1.0
|
||||||
|
dropout: float = 0
|
||||||
|
output_size: int = -1
|
||||||
|
|
||||||
|
# Additional parameters from ModelArgs
|
||||||
|
architecture: str = "vanilla"
|
||||||
|
share_encoder_decoder_emb: bool = True
|
||||||
|
global_local_decoder_residual_layer: str | None = None
|
||||||
|
|
||||||
|
tokenize_with_bpe_delimiter: bool = False
|
||||||
|
patching_thresholds_str: str | None = None
|
||||||
|
tie_local_encoder_decoder: bool = False
|
||||||
|
encoder_preds_low_entropy_toks: float | None = None
|
||||||
|
encoder_preds_random_toks: float | None = None
|
||||||
|
dim_token_emb: int | None = None
|
||||||
|
dim_patch_emb: int | None = None
|
||||||
|
|
||||||
|
encoder_ngram_table_dir: str | None = None
|
||||||
|
encoder_ngram_to_size_str: str | None = None
|
||||||
|
|
||||||
|
# Model architecture params
|
||||||
|
entropy_model_checkpoint_dir: str | None = None
|
||||||
|
entropy_model_is_ngram_model: bool = False
|
||||||
|
downsampling_by_pooling: str | None = None
|
||||||
|
n_heads_global: int = 8
|
||||||
|
n_heads_local_decoder: int = 8
|
||||||
|
n_heads_local_encoder: int = 8
|
||||||
|
n_kv_heads: int | None = None
|
||||||
|
n_kv_heads_global: int | None = None
|
||||||
|
conv_kernel_size: int | None = None
|
||||||
|
local_attention_window_len: int | None = None
|
||||||
|
|
||||||
|
# Performance optimization
|
||||||
|
sequence_parallel: bool = False
|
||||||
|
loss_parallel: bool = False
|
||||||
|
fuse_sequence_parallel: bool = False
|
||||||
|
use_fsdp: bool = True
|
||||||
|
attn_to_keep: str = "all"
|
||||||
|
|
||||||
|
# Parameter mixing
|
||||||
|
pm_size: int = 0
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
full_logging_n_layers: int = 4
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_hash_byte_sizes(self) -> Self:
|
||||||
|
if (
|
||||||
|
self.encoder_hash_byte_group_size is not None
|
||||||
|
and type(self.encoder_hash_byte_group_size) == str
|
||||||
|
):
|
||||||
|
self.encoder_hash_byte_group_size = [
|
||||||
|
int(x)
|
||||||
|
for x in self.encoder_hash_byte_group_size.split(",")
|
||||||
|
if len(x) > 0
|
||||||
|
]
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
||||||
|
# Global encoder specific dimensions
|
||||||
|
dim_token_emb: int | None = None
|
||||||
|
dim_patch_emb: int | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Override base args with global encoder specific values
|
||||||
|
self.dim = self.dim_global
|
||||||
|
self.n_layers = self.n_layers_global
|
||||||
|
self.n_heads = self.n_heads_global
|
||||||
|
self.n_kv_heads = self.n_kv_heads_global
|
||||||
|
self.local_attention_window_len = None
|
||||||
|
self.cross_attn_encoder = False
|
||||||
|
self.cross_attn_decoder = False
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDecoderArgs(ByteLatentTransformerArgs):
|
||||||
|
# Local decoder specific dimensions
|
||||||
|
dim_token_emb: int | None = None
|
||||||
|
dim_patch_emb: int | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Override base args with local decoder specific values
|
||||||
|
self.dim = self.dim_local_decoder
|
||||||
|
self.n_layers = self.n_layers_local_decoder
|
||||||
|
self.n_heads = self.n_heads_local_decoder
|
||||||
|
self.cross_attn_encoder = False
|
||||||
|
self.cross_attn_init_by_pooling = False
|
||||||
|
self.attn_bias_type = "local_block_causal"
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelArgs(BaseTransformerArgs):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
# Override defaults
|
||||||
|
attn_impl: str | None = "xformers"
|
||||||
|
attn_bias_type: str | None = "local_block_causal"
|
||||||
|
|
||||||
|
# Local encoder specific dimensions
|
||||||
|
dropout: float
|
||||||
|
vocab_size: int
|
||||||
|
patch_size: float
|
||||||
|
sliding_window: int | None
|
||||||
|
use_rope: bool
|
||||||
|
cross_attn_encoder: bool | None
|
||||||
|
cross_attn_decoder: bool | None
|
||||||
|
cross_attn_k: int | None
|
||||||
|
cross_attn_init_by_pooling: bool
|
||||||
|
patching_mode: str
|
||||||
|
use_local_encoder_transformer: bool
|
||||||
|
downsampling_by_pooling: str | None
|
||||||
|
encoder_hash_byte_group_size: Any | None = None
|
||||||
|
cross_attn_all_layers_encoder: bool = False
|
||||||
|
cross_attn_all_layers_decoder: bool = False
|
||||||
|
cross_attn_nheads: int | None
|
||||||
|
|
||||||
|
dim_token_emb: int
|
||||||
|
dim_patch_emb: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class LMTransformerArgs(BaseTransformerArgs):
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
vocab_size: int = -1
|
||||||
|
weight_tying: bool = False
|
||||||
|
|
||||||
|
sliding_window: int | None = None
|
||||||
|
|
508
blt_one_file.py
508
blt_one_file.py
|
@ -34,6 +34,17 @@ from bytelatent.distributed import get_local_rank
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
from blt_args import (
|
||||||
|
BaseTransformerArgs,
|
||||||
|
ByteLatentTransformerArgs,
|
||||||
|
GlobalTransformerArgs,
|
||||||
|
LocalDecoderArgs,
|
||||||
|
LocalModelArgs,
|
||||||
|
LMTransformerArgs,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||||
flex_attention_comp = torch.compile(flex_attention)
|
flex_attention_comp = torch.compile(flex_attention)
|
||||||
|
@ -71,65 +82,6 @@ def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
||||||
return reduced_embs
|
return reduced_embs
|
||||||
|
|
||||||
|
|
||||||
def concat_downsample(h, patch_lengths, patch_size):
|
|
||||||
# The assumption in this function is that seq_len = patch_size * num_patches.
|
|
||||||
bs, seq_len, emb_dim = h.shape
|
|
||||||
patch_end_ids = torch.cumsum(patch_lengths, dim=1)
|
|
||||||
patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
|
|
||||||
patch_end_ids.device
|
|
||||||
)
|
|
||||||
# Is clamp ok here?
|
|
||||||
patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
|
|
||||||
patch_ids = patch_ids.view(bs, -1, emb_dim)
|
|
||||||
# after gather h.shape = [batch_size, seq_len, dim]
|
|
||||||
h = torch.gather(h, 1, patch_ids)
|
|
||||||
h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
|
|
||||||
cat = []
|
|
||||||
if "avg" in pooling_mode or "mean" in pooling_mode:
|
|
||||||
cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
|
|
||||||
if "min" in pooling_mode:
|
|
||||||
cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
|
|
||||||
if "max" in pooling_mode:
|
|
||||||
cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
|
|
||||||
assert len(cat) > 0
|
|
||||||
h = torch.cat(cat, dim=-1)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def downsample(
|
|
||||||
h,
|
|
||||||
num_patches,
|
|
||||||
patch_lengths=None,
|
|
||||||
patch_ids=None,
|
|
||||||
downsampling_by_pooling=None,
|
|
||||||
patch_size=4,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Downsampling:
|
|
||||||
a. concatenating embeddings in the patch
|
|
||||||
Note: with dynamic patching, patch the last patch_size tokens.
|
|
||||||
b. pooling embeddings in the patch
|
|
||||||
"""
|
|
||||||
# input: h.shape = [batch_size, seq_len, dim]
|
|
||||||
# input: pool h.shape = [batch_size, seq_len / patch_size, dim]
|
|
||||||
# if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
|
|
||||||
if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
|
|
||||||
# By pooling
|
|
||||||
max_num_patches = num_patches
|
|
||||||
assert patch_ids is not None
|
|
||||||
h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
|
|
||||||
else:
|
|
||||||
# TODO: remove this condition
|
|
||||||
# By concatenating (fixed lengths patching)
|
|
||||||
assert patch_lengths is not None
|
|
||||||
h = concat_downsample(h, patch_lengths, patch_size)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def causal_mask(b, h, q_idx, kv_idx):
|
def causal_mask(b, h, q_idx, kv_idx):
|
||||||
return q_idx >= kv_idx
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
@ -215,35 +167,6 @@ class InitStdFactor(str, Enum):
|
||||||
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
||||||
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformerArgs(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
dim: int = 512
|
|
||||||
n_layers: int = 8
|
|
||||||
head_dim: int | None = None
|
|
||||||
n_heads: int | None = None
|
|
||||||
n_kv_heads: int | None = None
|
|
||||||
|
|
||||||
ffn_dim_multiplier: float | None = None
|
|
||||||
|
|
||||||
multiple_of: int = 256
|
|
||||||
|
|
||||||
norm_eps: float = 1e-5
|
|
||||||
|
|
||||||
rope_theta: float = 10000.0
|
|
||||||
rope_use_fp32_in_outer_product: bool = False
|
|
||||||
|
|
||||||
init_base_std: float | None = None
|
|
||||||
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
|
||||||
|
|
||||||
max_seqlen: int = 1024
|
|
||||||
|
|
||||||
attn_impl: str | None = "sdpa"
|
|
||||||
attn_bias_type: str | None = None
|
|
||||||
# Special token config
|
|
||||||
eos_id: int | None = EOS_ID
|
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(pred, target, **kwargs):
|
def cross_entropy(pred, target, **kwargs):
|
||||||
return F.nll_loss(
|
return F.nll_loss(
|
||||||
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
|
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
|
||||||
|
@ -730,15 +653,6 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
||||||
layer.init_weights(self.init_base_std, factor)
|
layer.init_weights(self.init_base_std, factor)
|
||||||
|
|
||||||
|
|
||||||
class LMTransformerArgs(BaseTransformerArgs):
|
|
||||||
seed: int = 42
|
|
||||||
|
|
||||||
vocab_size: int = -1
|
|
||||||
weight_tying: bool = False
|
|
||||||
|
|
||||||
sliding_window: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LMTransformer(
|
class LMTransformer(
|
||||||
BaseTransformer,
|
BaseTransformer,
|
||||||
PyTorchModelHubMixin,
|
PyTorchModelHubMixin,
|
||||||
|
@ -1401,177 +1315,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
|
||||||
return patch_ids
|
return patch_ids
|
||||||
|
|
||||||
|
|
||||||
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
||||||
# Basic model configuration
|
|
||||||
seed: int = 42
|
|
||||||
vocab_size: int = -1
|
|
||||||
dim: int = 512
|
|
||||||
n_layers: int = 8
|
|
||||||
n_heads: int = 8
|
|
||||||
# TODO: What is the purpose of this parameter?
|
|
||||||
weight_tying: bool = False
|
|
||||||
patch_in_forward: bool = False
|
|
||||||
|
|
||||||
# Architecture and dimensions
|
|
||||||
dim_token: int | None = None
|
|
||||||
dim_global: int = 512
|
|
||||||
dim_local_decoder: int = 512
|
|
||||||
dim_local_encoder: int = 512
|
|
||||||
n_layers_global: int = 8
|
|
||||||
n_layers_local_decoder: int = 8
|
|
||||||
n_layers_local_encoder: int = 8
|
|
||||||
|
|
||||||
# Tokenization and patching
|
|
||||||
patch_size: float | None = None
|
|
||||||
patching_mode: str | None = None
|
|
||||||
patching_threshold: float | None = None
|
|
||||||
patching_threshold_add: float | None = None
|
|
||||||
monotonicity: bool = False
|
|
||||||
patching_batch_size: int = 1
|
|
||||||
patching_device: str = "cuda"
|
|
||||||
max_patch_length: int | None = None
|
|
||||||
|
|
||||||
# Encoder/Decoder configuration
|
|
||||||
tie_local_encoder_decoder_logits: bool = False
|
|
||||||
use_local_encoder_transformer: bool = False
|
|
||||||
encoder_lm_loss: bool = False
|
|
||||||
max_encoder_seq_length: int | None = None
|
|
||||||
pad_to_max_length: bool = False
|
|
||||||
encoder_enable_byte_ngrams: bool = False
|
|
||||||
encoder_enable_byte_group_hash: bool = False
|
|
||||||
ngram_vocab_sizes: int | None = None
|
|
||||||
|
|
||||||
# Cross attention configurations
|
|
||||||
cross_attn_encoder: bool = False
|
|
||||||
cross_attn_decoder: bool = False
|
|
||||||
cross_attn_window_encoder: int | None = None
|
|
||||||
cross_attn_window_decoder: int | None = None
|
|
||||||
cross_attn_k: int | None = None
|
|
||||||
cross_attn_nheads: int | None = None
|
|
||||||
cross_attn_all_layers_decoder: bool = False
|
|
||||||
cross_attn_all_layers_encoder: bool = False
|
|
||||||
cross_attn_use_flex_attention: bool = True
|
|
||||||
cross_attn_init_by_pooling: bool = False
|
|
||||||
|
|
||||||
# Encoder hash configurations
|
|
||||||
encoder_hash_byte_group_size: Any | None = None
|
|
||||||
encoder_hash_byte_group_vocab: int = 30000
|
|
||||||
encoder_hash_byte_group_nb_functions: int = 3
|
|
||||||
|
|
||||||
# Model behavior and optimization
|
|
||||||
log_patch_lengths: bool = False
|
|
||||||
non_linearity: str = "swiglu"
|
|
||||||
use_rope: bool = True
|
|
||||||
recompute_fc1_out: bool = False
|
|
||||||
recompute_fc3_out: bool = False
|
|
||||||
recompute_attn: bool = True
|
|
||||||
custom_bwd: bool = False
|
|
||||||
layer_ckpt: str = "all"
|
|
||||||
|
|
||||||
# Initialization and attention
|
|
||||||
init_use_gaussian: bool = True
|
|
||||||
init_use_depth: str = "current"
|
|
||||||
attn_bias_type: str = "causal"
|
|
||||||
alpha_depth: str = "disabled"
|
|
||||||
max_length: int = 2048
|
|
||||||
|
|
||||||
# Norm configuration
|
|
||||||
norm_eps: float = 1e-5
|
|
||||||
norm_affine: bool = True
|
|
||||||
pre_norm: bool = True
|
|
||||||
norm_type: str = "rmsnorm"
|
|
||||||
|
|
||||||
# Additional configurations
|
|
||||||
multiple_of: int = 256
|
|
||||||
ffn_dim_multiplier: float = 1.0
|
|
||||||
dropout: float = 0
|
|
||||||
output_size: int = -1
|
|
||||||
|
|
||||||
# Additional parameters from ModelArgs
|
|
||||||
architecture: str = "vanilla"
|
|
||||||
share_encoder_decoder_emb: bool = True
|
|
||||||
global_local_decoder_residual_layer: str | None = None
|
|
||||||
|
|
||||||
tokenize_with_bpe_delimiter: bool = False
|
|
||||||
patching_thresholds_str: str | None = None
|
|
||||||
tie_local_encoder_decoder: bool = False
|
|
||||||
encoder_preds_low_entropy_toks: float | None = None
|
|
||||||
encoder_preds_random_toks: float | None = None
|
|
||||||
dim_token_emb: int | None = None
|
|
||||||
dim_patch_emb: int | None = None
|
|
||||||
|
|
||||||
encoder_ngram_table_dir: str | None = None
|
|
||||||
encoder_ngram_to_size_str: str | None = None
|
|
||||||
|
|
||||||
# Model architecture params
|
|
||||||
entropy_model_checkpoint_dir: str | None = None
|
|
||||||
entropy_model_is_ngram_model: bool = False
|
|
||||||
downsampling_by_pooling: str | None = None
|
|
||||||
n_heads_global: int = 8
|
|
||||||
n_heads_local_decoder: int = 8
|
|
||||||
n_heads_local_encoder: int = 8
|
|
||||||
n_kv_heads: int | None = None
|
|
||||||
n_kv_heads_global: int | None = None
|
|
||||||
conv_kernel_size: int | None = None
|
|
||||||
local_attention_window_len: int | None = None
|
|
||||||
|
|
||||||
# Performance optimization
|
|
||||||
sequence_parallel: bool = False
|
|
||||||
loss_parallel: bool = False
|
|
||||||
fuse_sequence_parallel: bool = False
|
|
||||||
use_fsdp: bool = True
|
|
||||||
attn_to_keep: str = "all"
|
|
||||||
|
|
||||||
# Parameter mixing
|
|
||||||
pm_size: int = 0
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
full_logging_n_layers: int = 4
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_hash_byte_sizes(self) -> Self:
|
|
||||||
if (
|
|
||||||
self.encoder_hash_byte_group_size is not None
|
|
||||||
and type(self.encoder_hash_byte_group_size) == str
|
|
||||||
):
|
|
||||||
self.encoder_hash_byte_group_size = [
|
|
||||||
int(x)
|
|
||||||
for x in self.encoder_hash_byte_group_size.split(",")
|
|
||||||
if len(x) > 0
|
|
||||||
]
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
|
||||||
# Global encoder specific dimensions
|
|
||||||
dim_token_emb: int | None = None
|
|
||||||
dim_patch_emb: int | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Override base args with global encoder specific values
|
|
||||||
self.dim = self.dim_global
|
|
||||||
self.n_layers = self.n_layers_global
|
|
||||||
self.n_heads = self.n_heads_global
|
|
||||||
self.n_kv_heads = self.n_kv_heads_global
|
|
||||||
self.local_attention_window_len = None
|
|
||||||
self.cross_attn_encoder = False
|
|
||||||
self.cross_attn_decoder = False
|
|
||||||
|
|
||||||
|
|
||||||
class LocalDecoderArgs(ByteLatentTransformerArgs):
|
|
||||||
# Local decoder specific dimensions
|
|
||||||
dim_token_emb: int | None = None
|
|
||||||
dim_patch_emb: int | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Override base args with local decoder specific values
|
|
||||||
self.dim = self.dim_local_decoder
|
|
||||||
self.n_layers = self.n_layers_local_decoder
|
|
||||||
self.n_heads = self.n_heads_local_decoder
|
|
||||||
self.cross_attn_encoder = False
|
|
||||||
self.cross_attn_init_by_pooling = False
|
|
||||||
self.attn_bias_type = "local_block_causal"
|
|
||||||
|
|
||||||
|
|
||||||
def create_global_transformer(args: ByteLatentTransformerArgs):
|
def create_global_transformer(args: ByteLatentTransformerArgs):
|
||||||
global_args = args.model_copy(
|
global_args = args.model_copy(
|
||||||
|
@ -1592,34 +1335,6 @@ def create_global_transformer(args: ByteLatentTransformerArgs):
|
||||||
return GlobalTransformer(global_args)
|
return GlobalTransformer(global_args)
|
||||||
|
|
||||||
|
|
||||||
class LocalModelArgs(BaseTransformerArgs):
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
# Override defaults
|
|
||||||
attn_impl: str | None = "xformers"
|
|
||||||
attn_bias_type: str | None = "local_block_causal"
|
|
||||||
|
|
||||||
# Local encoder specific dimensions
|
|
||||||
dropout: float
|
|
||||||
vocab_size: int
|
|
||||||
patch_size: float
|
|
||||||
sliding_window: int | None
|
|
||||||
use_rope: bool
|
|
||||||
cross_attn_encoder: bool | None
|
|
||||||
cross_attn_decoder: bool | None
|
|
||||||
cross_attn_k: int | None
|
|
||||||
cross_attn_init_by_pooling: bool
|
|
||||||
patching_mode: str
|
|
||||||
use_local_encoder_transformer: bool
|
|
||||||
downsampling_by_pooling: str | None
|
|
||||||
encoder_hash_byte_group_size: Any | None = None
|
|
||||||
cross_attn_all_layers_encoder: bool = False
|
|
||||||
cross_attn_all_layers_decoder: bool = False
|
|
||||||
cross_attn_nheads: int | None
|
|
||||||
|
|
||||||
dim_token_emb: int
|
|
||||||
dim_patch_emb: int | None
|
|
||||||
|
|
||||||
|
|
||||||
class LocalModelBase(nn.Module):
|
class LocalModelBase(nn.Module):
|
||||||
def __init__(self, args: LocalModelArgs):
|
def __init__(self, args: LocalModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1850,13 +1565,14 @@ class LocalEncoder(LocalModelBase):
|
||||||
):
|
):
|
||||||
# apply pooling and project
|
# apply pooling and project
|
||||||
if self.cross_attn_init_by_pooling and patch_embeds is None:
|
if self.cross_attn_init_by_pooling and patch_embeds is None:
|
||||||
patch_embeds = downsample(
|
# patch_embeds = downsample(
|
||||||
h,
|
# h,
|
||||||
num_patches,
|
# num_patches,
|
||||||
patch_ids=patch_ids,
|
# patch_ids=patch_ids,
|
||||||
downsampling_by_pooling=self.downsampling_by_pooling,
|
# downsampling_by_pooling=self.downsampling_by_pooling,
|
||||||
patch_size=self.patch_size,
|
# patch_size=self.patch_size,
|
||||||
)
|
# )
|
||||||
|
patch_embeds = patch_reduce(h, num_patches, "amax", patch_ids)
|
||||||
if self.patch_embedding_projection is not None:
|
if self.patch_embedding_projection is not None:
|
||||||
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
||||||
patch_embeds = patch_embeds.reshape(
|
patch_embeds = patch_embeds.reshape(
|
||||||
|
@ -2146,94 +1862,6 @@ class GlobalTransformer(BaseTransformer):
|
||||||
b=3 * std,
|
b=3 * std,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
|
||||||
local_encoder_args = LocalModelArgs(
|
|
||||||
# Updated args
|
|
||||||
dim=args.dim_local_encoder,
|
|
||||||
n_layers=args.n_layers_local_encoder,
|
|
||||||
n_heads=args.n_heads_local_encoder,
|
|
||||||
dim_token_emb=get_encoder_dim_token_emb(args),
|
|
||||||
dim_patch_emb=get_encoder_dim_patch_emb(args),
|
|
||||||
cross_attn_encoder=args.cross_attn_encoder,
|
|
||||||
cross_attn_decoder=False,
|
|
||||||
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
|
|
||||||
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
|
|
||||||
# Defaults
|
|
||||||
head_dim=args.head_dim,
|
|
||||||
max_seqlen=args.max_encoder_seq_length,
|
|
||||||
dropout=args.dropout,
|
|
||||||
vocab_size=args.vocab_size + args.pm_size,
|
|
||||||
norm_eps=args.norm_eps,
|
|
||||||
patch_size=args.patch_size,
|
|
||||||
sliding_window=args.local_attention_window_len,
|
|
||||||
use_rope=args.use_rope,
|
|
||||||
rope_theta=args.rope_theta,
|
|
||||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
|
||||||
init_base_std=args.init_base_std,
|
|
||||||
init_std_factor=args.init_std_factor,
|
|
||||||
n_kv_heads=args.n_kv_heads,
|
|
||||||
attn_impl=args.attn_impl,
|
|
||||||
attn_bias_type="local_block_causal",
|
|
||||||
multiple_of=args.multiple_of,
|
|
||||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
||||||
patching_mode=args.patching_mode,
|
|
||||||
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
|
||||||
downsampling_by_pooling=args.downsampling_by_pooling,
|
|
||||||
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
|
||||||
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
|
||||||
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
|
||||||
cross_attn_nheads=args.cross_attn_nheads,
|
|
||||||
eos_id=args.eos_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LocalEncoder(local_encoder_args)
|
|
||||||
|
|
||||||
|
|
||||||
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
|
||||||
# First deep copy the original args
|
|
||||||
local_decoder_args = LocalModelArgs(
|
|
||||||
dim=args.dim_local_decoder,
|
|
||||||
n_layers=args.n_layers_local_decoder,
|
|
||||||
n_heads=args.n_heads_local_decoder,
|
|
||||||
dim_token_emb=get_decoder_dim_token_emb(args),
|
|
||||||
dim_patch_emb=args.dim_global,
|
|
||||||
cross_attn_encoder=False,
|
|
||||||
cross_attn_decoder=args.cross_attn_decoder,
|
|
||||||
cross_attn_init_by_pooling=False, # states are already defined
|
|
||||||
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
|
|
||||||
# Defaults
|
|
||||||
head_dim=args.head_dim,
|
|
||||||
max_seqlen=args.max_encoder_seq_length,
|
|
||||||
dropout=args.dropout,
|
|
||||||
vocab_size=args.vocab_size + args.pm_size,
|
|
||||||
norm_eps=args.norm_eps,
|
|
||||||
patch_size=args.patch_size,
|
|
||||||
sliding_window=args.local_attention_window_len,
|
|
||||||
use_rope=args.use_rope,
|
|
||||||
rope_theta=args.rope_theta,
|
|
||||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
|
||||||
init_base_std=args.init_base_std,
|
|
||||||
init_std_factor=args.init_std_factor,
|
|
||||||
n_kv_heads=args.n_kv_heads,
|
|
||||||
attn_impl=args.attn_impl,
|
|
||||||
attn_bias_type="local_block_causal",
|
|
||||||
multiple_of=args.multiple_of,
|
|
||||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
||||||
patching_mode=args.patching_mode,
|
|
||||||
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
|
||||||
downsampling_by_pooling=args.downsampling_by_pooling,
|
|
||||||
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
|
||||||
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
|
||||||
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
|
||||||
cross_attn_nheads=args.cross_attn_nheads,
|
|
||||||
eos_id=args.eos_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LocalDecoder(local_decoder_args)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingType(Enum):
|
class EmbeddingType(Enum):
|
||||||
HASH_TOK = auto()
|
HASH_TOK = auto()
|
||||||
NGRAM = auto()
|
NGRAM = auto()
|
||||||
|
@ -2381,9 +2009,85 @@ class ByteLatentTransformer(
|
||||||
)
|
)
|
||||||
|
|
||||||
# ByteLatent modules
|
# ByteLatent modules
|
||||||
self.local_encoder = create_local_encoder(args)
|
local_encoder_args = LocalModelArgs(
|
||||||
|
# Updated args
|
||||||
|
dim=args.dim_local_encoder,
|
||||||
|
n_layers=args.n_layers_local_encoder,
|
||||||
|
n_heads=args.n_heads_local_encoder,
|
||||||
|
dim_token_emb=get_encoder_dim_token_emb(args),
|
||||||
|
dim_patch_emb=get_encoder_dim_patch_emb(args),
|
||||||
|
cross_attn_encoder=args.cross_attn_encoder,
|
||||||
|
cross_attn_decoder=False,
|
||||||
|
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
|
||||||
|
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
|
||||||
|
# Defaults
|
||||||
|
head_dim=args.head_dim,
|
||||||
|
max_seqlen=args.max_encoder_seq_length,
|
||||||
|
dropout=args.dropout,
|
||||||
|
vocab_size=args.vocab_size + args.pm_size,
|
||||||
|
norm_eps=args.norm_eps,
|
||||||
|
patch_size=args.patch_size,
|
||||||
|
sliding_window=args.local_attention_window_len,
|
||||||
|
use_rope=args.use_rope,
|
||||||
|
rope_theta=args.rope_theta,
|
||||||
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
|
init_base_std=args.init_base_std,
|
||||||
|
init_std_factor=args.init_std_factor,
|
||||||
|
n_kv_heads=args.n_kv_heads,
|
||||||
|
attn_impl=args.attn_impl,
|
||||||
|
attn_bias_type="local_block_causal",
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
patching_mode=args.patching_mode,
|
||||||
|
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
||||||
|
downsampling_by_pooling=args.downsampling_by_pooling,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
||||||
|
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
||||||
|
cross_attn_nheads=args.cross_attn_nheads,
|
||||||
|
eos_id=args.eos_id,
|
||||||
|
)
|
||||||
|
self.local_encoder = LocalEncoder(local_encoder_args)
|
||||||
|
local_decoder_args = LocalModelArgs(
|
||||||
|
dim=args.dim_local_decoder,
|
||||||
|
n_layers=args.n_layers_local_decoder,
|
||||||
|
n_heads=args.n_heads_local_decoder,
|
||||||
|
dim_token_emb=get_decoder_dim_token_emb(args),
|
||||||
|
dim_patch_emb=args.dim_global,
|
||||||
|
cross_attn_encoder=False,
|
||||||
|
cross_attn_decoder=args.cross_attn_decoder,
|
||||||
|
cross_attn_init_by_pooling=False, # states are already defined
|
||||||
|
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
|
||||||
|
# Defaults
|
||||||
|
head_dim=args.head_dim,
|
||||||
|
max_seqlen=args.max_encoder_seq_length,
|
||||||
|
dropout=args.dropout,
|
||||||
|
vocab_size=args.vocab_size + args.pm_size,
|
||||||
|
norm_eps=args.norm_eps,
|
||||||
|
patch_size=args.patch_size,
|
||||||
|
sliding_window=args.local_attention_window_len,
|
||||||
|
use_rope=args.use_rope,
|
||||||
|
rope_theta=args.rope_theta,
|
||||||
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
|
init_base_std=args.init_base_std,
|
||||||
|
init_std_factor=args.init_std_factor,
|
||||||
|
n_kv_heads=args.n_kv_heads,
|
||||||
|
attn_impl=args.attn_impl,
|
||||||
|
attn_bias_type="local_block_causal",
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
patching_mode=args.patching_mode,
|
||||||
|
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
||||||
|
downsampling_by_pooling=args.downsampling_by_pooling,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
||||||
|
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
||||||
|
cross_attn_nheads=args.cross_attn_nheads,
|
||||||
|
eos_id=args.eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
self.global_transformer = create_global_transformer(args)
|
self.global_transformer = create_global_transformer(args)
|
||||||
self.local_decoder = create_local_decoder(args)
|
self.local_decoder = LocalDecoder(local_decoder_args)
|
||||||
self.encoder_hash_tok_embedding = init_embeddings(
|
self.encoder_hash_tok_embedding = init_embeddings(
|
||||||
args,
|
args,
|
||||||
EmbeddingType.HASH_TOK,
|
EmbeddingType.HASH_TOK,
|
||||||
|
@ -2534,21 +2238,7 @@ class ByteLatentTransformer(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Downsampling
|
# Downsampling
|
||||||
if not self.cross_attn_encoder:
|
h = h_cross.view(bs, patch_lengths.shape[1], -1)
|
||||||
assert (
|
|
||||||
patch_ids.shape[1] == h_encoder.shape[1]
|
|
||||||
), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
|
|
||||||
h = downsample(
|
|
||||||
h_encoder,
|
|
||||||
patch_lengths.shape[1],
|
|
||||||
patch_lengths,
|
|
||||||
patch_ids,
|
|
||||||
downsampling_by_pooling=self.downsampling_by_pooling,
|
|
||||||
patch_size=self.patch_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Reshape h_cross
|
|
||||||
h = h_cross.view(bs, patch_lengths.shape[1], -1)
|
|
||||||
|
|
||||||
# Global transformer
|
# Global transformer
|
||||||
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
|
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
|
||||||
|
@ -2615,4 +2305,4 @@ class ByteLatentTransformer(
|
||||||
std=emb_std,
|
std=emb_std,
|
||||||
a=-3 * emb_std,
|
a=-3 * emb_std,
|
||||||
b=3 * emb_std,
|
b=3 * emb_std,
|
||||||
)
|
)
|
|
@ -56,7 +56,7 @@ def sample_top_p(probs, p):
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_nocache(
|
def generate(
|
||||||
prompts: list[str] | None,
|
prompts: list[str] | None,
|
||||||
*,
|
*,
|
||||||
model: ByteLatentTransformer,
|
model: ByteLatentTransformer,
|
||||||
|
@ -186,9 +186,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
|
||||||
|
|
||||||
# Generate text
|
# Generate text
|
||||||
print("Generating text...")
|
print("Generating text...")
|
||||||
prompts = [prompt]
|
outputs = generate(
|
||||||
outputs = generate_nocache(
|
[prompt],
|
||||||
prompts,
|
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
patcher=model.patcher, # Use the model's patcher
|
patcher=model.patcher, # Use the model's patcher
|
||||||
|
@ -197,7 +196,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
|
||||||
|
|
||||||
# Decode and print results
|
# Decode and print results
|
||||||
text_outputs = [tokenizer.decode(t) for t in outputs]
|
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||||
for p, t in zip(prompts, text_outputs):
|
for p, t in zip([prompt], text_outputs):
|
||||||
print(f'Prompt: "{p}"')
|
print(f'Prompt: "{p}"')
|
||||||
print(f'Completion: "{t}"')
|
print(f'Completion: "{t}"')
|
||||||
print()
|
print()
|
||||||
|
|
Loading…
Add table
Reference in a new issue