[WIP] Changes for training entropy model and correcting attention in local models
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-17 22:21:50 +00:00
parent caec8d2621
commit 7f305b3871
15 changed files with 349 additions and 138 deletions

View file

@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs
logger = logging.getLogger() logger = logging.getLogger()
@ -163,6 +164,8 @@ class TrainArgs(BaseModel):
seed: int = 42 seed: int = 42
debug_dynamo: bool = False
# Number of gradient accumulation steps # Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps # Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1 grad_acc_steps: int = 1
@ -176,6 +179,10 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs() data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs() optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs() distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs() env: EnvironmentArgs = EnvironmentArgs()

View file

@ -4,7 +4,7 @@ from enum import Enum
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
from xformers.ops import AttentionBias, fmha from xformers.ops import AttentionBias, fmha
from bytelatent import probe from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention) flex_attention_comp = torch.compile(flex_attention)
@ -30,13 +31,14 @@ class InitStdFactor(Enum):
class BaseTransformerArgs(BaseModel): class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512 dim: int = 512
n_layers: int = 8 n_layers: int = 8
head_dim: Optional[int] = None head_dim: int | None = None
n_heads: Optional[int] = None n_heads: int | None = None
n_kv_heads: Optional[int] = None n_kv_heads: int | None = None
ffn_dim_multiplier: Optional[float] = None ffn_dim_multiplier: float | None = None
multiple_of: int = 256 multiple_of: int = 256
@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
rope_theta: float = 10000.0 rope_theta: float = 10000.0
init_base_std: Optional[float] = None init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024 max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
def cross_entropy(pred, target, **kwargs): def cross_entropy(pred, target, **kwargs):
return F.nll_loss( return F.nll_loss(
@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
torch.nn.init.ones_(self.weight) # type: ignore torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@ -371,9 +390,12 @@ class Attention(nn.Module):
output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "fmha": elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias) assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
output = output.view(query_shape)
# This uses B S H D instead of B H S D of pytorch # This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa": elif attn_impl == "sdpa":
@ -522,14 +544,16 @@ class TransformerBlock(nn.Module):
mask: Optional[Union[BlockMask, AttentionBias, str]] = None, mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa", attn_impl: str = "sdpa",
) -> torch.Tensor: ) -> torch.Tensor:
h = x + self.attention( attn_out = self.attention(
self.attention_norm(x), self.attention_norm(x),
freq_cis, freq_cis,
tok_idx=tok_idx, tok_idx=tok_idx,
mask=mask, mask=mask,
attn_impl=attn_impl, attn_impl=attn_impl,
) )
out = h + self.feed_forward(self.ffn_norm(h)) h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
return out return out
def init_weights(self, init_std=None, factor=1.0): def init_weights(self, init_std=None, factor=1.0):
@ -545,6 +569,8 @@ class BaseTransformer(nn.Module):
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
self.init_base_std = args.init_base_std self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor) self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding( self.rope_embeddings = RotaryEmbedding(
@ -552,6 +578,7 @@ class BaseTransformer(nn.Module):
head_dim=args.head_dim or args.dim // args.n_heads, head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen, max_seqlen=args.max_seqlen,
) )
self.eos_id = args.eos_id
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
for _ in range(args.n_layers): for _ in range(args.n_layers):

View file

@ -15,7 +15,6 @@ optim:
distributed: distributed:
fsdp_type: full_shard fsdp_type: full_shard
compile: true
model_dtype: bf16 model_dtype: bf16
matmul_allow_tf32: false matmul_allow_tf32: false
selective_activation_checkpointing: false selective_activation_checkpointing: false
@ -58,13 +57,13 @@ model:
recompute_attn: false recompute_attn: false
custom_bwd: false custom_bwd: false
layer_ckpt: "none" layer_ckpt: "none"
efficient_attn: "sdpa"
patch_only_encoder: false patch_only_encoder: false
patch_only_decoder: false patch_only_decoder: false
use_local_encoder_transformer: true use_local_encoder_transformer: true
init_use_gaussian: true init_use_gaussian: true
init_use_depth: "current" init_use_depth: "current"
attn_bias_type: "block_causal" attn_bias_type: "block_causal"
attn_impl: "xformers"
alpha_depth: "disabled" alpha_depth: "disabled"
max_length: 256 max_length: 256
local_attention_window_len: 512 local_attention_window_len: 512

View file

@ -27,6 +27,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1], dataset_files=[ARROW_TEST_DATA_1],
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = initial_state.build() arrow_file = initial_state.build()
start_state = arrow_file.get_state() start_state = arrow_file.get_state()
@ -55,6 +56,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1], dataset_files=[ARROW_TEST_DATA_1],
row_num=251, row_num=251,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = resumed_state.build() arrow_file = resumed_state.build()
for example in arrow_file.create_iter(): for example in arrow_file.create_iter():
@ -74,6 +76,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1], dataset_files=[ARROW_TEST_DATA_1],
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = rank_state.build() arrow_file = rank_state.build()
expected_ids = [] expected_ids = []

View file

@ -11,7 +11,6 @@ import socket
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from dataclasses import asdict, dataclass
from functools import lru_cache, partial, reduce from functools import lru_cache, partial, reduce
from itertools import chain from itertools import chain
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View file

@ -1,12 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import json import json
import logging
import os import os
import re
import torch import torch
from bytelatent.transformer import LMTransformer, LMTransformerArgs from bytelatent.transformer import LMTransformer, LMTransformerArgs
logger = logging.getLogger()
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
@ -14,6 +16,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["model"] model_params = reloaded["model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
)
entropy_model = LMTransformer( entropy_model = LMTransformer(
LMTransformerArgs( LMTransformerArgs(
dim=model_params["dim"], dim=model_params["dim"],
@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
max_seqlen=model_params["max_length"], max_seqlen=model_params["max_length"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"], ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"], vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
) )
) )

View file

@ -15,8 +15,8 @@ from bytelatent.base_transformer import (
TransformerBlock, TransformerBlock,
) )
from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.local_models import LocalDecoder, LocalEncoder from bytelatent.model.latent_transformer import GlobalTransformer
from bytelatent.model.transformer import GlobalTransformer from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.utils import downsample from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
class ByteLatentTransformerArgs(BaseTransformerArgs): class ByteLatentTransformerArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Basic model configuration # Basic model configuration
seed: int = 42 seed: int = 42
vocab_size: int = -1 vocab_size: int = -1
@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
n_heads: int = 8 n_heads: int = 8
# TODO: What is the purpose of this parameter? # TODO: What is the purpose of this parameter?
weight_tying: bool = False weight_tying: bool = False
sliding_window: Optional[int] = None
# Architecture and dimensions # Architecture and dimensions
dim_token: int = 256 dim_token: int = 256
@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
recompute_attn: bool = True recompute_attn: bool = True
custom_bwd: bool = False custom_bwd: bool = False
layer_ckpt: str = "all" layer_ckpt: str = "all"
efficient_attn: str | None = None
# Architecture options
patch_only_encoder: bool = False
patch_only_decoder: bool = False
# Initialization and attention # Initialization and attention
init_use_gaussian: bool = True init_use_gaussian: bool = True
@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
# Logging # Logging
full_logging_n_layers: int = 4 full_logging_n_layers: int = 4
# Special token config
eos_id: int | None = None
@model_validator(mode="after") @model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self: def check_hash_byte_sizes(self) -> Self:
if ( if (
@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
return self return self
class LocalEncoderArgs(ByteLatentTransformerArgs):
# Local encoder specific dimensions
n_heads_local_encoder: int = 8
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with local encoder specific values
self.dim = self.dim_local_encoder
self.n_layers = self.n_layers_local_encoder
self.n_heads = self.n_heads_local_encoder
self.cross_attn_decoder = False
self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
self.attn_bias_type = "local_block_causal"
class GlobalTransformerArgs(ByteLatentTransformerArgs): class GlobalTransformerArgs(ByteLatentTransformerArgs):
# Global encoder specific dimensions # Global encoder specific dimensions
dim_token_emb: int | None = None dim_token_emb: int | None = None
@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
# First deep copy the original args local_encoder_args = LocalModelArgs(
# Replace with local encoder specific values # Updated args
local_encoder_args = args.model_copy( dim=args.dim_local_encoder,
deep=True, n_layers=args.n_layers_local_encoder,
update=dict( n_heads=args.n_heads_local_encoder,
dim=args.dim_local_encoder, dim_token_emb=get_encoder_dim_token_emb(args),
n_layers=args.n_layers_local_encoder, dim_patch_emb=get_encoder_dim_patch_emb(args),
n_heads=args.n_heads_local_encoder, cross_attn_encoder=args.cross_attn_encoder,
dim_token_emb=get_encoder_dim_token_emb(args), cross_attn_decoder=False,
dim_patch_emb=get_encoder_dim_patch_emb(args), cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_decoder=False, cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, # Defaults
attn_bias_type="local_block_causal", head_dim=args.head_dim,
), max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
) )
return LocalEncoder(local_encoder_args) return LocalEncoder(local_encoder_args)
@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
# First deep copy the original args # First deep copy the original args
local_decoder_args = args.model_copy( local_decoder_args = LocalModelArgs(
deep=True, dim=args.dim_local_decoder,
update=dict( n_layers=args.n_layers_local_decoder,
dim=args.dim_local_decoder, n_heads=args.n_heads_local_decoder,
n_layers=args.n_layers_local_decoder, dim_token_emb=get_decoder_dim_token_emb(args),
n_heads=args.n_heads_local_decoder, dim_patch_emb=args.dim_global,
cross_attn_encoder=False, cross_attn_encoder=False,
cross_attn_init_by_pooling=False, # states are already defined cross_attn_decoder=args.cross_attn_decoder,
dim_token_emb=get_decoder_dim_token_emb(args), cross_attn_init_by_pooling=False, # states are already defined
dim_patch_emb=args.dim_global, cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, # Defaults
), head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
) )
return LocalDecoder(local_decoder_args) return LocalDecoder(local_decoder_args)
@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module):
# General configuration # General configuration
self.weight_tying = args.weight_tying self.weight_tying = args.weight_tying
self.sliding_window = args.sliding_window
self.patch_size = args.patch_size self.patch_size = args.patch_size
self.patching_mode = args.patching_mode self.patching_mode = args.patching_mode
self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( self.boe_id, self.bos_id, self.pad_id, self.eos_id = (

View file

@ -11,6 +11,7 @@ from xformers.ops import AttentionBias
from bytelatent.base_transformer import ( from bytelatent.base_transformer import (
BaseTransformer, BaseTransformer,
BaseTransformerArgs,
RMSNorm, RMSNorm,
flex_attention_comp, flex_attention_comp,
repeat_kv, repeat_kv,
@ -142,11 +143,10 @@ class CrossAttention(nn.Module):
class GlobalTransformer(BaseTransformer): class GlobalTransformer(BaseTransformer):
def __init__(self, args): def __init__(self, args: BaseTransformerArgs):
super().__init__(args) super().__init__(args)
self.dropout = args.dropout self.dropout = args.dropout
self.sliding_window = args.sliding_window self.eos_id = args.eos_id
self.efficient_attn = args.efficient_attn
self.token_embedding_projection = None self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim: if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer):
and projection to the token space. and projection to the token space.
""" """
bs, seqlen = tokens.shape bs, seqlen = tokens.shape
attn_impl = self.efficient_attn
h = embeds h = embeds
mask = ( mask = (
mask mask
if mask is not None if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window) else create_causal_mask(
seqlen,
self.attn_impl,
self.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
) )
if self.token_embedding_projection is not None and h.shape[-1] != self.dim: if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer):
h = F.dropout(h, p=self.dropout, training=self.training) h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
return h, cache return h, cache
def init_weights(self, init_base_std: float): def init_weights(self, init_base_std: float):

View file

@ -1,44 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import logging import logging
from typing import List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch import torch
import torch.nn import torch.nn
import torch.nn as nn import torch.nn as nn
from pydantic import BaseModel, ConfigDict
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias from xformers.ops import AttentionBias
from bytelatent.base_transformer import ( from bytelatent.base_transformer import (
BaseTransformerArgs,
InitStdFactor, InitStdFactor,
RMSNorm, RMSNorm,
RotaryEmbedding, RotaryEmbedding,
TransformerBlock, TransformerBlock,
) )
from bytelatent.model.transformer import CrossAttention from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger() logger = logging.getLogger()
class LocalModelArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Override defaults
attn_impl: str | None = "xformers"
attn_bias_type: str | None = "local_block_causal"
# Local encoder specific dimensions
dropout: float
vocab_size: int
patch_size: int
sliding_window: int | None
use_rope: bool
cross_attn_encoder: bool | None
cross_attn_decoder: bool | None
cross_attn_k: int | None
cross_attn_init_by_pooling: bool
patching_mode: str
use_local_encoder_transformer: bool
downsampling_by_pooling: str | None
encoder_hash_byte_group_size: Any | None = None
cross_attn_all_layers_encoder: bool = False
cross_attn_all_layers_decoder: bool = False
cross_attn_nheads: int | None
dim_token_emb: int
dim_patch_emb: int | None
class LocalModelBase(nn.Module): class LocalModelBase(nn.Module):
def __init__(self, args): def __init__(self, args: LocalModelArgs):
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
self.dropout = args.dropout self.dropout = args.dropout
self.vocab_size = args.vocab_size + args.pm_size self.vocab_size = args.vocab_size
self.patch_size = args.patch_size self.patch_size = args.patch_size
self.efficient_attn = args.efficient_attn self.attn_impl = args.attn_impl
self.sliding_window = args.sliding_window self.sliding_window = args.sliding_window
self.use_rope = args.use_rope self.use_rope = args.use_rope
self.init_std_factor = args.init_std_factor self.init_std_factor = args.init_std_factor
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
self.cross_attn_k = getattr(args, "cross_attn_k", None) self.cross_attn_k = getattr(args, "cross_attn_k", None)
self.eos_id = args.eos_id
self.boe_id = BOE_ID self.boe_id = BOE_ID
@ -54,7 +85,7 @@ class LocalModelBase(nn.Module):
self.rope = RotaryEmbedding( self.rope = RotaryEmbedding(
theta=args.rope_theta, theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads, head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), max_seqlen=args.max_seqlen,
) )
self.pos_embeddings = None self.pos_embeddings = None
@ -66,21 +97,15 @@ class LocalModelBase(nn.Module):
self.patch_embedding_projection = self._create_patch_projection(args) self.patch_embedding_projection = self._create_patch_projection(args)
def _should_create_patch_projection(self, args): def _should_create_patch_projection(self, args: LocalModelArgs):
dimension_mismatch = ( dimension_mismatch = (
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
) )
# Check cross attention conditions # Check cross attention conditions
cross_attn_conditions = ( cross_attn_conditions = (
hasattr(args, "cross_attn_encoder") args.cross_attn_encoder and args.cross_attn_init_by_pooling
and args.cross_attn_encoder ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
and getattr(args, "cross_attn_init_by_pooling")
) or (
hasattr(args, "cross_attn_decoder")
and args.cross_attn_decoder
and getattr(args, "cross_attn_init_by_pooling")
)
return dimension_mismatch or cross_attn_conditions return dimension_mismatch or cross_attn_conditions
@ -172,7 +197,7 @@ class LocalModelBase(nn.Module):
class LocalEncoder(LocalModelBase): class LocalEncoder(LocalModelBase):
def __init__(self, args): def __init__(self, args: LocalModelArgs):
super().__init__(args) super().__init__(args)
self.output_proj = ( self.output_proj = (
args.patching_mode in ["entropy", "probmax"] args.patching_mode in ["entropy", "probmax"]
@ -180,7 +205,6 @@ class LocalEncoder(LocalModelBase):
self.apply_transformer = args.use_local_encoder_transformer self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling self.downsampling_by_pooling = args.downsampling_by_pooling
self.patch_only = args.patch_only_encoder
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
@ -224,7 +248,14 @@ class LocalEncoder(LocalModelBase):
""" """ """ """
bs, seqlen = tokens.shape bs, seqlen = tokens.shape
if mask is None: if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = self.apply_embedding(tokens, embeds) h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
@ -232,7 +263,7 @@ class LocalEncoder(LocalModelBase):
h = F.dropout(h, p=self.dropout, training=self.training) h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
# check if cross attention should be applied to either all layer or only the last layer # check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and ( if self.cross_attn_encoder and (
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
@ -273,12 +304,10 @@ class LocalEncoder(LocalModelBase):
class LocalDecoder(LocalModelBase): class LocalDecoder(LocalModelBase):
def __init__(self, args): def __init__(self, args: LocalModelArgs):
super().__init__(args) super().__init__(args)
# Model configuration flags # Model configuration flags
self.patch_only = args.patch_only_decoder
self.expects_embeddings = args.share_encoder_decoder_emb
self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
@ -317,7 +346,14 @@ class LocalDecoder(LocalModelBase):
assert embeds is not None, "Embeddings must be provided" assert embeds is not None, "Embeddings must be provided"
if mask is None: if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = embeds h = embeds
@ -347,7 +383,7 @@ class LocalDecoder(LocalModelBase):
) )
h = h + h_cross h = h + h_cross
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
h_preds = self.norm(h) h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)

View file

@ -1,8 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
import torch import torch
from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import create_block_mask
from xformers.ops import fmha from xformers.ops import fmha
logger = logging.getLogger()
def patch_reduce(h, max_num_patches, reduction, patch_ids): def patch_reduce(h, max_num_patches, reduction, patch_ids):
""" """
@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx return q_idx >= kv_idx
def create_causal_mask(seqlen, attn_impl, sliding_window): def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
if sliding_window is not None and attn_impl == "xformers": """
return fmha.attn_bias.LocalAttentionFromBottomRightMask( 0 0 0 1 0 0 0 1 0 0 0
window_left=sliding_window - 1, window_right=0 0 1 0 0 0 1 0 0 0 0 0
) -> 4 4 3 2 4 5
elif attn_impl == "xformers": """
return fmha.attn_bias.LowerTriangularMask() mask = batch == eos_id
mask[:, -1] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row, col = torch.where(mask)
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [int(col[0].item() + 1)] + seqlens.tolist()
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "xformers":
if attn_bias_type is None:
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "causal":
assert sliding_window is None
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "block_causal":
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
)
elif attn_bias_type == "local_block_causal":
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
).make_local_attention(sliding_window)
else:
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "sdpa": elif attn_impl == "sdpa":
return "causal" BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
if attn_bias_type == "causal":
return "causal"
if BLT_SUPPRESS_ATTN_ERROR == 1:
logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
)
return "causal"
else:
raise ValueError(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
)
elif attn_impl == "flex_attention": elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen) return create_block_mask(causal_mask, None, None, seqlen, seqlen)
elif attn_impl == "fmha": elif attn_impl == "fmha":

View file

@ -0,0 +1,38 @@
import fsspec
from luigi.target import FileSystem, FileSystemTarget
class FSSpecFileSystem(FileSystem):
def __init__(self, fs: fsspec.AbstractFileSystem):
self.fs = fs
def exists(self, path):
return self.fs.exists()
def remove(self, path, recursive=True, skip_trash=True):
raise NotImplementedError()
def isdir(self, path):
return self.fs.isdir(path)
def listdir(self, path):
return self.fs.ls(path)
class FSSpecTarget(FileSystemTarget):
def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None):
self.path = path
if fs is None:
self.fsspec_fs = fsspec.filesystem("file")
else:
self.fsspec_fs = fs
self._fs = None
@property
def fs(self):
if self._fs is None:
self._fs = FSSpecFileSystem(self.fsspec_fs)
return self._fs
def open(self, mode):
return self.fs.open(self.path, mode=mode)

View file

@ -23,9 +23,10 @@ from bytelatent.model.blt import (
init_embeddings, init_embeddings,
patch_ids_from_lengths, patch_ids_from_lengths,
) )
from bytelatent.model.transformer import CrossAttention from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.tokenizers.constants import EOS_ID
from bytelatent.train import compute_loss from bytelatent.train import compute_loss
@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch):
def fake_batch(): def fake_batch():
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
del batch_dict["x2"] del batch_dict["x2"]
del batch_dict["y2"] del batch_dict["y2"]
del batch_dict["src_names"] del batch_dict["src_names"]
@ -98,18 +99,17 @@ def create_args(cross_attention=False):
recompute_attn=False, recompute_attn=False,
custom_bwd=False, custom_bwd=False,
layer_ckpt="none", layer_ckpt="none",
efficient_attn="sdpa",
patch_only_encoder=False,
patch_only_decoder=False,
use_local_encoder_transformer=True, use_local_encoder_transformer=True,
init_use_gaussian=True, init_use_gaussian=True,
init_use_depth="current", init_use_depth="current",
attn_bias_type="block_causal", attn_bias_type="block_causal",
attn_impl="xformers",
alpha_depth="disabled", alpha_depth="disabled",
max_length=256, max_length=256,
local_attention_window_len=512, local_attention_window_len=512,
max_seqlen=12288, max_seqlen=12288,
downsampling_by_pooling="max", downsampling_by_pooling="max",
eos_id=EOS_ID,
) )
return transformer_args return transformer_args
@ -341,10 +341,15 @@ class TestByteLatentTransformer:
model = ByteLatentTransformer(args) model = ByteLatentTransformer(args)
assert model is not None assert model is not None
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"])
def test_blt_transformer_forward(self, attn_type): def test_blt_transformer_forward(self, attn_impl):
args = create_args() args = create_args()
args = args.model_copy(update=dict(efficient_attn=attn_type)) if attn_impl == "sdpa":
os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1"
else:
os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0"
args = args.model_copy(update=dict(attn_impl=attn_impl))
model = ByteLatentTransformer(args) model = ByteLatentTransformer(args)
model = model.cuda() model = model.cuda()
batch = fake_batch() batch = fake_batch()
@ -393,7 +398,9 @@ class TestByteLatentTransformer:
n_kv_heads=4, n_kv_heads=4,
norm_eps=1e-6, norm_eps=1e-6,
).to("cuda") ).to("cuda")
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) mask = create_causal_mask(
x.shape[1], "flex_attention", None, sliding_window=None
)
output = cross_attention(x, kv, mask) output = cross_attention(x, kv, mask)
assert output is not None assert output is not None
assert output.shape == (2, 256, 512) assert output.shape == (2, 256, 512)
@ -440,7 +447,7 @@ class TestByteLatentTransformer:
def test_loss_backward(self): def test_loss_backward(self):
args = create_args() args = create_args()
args = args.model_copy(update=dict(efficient_attn="sdpa")) args = args.model_copy(update=dict(attn_impl="xformers"))
batch = fake_batch() batch = fake_batch()
model = ByteLatentTransformer(args) model = ByteLatentTransformer(args)
steps = 10 steps = 10

View file

@ -24,6 +24,7 @@ def test_entropy_model():
dataset_files=[ARROW_TEST_DATA], dataset_files=[ARROW_TEST_DATA],
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = initial_state.build() arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs( tokenizer_args = TokenizerArgs(
@ -38,7 +39,7 @@ def test_entropy_model():
BLT_DATA, BLT_DATA,
"entropy_model.pth", "entropy_model.pth",
), ),
) ).cuda()
preprocess_iter = PreprocessIterator( preprocess_iter = PreprocessIterator(
arrow_file, arrow_file,
tokenizer_args=tokenizer_args, tokenizer_args=tokenizer_args,
@ -48,8 +49,10 @@ def test_entropy_model():
for example in preprocess_iter.create_iter(): for example in preprocess_iter.create_iter():
tokens = torch.tensor(example.tokens).unsqueeze(0) tokens = torch.tensor(example.tokens).unsqueeze(0)
expected_entropies = torch.tensor(example.entropies).unsqueeze(0) expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
preds = entropy_model(tokens) preds = entropy_model(tokens.cuda())
pred_entropies = entropy(preds) pred_entropies = entropy(preds)
assert pred_entropies.shape == expected_entropies.shape assert pred_entropies.shape == expected_entropies.shape
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5) assert torch.allclose(
pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5
)
break break

View file

@ -644,6 +644,10 @@ def main():
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(cfg) train_args = TrainArgs.model_validate(cfg)
if train_args.debug_dynamo:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
train(train_args) train(train_args)

View file

@ -22,23 +22,7 @@ from bytelatent.base_transformer import (
RMSNorm, RMSNorm,
cross_entropy, cross_entropy,
) )
from bytelatent.model.utils import create_causal_mask
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
def attention_flops_per_token(n_layers, seq_len, dim, causal): def attention_flops_per_token(n_layers, seq_len, dim, causal):
@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer):
target: Optional[torch.Tensor] = None, target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str = "sdpa", attn_impl: str | None = None,
): ):
if attn_impl is None:
attn_impl = self.attn_impl
bsz, seqlen = token_values.shape bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values) h = self.tok_embeddings(token_values)
@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer):
mask = ( mask = (
mask mask
if mask is not None if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window) else create_causal_mask(
seqlen,
attn_impl,
self.attn_bias_type,
sliding_window=self.sliding_window,
tokens=token_values,
eos_id=self.eos_id,
)
) )
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)