Changes for training entropy model and correcting attention in local models (#25)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-17 14:23:01 -08:00 committed by GitHub
parent caec8d2621
commit 6ffeb66b53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 349 additions and 138 deletions

View file

@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs
logger = logging.getLogger()
@ -163,6 +164,8 @@ class TrainArgs(BaseModel):
seed: int = 42
debug_dynamo: bool = False
# Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1
@ -176,6 +179,10 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs()

View file

@ -4,7 +4,7 @@ from enum import Enum
from typing import Optional, Tuple, Union
import torch
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
@ -30,13 +31,14 @@ class InitStdFactor(Enum):
class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512
n_layers: int = 8
head_dim: Optional[int] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
multiple_of: int = 256
@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
rope_theta: float = 10000.0
init_base_std: Optional[float] = None
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
class Attention(nn.Module):
def __init__(
self,
@ -371,9 +390,12 @@ class Attention(nn.Module):
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "fmha":
elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
output = output.view(query_shape)
# This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa":
@ -522,14 +544,16 @@ class TransformerBlock(nn.Module):
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
h = x + self.attention(
attn_out = self.attention(
self.attention_norm(x),
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
out = h + self.feed_forward(self.ffn_norm(h))
h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
return out
def init_weights(self, init_std=None, factor=1.0):
@ -545,6 +569,8 @@ class BaseTransformer(nn.Module):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
@ -552,6 +578,7 @@ class BaseTransformer(nn.Module):
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
)
self.eos_id = args.eos_id
self.layers = nn.ModuleList()
for _ in range(args.n_layers):

View file

@ -15,7 +15,6 @@ optim:
distributed:
fsdp_type: full_shard
compile: true
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
@ -58,13 +57,13 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
efficient_attn: "sdpa"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512

View file

@ -27,6 +27,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
start_state = arrow_file.get_state()
@ -55,6 +56,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=251,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = resumed_state.build()
for example in arrow_file.create_iter():
@ -74,6 +76,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = rank_state.build()
expected_ids = []

View file

@ -11,7 +11,6 @@ import socket
import subprocess
import sys
import tempfile
from dataclasses import asdict, dataclass
from functools import lru_cache, partial, reduce
from itertools import chain
from typing import List, Optional, Tuple, Union

View file

@ -1,12 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
import re
import torch
from bytelatent.transformer import LMTransformer, LMTransformerArgs
logger = logging.getLogger()
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
@ -14,6 +16,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
)
entropy_model = LMTransformer(
LMTransformerArgs(
dim=model_params["dim"],
@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
max_seqlen=model_params["max_length"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
)
)

View file

@ -15,8 +15,8 @@ from bytelatent.base_transformer import (
TransformerBlock,
)
from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.local_models import LocalDecoder, LocalEncoder
from bytelatent.model.transformer import GlobalTransformer
from bytelatent.model.latent_transformer import GlobalTransformer
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
class ByteLatentTransformerArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Basic model configuration
seed: int = 42
vocab_size: int = -1
@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
n_heads: int = 8
# TODO: What is the purpose of this parameter?
weight_tying: bool = False
sliding_window: Optional[int] = None
# Architecture and dimensions
dim_token: int = 256
@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
recompute_attn: bool = True
custom_bwd: bool = False
layer_ckpt: str = "all"
efficient_attn: str | None = None
# Architecture options
patch_only_encoder: bool = False
patch_only_decoder: bool = False
# Initialization and attention
init_use_gaussian: bool = True
@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
# Logging
full_logging_n_layers: int = 4
# Special token config
eos_id: int | None = None
@model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self:
if (
@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
return self
class LocalEncoderArgs(ByteLatentTransformerArgs):
# Local encoder specific dimensions
n_heads_local_encoder: int = 8
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with local encoder specific values
self.dim = self.dim_local_encoder
self.n_layers = self.n_layers_local_encoder
self.n_heads = self.n_heads_local_encoder
self.cross_attn_decoder = False
self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
self.attn_bias_type = "local_block_causal"
class GlobalTransformerArgs(ByteLatentTransformerArgs):
# Global encoder specific dimensions
dim_token_emb: int | None = None
@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
# First deep copy the original args
# Replace with local encoder specific values
local_encoder_args = args.model_copy(
deep=True,
update=dict(
local_encoder_args = LocalModelArgs(
# Updated args
dim=args.dim_local_encoder,
n_layers=args.n_layers_local_encoder,
n_heads=args.n_heads_local_encoder,
dim_token_emb=get_encoder_dim_token_emb(args),
dim_patch_emb=get_encoder_dim_patch_emb(args),
cross_attn_encoder=args.cross_attn_encoder,
cross_attn_decoder=False,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
),
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
)
return LocalEncoder(local_encoder_args)
@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
# First deep copy the original args
local_decoder_args = args.model_copy(
deep=True,
update=dict(
local_decoder_args = LocalModelArgs(
dim=args.dim_local_decoder,
n_layers=args.n_layers_local_decoder,
n_heads=args.n_heads_local_decoder,
cross_attn_encoder=False,
cross_attn_init_by_pooling=False, # states are already defined
dim_token_emb=get_decoder_dim_token_emb(args),
dim_patch_emb=args.dim_global,
cross_attn_encoder=False,
cross_attn_decoder=args.cross_attn_decoder,
cross_attn_init_by_pooling=False, # states are already defined
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
),
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
)
return LocalDecoder(local_decoder_args)
@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module):
# General configuration
self.weight_tying = args.weight_tying
self.sliding_window = args.sliding_window
self.patch_size = args.patch_size
self.patching_mode = args.patching_mode
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (

View file

@ -11,6 +11,7 @@ from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
RMSNorm,
flex_attention_comp,
repeat_kv,
@ -142,11 +143,10 @@ class CrossAttention(nn.Module):
class GlobalTransformer(BaseTransformer):
def __init__(self, args):
def __init__(self, args: BaseTransformerArgs):
super().__init__(args)
self.dropout = args.dropout
self.sliding_window = args.sliding_window
self.efficient_attn = args.efficient_attn
self.eos_id = args.eos_id
self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer):
and projection to the token space.
"""
bs, seqlen = tokens.shape
attn_impl = self.efficient_attn
h = embeds
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
else create_causal_mask(
seqlen,
self.attn_impl,
self.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
)
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer):
h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
return h, cache
def init_weights(self, init_base_std: float):

View file

@ -1,44 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from pydantic import BaseModel, ConfigDict
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformerArgs,
InitStdFactor,
RMSNorm,
RotaryEmbedding,
TransformerBlock,
)
from bytelatent.model.transformer import CrossAttention
from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger()
class LocalModelArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Override defaults
attn_impl: str | None = "xformers"
attn_bias_type: str | None = "local_block_causal"
# Local encoder specific dimensions
dropout: float
vocab_size: int
patch_size: int
sliding_window: int | None
use_rope: bool
cross_attn_encoder: bool | None
cross_attn_decoder: bool | None
cross_attn_k: int | None
cross_attn_init_by_pooling: bool
patching_mode: str
use_local_encoder_transformer: bool
downsampling_by_pooling: str | None
encoder_hash_byte_group_size: Any | None = None
cross_attn_all_layers_encoder: bool = False
cross_attn_all_layers_decoder: bool = False
cross_attn_nheads: int | None
dim_token_emb: int
dim_patch_emb: int | None
class LocalModelBase(nn.Module):
def __init__(self, args):
def __init__(self, args: LocalModelArgs):
super().__init__()
self.dim = args.dim
self.dropout = args.dropout
self.vocab_size = args.vocab_size + args.pm_size
self.vocab_size = args.vocab_size
self.patch_size = args.patch_size
self.efficient_attn = args.efficient_attn
self.attn_impl = args.attn_impl
self.sliding_window = args.sliding_window
self.use_rope = args.use_rope
self.init_std_factor = args.init_std_factor
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
self.cross_attn_k = getattr(args, "cross_attn_k", None)
self.eos_id = args.eos_id
self.boe_id = BOE_ID
@ -54,7 +85,7 @@ class LocalModelBase(nn.Module):
self.rope = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
max_seqlen=args.max_seqlen,
)
self.pos_embeddings = None
@ -66,21 +97,15 @@ class LocalModelBase(nn.Module):
self.patch_embedding_projection = self._create_patch_projection(args)
def _should_create_patch_projection(self, args):
def _should_create_patch_projection(self, args: LocalModelArgs):
dimension_mismatch = (
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
)
# Check cross attention conditions
cross_attn_conditions = (
hasattr(args, "cross_attn_encoder")
and args.cross_attn_encoder
and getattr(args, "cross_attn_init_by_pooling")
) or (
hasattr(args, "cross_attn_decoder")
and args.cross_attn_decoder
and getattr(args, "cross_attn_init_by_pooling")
)
args.cross_attn_encoder and args.cross_attn_init_by_pooling
) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
return dimension_mismatch or cross_attn_conditions
@ -172,7 +197,7 @@ class LocalModelBase(nn.Module):
class LocalEncoder(LocalModelBase):
def __init__(self, args):
def __init__(self, args: LocalModelArgs):
super().__init__(args)
self.output_proj = (
args.patching_mode in ["entropy", "probmax"]
@ -180,7 +205,6 @@ class LocalEncoder(LocalModelBase):
self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling
self.patch_only = args.patch_only_encoder
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
@ -224,7 +248,14 @@ class LocalEncoder(LocalModelBase):
""" """
bs, seqlen = tokens.shape
if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
@ -232,7 +263,7 @@ class LocalEncoder(LocalModelBase):
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
# check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and (
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
@ -273,12 +304,10 @@ class LocalEncoder(LocalModelBase):
class LocalDecoder(LocalModelBase):
def __init__(self, args):
def __init__(self, args: LocalModelArgs):
super().__init__(args)
# Model configuration flags
self.patch_only = args.patch_only_decoder
self.expects_embeddings = args.share_encoder_decoder_emb
self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
@ -317,7 +346,14 @@ class LocalDecoder(LocalModelBase):
assert embeds is not None, "Embeddings must be provided"
if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = embeds
@ -347,7 +383,7 @@ class LocalDecoder(LocalModelBase):
)
h = h + h_cross
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)

View file

@ -1,8 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
import torch
from torch.nn.attention.flex_attention import create_block_mask
from xformers.ops import fmha
logger = logging.getLogger()
def patch_reduce(h, max_num_patches, reduction, patch_ids):
"""
@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
"""
0 0 0 1 0 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 0
-> 4 4 3 2 4 5
"""
mask = batch == eos_id
mask[:, -1] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row, col = torch.where(mask)
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [int(col[0].item() + 1)] + seqlens.tolist()
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "xformers":
if attn_bias_type is None:
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "causal":
assert sliding_window is None
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "block_causal":
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
)
elif attn_bias_type == "local_block_causal":
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
).make_local_attention(sliding_window)
else:
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
if attn_bias_type == "causal":
return "causal"
if BLT_SUPPRESS_ATTN_ERROR == 1:
logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
)
return "causal"
else:
raise ValueError(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
)
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
elif attn_impl == "fmha":

View file

@ -0,0 +1,38 @@
import fsspec
from luigi.target import FileSystem, FileSystemTarget
class FSSpecFileSystem(FileSystem):
def __init__(self, fs: fsspec.AbstractFileSystem):
self.fs = fs
def exists(self, path):
return self.fs.exists()
def remove(self, path, recursive=True, skip_trash=True):
raise NotImplementedError()
def isdir(self, path):
return self.fs.isdir(path)
def listdir(self, path):
return self.fs.ls(path)
class FSSpecTarget(FileSystemTarget):
def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None):
self.path = path
if fs is None:
self.fsspec_fs = fsspec.filesystem("file")
else:
self.fsspec_fs = fs
self._fs = None
@property
def fs(self):
if self._fs is None:
self._fs = FSSpecFileSystem(self.fsspec_fs)
return self._fs
def open(self, mode):
return self.fs.open(self.path, mode=mode)

View file

@ -23,9 +23,10 @@ from bytelatent.model.blt import (
init_embeddings,
patch_ids_from_lengths,
)
from bytelatent.model.transformer import CrossAttention
from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.tokenizers.constants import EOS_ID
from bytelatent.train import compute_loss
@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch):
def fake_batch():
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
del batch_dict["x2"]
del batch_dict["y2"]
del batch_dict["src_names"]
@ -98,18 +99,17 @@ def create_args(cross_attention=False):
recompute_attn=False,
custom_bwd=False,
layer_ckpt="none",
efficient_attn="sdpa",
patch_only_encoder=False,
patch_only_decoder=False,
use_local_encoder_transformer=True,
init_use_gaussian=True,
init_use_depth="current",
attn_bias_type="block_causal",
attn_impl="xformers",
alpha_depth="disabled",
max_length=256,
local_attention_window_len=512,
max_seqlen=12288,
downsampling_by_pooling="max",
eos_id=EOS_ID,
)
return transformer_args
@ -341,10 +341,15 @@ class TestByteLatentTransformer:
model = ByteLatentTransformer(args)
assert model is not None
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
def test_blt_transformer_forward(self, attn_type):
@pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"])
def test_blt_transformer_forward(self, attn_impl):
args = create_args()
args = args.model_copy(update=dict(efficient_attn=attn_type))
if attn_impl == "sdpa":
os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1"
else:
os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0"
args = args.model_copy(update=dict(attn_impl=attn_impl))
model = ByteLatentTransformer(args)
model = model.cuda()
batch = fake_batch()
@ -393,7 +398,9 @@ class TestByteLatentTransformer:
n_kv_heads=4,
norm_eps=1e-6,
).to("cuda")
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
mask = create_causal_mask(
x.shape[1], "flex_attention", None, sliding_window=None
)
output = cross_attention(x, kv, mask)
assert output is not None
assert output.shape == (2, 256, 512)
@ -440,7 +447,7 @@ class TestByteLatentTransformer:
def test_loss_backward(self):
args = create_args()
args = args.model_copy(update=dict(efficient_attn="sdpa"))
args = args.model_copy(update=dict(attn_impl="xformers"))
batch = fake_batch()
model = ByteLatentTransformer(args)
steps = 10

View file

@ -24,6 +24,7 @@ def test_entropy_model():
dataset_files=[ARROW_TEST_DATA],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs(
@ -38,7 +39,7 @@ def test_entropy_model():
BLT_DATA,
"entropy_model.pth",
),
)
).cuda()
preprocess_iter = PreprocessIterator(
arrow_file,
tokenizer_args=tokenizer_args,
@ -48,8 +49,10 @@ def test_entropy_model():
for example in preprocess_iter.create_iter():
tokens = torch.tensor(example.tokens).unsqueeze(0)
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
preds = entropy_model(tokens)
preds = entropy_model(tokens.cuda())
pred_entropies = entropy(preds)
assert pred_entropies.shape == expected_entropies.shape
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
assert torch.allclose(
pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5
)
break

View file

@ -644,6 +644,10 @@ def main():
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(cfg)
if train_args.debug_dynamo:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
train(train_args)

View file

@ -22,23 +22,7 @@ from bytelatent.base_transformer import (
RMSNorm,
cross_entropy,
)
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
from bytelatent.model.utils import create_causal_mask
def attention_flops_per_token(n_layers, seq_len, dim, causal):
@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer):
target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str = "sdpa",
attn_impl: str | None = None,
):
if attn_impl is None:
attn_impl = self.attn_impl
bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values)
@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer):
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
else create_causal_mask(
seqlen,
attn_impl,
self.attn_bias_type,
sliding_window=self.sliding_window,
tokens=token_values,
eos_id=self.eos_id,
)
)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)