[WIP] Changes for training entropy model and correcting attention in local models

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-16 21:51:04 +00:00
parent caec8d2621
commit 38022ac06e
12 changed files with 331 additions and 129 deletions

View file

@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs
logger = logging.getLogger()
@ -176,6 +177,10 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs()

View file

@ -4,7 +4,7 @@ from enum import Enum
from typing import Optional, Tuple, Union
import torch
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
@ -30,13 +31,14 @@ class InitStdFactor(Enum):
class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512
n_layers: int = 8
head_dim: Optional[int] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
multiple_of: int = 256
@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
rope_theta: float = 10000.0
init_base_std: Optional[float] = None
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
class Attention(nn.Module):
def __init__(
self,
@ -371,9 +390,17 @@ class Attention(nn.Module):
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "fmha":
elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape)
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
print("attn out", output.shape, "query_reshape", query_shape)
output_original_shape = output.view(query_shape)
print("Reshape success")
return output_original_shape
# This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa":
@ -545,6 +572,8 @@ class BaseTransformer(nn.Module):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
@ -552,6 +581,7 @@ class BaseTransformer(nn.Module):
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
)
self.eos_id = args.eos_id
self.layers = nn.ModuleList()
for _ in range(args.n_layers):

View file

@ -58,13 +58,13 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
efficient_attn: "sdpa"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512

View file

@ -27,6 +27,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
start_state = arrow_file.get_state()
@ -55,6 +56,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=251,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = resumed_state.build()
for example in arrow_file.create_iter():
@ -74,6 +76,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = rank_state.build()
expected_ids = []

View file

@ -15,8 +15,8 @@ from bytelatent.base_transformer import (
TransformerBlock,
)
from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.local_models import LocalDecoder, LocalEncoder
from bytelatent.model.transformer import GlobalTransformer
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.global_transformer import GlobalTransformer
from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
class ByteLatentTransformerArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Basic model configuration
seed: int = 42
vocab_size: int = -1
@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
n_heads: int = 8
# TODO: What is the purpose of this parameter?
weight_tying: bool = False
sliding_window: Optional[int] = None
# Architecture and dimensions
dim_token: int = 256
@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
recompute_attn: bool = True
custom_bwd: bool = False
layer_ckpt: str = "all"
efficient_attn: str | None = None
# Architecture options
patch_only_encoder: bool = False
patch_only_decoder: bool = False
# Initialization and attention
init_use_gaussian: bool = True
@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
# Logging
full_logging_n_layers: int = 4
# Special token config
eos_id: int | None = None
@model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self:
if (
@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
return self
class LocalEncoderArgs(ByteLatentTransformerArgs):
# Local encoder specific dimensions
n_heads_local_encoder: int = 8
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with local encoder specific values
self.dim = self.dim_local_encoder
self.n_layers = self.n_layers_local_encoder
self.n_heads = self.n_heads_local_encoder
self.cross_attn_decoder = False
self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
self.attn_bias_type = "local_block_causal"
class GlobalTransformerArgs(ByteLatentTransformerArgs):
# Global encoder specific dimensions
dim_token_emb: int | None = None
@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
# First deep copy the original args
# Replace with local encoder specific values
local_encoder_args = args.model_copy(
deep=True,
update=dict(
dim=args.dim_local_encoder,
n_layers=args.n_layers_local_encoder,
n_heads=args.n_heads_local_encoder,
dim_token_emb=get_encoder_dim_token_emb(args),
dim_patch_emb=get_encoder_dim_patch_emb(args),
cross_attn_decoder=False,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
attn_bias_type="local_block_causal",
),
local_encoder_args = LocalModelArgs(
# Updated args
dim=args.dim_local_encoder,
n_layers=args.n_layers_local_encoder,
n_heads=args.n_heads_local_encoder,
dim_token_emb=get_encoder_dim_token_emb(args),
dim_patch_emb=get_encoder_dim_patch_emb(args),
cross_attn_encoder=args.cross_attn_encoder,
cross_attn_decoder=False,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
)
return LocalEncoder(local_encoder_args)
@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
# First deep copy the original args
local_decoder_args = args.model_copy(
deep=True,
update=dict(
dim=args.dim_local_decoder,
n_layers=args.n_layers_local_decoder,
n_heads=args.n_heads_local_decoder,
cross_attn_encoder=False,
cross_attn_init_by_pooling=False, # states are already defined
dim_token_emb=get_decoder_dim_token_emb(args),
dim_patch_emb=args.dim_global,
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
),
local_decoder_args = LocalModelArgs(
dim=args.dim_local_decoder,
n_layers=args.n_layers_local_decoder,
n_heads=args.n_heads_local_decoder,
dim_token_emb=get_decoder_dim_token_emb(args),
dim_patch_emb=args.dim_global,
cross_attn_encoder=False,
cross_attn_decoder=args.cross_attn_decoder,
cross_attn_init_by_pooling=False, # states are already defined
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
# Defaults
head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
)
return LocalDecoder(local_decoder_args)
@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module):
# General configuration
self.weight_tying = args.weight_tying
self.sliding_window = args.sliding_window
self.patch_size = args.patch_size
self.patching_mode = args.patching_mode
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (

View file

@ -11,6 +11,7 @@ from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
RMSNorm,
flex_attention_comp,
repeat_kv,
@ -142,11 +143,10 @@ class CrossAttention(nn.Module):
class GlobalTransformer(BaseTransformer):
def __init__(self, args):
def __init__(self, args: BaseTransformerArgs):
super().__init__(args)
self.dropout = args.dropout
self.sliding_window = args.sliding_window
self.efficient_attn = args.efficient_attn
self.eos_id = args.eos_id
self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer):
and projection to the token space.
"""
bs, seqlen = tokens.shape
attn_impl = self.efficient_attn
h = embeds
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
else create_causal_mask(
seqlen,
self.attn_impl,
self.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
)
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer):
h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
return h, cache
def init_weights(self, init_base_std: float):

View file

@ -1,8 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict
import torch
import torch.nn
import torch.nn as nn
@ -16,29 +17,69 @@ from bytelatent.base_transformer import (
RotaryEmbedding,
TransformerBlock,
)
from bytelatent.model.transformer import CrossAttention
from bytelatent.model.global_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger()
class LocalModelArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
# Local encoder specific dimensions
head_dim: int | None
dim: int
dropout: float
vocab_size: int
patch_size: int
sliding_window: int | None
use_rope: bool
init_base_std: float | None = None
init_std_factor: InitStdFactor
cross_attn_encoder: bool | None
cross_attn_decoder: bool | None
cross_attn_k: int | None
cross_attn_init_by_pooling: bool
norm_eps: float
rope_theta: float
max_seqlen: int
ffn_dim_multiplier: float | None = None
patching_mode: str
use_local_encoder_transformer: bool
downsampling_by_pooling: str | None
encoder_hash_byte_group_size: Any | None = None
cross_attn_all_layers_encoder: bool = False
cross_attn_all_layers_decoder: bool = False
cross_attn_nheads: int | None
n_layers: int
n_heads: int
n_kv_heads: int | None = None
dim_token_emb: int
dim_patch_emb: int | None
attn_impl: str | None = "xformers"
attn_bias_type: str | None = "local_block_causal"
multiple_of: int = 256
eos_id: int | None = None
class LocalModelBase(nn.Module):
def __init__(self, args):
def __init__(self, args: LocalModelArgs):
super().__init__()
self.dim = args.dim
self.dropout = args.dropout
self.vocab_size = args.vocab_size + args.pm_size
self.vocab_size = args.vocab_size
self.patch_size = args.patch_size
self.efficient_attn = args.efficient_attn
self.attn_impl = args.attn_impl
self.sliding_window = args.sliding_window
self.use_rope = args.use_rope
self.init_std_factor = args.init_std_factor
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
self.cross_attn_k = getattr(args, "cross_attn_k", None)
self.eos_id = args.eos_id
self.boe_id = BOE_ID
@ -54,7 +95,7 @@ class LocalModelBase(nn.Module):
self.rope = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
max_seqlen=args.max_seqlen,
)
self.pos_embeddings = None
@ -66,21 +107,15 @@ class LocalModelBase(nn.Module):
self.patch_embedding_projection = self._create_patch_projection(args)
def _should_create_patch_projection(self, args):
def _should_create_patch_projection(self, args: LocalModelArgs):
dimension_mismatch = (
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
)
# Check cross attention conditions
cross_attn_conditions = (
hasattr(args, "cross_attn_encoder")
and args.cross_attn_encoder
and getattr(args, "cross_attn_init_by_pooling")
) or (
hasattr(args, "cross_attn_decoder")
and args.cross_attn_decoder
and getattr(args, "cross_attn_init_by_pooling")
)
args.cross_attn_encoder and args.cross_attn_init_by_pooling
) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
return dimension_mismatch or cross_attn_conditions
@ -172,7 +207,7 @@ class LocalModelBase(nn.Module):
class LocalEncoder(LocalModelBase):
def __init__(self, args):
def __init__(self, args: LocalModelArgs):
super().__init__(args)
self.output_proj = (
args.patching_mode in ["entropy", "probmax"]
@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase):
self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling
self.patch_only = args.patch_only_encoder
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase):
""" """
bs, seqlen = tokens.shape
if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase):
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
# check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and (
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase):
class LocalDecoder(LocalModelBase):
def __init__(self, args):
def __init__(self, args: LocalModelArgs):
super().__init__(args)
# Model configuration flags
self.patch_only = args.patch_only_decoder
self.expects_embeddings = args.share_encoder_decoder_emb
self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase):
assert embeds is not None, "Embeddings must be provided"
if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = embeds
@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase):
)
h = h + h_cross
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)

View file

@ -1,8 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import torch
from torch.nn.attention.flex_attention import create_block_mask
from xformers.ops import fmha
logger = logging.getLogger()
def patch_reduce(h, max_num_patches, reduction, patch_ids):
"""
@ -97,14 +100,72 @@ def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
"""
0 0 0 1 0 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 0
-> 4 4 3 2 4 5
"""
mask = batch == eos_id
mask[:, -1] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row, col = torch.where(mask)
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [int(col[0].item() + 1)] + seqlens.tolist()
WARNED_SDPA = False
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "xformers":
if attn_bias_type is None:
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "causal":
assert sliding_window is None
print("attn: causal")
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "block_causal":
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
print("attn: block_causal")
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
)
elif attn_bias_type == "local_block_causal":
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
print("attn: local_block_causal")
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
).make_local_attention(sliding_window)
else:
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "sdpa":
global WARNED_SDPA
if not WARNED_SDPA:
logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention."
)
WARNED_SDPA = True
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)

View file

@ -0,0 +1,38 @@
import fsspec
from luigi.target import FileSystem, FileSystemTarget
class FSSpecFileSystem(FileSystem):
def __init__(self, fs: fsspec.AbstractFileSystem):
self.fs = fs
def exists(self, path):
return self.fs.exists()
def remove(self, path, recursive=True, skip_trash=True):
raise NotImplementedError()
def isdir(self, path):
return self.fs.isdir(path)
def listdir(self, path):
return self.fs.ls(path)
class FSSpecTarget(FileSystemTarget):
def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None):
self.path = path
if fs is None:
self.fsspec_fs = fsspec.filesystem("file")
else:
self.fsspec_fs = fs
self._fs = None
@property
def fs(self):
if self._fs is None:
self._fs = FSSpecFileSystem(self.fsspec_fs)
return self._fs
def open(self, mode):
return self.fs.open(self.path, mode=mode)

View file

@ -23,9 +23,10 @@ from bytelatent.model.blt import (
init_embeddings,
patch_ids_from_lengths,
)
from bytelatent.model.transformer import CrossAttention
from bytelatent.model.global_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.tokenizers.constants import EOS_ID
from bytelatent.train import compute_loss
@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch):
def fake_batch():
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
del batch_dict["x2"]
del batch_dict["y2"]
del batch_dict["src_names"]
@ -98,18 +99,17 @@ def create_args(cross_attention=False):
recompute_attn=False,
custom_bwd=False,
layer_ckpt="none",
efficient_attn="sdpa",
patch_only_encoder=False,
patch_only_decoder=False,
use_local_encoder_transformer=True,
init_use_gaussian=True,
init_use_depth="current",
attn_bias_type="block_causal",
attn_impl="xformers",
alpha_depth="disabled",
max_length=256,
local_attention_window_len=512,
max_seqlen=12288,
downsampling_by_pooling="max",
eos_id=EOS_ID,
)
return transformer_args
@ -341,10 +341,10 @@ class TestByteLatentTransformer:
model = ByteLatentTransformer(args)
assert model is not None
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
def test_blt_transformer_forward(self, attn_type):
@pytest.mark.parametrize("attn_impl", ["fmha", "sdpa", "xformers"])
def test_blt_transformer_forward(self, attn_impl):
args = create_args()
args = args.model_copy(update=dict(efficient_attn=attn_type))
args = args.model_copy(update=dict(attn_impl=attn_impl))
model = ByteLatentTransformer(args)
model = model.cuda()
batch = fake_batch()
@ -393,7 +393,9 @@ class TestByteLatentTransformer:
n_kv_heads=4,
norm_eps=1e-6,
).to("cuda")
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
mask = create_causal_mask(
x.shape[1], "flex_attention", None, sliding_window=None
)
output = cross_attention(x, kv, mask)
assert output is not None
assert output.shape == (2, 256, 512)
@ -440,7 +442,7 @@ class TestByteLatentTransformer:
def test_loss_backward(self):
args = create_args()
args = args.model_copy(update=dict(efficient_attn="sdpa"))
args = args.model_copy(update=dict(attn_impl="sdpa"))
batch = fake_batch()
model = ByteLatentTransformer(args)
steps = 10

View file

@ -24,6 +24,7 @@ def test_entropy_model():
dataset_files=[ARROW_TEST_DATA],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs(

View file

@ -22,23 +22,7 @@ from bytelatent.base_transformer import (
RMSNorm,
cross_entropy,
)
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
from bytelatent.model.utils import create_causal_mask
def attention_flops_per_token(n_layers, seq_len, dim, causal):
@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer):
target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str = "sdpa",
attn_impl: str | None = None,
):
if attn_impl is None:
attn_impl = self.attn_impl
bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values)
@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer):
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
else create_causal_mask(
seqlen,
attn_impl,
self.attn_bias_type,
sliding_window=self.sliding_window,
tokens=token_values,
eos_id=self.eos_id,
)
)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)