From 374409fa3b90e856f3b4c306c16fff21ac210ca7 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 17 Jan 2025 01:01:29 +0000 Subject: [PATCH] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 7 + bytelatent/base_transformer.py | 45 ++++-- bytelatent/configs/debug.yaml | 3 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/distributed.py | 1 - bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 73 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/train.py | 4 + bytelatent/transformer.py | 31 ++--- 14 files changed, 334 insertions(+), 133 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..b9144c6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -163,6 +164,8 @@ class TrainArgs(BaseModel): seed: int = 42 + debug_dynamo: bool = False + # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 @@ -176,6 +179,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..dd0cce6 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,12 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -522,14 +544,16 @@ class TransformerBlock(nn.Module): mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: - h = x + self.attention( + attn_out = self.attention( self.attention_norm(x), freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) - out = h + self.feed_forward(self.ffn_norm(h)) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): @@ -545,6 +569,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +578,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..4ae4459 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -15,7 +15,6 @@ optim: distributed: fsdp_type: full_shard - compile: true model_dtype: bf16 matmul_allow_tf32: false selective_activation_checkpointing: false @@ -58,13 +57,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index b211858..168cb7c 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -11,7 +11,6 @@ import socket import subprocess import sys import tempfile -from dataclasses import asdict, dataclass from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..1d20cfa 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.global_transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..f182780 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn +from pydantic import BaseModel, ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..42eb185 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging + import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +101,69 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..4d8e9c7 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/train.py b/bytelatent/train.py index 6cb13b9..80bd393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -644,6 +644,10 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) train_args = TrainArgs.model_validate(cfg) + if train_args.debug_dynamo: + import torch._dynamo + + torch._dynamo.config.suppress_errors = True train(train_args) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)