mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 08:27:45 +00:00
Changes for training entropy model and correcting attention in local models (#25)
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
This commit is contained in:
parent
caec8d2621
commit
6ffeb66b53
|
@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs
|
||||||
from bytelatent.optim import OptimArgs
|
from bytelatent.optim import OptimArgs
|
||||||
from bytelatent.profiling import ProfilerArgs
|
from bytelatent.profiling import ProfilerArgs
|
||||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
from bytelatent.transformer import LMTransformerArgs
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
@ -163,6 +164,8 @@ class TrainArgs(BaseModel):
|
||||||
|
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
|
|
||||||
|
debug_dynamo: bool = False
|
||||||
|
|
||||||
# Number of gradient accumulation steps
|
# Number of gradient accumulation steps
|
||||||
# Total batch size is batch_size*grad_acc_steps
|
# Total batch size is batch_size*grad_acc_steps
|
||||||
grad_acc_steps: int = 1
|
grad_acc_steps: int = 1
|
||||||
|
@ -176,6 +179,10 @@ class TrainArgs(BaseModel):
|
||||||
data: DataloaderArgs = DataloaderArgs()
|
data: DataloaderArgs = DataloaderArgs()
|
||||||
optim: OptimArgs = OptimArgs()
|
optim: OptimArgs = OptimArgs()
|
||||||
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
||||||
|
# This is only needed for training the entropy model
|
||||||
|
entropy_model: LMTransformerArgs | None = None
|
||||||
|
# Instead of training main model, train entropy model
|
||||||
|
train_entropy_model: bool = False
|
||||||
distributed: DistributedArgs = DistributedArgs()
|
distributed: DistributedArgs = DistributedArgs()
|
||||||
env: EnvironmentArgs = EnvironmentArgs()
|
env: EnvironmentArgs = EnvironmentArgs()
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from enum import Enum
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
|
@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
|
||||||
from xformers.ops import AttentionBias, fmha
|
from xformers.ops import AttentionBias, fmha
|
||||||
|
|
||||||
from bytelatent import probe
|
from bytelatent import probe
|
||||||
|
from bytelatent.tokenizers.constants import EOS_ID
|
||||||
|
|
||||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||||
flex_attention_comp = torch.compile(flex_attention)
|
flex_attention_comp = torch.compile(flex_attention)
|
||||||
|
@ -30,13 +31,14 @@ class InitStdFactor(Enum):
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformerArgs(BaseModel):
|
class BaseTransformerArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
dim: int = 512
|
dim: int = 512
|
||||||
n_layers: int = 8
|
n_layers: int = 8
|
||||||
head_dim: Optional[int] = None
|
head_dim: int | None = None
|
||||||
n_heads: Optional[int] = None
|
n_heads: int | None = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: int | None = None
|
||||||
|
|
||||||
ffn_dim_multiplier: Optional[float] = None
|
ffn_dim_multiplier: float | None = None
|
||||||
|
|
||||||
multiple_of: int = 256
|
multiple_of: int = 256
|
||||||
|
|
||||||
|
@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
|
||||||
|
|
||||||
rope_theta: float = 10000.0
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
init_base_std: Optional[float] = None
|
init_base_std: float | None = None
|
||||||
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
||||||
|
|
||||||
max_seqlen: int = 1024
|
max_seqlen: int = 1024
|
||||||
|
|
||||||
|
attn_impl: str | None = "sdpa"
|
||||||
|
attn_bias_type: str | None = None
|
||||||
|
# Special token config
|
||||||
|
eos_id: int | None = EOS_ID
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(pred, target, **kwargs):
|
def cross_entropy(pred, target, **kwargs):
|
||||||
return F.nll_loss(
|
return F.nll_loss(
|
||||||
|
@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
|
||||||
torch.nn.init.ones_(self.weight) # type: ignore
|
torch.nn.init.ones_(self.weight) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def _reshape_for_attn_bias(
|
||||||
|
attn_bias: AttentionBias | None,
|
||||||
|
*tensors: torch.Tensor,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
to_transform = list(tensors)
|
||||||
|
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
|
||||||
|
# could be `view` instead of reshape during training, but for inference
|
||||||
|
# have to reshape due to strides mismatch
|
||||||
|
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
|
||||||
|
return to_transform
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -371,9 +390,12 @@ class Attention(nn.Module):
|
||||||
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
||||||
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
||||||
|
|
||||||
elif attn_impl == "fmha":
|
elif attn_impl == "xformers":
|
||||||
assert mask is None or isinstance(mask, AttentionBias)
|
assert mask is None or isinstance(mask, AttentionBias)
|
||||||
|
query_shape = xq.shape
|
||||||
|
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
|
||||||
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
|
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
|
||||||
|
output = output.view(query_shape)
|
||||||
# This uses B S H D instead of B H S D of pytorch
|
# This uses B S H D instead of B H S D of pytorch
|
||||||
|
|
||||||
elif attn_impl == "sdpa":
|
elif attn_impl == "sdpa":
|
||||||
|
@ -522,14 +544,16 @@ class TransformerBlock(nn.Module):
|
||||||
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
||||||
attn_impl: str = "sdpa",
|
attn_impl: str = "sdpa",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
h = x + self.attention(
|
attn_out = self.attention(
|
||||||
self.attention_norm(x),
|
self.attention_norm(x),
|
||||||
freq_cis,
|
freq_cis,
|
||||||
tok_idx=tok_idx,
|
tok_idx=tok_idx,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
attn_impl=attn_impl,
|
attn_impl=attn_impl,
|
||||||
)
|
)
|
||||||
out = h + self.feed_forward(self.ffn_norm(h))
|
h = x + attn_out
|
||||||
|
h_norm = self.ffn_norm(h)
|
||||||
|
out = h + self.feed_forward(h_norm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def init_weights(self, init_std=None, factor=1.0):
|
def init_weights(self, init_std=None, factor=1.0):
|
||||||
|
@ -545,6 +569,8 @@ class BaseTransformer(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
self.init_base_std = args.init_base_std
|
self.init_base_std = args.init_base_std
|
||||||
|
self.attn_impl = args.attn_impl
|
||||||
|
self.attn_bias_type = args.attn_bias_type
|
||||||
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
||||||
self.max_seqlen = args.max_seqlen
|
self.max_seqlen = args.max_seqlen
|
||||||
self.rope_embeddings = RotaryEmbedding(
|
self.rope_embeddings = RotaryEmbedding(
|
||||||
|
@ -552,6 +578,7 @@ class BaseTransformer(nn.Module):
|
||||||
head_dim=args.head_dim or args.dim // args.n_heads,
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||||
max_seqlen=args.max_seqlen,
|
max_seqlen=args.max_seqlen,
|
||||||
)
|
)
|
||||||
|
self.eos_id = args.eos_id
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
for _ in range(args.n_layers):
|
for _ in range(args.n_layers):
|
||||||
|
|
|
@ -15,7 +15,6 @@ optim:
|
||||||
|
|
||||||
distributed:
|
distributed:
|
||||||
fsdp_type: full_shard
|
fsdp_type: full_shard
|
||||||
compile: true
|
|
||||||
model_dtype: bf16
|
model_dtype: bf16
|
||||||
matmul_allow_tf32: false
|
matmul_allow_tf32: false
|
||||||
selective_activation_checkpointing: false
|
selective_activation_checkpointing: false
|
||||||
|
@ -58,13 +57,13 @@ model:
|
||||||
recompute_attn: false
|
recompute_attn: false
|
||||||
custom_bwd: false
|
custom_bwd: false
|
||||||
layer_ckpt: "none"
|
layer_ckpt: "none"
|
||||||
efficient_attn: "sdpa"
|
|
||||||
patch_only_encoder: false
|
patch_only_encoder: false
|
||||||
patch_only_decoder: false
|
patch_only_decoder: false
|
||||||
use_local_encoder_transformer: true
|
use_local_encoder_transformer: true
|
||||||
init_use_gaussian: true
|
init_use_gaussian: true
|
||||||
init_use_depth: "current"
|
init_use_depth: "current"
|
||||||
attn_bias_type: "block_causal"
|
attn_bias_type: "block_causal"
|
||||||
|
attn_impl: "xformers"
|
||||||
alpha_depth: "disabled"
|
alpha_depth: "disabled"
|
||||||
max_length: 256
|
max_length: 256
|
||||||
local_attention_window_len: 512
|
local_attention_window_len: 512
|
||||||
|
|
|
@ -27,6 +27,7 @@ def test_basic_arrow_file():
|
||||||
dataset_files=[ARROW_TEST_DATA_1],
|
dataset_files=[ARROW_TEST_DATA_1],
|
||||||
row_num=0,
|
row_num=0,
|
||||||
arrow_batch_size=100,
|
arrow_batch_size=100,
|
||||||
|
s3_profile=None,
|
||||||
)
|
)
|
||||||
arrow_file = initial_state.build()
|
arrow_file = initial_state.build()
|
||||||
start_state = arrow_file.get_state()
|
start_state = arrow_file.get_state()
|
||||||
|
@ -55,6 +56,7 @@ def test_basic_arrow_file():
|
||||||
dataset_files=[ARROW_TEST_DATA_1],
|
dataset_files=[ARROW_TEST_DATA_1],
|
||||||
row_num=251,
|
row_num=251,
|
||||||
arrow_batch_size=100,
|
arrow_batch_size=100,
|
||||||
|
s3_profile=None,
|
||||||
)
|
)
|
||||||
arrow_file = resumed_state.build()
|
arrow_file = resumed_state.build()
|
||||||
for example in arrow_file.create_iter():
|
for example in arrow_file.create_iter():
|
||||||
|
@ -74,6 +76,7 @@ def test_basic_arrow_file():
|
||||||
dataset_files=[ARROW_TEST_DATA_1],
|
dataset_files=[ARROW_TEST_DATA_1],
|
||||||
row_num=0,
|
row_num=0,
|
||||||
arrow_batch_size=100,
|
arrow_batch_size=100,
|
||||||
|
s3_profile=None,
|
||||||
)
|
)
|
||||||
arrow_file = rank_state.build()
|
arrow_file = rank_state.build()
|
||||||
expected_ids = []
|
expected_ids = []
|
||||||
|
|
|
@ -11,7 +11,6 @@ import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from functools import lru_cache, partial, reduce
|
from functools import lru_cache, partial, reduce
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
|
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
|
||||||
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||||
|
@ -14,6 +16,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
||||||
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model_params = reloaded["model"]
|
model_params = reloaded["model"]
|
||||||
|
logger.warning(
|
||||||
|
"Update checkpoint to load attn and sliding window args from checkpoint"
|
||||||
|
)
|
||||||
entropy_model = LMTransformer(
|
entropy_model = LMTransformer(
|
||||||
LMTransformerArgs(
|
LMTransformerArgs(
|
||||||
dim=model_params["dim"],
|
dim=model_params["dim"],
|
||||||
|
@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
||||||
max_seqlen=model_params["max_length"],
|
max_seqlen=model_params["max_length"],
|
||||||
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
||||||
vocab_size=model_params["vocab_size"],
|
vocab_size=model_params["vocab_size"],
|
||||||
|
attn_bias_type="local_block_causal",
|
||||||
|
attn_impl="xformers",
|
||||||
|
sliding_window=512,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@ from bytelatent.base_transformer import (
|
||||||
TransformerBlock,
|
TransformerBlock,
|
||||||
)
|
)
|
||||||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||||
from bytelatent.model.local_models import LocalDecoder, LocalEncoder
|
from bytelatent.model.latent_transformer import GlobalTransformer
|
||||||
from bytelatent.model.transformer import GlobalTransformer
|
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
|
||||||
from bytelatent.model.utils import downsample
|
from bytelatent.model.utils import downsample
|
||||||
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||||
|
|
||||||
|
@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
|
||||||
|
|
||||||
|
|
||||||
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
# Basic model configuration
|
# Basic model configuration
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
vocab_size: int = -1
|
vocab_size: int = -1
|
||||||
|
@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
n_heads: int = 8
|
n_heads: int = 8
|
||||||
# TODO: What is the purpose of this parameter?
|
# TODO: What is the purpose of this parameter?
|
||||||
weight_tying: bool = False
|
weight_tying: bool = False
|
||||||
sliding_window: Optional[int] = None
|
|
||||||
|
|
||||||
# Architecture and dimensions
|
# Architecture and dimensions
|
||||||
dim_token: int = 256
|
dim_token: int = 256
|
||||||
|
@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
recompute_attn: bool = True
|
recompute_attn: bool = True
|
||||||
custom_bwd: bool = False
|
custom_bwd: bool = False
|
||||||
layer_ckpt: str = "all"
|
layer_ckpt: str = "all"
|
||||||
efficient_attn: str | None = None
|
|
||||||
|
|
||||||
# Architecture options
|
|
||||||
patch_only_encoder: bool = False
|
|
||||||
patch_only_decoder: bool = False
|
|
||||||
|
|
||||||
# Initialization and attention
|
# Initialization and attention
|
||||||
init_use_gaussian: bool = True
|
init_use_gaussian: bool = True
|
||||||
|
@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
# Logging
|
# Logging
|
||||||
full_logging_n_layers: int = 4
|
full_logging_n_layers: int = 4
|
||||||
|
|
||||||
# Special token config
|
|
||||||
eos_id: int | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_hash_byte_sizes(self) -> Self:
|
def check_hash_byte_sizes(self) -> Self:
|
||||||
if (
|
if (
|
||||||
|
@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LocalEncoderArgs(ByteLatentTransformerArgs):
|
|
||||||
# Local encoder specific dimensions
|
|
||||||
n_heads_local_encoder: int = 8
|
|
||||||
dim_token_emb: int | None = None
|
|
||||||
dim_patch_emb: int | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Override base args with local encoder specific values
|
|
||||||
self.dim = self.dim_local_encoder
|
|
||||||
self.n_layers = self.n_layers_local_encoder
|
|
||||||
self.n_heads = self.n_heads_local_encoder
|
|
||||||
self.cross_attn_decoder = False
|
|
||||||
self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
|
|
||||||
self.attn_bias_type = "local_block_causal"
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
||||||
# Global encoder specific dimensions
|
# Global encoder specific dimensions
|
||||||
dim_token_emb: int | None = None
|
dim_token_emb: int | None = None
|
||||||
|
@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor
|
||||||
|
|
||||||
|
|
||||||
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
||||||
# First deep copy the original args
|
local_encoder_args = LocalModelArgs(
|
||||||
# Replace with local encoder specific values
|
# Updated args
|
||||||
local_encoder_args = args.model_copy(
|
dim=args.dim_local_encoder,
|
||||||
deep=True,
|
n_layers=args.n_layers_local_encoder,
|
||||||
update=dict(
|
n_heads=args.n_heads_local_encoder,
|
||||||
dim=args.dim_local_encoder,
|
dim_token_emb=get_encoder_dim_token_emb(args),
|
||||||
n_layers=args.n_layers_local_encoder,
|
dim_patch_emb=get_encoder_dim_patch_emb(args),
|
||||||
n_heads=args.n_heads_local_encoder,
|
cross_attn_encoder=args.cross_attn_encoder,
|
||||||
dim_token_emb=get_encoder_dim_token_emb(args),
|
cross_attn_decoder=False,
|
||||||
dim_patch_emb=get_encoder_dim_patch_emb(args),
|
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
|
||||||
cross_attn_decoder=False,
|
cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
|
||||||
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
|
# Defaults
|
||||||
attn_bias_type="local_block_causal",
|
head_dim=args.head_dim,
|
||||||
),
|
max_seqlen=args.max_encoder_seq_length,
|
||||||
|
dropout=args.dropout,
|
||||||
|
vocab_size=args.vocab_size + args.pm_size,
|
||||||
|
norm_eps=args.norm_eps,
|
||||||
|
patch_size=args.patch_size,
|
||||||
|
sliding_window=args.local_attention_window_len,
|
||||||
|
use_rope=args.use_rope,
|
||||||
|
rope_theta=args.rope_theta,
|
||||||
|
init_base_std=args.init_base_std,
|
||||||
|
init_std_factor=args.init_std_factor,
|
||||||
|
n_kv_heads=args.n_kv_heads,
|
||||||
|
attn_impl=args.attn_impl,
|
||||||
|
attn_bias_type="local_block_causal",
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
patching_mode=args.patching_mode,
|
||||||
|
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
||||||
|
downsampling_by_pooling=args.downsampling_by_pooling,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
||||||
|
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
||||||
|
cross_attn_nheads=args.cross_attn_nheads,
|
||||||
|
eos_id=args.eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LocalEncoder(local_encoder_args)
|
return LocalEncoder(local_encoder_args)
|
||||||
|
@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
||||||
|
|
||||||
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
||||||
# First deep copy the original args
|
# First deep copy the original args
|
||||||
local_decoder_args = args.model_copy(
|
local_decoder_args = LocalModelArgs(
|
||||||
deep=True,
|
dim=args.dim_local_decoder,
|
||||||
update=dict(
|
n_layers=args.n_layers_local_decoder,
|
||||||
dim=args.dim_local_decoder,
|
n_heads=args.n_heads_local_decoder,
|
||||||
n_layers=args.n_layers_local_decoder,
|
dim_token_emb=get_decoder_dim_token_emb(args),
|
||||||
n_heads=args.n_heads_local_decoder,
|
dim_patch_emb=args.dim_global,
|
||||||
cross_attn_encoder=False,
|
cross_attn_encoder=False,
|
||||||
cross_attn_init_by_pooling=False, # states are already defined
|
cross_attn_decoder=args.cross_attn_decoder,
|
||||||
dim_token_emb=get_decoder_dim_token_emb(args),
|
cross_attn_init_by_pooling=False, # states are already defined
|
||||||
dim_patch_emb=args.dim_global,
|
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
|
||||||
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
|
# Defaults
|
||||||
),
|
head_dim=args.head_dim,
|
||||||
|
max_seqlen=args.max_encoder_seq_length,
|
||||||
|
dropout=args.dropout,
|
||||||
|
vocab_size=args.vocab_size + args.pm_size,
|
||||||
|
norm_eps=args.norm_eps,
|
||||||
|
patch_size=args.patch_size,
|
||||||
|
sliding_window=args.local_attention_window_len,
|
||||||
|
use_rope=args.use_rope,
|
||||||
|
rope_theta=args.rope_theta,
|
||||||
|
init_base_std=args.init_base_std,
|
||||||
|
init_std_factor=args.init_std_factor,
|
||||||
|
n_kv_heads=args.n_kv_heads,
|
||||||
|
attn_impl=args.attn_impl,
|
||||||
|
attn_bias_type="local_block_causal",
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
patching_mode=args.patching_mode,
|
||||||
|
use_local_encoder_transformer=args.use_local_encoder_transformer,
|
||||||
|
downsampling_by_pooling=args.downsampling_by_pooling,
|
||||||
|
encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
|
||||||
|
cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
|
||||||
|
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
||||||
|
cross_attn_nheads=args.cross_attn_nheads,
|
||||||
|
eos_id=args.eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LocalDecoder(local_decoder_args)
|
return LocalDecoder(local_decoder_args)
|
||||||
|
@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module):
|
||||||
|
|
||||||
# General configuration
|
# General configuration
|
||||||
self.weight_tying = args.weight_tying
|
self.weight_tying = args.weight_tying
|
||||||
self.sliding_window = args.sliding_window
|
|
||||||
self.patch_size = args.patch_size
|
self.patch_size = args.patch_size
|
||||||
self.patching_mode = args.patching_mode
|
self.patching_mode = args.patching_mode
|
||||||
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
|
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
|
||||||
|
|
|
@ -11,6 +11,7 @@ from xformers.ops import AttentionBias
|
||||||
|
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformer,
|
BaseTransformer,
|
||||||
|
BaseTransformerArgs,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
flex_attention_comp,
|
flex_attention_comp,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
|
@ -142,11 +143,10 @@ class CrossAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GlobalTransformer(BaseTransformer):
|
class GlobalTransformer(BaseTransformer):
|
||||||
def __init__(self, args):
|
def __init__(self, args: BaseTransformerArgs):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.dropout = args.dropout
|
self.dropout = args.dropout
|
||||||
self.sliding_window = args.sliding_window
|
self.eos_id = args.eos_id
|
||||||
self.efficient_attn = args.efficient_attn
|
|
||||||
|
|
||||||
self.token_embedding_projection = None
|
self.token_embedding_projection = None
|
||||||
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
||||||
|
@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer):
|
||||||
and projection to the token space.
|
and projection to the token space.
|
||||||
"""
|
"""
|
||||||
bs, seqlen = tokens.shape
|
bs, seqlen = tokens.shape
|
||||||
attn_impl = self.efficient_attn
|
|
||||||
|
|
||||||
h = embeds
|
h = embeds
|
||||||
|
|
||||||
mask = (
|
mask = (
|
||||||
mask
|
mask
|
||||||
if mask is not None
|
if mask is not None
|
||||||
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
|
else create_causal_mask(
|
||||||
|
seqlen,
|
||||||
|
self.attn_impl,
|
||||||
|
self.attn_bias_type,
|
||||||
|
tokens=tokens,
|
||||||
|
eos_id=self.eos_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
|
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
|
||||||
|
@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer):
|
||||||
|
|
||||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
||||||
return h, cache
|
return h, cache
|
||||||
|
|
||||||
def init_weights(self, init_base_std: float):
|
def init_weights(self, init_base_std: float):
|
|
@ -1,44 +1,75 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
|
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
|
BaseTransformerArgs,
|
||||||
InitStdFactor,
|
InitStdFactor,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
RotaryEmbedding,
|
RotaryEmbedding,
|
||||||
TransformerBlock,
|
TransformerBlock,
|
||||||
)
|
)
|
||||||
from bytelatent.model.transformer import CrossAttention
|
from bytelatent.model.latent_transformer import CrossAttention
|
||||||
from bytelatent.model.utils import create_causal_mask, downsample
|
from bytelatent.model.utils import create_causal_mask, downsample
|
||||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelArgs(BaseTransformerArgs):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
# Override defaults
|
||||||
|
attn_impl: str | None = "xformers"
|
||||||
|
attn_bias_type: str | None = "local_block_causal"
|
||||||
|
|
||||||
|
# Local encoder specific dimensions
|
||||||
|
dropout: float
|
||||||
|
vocab_size: int
|
||||||
|
patch_size: int
|
||||||
|
sliding_window: int | None
|
||||||
|
use_rope: bool
|
||||||
|
cross_attn_encoder: bool | None
|
||||||
|
cross_attn_decoder: bool | None
|
||||||
|
cross_attn_k: int | None
|
||||||
|
cross_attn_init_by_pooling: bool
|
||||||
|
patching_mode: str
|
||||||
|
use_local_encoder_transformer: bool
|
||||||
|
downsampling_by_pooling: str | None
|
||||||
|
encoder_hash_byte_group_size: Any | None = None
|
||||||
|
cross_attn_all_layers_encoder: bool = False
|
||||||
|
cross_attn_all_layers_decoder: bool = False
|
||||||
|
cross_attn_nheads: int | None
|
||||||
|
|
||||||
|
dim_token_emb: int
|
||||||
|
dim_patch_emb: int | None
|
||||||
|
|
||||||
|
|
||||||
class LocalModelBase(nn.Module):
|
class LocalModelBase(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args: LocalModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
self.dropout = args.dropout
|
self.dropout = args.dropout
|
||||||
self.vocab_size = args.vocab_size + args.pm_size
|
self.vocab_size = args.vocab_size
|
||||||
self.patch_size = args.patch_size
|
self.patch_size = args.patch_size
|
||||||
|
|
||||||
self.efficient_attn = args.efficient_attn
|
self.attn_impl = args.attn_impl
|
||||||
self.sliding_window = args.sliding_window
|
self.sliding_window = args.sliding_window
|
||||||
self.use_rope = args.use_rope
|
self.use_rope = args.use_rope
|
||||||
self.init_std_factor = args.init_std_factor
|
self.init_std_factor = args.init_std_factor
|
||||||
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
|
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
|
||||||
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
|
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
|
||||||
self.cross_attn_k = getattr(args, "cross_attn_k", None)
|
self.cross_attn_k = getattr(args, "cross_attn_k", None)
|
||||||
|
self.eos_id = args.eos_id
|
||||||
|
|
||||||
self.boe_id = BOE_ID
|
self.boe_id = BOE_ID
|
||||||
|
|
||||||
|
@ -54,7 +85,7 @@ class LocalModelBase(nn.Module):
|
||||||
self.rope = RotaryEmbedding(
|
self.rope = RotaryEmbedding(
|
||||||
theta=args.rope_theta,
|
theta=args.rope_theta,
|
||||||
head_dim=args.head_dim or args.dim // args.n_heads,
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||||
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
|
max_seqlen=args.max_seqlen,
|
||||||
)
|
)
|
||||||
self.pos_embeddings = None
|
self.pos_embeddings = None
|
||||||
|
|
||||||
|
@ -66,21 +97,15 @@ class LocalModelBase(nn.Module):
|
||||||
|
|
||||||
self.patch_embedding_projection = self._create_patch_projection(args)
|
self.patch_embedding_projection = self._create_patch_projection(args)
|
||||||
|
|
||||||
def _should_create_patch_projection(self, args):
|
def _should_create_patch_projection(self, args: LocalModelArgs):
|
||||||
dimension_mismatch = (
|
dimension_mismatch = (
|
||||||
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
|
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check cross attention conditions
|
# Check cross attention conditions
|
||||||
cross_attn_conditions = (
|
cross_attn_conditions = (
|
||||||
hasattr(args, "cross_attn_encoder")
|
args.cross_attn_encoder and args.cross_attn_init_by_pooling
|
||||||
and args.cross_attn_encoder
|
) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
|
||||||
and getattr(args, "cross_attn_init_by_pooling")
|
|
||||||
) or (
|
|
||||||
hasattr(args, "cross_attn_decoder")
|
|
||||||
and args.cross_attn_decoder
|
|
||||||
and getattr(args, "cross_attn_init_by_pooling")
|
|
||||||
)
|
|
||||||
|
|
||||||
return dimension_mismatch or cross_attn_conditions
|
return dimension_mismatch or cross_attn_conditions
|
||||||
|
|
||||||
|
@ -172,7 +197,7 @@ class LocalModelBase(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LocalEncoder(LocalModelBase):
|
class LocalEncoder(LocalModelBase):
|
||||||
def __init__(self, args):
|
def __init__(self, args: LocalModelArgs):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.output_proj = (
|
self.output_proj = (
|
||||||
args.patching_mode in ["entropy", "probmax"]
|
args.patching_mode in ["entropy", "probmax"]
|
||||||
|
@ -180,7 +205,6 @@ class LocalEncoder(LocalModelBase):
|
||||||
|
|
||||||
self.apply_transformer = args.use_local_encoder_transformer
|
self.apply_transformer = args.use_local_encoder_transformer
|
||||||
self.downsampling_by_pooling = args.downsampling_by_pooling
|
self.downsampling_by_pooling = args.downsampling_by_pooling
|
||||||
self.patch_only = args.patch_only_encoder
|
|
||||||
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
|
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
|
||||||
self.cross_attn_encoder = args.cross_attn_encoder
|
self.cross_attn_encoder = args.cross_attn_encoder
|
||||||
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
|
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
|
||||||
|
@ -224,7 +248,14 @@ class LocalEncoder(LocalModelBase):
|
||||||
""" """
|
""" """
|
||||||
bs, seqlen = tokens.shape
|
bs, seqlen = tokens.shape
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
|
mask = create_causal_mask(
|
||||||
|
seqlen,
|
||||||
|
self.attn_impl,
|
||||||
|
"local_block_causal",
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
tokens=tokens,
|
||||||
|
eos_id=self.eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
h = self.apply_embedding(tokens, embeds)
|
h = self.apply_embedding(tokens, embeds)
|
||||||
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
||||||
|
@ -232,7 +263,7 @@ class LocalEncoder(LocalModelBase):
|
||||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
|
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
|
||||||
# check if cross attention should be applied to either all layer or only the last layer
|
# check if cross attention should be applied to either all layer or only the last layer
|
||||||
if self.cross_attn_encoder and (
|
if self.cross_attn_encoder and (
|
||||||
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
|
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
|
||||||
|
@ -273,12 +304,10 @@ class LocalEncoder(LocalModelBase):
|
||||||
|
|
||||||
|
|
||||||
class LocalDecoder(LocalModelBase):
|
class LocalDecoder(LocalModelBase):
|
||||||
def __init__(self, args):
|
def __init__(self, args: LocalModelArgs):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
# Model configuration flags
|
# Model configuration flags
|
||||||
self.patch_only = args.patch_only_decoder
|
|
||||||
self.expects_embeddings = args.share_encoder_decoder_emb
|
|
||||||
self.cross_attn_decoder = args.cross_attn_decoder
|
self.cross_attn_decoder = args.cross_attn_decoder
|
||||||
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
|
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
|
||||||
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||||
|
@ -317,7 +346,14 @@ class LocalDecoder(LocalModelBase):
|
||||||
assert embeds is not None, "Embeddings must be provided"
|
assert embeds is not None, "Embeddings must be provided"
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
|
mask = create_causal_mask(
|
||||||
|
seqlen,
|
||||||
|
self.attn_impl,
|
||||||
|
"local_block_causal",
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
tokens=tokens,
|
||||||
|
eos_id=self.eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
h = embeds
|
h = embeds
|
||||||
|
|
||||||
|
@ -347,7 +383,7 @@ class LocalDecoder(LocalModelBase):
|
||||||
)
|
)
|
||||||
h = h + h_cross
|
h = h + h_cross
|
||||||
|
|
||||||
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
|
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
|
||||||
|
|
||||||
h_preds = self.norm(h)
|
h_preds = self.norm(h)
|
||||||
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
|
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.attention.flex_attention import create_block_mask
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
from xformers.ops import fmha
|
from xformers.ops import fmha
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
||||||
"""
|
"""
|
||||||
|
@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx):
|
||||||
return q_idx >= kv_idx
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
|
||||||
def create_causal_mask(seqlen, attn_impl, sliding_window):
|
def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
|
||||||
if sliding_window is not None and attn_impl == "xformers":
|
"""
|
||||||
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
0 0 0 1 0 0 0 1 0 0 0
|
||||||
window_left=sliding_window - 1, window_right=0
|
0 1 0 0 0 1 0 0 0 0 0
|
||||||
)
|
-> 4 4 3 2 4 5
|
||||||
elif attn_impl == "xformers":
|
"""
|
||||||
return fmha.attn_bias.LowerTriangularMask()
|
mask = batch == eos_id
|
||||||
|
mask[:, -1] = True # virtual eos at the end of each row
|
||||||
|
|
||||||
|
# 0 0 0 1 0 0 0 1 0 0 X
|
||||||
|
# 0 1 0 0 0 1 0 0 0 0 X
|
||||||
|
row, col = torch.where(mask)
|
||||||
|
|
||||||
|
# row = 0, 0, 0, 1, 1, 1
|
||||||
|
# col = 3, 7, 10, 1, 5, 10
|
||||||
|
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
|
||||||
|
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
|
||||||
|
return [int(col[0].item() + 1)] + seqlens.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def create_causal_mask(
|
||||||
|
seqlen,
|
||||||
|
attn_impl: str,
|
||||||
|
attn_bias_type: str | None,
|
||||||
|
*,
|
||||||
|
eos_id: int | None = None,
|
||||||
|
tokens: torch.Tensor | None = None,
|
||||||
|
sliding_window: int | None = None,
|
||||||
|
):
|
||||||
|
if attn_impl == "xformers":
|
||||||
|
if attn_bias_type is None:
|
||||||
|
return fmha.attn_bias.LowerTriangularMask()
|
||||||
|
elif attn_bias_type == "causal":
|
||||||
|
assert sliding_window is None
|
||||||
|
return fmha.attn_bias.LowerTriangularMask()
|
||||||
|
elif attn_bias_type == "block_causal":
|
||||||
|
assert sliding_window is None
|
||||||
|
assert eos_id is not None
|
||||||
|
assert tokens is not None
|
||||||
|
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
q_seqlen=tokens_to_seqlen(tokens, eos_id)
|
||||||
|
)
|
||||||
|
elif attn_bias_type == "local_block_causal":
|
||||||
|
assert sliding_window is not None
|
||||||
|
assert eos_id is not None
|
||||||
|
assert tokens is not None
|
||||||
|
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
q_seqlen=tokens_to_seqlen(tokens, eos_id)
|
||||||
|
).make_local_attention(sliding_window)
|
||||||
|
else:
|
||||||
|
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
||||||
|
window_left=sliding_window - 1, window_right=0
|
||||||
|
)
|
||||||
elif attn_impl == "sdpa":
|
elif attn_impl == "sdpa":
|
||||||
return "causal"
|
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
|
||||||
|
|
||||||
|
if attn_bias_type == "causal":
|
||||||
|
return "causal"
|
||||||
|
|
||||||
|
if BLT_SUPPRESS_ATTN_ERROR == 1:
|
||||||
|
logging.warning(
|
||||||
|
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
|
||||||
|
)
|
||||||
|
return "causal"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
|
||||||
|
)
|
||||||
elif attn_impl == "flex_attention":
|
elif attn_impl == "flex_attention":
|
||||||
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
||||||
elif attn_impl == "fmha":
|
elif attn_impl == "fmha":
|
||||||
|
|
38
bytelatent/preprocess/fsspec_target.py
Normal file
38
bytelatent/preprocess/fsspec_target.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import fsspec
|
||||||
|
from luigi.target import FileSystem, FileSystemTarget
|
||||||
|
|
||||||
|
|
||||||
|
class FSSpecFileSystem(FileSystem):
|
||||||
|
def __init__(self, fs: fsspec.AbstractFileSystem):
|
||||||
|
self.fs = fs
|
||||||
|
|
||||||
|
def exists(self, path):
|
||||||
|
return self.fs.exists()
|
||||||
|
|
||||||
|
def remove(self, path, recursive=True, skip_trash=True):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def isdir(self, path):
|
||||||
|
return self.fs.isdir(path)
|
||||||
|
|
||||||
|
def listdir(self, path):
|
||||||
|
return self.fs.ls(path)
|
||||||
|
|
||||||
|
|
||||||
|
class FSSpecTarget(FileSystemTarget):
|
||||||
|
def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None):
|
||||||
|
self.path = path
|
||||||
|
if fs is None:
|
||||||
|
self.fsspec_fs = fsspec.filesystem("file")
|
||||||
|
else:
|
||||||
|
self.fsspec_fs = fs
|
||||||
|
self._fs = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fs(self):
|
||||||
|
if self._fs is None:
|
||||||
|
self._fs = FSSpecFileSystem(self.fsspec_fs)
|
||||||
|
return self._fs
|
||||||
|
|
||||||
|
def open(self, mode):
|
||||||
|
return self.fs.open(self.path, mode=mode)
|
|
@ -23,9 +23,10 @@ from bytelatent.model.blt import (
|
||||||
init_embeddings,
|
init_embeddings,
|
||||||
patch_ids_from_lengths,
|
patch_ids_from_lengths,
|
||||||
)
|
)
|
||||||
from bytelatent.model.transformer import CrossAttention
|
from bytelatent.model.latent_transformer import CrossAttention
|
||||||
from bytelatent.model.utils import create_causal_mask
|
from bytelatent.model.utils import create_causal_mask
|
||||||
from bytelatent.optim import OptimArgs, build_optimizer
|
from bytelatent.optim import OptimArgs, build_optimizer
|
||||||
|
from bytelatent.tokenizers.constants import EOS_ID
|
||||||
from bytelatent.train import compute_loss
|
from bytelatent.train import compute_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch):
|
||||||
|
|
||||||
|
|
||||||
def fake_batch():
|
def fake_batch():
|
||||||
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
|
batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
|
||||||
del batch_dict["x2"]
|
del batch_dict["x2"]
|
||||||
del batch_dict["y2"]
|
del batch_dict["y2"]
|
||||||
del batch_dict["src_names"]
|
del batch_dict["src_names"]
|
||||||
|
@ -98,18 +99,17 @@ def create_args(cross_attention=False):
|
||||||
recompute_attn=False,
|
recompute_attn=False,
|
||||||
custom_bwd=False,
|
custom_bwd=False,
|
||||||
layer_ckpt="none",
|
layer_ckpt="none",
|
||||||
efficient_attn="sdpa",
|
|
||||||
patch_only_encoder=False,
|
|
||||||
patch_only_decoder=False,
|
|
||||||
use_local_encoder_transformer=True,
|
use_local_encoder_transformer=True,
|
||||||
init_use_gaussian=True,
|
init_use_gaussian=True,
|
||||||
init_use_depth="current",
|
init_use_depth="current",
|
||||||
attn_bias_type="block_causal",
|
attn_bias_type="block_causal",
|
||||||
|
attn_impl="xformers",
|
||||||
alpha_depth="disabled",
|
alpha_depth="disabled",
|
||||||
max_length=256,
|
max_length=256,
|
||||||
local_attention_window_len=512,
|
local_attention_window_len=512,
|
||||||
max_seqlen=12288,
|
max_seqlen=12288,
|
||||||
downsampling_by_pooling="max",
|
downsampling_by_pooling="max",
|
||||||
|
eos_id=EOS_ID,
|
||||||
)
|
)
|
||||||
return transformer_args
|
return transformer_args
|
||||||
|
|
||||||
|
@ -341,10 +341,15 @@ class TestByteLatentTransformer:
|
||||||
model = ByteLatentTransformer(args)
|
model = ByteLatentTransformer(args)
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
@pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
|
@pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"])
|
||||||
def test_blt_transformer_forward(self, attn_type):
|
def test_blt_transformer_forward(self, attn_impl):
|
||||||
args = create_args()
|
args = create_args()
|
||||||
args = args.model_copy(update=dict(efficient_attn=attn_type))
|
if attn_impl == "sdpa":
|
||||||
|
os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1"
|
||||||
|
else:
|
||||||
|
os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0"
|
||||||
|
|
||||||
|
args = args.model_copy(update=dict(attn_impl=attn_impl))
|
||||||
model = ByteLatentTransformer(args)
|
model = ByteLatentTransformer(args)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
batch = fake_batch()
|
batch = fake_batch()
|
||||||
|
@ -393,7 +398,9 @@ class TestByteLatentTransformer:
|
||||||
n_kv_heads=4,
|
n_kv_heads=4,
|
||||||
norm_eps=1e-6,
|
norm_eps=1e-6,
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
|
mask = create_causal_mask(
|
||||||
|
x.shape[1], "flex_attention", None, sliding_window=None
|
||||||
|
)
|
||||||
output = cross_attention(x, kv, mask)
|
output = cross_attention(x, kv, mask)
|
||||||
assert output is not None
|
assert output is not None
|
||||||
assert output.shape == (2, 256, 512)
|
assert output.shape == (2, 256, 512)
|
||||||
|
@ -440,7 +447,7 @@ class TestByteLatentTransformer:
|
||||||
|
|
||||||
def test_loss_backward(self):
|
def test_loss_backward(self):
|
||||||
args = create_args()
|
args = create_args()
|
||||||
args = args.model_copy(update=dict(efficient_attn="sdpa"))
|
args = args.model_copy(update=dict(attn_impl="xformers"))
|
||||||
batch = fake_batch()
|
batch = fake_batch()
|
||||||
model = ByteLatentTransformer(args)
|
model = ByteLatentTransformer(args)
|
||||||
steps = 10
|
steps = 10
|
||||||
|
|
|
@ -24,6 +24,7 @@ def test_entropy_model():
|
||||||
dataset_files=[ARROW_TEST_DATA],
|
dataset_files=[ARROW_TEST_DATA],
|
||||||
row_num=0,
|
row_num=0,
|
||||||
arrow_batch_size=100,
|
arrow_batch_size=100,
|
||||||
|
s3_profile=None,
|
||||||
)
|
)
|
||||||
arrow_file = initial_state.build()
|
arrow_file = initial_state.build()
|
||||||
tokenizer_args = TokenizerArgs(
|
tokenizer_args = TokenizerArgs(
|
||||||
|
@ -38,7 +39,7 @@ def test_entropy_model():
|
||||||
BLT_DATA,
|
BLT_DATA,
|
||||||
"entropy_model.pth",
|
"entropy_model.pth",
|
||||||
),
|
),
|
||||||
)
|
).cuda()
|
||||||
preprocess_iter = PreprocessIterator(
|
preprocess_iter = PreprocessIterator(
|
||||||
arrow_file,
|
arrow_file,
|
||||||
tokenizer_args=tokenizer_args,
|
tokenizer_args=tokenizer_args,
|
||||||
|
@ -48,8 +49,10 @@ def test_entropy_model():
|
||||||
for example in preprocess_iter.create_iter():
|
for example in preprocess_iter.create_iter():
|
||||||
tokens = torch.tensor(example.tokens).unsqueeze(0)
|
tokens = torch.tensor(example.tokens).unsqueeze(0)
|
||||||
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
|
expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
|
||||||
preds = entropy_model(tokens)
|
preds = entropy_model(tokens.cuda())
|
||||||
pred_entropies = entropy(preds)
|
pred_entropies = entropy(preds)
|
||||||
assert pred_entropies.shape == expected_entropies.shape
|
assert pred_entropies.shape == expected_entropies.shape
|
||||||
assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
|
assert torch.allclose(
|
||||||
|
pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
|
@ -644,6 +644,10 @@ def main():
|
||||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||||
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||||
train_args = TrainArgs.model_validate(cfg)
|
train_args = TrainArgs.model_validate(cfg)
|
||||||
|
if train_args.debug_dynamo:
|
||||||
|
import torch._dynamo
|
||||||
|
|
||||||
|
torch._dynamo.config.suppress_errors = True
|
||||||
train(train_args)
|
train(train_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,23 +22,7 @@ from bytelatent.base_transformer import (
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
cross_entropy,
|
cross_entropy,
|
||||||
)
|
)
|
||||||
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
|
||||||
def create_causal_mask(seqlen, attn_impl, sliding_window):
|
|
||||||
if sliding_window is not None and attn_impl == "xformers":
|
|
||||||
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
|
||||||
window_left=sliding_window - 1, window_right=0
|
|
||||||
)
|
|
||||||
elif attn_impl == "xformers":
|
|
||||||
return fmha.attn_bias.LowerTriangularMask()
|
|
||||||
elif attn_impl == "sdpa":
|
|
||||||
return "causal"
|
|
||||||
elif attn_impl == "flex_attention":
|
|
||||||
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
||||||
|
@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer):
|
||||||
target: Optional[torch.Tensor] = None,
|
target: Optional[torch.Tensor] = None,
|
||||||
tok_idx: Optional[torch.Tensor] = None,
|
tok_idx: Optional[torch.Tensor] = None,
|
||||||
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
|
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
|
||||||
attn_impl: str = "sdpa",
|
attn_impl: str | None = None,
|
||||||
):
|
):
|
||||||
|
if attn_impl is None:
|
||||||
|
attn_impl = self.attn_impl
|
||||||
bsz, seqlen = token_values.shape
|
bsz, seqlen = token_values.shape
|
||||||
|
|
||||||
h = self.tok_embeddings(token_values)
|
h = self.tok_embeddings(token_values)
|
||||||
|
@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer):
|
||||||
mask = (
|
mask = (
|
||||||
mask
|
mask
|
||||||
if mask is not None
|
if mask is not None
|
||||||
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
|
else create_causal_mask(
|
||||||
|
seqlen,
|
||||||
|
attn_impl,
|
||||||
|
self.attn_bias_type,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
tokens=token_values,
|
||||||
|
eos_id=self.eos_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue