Merge 38022ac06e into sapling-pr-archive-EntilZha
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-01-16 13:51:17 -08:00 committed by GitHub
commit 020cf16c1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 341 additions and 133 deletions

View file

@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs
logger = logging.getLogger() logger = logging.getLogger()
@ -176,6 +177,10 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs() data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs() optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs() distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs() env: EnvironmentArgs = EnvironmentArgs()

View file

@ -1,10 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
@ -15,8 +15,12 @@ from torch.nn.attention.flex_attention import (
from xformers.ops import AttentionBias, fmha from xformers.ops import AttentionBias, fmha
from bytelatent import probe from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
flex_attention_comp = torch.compile(flex_attention) if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
else:
flex_attention_comp = None
class InitStdFactor(Enum): class InitStdFactor(Enum):
@ -27,13 +31,14 @@ class InitStdFactor(Enum):
class BaseTransformerArgs(BaseModel): class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512 dim: int = 512
n_layers: int = 8 n_layers: int = 8
head_dim: Optional[int] = None head_dim: int | None = None
n_heads: Optional[int] = None n_heads: int | None = None
n_kv_heads: Optional[int] = None n_kv_heads: int | None = None
ffn_dim_multiplier: Optional[float] = None ffn_dim_multiplier: float | None = None
multiple_of: int = 256 multiple_of: int = 256
@ -41,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
rope_theta: float = 10000.0 rope_theta: float = 10000.0
init_base_std: Optional[float] = None init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024 max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
def cross_entropy(pred, target, **kwargs): def cross_entropy(pred, target, **kwargs):
return F.nll_loss( return F.nll_loss(
@ -291,6 +301,18 @@ class RMSNorm(nn.Module):
torch.nn.init.ones_(self.weight) # type: ignore torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@ -368,9 +390,17 @@ class Attention(nn.Module):
output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "fmha": elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias) assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape)
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
print("attn out", output.shape, "query_reshape", query_shape)
output_original_shape = output.view(query_shape)
print("Reshape success")
return output_original_shape
# This uses B S H D instead of B H S D of pytorch # This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa": elif attn_impl == "sdpa":
@ -542,6 +572,8 @@ class BaseTransformer(nn.Module):
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
self.init_base_std = args.init_base_std self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor) self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding( self.rope_embeddings = RotaryEmbedding(
@ -549,6 +581,7 @@ class BaseTransformer(nn.Module):
head_dim=args.head_dim or args.dim // args.n_heads, head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen, max_seqlen=args.max_seqlen,
) )
self.eos_id = args.eos_id
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
for _ in range(args.n_layers): for _ in range(args.n_layers):

View file

@ -58,13 +58,13 @@ model:
recompute_attn: false recompute_attn: false
custom_bwd: false custom_bwd: false
layer_ckpt: "none" layer_ckpt: "none"
efficient_attn: "sdpa"
patch_only_encoder: false patch_only_encoder: false
patch_only_decoder: false patch_only_decoder: false
use_local_encoder_transformer: true use_local_encoder_transformer: true
init_use_gaussian: true init_use_gaussian: true
init_use_depth: "current" init_use_depth: "current"
attn_bias_type: "block_causal" attn_bias_type: "block_causal"
attn_impl: "xformers"
alpha_depth: "disabled" alpha_depth: "disabled"
max_length: 256 max_length: 256
local_attention_window_len: 512 local_attention_window_len: 512

View file

@ -27,6 +27,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1], dataset_files=[ARROW_TEST_DATA_1],
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = initial_state.build() arrow_file = initial_state.build()
start_state = arrow_file.get_state() start_state = arrow_file.get_state()
@ -55,6 +56,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1], dataset_files=[ARROW_TEST_DATA_1],
row_num=251, row_num=251,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = resumed_state.build() arrow_file = resumed_state.build()
for example in arrow_file.create_iter(): for example in arrow_file.create_iter():
@ -74,6 +76,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1], dataset_files=[ARROW_TEST_DATA_1],
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = rank_state.build() arrow_file = rank_state.build()
expected_ids = [] expected_ids = []

View file

@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import atexit import atexit
import contextlib import contextlib
import logging import logging
@ -48,9 +47,13 @@ default_no_recompute_ops = {
torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default, torch.ops.c10d_functional.reduce_scatter_tensor.default,
torch.ops.xformers_flash.flash_fwd.default, torch.ops.xformers_flash.flash_fwd.default,
torch.ops.xformers.efficient_attention_forward_cutlass.default,
} }
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
default_no_recompute_ops.add(
torch.ops.xformers.efficient_attention_forward_cutlass.default
)
class DistributedArgs(BaseModel): class DistributedArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")

View file

@ -15,8 +15,8 @@ from bytelatent.base_transformer import (
TransformerBlock, TransformerBlock,
) )
from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.local_models import LocalDecoder, LocalEncoder from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.transformer import GlobalTransformer from bytelatent.model.global_transformer import GlobalTransformer
from bytelatent.model.utils import downsample from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
class ByteLatentTransformerArgs(BaseTransformerArgs): class ByteLatentTransformerArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")
# Basic model configuration # Basic model configuration
seed: int = 42 seed: int = 42
vocab_size: int = -1 vocab_size: int = -1
@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
n_heads: int = 8 n_heads: int = 8
# TODO: What is the purpose of this parameter? # TODO: What is the purpose of this parameter?
weight_tying: bool = False weight_tying: bool = False
sliding_window: Optional[int] = None
# Architecture and dimensions # Architecture and dimensions
dim_token: int = 256 dim_token: int = 256
@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
recompute_attn: bool = True recompute_attn: bool = True
custom_bwd: bool = False custom_bwd: bool = False
layer_ckpt: str = "all" layer_ckpt: str = "all"
efficient_attn: str | None = None
# Architecture options
patch_only_encoder: bool = False
patch_only_decoder: bool = False
# Initialization and attention # Initialization and attention
init_use_gaussian: bool = True init_use_gaussian: bool = True
@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
# Logging # Logging
full_logging_n_layers: int = 4 full_logging_n_layers: int = 4
# Special token config
eos_id: int | None = None
@model_validator(mode="after") @model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self: def check_hash_byte_sizes(self) -> Self:
if ( if (
@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
return self return self
class LocalEncoderArgs(ByteLatentTransformerArgs):
# Local encoder specific dimensions
n_heads_local_encoder: int = 8
dim_token_emb: int | None = None
dim_patch_emb: int | None = None
def __post_init__(self):
# Override base args with local encoder specific values
self.dim = self.dim_local_encoder
self.n_layers = self.n_layers_local_encoder
self.n_heads = self.n_heads_local_encoder
self.cross_attn_decoder = False
self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
self.attn_bias_type = "local_block_causal"
class GlobalTransformerArgs(ByteLatentTransformerArgs): class GlobalTransformerArgs(ByteLatentTransformerArgs):
# Global encoder specific dimensions # Global encoder specific dimensions
dim_token_emb: int | None = None dim_token_emb: int | None = None
@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
# First deep copy the original args local_encoder_args = LocalModelArgs(
# Replace with local encoder specific values # Updated args
local_encoder_args = args.model_copy( dim=args.dim_local_encoder,
deep=True, n_layers=args.n_layers_local_encoder,
update=dict( n_heads=args.n_heads_local_encoder,
dim=args.dim_local_encoder, dim_token_emb=get_encoder_dim_token_emb(args),
n_layers=args.n_layers_local_encoder, dim_patch_emb=get_encoder_dim_patch_emb(args),
n_heads=args.n_heads_local_encoder, cross_attn_encoder=args.cross_attn_encoder,
dim_token_emb=get_encoder_dim_token_emb(args), cross_attn_decoder=False,
dim_patch_emb=get_encoder_dim_patch_emb(args), cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
cross_attn_decoder=False, cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, # Defaults
attn_bias_type="local_block_causal", head_dim=args.head_dim,
), max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
) )
return LocalEncoder(local_encoder_args) return LocalEncoder(local_encoder_args)
@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
# First deep copy the original args # First deep copy the original args
local_decoder_args = args.model_copy( local_decoder_args = LocalModelArgs(
deep=True, dim=args.dim_local_decoder,
update=dict( n_layers=args.n_layers_local_decoder,
dim=args.dim_local_decoder, n_heads=args.n_heads_local_decoder,
n_layers=args.n_layers_local_decoder, dim_token_emb=get_decoder_dim_token_emb(args),
n_heads=args.n_heads_local_decoder, dim_patch_emb=args.dim_global,
cross_attn_encoder=False, cross_attn_encoder=False,
cross_attn_init_by_pooling=False, # states are already defined cross_attn_decoder=args.cross_attn_decoder,
dim_token_emb=get_decoder_dim_token_emb(args), cross_attn_init_by_pooling=False, # states are already defined
dim_patch_emb=args.dim_global, cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, # Defaults
), head_dim=args.head_dim,
max_seqlen=args.max_encoder_seq_length,
dropout=args.dropout,
vocab_size=args.vocab_size + args.pm_size,
norm_eps=args.norm_eps,
patch_size=args.patch_size,
sliding_window=args.local_attention_window_len,
use_rope=args.use_rope,
rope_theta=args.rope_theta,
init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads,
attn_impl=args.attn_impl,
attn_bias_type="local_block_causal",
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
patching_mode=args.patching_mode,
use_local_encoder_transformer=args.use_local_encoder_transformer,
downsampling_by_pooling=args.downsampling_by_pooling,
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
) )
return LocalDecoder(local_decoder_args) return LocalDecoder(local_decoder_args)
@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module):
# General configuration # General configuration
self.weight_tying = args.weight_tying self.weight_tying = args.weight_tying
self.sliding_window = args.sliding_window
self.patch_size = args.patch_size self.patch_size = args.patch_size
self.patching_mode = args.patching_mode self.patching_mode = args.patching_mode
self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( self.boe_id, self.bos_id, self.pad_id, self.eos_id = (

View file

@ -11,6 +11,7 @@ from xformers.ops import AttentionBias
from bytelatent.base_transformer import ( from bytelatent.base_transformer import (
BaseTransformer, BaseTransformer,
BaseTransformerArgs,
RMSNorm, RMSNorm,
flex_attention_comp, flex_attention_comp,
repeat_kv, repeat_kv,
@ -142,11 +143,10 @@ class CrossAttention(nn.Module):
class GlobalTransformer(BaseTransformer): class GlobalTransformer(BaseTransformer):
def __init__(self, args): def __init__(self, args: BaseTransformerArgs):
super().__init__(args) super().__init__(args)
self.dropout = args.dropout self.dropout = args.dropout
self.sliding_window = args.sliding_window self.eos_id = args.eos_id
self.efficient_attn = args.efficient_attn
self.token_embedding_projection = None self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim: if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer):
and projection to the token space. and projection to the token space.
""" """
bs, seqlen = tokens.shape bs, seqlen = tokens.shape
attn_impl = self.efficient_attn
h = embeds h = embeds
mask = ( mask = (
mask mask
if mask is not None if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window) else create_causal_mask(
seqlen,
self.attn_impl,
self.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
) )
if self.token_embedding_projection is not None and h.shape[-1] != self.dim: if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer):
h = F.dropout(h, p=self.dropout, training=self.training) h = F.dropout(h, p=self.dropout, training=self.training)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
return h, cache return h, cache
def init_weights(self, init_base_std: float): def init_weights(self, init_base_std: float):

View file

@ -1,8 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import logging import logging
from typing import List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict
import torch import torch
import torch.nn import torch.nn
import torch.nn as nn import torch.nn as nn
@ -16,29 +17,69 @@ from bytelatent.base_transformer import (
RotaryEmbedding, RotaryEmbedding,
TransformerBlock, TransformerBlock,
) )
from bytelatent.model.transformer import CrossAttention from bytelatent.model.global_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger() logger = logging.getLogger()
class LocalModelArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
# Local encoder specific dimensions
head_dim: int | None
dim: int
dropout: float
vocab_size: int
patch_size: int
sliding_window: int | None
use_rope: bool
init_base_std: float | None = None
init_std_factor: InitStdFactor
cross_attn_encoder: bool | None
cross_attn_decoder: bool | None
cross_attn_k: int | None
cross_attn_init_by_pooling: bool
norm_eps: float
rope_theta: float
max_seqlen: int
ffn_dim_multiplier: float | None = None
patching_mode: str
use_local_encoder_transformer: bool
downsampling_by_pooling: str | None
encoder_hash_byte_group_size: Any | None = None
cross_attn_all_layers_encoder: bool = False
cross_attn_all_layers_decoder: bool = False
cross_attn_nheads: int | None
n_layers: int
n_heads: int
n_kv_heads: int | None = None
dim_token_emb: int
dim_patch_emb: int | None
attn_impl: str | None = "xformers"
attn_bias_type: str | None = "local_block_causal"
multiple_of: int = 256
eos_id: int | None = None
class LocalModelBase(nn.Module): class LocalModelBase(nn.Module):
def __init__(self, args): def __init__(self, args: LocalModelArgs):
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
self.dropout = args.dropout self.dropout = args.dropout
self.vocab_size = args.vocab_size + args.pm_size self.vocab_size = args.vocab_size
self.patch_size = args.patch_size self.patch_size = args.patch_size
self.efficient_attn = args.efficient_attn self.attn_impl = args.attn_impl
self.sliding_window = args.sliding_window self.sliding_window = args.sliding_window
self.use_rope = args.use_rope self.use_rope = args.use_rope
self.init_std_factor = args.init_std_factor self.init_std_factor = args.init_std_factor
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
self.cross_attn_k = getattr(args, "cross_attn_k", None) self.cross_attn_k = getattr(args, "cross_attn_k", None)
self.eos_id = args.eos_id
self.boe_id = BOE_ID self.boe_id = BOE_ID
@ -54,7 +95,7 @@ class LocalModelBase(nn.Module):
self.rope = RotaryEmbedding( self.rope = RotaryEmbedding(
theta=args.rope_theta, theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads, head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), max_seqlen=args.max_seqlen,
) )
self.pos_embeddings = None self.pos_embeddings = None
@ -66,21 +107,15 @@ class LocalModelBase(nn.Module):
self.patch_embedding_projection = self._create_patch_projection(args) self.patch_embedding_projection = self._create_patch_projection(args)
def _should_create_patch_projection(self, args): def _should_create_patch_projection(self, args: LocalModelArgs):
dimension_mismatch = ( dimension_mismatch = (
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
) )
# Check cross attention conditions # Check cross attention conditions
cross_attn_conditions = ( cross_attn_conditions = (
hasattr(args, "cross_attn_encoder") args.cross_attn_encoder and args.cross_attn_init_by_pooling
and args.cross_attn_encoder ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
and getattr(args, "cross_attn_init_by_pooling")
) or (
hasattr(args, "cross_attn_decoder")
and args.cross_attn_decoder
and getattr(args, "cross_attn_init_by_pooling")
)
return dimension_mismatch or cross_attn_conditions return dimension_mismatch or cross_attn_conditions
@ -172,7 +207,7 @@ class LocalModelBase(nn.Module):
class LocalEncoder(LocalModelBase): class LocalEncoder(LocalModelBase):
def __init__(self, args): def __init__(self, args: LocalModelArgs):
super().__init__(args) super().__init__(args)
self.output_proj = ( self.output_proj = (
args.patching_mode in ["entropy", "probmax"] args.patching_mode in ["entropy", "probmax"]
@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase):
self.apply_transformer = args.use_local_encoder_transformer self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling self.downsampling_by_pooling = args.downsampling_by_pooling
self.patch_only = args.patch_only_encoder
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_encoder = args.cross_attn_encoder
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase):
""" """ """ """
bs, seqlen = tokens.shape bs, seqlen = tokens.shape
if mask is None: if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = self.apply_embedding(tokens, embeds) h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase):
h = F.dropout(h, p=self.dropout, training=self.training) h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
# check if cross attention should be applied to either all layer or only the last layer # check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and ( if self.cross_attn_encoder and (
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase):
class LocalDecoder(LocalModelBase): class LocalDecoder(LocalModelBase):
def __init__(self, args): def __init__(self, args: LocalModelArgs):
super().__init__(args) super().__init__(args)
# Model configuration flags # Model configuration flags
self.patch_only = args.patch_only_decoder
self.expects_embeddings = args.share_encoder_decoder_emb
self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_decoder = args.cross_attn_decoder
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase):
assert embeds is not None, "Embeddings must be provided" assert embeds is not None, "Embeddings must be provided"
if mask is None: if mask is None:
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = embeds h = embeds
@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase):
) )
h = h + h_cross h = h + h_cross
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
h_preds = self.norm(h) h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)

View file

@ -1,8 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import torch import torch
from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import create_block_mask
from xformers.ops import fmha from xformers.ops import fmha
logger = logging.getLogger()
def patch_reduce(h, max_num_patches, reduction, patch_ids): def patch_reduce(h, max_num_patches, reduction, patch_ids):
""" """
@ -97,14 +100,72 @@ def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx return q_idx >= kv_idx
def create_causal_mask(seqlen, attn_impl, sliding_window): def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
if sliding_window is not None and attn_impl == "xformers": """
return fmha.attn_bias.LocalAttentionFromBottomRightMask( 0 0 0 1 0 0 0 1 0 0 0
window_left=sliding_window - 1, window_right=0 0 1 0 0 0 1 0 0 0 0 0
) -> 4 4 3 2 4 5
elif attn_impl == "xformers": """
return fmha.attn_bias.LowerTriangularMask() mask = batch == eos_id
mask[:, -1] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row, col = torch.where(mask)
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [int(col[0].item() + 1)] + seqlens.tolist()
WARNED_SDPA = False
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "xformers":
if attn_bias_type is None:
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "causal":
assert sliding_window is None
print("attn: causal")
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "block_causal":
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
print("attn: block_causal")
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
)
elif attn_bias_type == "local_block_causal":
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
print("attn: local_block_causal")
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
).make_local_attention(sliding_window)
else:
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "sdpa": elif attn_impl == "sdpa":
global WARNED_SDPA
if not WARNED_SDPA:
logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention."
)
WARNED_SDPA = True
return "causal" return "causal"
elif attn_impl == "flex_attention": elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen) return create_block_mask(causal_mask, None, None, seqlen, seqlen)

View file

@ -0,0 +1,38 @@
import fsspec
from luigi.target import FileSystem, FileSystemTarget
class FSSpecFileSystem(FileSystem):
def __init__(self, fs: fsspec.AbstractFileSystem):
self.fs = fs
def exists(self, path):
return self.fs.exists()
def remove(self, path, recursive=True, skip_trash=True):
raise NotImplementedError()
def isdir(self, path):
return self.fs.isdir(path)
def listdir(self, path):
return self.fs.ls(path)
class FSSpecTarget(FileSystemTarget):
def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None):
self.path = path
if fs is None:
self.fsspec_fs = fsspec.filesystem("file")
else:
self.fsspec_fs = fs
self._fs = None
@property
def fs(self):
if self._fs is None:
self._fs = FSSpecFileSystem(self.fsspec_fs)
return self._fs
def open(self, mode):
return self.fs.open(self.path, mode=mode)

View file

@ -23,9 +23,10 @@ from bytelatent.model.blt import (
init_embeddings, init_embeddings,
patch_ids_from_lengths, patch_ids_from_lengths,
) )
from bytelatent.model.transformer import CrossAttention from bytelatent.model.global_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.tokenizers.constants import EOS_ID
from bytelatent.train import compute_loss from bytelatent.train import compute_loss
@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch):
def fake_batch(): def fake_batch():
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
del batch_dict["x2"] del batch_dict["x2"]
del batch_dict["y2"] del batch_dict["y2"]
del batch_dict["src_names"] del batch_dict["src_names"]
@ -98,18 +99,17 @@ def create_args(cross_attention=False):
recompute_attn=False, recompute_attn=False,
custom_bwd=False, custom_bwd=False,
layer_ckpt="none", layer_ckpt="none",
efficient_attn="sdpa",
patch_only_encoder=False,
patch_only_decoder=False,
use_local_encoder_transformer=True, use_local_encoder_transformer=True,
init_use_gaussian=True, init_use_gaussian=True,
init_use_depth="current", init_use_depth="current",
attn_bias_type="block_causal", attn_bias_type="block_causal",
attn_impl="xformers",
alpha_depth="disabled", alpha_depth="disabled",
max_length=256, max_length=256,
local_attention_window_len=512, local_attention_window_len=512,
max_seqlen=12288, max_seqlen=12288,
downsampling_by_pooling="max", downsampling_by_pooling="max",
eos_id=EOS_ID,
) )
return transformer_args return transformer_args
@ -341,10 +341,10 @@ class TestByteLatentTransformer:
model = ByteLatentTransformer(args) model = ByteLatentTransformer(args)
assert model is not None assert model is not None
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) @pytest.mark.parametrize("attn_impl", ["fmha", "sdpa", "xformers"])
def test_blt_transformer_forward(self, attn_type): def test_blt_transformer_forward(self, attn_impl):
args = create_args() args = create_args()
args = args.model_copy(update=dict(efficient_attn=attn_type)) args = args.model_copy(update=dict(attn_impl=attn_impl))
model = ByteLatentTransformer(args) model = ByteLatentTransformer(args)
model = model.cuda() model = model.cuda()
batch = fake_batch() batch = fake_batch()
@ -393,7 +393,9 @@ class TestByteLatentTransformer:
n_kv_heads=4, n_kv_heads=4,
norm_eps=1e-6, norm_eps=1e-6,
).to("cuda") ).to("cuda")
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) mask = create_causal_mask(
x.shape[1], "flex_attention", None, sliding_window=None
)
output = cross_attention(x, kv, mask) output = cross_attention(x, kv, mask)
assert output is not None assert output is not None
assert output.shape == (2, 256, 512) assert output.shape == (2, 256, 512)
@ -440,7 +442,7 @@ class TestByteLatentTransformer:
def test_loss_backward(self): def test_loss_backward(self):
args = create_args() args = create_args()
args = args.model_copy(update=dict(efficient_attn="sdpa")) args = args.model_copy(update=dict(attn_impl="sdpa"))
batch = fake_batch() batch = fake_batch()
model = ByteLatentTransformer(args) model = ByteLatentTransformer(args)
steps = 10 steps = 10

View file

@ -24,6 +24,7 @@ def test_entropy_model():
dataset_files=[ARROW_TEST_DATA], dataset_files=[ARROW_TEST_DATA],
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None,
) )
arrow_file = initial_state.build() arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs( tokenizer_args = TokenizerArgs(

View file

@ -22,23 +22,7 @@ from bytelatent.base_transformer import (
RMSNorm, RMSNorm,
cross_entropy, cross_entropy,
) )
from bytelatent.model.utils import create_causal_mask
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
def attention_flops_per_token(n_layers, seq_len, dim, causal): def attention_flops_per_token(n_layers, seq_len, dim, causal):
@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer):
target: Optional[torch.Tensor] = None, target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str = "sdpa", attn_impl: str | None = None,
): ):
if attn_impl is None:
attn_impl = self.attn_impl
bsz, seqlen = token_values.shape bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values) h = self.tok_embeddings(token_values)
@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer):
mask = ( mask = (
mask mask
if mask is not None if mask is not None
else create_causal_mask(seqlen, attn_impl, self.sliding_window) else create_causal_mask(
seqlen,
attn_impl,
self.attn_bias_type,
sliding_window=self.sliding_window,
tokens=token_values,
eos_id=self.eos_id,
)
) )
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)