# Copyright (c) Meta Platforms, Inc. and affiliates. from enum import Enum, auto from typing import Any, List, Optional, Tuple, Union import torch from huggingface_hub import PyTorchModelHubMixin from pydantic import model_validator from torch import nn from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention from typing_extensions import Self import json import logging import torch import torch.nn import torch.nn as nn from pydantic import ConfigDict from torch.nn import functional as F from xformers.ops import AttentionBias, fmha import abc import os import time from collections import defaultdict from pydantic import BaseModel from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID RMSNorm = nn.RMSNorm from bytelatent.distributed import get_local_rank logger = logging.getLogger() if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) else: flex_attention_comp = None def patch_reduce(h, max_num_patches, reduction, patch_ids): """ Reduce variable length patches to single embedding per patch Note: this works with variable number of patches for different sequences in the batch It handles variable length patches by assuming that patch_lengths will be 0 for any extra patches on the *right*. Since there can be a variable number of patches this function also return the number of patches for each sequence in the batch. Any embeddings on the right that are not allocated to a patch (i.e. if the sum(patch_lengths[i]) < seq_len for any i) will be sent to a dummy patch, which is trimmed before returning. """ bs, seq_len, emb_dim = h.shape patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) reduced_embs = torch.zeros( (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device ) reduced_embs = reduced_embs.scatter_reduce( src=h, dim=1, index=patch_ids, reduce=reduction, include_self=False, ) reduced_embs = reduced_embs[:, :max_num_patches, :] return reduced_embs def concat_downsample(h, patch_lengths, patch_size): # The assumption in this function is that seq_len = patch_size * num_patches. bs, seq_len, emb_dim = h.shape patch_end_ids = torch.cumsum(patch_lengths, dim=1) patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( patch_end_ids.device ) # Is clamp ok here? patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) patch_ids = patch_ids.view(bs, -1, emb_dim) # after gather h.shape = [batch_size, seq_len, dim] h = torch.gather(h, 1, patch_ids) h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) return h def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): cat = [] if "avg" in pooling_mode or "mean" in pooling_mode: cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) if "min" in pooling_mode: cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) if "max" in pooling_mode: cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) assert len(cat) > 0 h = torch.cat(cat, dim=-1) return h def downsample( h, num_patches, patch_lengths=None, patch_ids=None, downsampling_by_pooling=None, patch_size=4, ): """ Downsampling: a. concatenating embeddings in the patch Note: with dynamic patching, patch the last patch_size tokens. b. pooling embeddings in the patch """ # input: h.shape = [batch_size, seq_len, dim] # input: pool h.shape = [batch_size, seq_len / patch_size, dim] # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: # By pooling max_num_patches = num_patches assert patch_ids is not None h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) else: # TODO: remove this condition # By concatenating (fixed lengths patching) assert patch_lengths is not None h = concat_downsample(h, patch_lengths, patch_size) return h def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): """ 0 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 0 0 0 -> 4 4 3 2 4 5 """ mask = batch == eos_id mask[:, -1] = True # virtual eos at the end of each row # 0 0 0 1 0 0 0 1 0 0 X # 0 1 0 0 0 1 0 0 0 0 X row, col = torch.where(mask) # row = 0, 0, 0, 1, 1, 1 # col = 3, 7, 10, 1, 5, 10 seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) return [int(col[0].item() + 1)] + seqlens.tolist() def create_causal_mask( seqlen, attn_impl: str, attn_bias_type: str | None, *, eos_id: int | None = None, tokens: torch.Tensor | None = None, sliding_window: int | None = None, ): if attn_impl == "xformers": if attn_bias_type is None: return fmha.attn_bias.LowerTriangularMask() elif attn_bias_type == "causal": assert sliding_window is None return fmha.attn_bias.LowerTriangularMask() elif attn_bias_type == "block_causal": assert sliding_window is None assert eos_id is not None assert tokens is not None return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( q_seqlen=tokens_to_seqlen(tokens, eos_id) ) elif attn_bias_type == "local_block_causal": assert sliding_window is not None assert eos_id is not None assert tokens is not None return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( q_seqlen=tokens_to_seqlen(tokens, eos_id) ).make_local_attention(sliding_window) else: return fmha.attn_bias.LocalAttentionFromBottomRightMask( window_left=sliding_window - 1, window_right=0 ) elif attn_impl == "sdpa": BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) if attn_bias_type == "causal": return "causal" if BLT_SUPPRESS_ATTN_ERROR == 1: return "causal" else: raise ValueError( "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" ) elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) elif attn_impl == "fmha": return None else: raise NotImplementedError( f"Attention {attn_impl} with {sliding_window} sliding window not implemented" ) class InitStdFactor(str, Enum): DISABLED = "disabled" # Init std is divided by 1.0 GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 class BaseTransformerArgs(BaseModel): model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 head_dim: int | None = None n_heads: int | None = None n_kv_heads: int | None = None ffn_dim_multiplier: float | None = None multiple_of: int = 256 norm_eps: float = 1e-5 rope_theta: float = 10000.0 rope_use_fp32_in_outer_product: bool = False init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 attn_impl: str | None = "sdpa" attn_bias_type: str | None = None # Special token config eos_id: int | None = EOS_ID def cross_entropy(pred, target, **kwargs): return F.nll_loss( F.log_softmax(pred.flatten(end_dim=-2).float(), -1), target.flatten(end_dim=-1), **kwargs, ) def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) def precompute_freqs_cis( dim: int, end: int, theta: float = 10000.0, rope_use_fp32_in_outer_product: bool = False, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) if rope_use_fp32_in_outer_product: t = t.to(torch.float32) freqs = torch.outer(t, freqs).float() cos, sin = freqs.cos(), freqs.sin() return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. seq_dim (int): Sequence dimension index. Returns: torch.Tensor: Reshaped frequency tensor. """ ndim = x.ndim assert 0 <= seq_dim < ndim assert freqs_cis.shape == ( x.shape[seq_dim], x.shape[-3], 2, 2, ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" shape = [ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) ] + [2, 2] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, seq_dim: int, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 freqs_cis = reshape_for_broadcast( freqs_cis, xq_, seq_dim ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 xq_out = (xq_ * freqs_cis).sum(5).flatten(3) xk_out = (xk_ * freqs_cis).sum(5).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) # Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. class RotaryEmbedding(torch.nn.Module): """ RotaryEmbedding Module """ def __init__( self, theta: float, head_dim: int, max_seqlen: int = 1024, rope_use_fp32_in_outer_product: bool = False, ): super().__init__() self.theta = theta self.head_dim = head_dim self.max_seqlen = max_seqlen self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product self.register_buffer( "freqs_cis", precompute_freqs_cis( dim=head_dim, end=max_seqlen, theta=theta, rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, ), persistent=False, ) def reset_parameters(self): self.freqs_cis[...] = precompute_freqs_cis( dim=self.head_dim, end=self.max_seqlen, theta=self.theta, rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, ) def forward( self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None ): """ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions Args: seqlen (int): Contiguous sequence length tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen Returns: Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis """ test = (seqlen is not None) or (tok_idx is not None) assert test, "Should provide atleast seqlen or tok_idx" if tok_idx is not None: return self.freqs_cis[tok_idx] elif seqlen is not None: return self.freqs_cis[0:seqlen] def _reshape_for_attn_bias( attn_bias: AttentionBias | None, *tensors: torch.Tensor, ) -> list[torch.Tensor]: to_transform = list(tensors) if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): # could be `view` instead of reshape during training, but for inference # have to reshape due to strides mismatch to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] return to_transform class Attention(nn.Module): def __init__( self, dim: int, head_dim: int, n_heads: int, n_kv_heads: int, rope_theta: float, ): super().__init__() self.dim = dim self.head_dim = head_dim self.rope_theta = rope_theta self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads self.wq = nn.Linear( dim, n_heads * head_dim, bias=False, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, ) def forward( self, x: torch.Tensor, freq_cis: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: # B S D bsz, seq_len, dim = x.shape xq = self.wq(x.view_as(x)) xk = self.wk(x.view_as(x)) xv = self.wv(x.view_as(x)) output_shape = xq.shape # B S D -> B S H D xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) # This condition helps us be easily compatible # with inference by adding a pluggable KVCache if hasattr(self, "kv_cache"): xk, xv = self.kv_cache.update(xk, xv, tok_idx) xk = repeat_kv(xk, self.heads_per_group, dim=2) xv = repeat_kv(xv, self.heads_per_group, dim=2) if attn_impl == "flex_attention": assert mask is None or isinstance(mask, BlockMask) xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) query_shape = xq.shape xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) assert mask is None or isinstance(mask, (str, torch.Tensor)) is_causal = (mask == "causal") if isinstance(mask, str) else False mask = mask if isinstance(mask, torch.Tensor) else None output = F.scaled_dot_product_attention( xq, xk, xv, is_causal=is_causal, attn_mask=mask, ) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D else: raise NotImplementedError( f"Attention implementation {attn_impl} not supported" ) output_reshaped = output.reshape(output_shape) output = self.wo(output_reshaped) return output def reset_parameters(self, init_std=None, factor=1.0): init_std = init_std or (self.dim ** (-0.5)) / factor for w in [self.wq, self.wk, self.wv]: nn.init.trunc_normal_( w.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) nn.init.trunc_normal_( self.wo.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) class FeedForward(nn.Module): def __init__( self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float], mp_size: int = 1, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) assert hidden_dim % mp_size == 0 self.dim = dim self.hidden_dim = hidden_dim self.w1 = nn.Linear( dim, hidden_dim, bias=False, ) self.w3 = nn.Linear( dim, hidden_dim, bias=False, ) self.w2 = nn.Linear( hidden_dim, dim, bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # B S D x1 = self.w1(x.view_as(x)) x3 = self.w3(x.view_as(x)) output = self.w2(F.silu(x1) * x3) return output def reset_parameters(self, init_std=None, factor=1.0): in_init_std = init_std or (self.dim ** (-0.5)) / factor out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor nn.init.trunc_normal_( self.w1.weight, mean=0.0, std=in_init_std, a=-3 * in_init_std, b=3 * in_init_std, ) nn.init.trunc_normal_( self.w2.weight, mean=0.0, std=out_init_std, a=-3 * out_init_std, b=3 * out_init_std, ) nn.init.trunc_normal_( self.w3.weight, mean=0.0, std=in_init_std, a=-3 * in_init_std, b=3 * in_init_std, ) class TransformerBlock(nn.Module): def __init__(self, args: BaseTransformerArgs): super().__init__() assert (args.head_dim is not None) or ( args.n_heads is not None ), "Should specify at least head_dim or n_heads" self.head_dim = args.head_dim or args.dim // args.n_heads self.n_heads = args.n_heads or args.dim // args.head_dim self.n_kv_heads = args.n_kv_heads or self.n_heads assert args.n_heads % self.n_kv_heads == 0 assert args.dim % args.n_heads == 0 self.attention = Attention( dim=args.dim, head_dim=self.head_dim, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, rope_theta=args.rope_theta, ) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, ) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward( self, x: torch.Tensor, freq_cis: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: norm_x = self.attention_norm(x) attn_out = self.attention( norm_x, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) h = x + attn_out h_norm = self.ffn_norm(h) out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): self.attention.reset_parameters(init_std, factor) self.attention_norm.reset_parameters() self.feed_forward.reset_parameters(init_std, factor) self.ffn_norm.reset_parameters() class SequenceModelWithOutput(abc.ABC): @abc.abstractmethod def get_output_seq_len(self) -> int: pass class BaseTransformer(nn.Module, SequenceModelWithOutput): def __init__(self, args: BaseTransformerArgs): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std self.attn_impl = args.attn_impl self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): self.layers.append(TransformerBlock(args)) def get_output_seq_len(self): return self.max_seqlen def forward( self, h, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ): freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) for i, layer in enumerate(self.layers): h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) return h def init_weights(self): self.rope_embeddings.reset_parameters() for depth, layer in enumerate(self.layers): factor = { InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, InitStdFactor.DIM_RATIO: self.dim / 4096, InitStdFactor.DISABLED: 1.0, }[self.init_std_factor] layer.init_weights(self.init_base_std, factor) class LMTransformerArgs(BaseTransformerArgs): seed: int = 42 vocab_size: int = -1 weight_tying: bool = False sliding_window: int | None = None class LMTransformer( BaseTransformer, PyTorchModelHubMixin, repo_url="https://github.com/facebookresearch/blt", # paper_url="https://arxiv.org/abs/2412.09871", pipeline_tag="text-generation", license="other", license_name="fair-noncommercial-research-license", license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", coders={ LMTransformerArgs: ( lambda x: {"args": x.model_dump()}, lambda data: LMTransformerArgs(**data), ) }, ): def __init__(self, args: LMTransformerArgs): super().__init__(args) self.weight_tying = args.weight_tying self.sliding_window = args.sliding_window assert args.vocab_size > 0 self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear( args.dim, args.vocab_size, bias=False, ) if args.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight def push_to_hub(self, *args, **kwargs): raise ValueError( "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." ) def forward( self, token_values: torch.Tensor, target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, attn_impl: str | None = None, ): if attn_impl is None: attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) mask = ( mask if mask is not None else create_causal_mask( seqlen, attn_impl, self.attn_bias_type, sliding_window=self.sliding_window, tokens=token_values, eos_id=self.eos_id, ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) logits = self.output(self.norm(h)) if target is not None: return cross_entropy(logits, target) else: return logits def reset_parameters(self, init_std=None): self.norm.reset_parameters() def init_weights(self): self.reset_parameters() init_std = self.dim ** (-0.5) nn.init.trunc_normal_( self.tok_embeddings.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) super().init_weights() if not self.weight_tying: nn.init.trunc_normal_( self.output.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) class PatchingModeEnum(str, Enum): entropy = "entropy" bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" static = "static" byte = "byte" class PatcherArgs(BaseModel): patching_mode: PatchingModeEnum = PatchingModeEnum.entropy patching_device: str = "cuda" entropy_model_checkpoint_dir: str | None = None realtime_patching: bool = False threshold: float = 1.335442066192627 threshold_add: float | None = None max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 device: str = "cuda" monotonicity: bool = False log_time: bool = False def build(self) -> "Patcher": return Patcher(self) def rightpad(seq, pad_id, max_len): return seq + [pad_id] * (max_len - len(seq)) def check_non_zero_after_zero(tensor): zero_mask = tensor == 0 shifted_mask = torch.cat( [ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), zero_mask[:, :-1], ], dim=1, ) non_zero_after_zero = (tensor != 0) & shifted_mask return non_zero_after_zero.any() def to_device(entropy_model, device=None): if device == "cuda": rank = get_local_rank() device = f"cuda:{rank}" entropy_model = entropy_model.to(device) return entropy_model, device def split_large_numbers(lst, m): new_lst = [] for i in lst: if i > m: while i > m: new_lst.append(m) i -= m new_lst.append(i) else: new_lst.append(i) assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" return new_lst class Patcher: def __init__(self, patcher_args: PatcherArgs): self.patcher_args = patcher_args self.patching_mode = patcher_args.patching_mode self.realtime_patching = patcher_args.realtime_patching if self.realtime_patching: assert ( patcher_args.entropy_model_checkpoint_dir is not None ), "Cannot require realtime patching without an entropy model checkpoint" maybe_consolidated = os.path.join( patcher_args.entropy_model_checkpoint_dir, "consolidated/consolidated.pth", ) if os.path.exists(maybe_consolidated): state_path = maybe_consolidated else: state_path = os.path.join( patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" ) entropy_model, _ = load_entropy_model( patcher_args.entropy_model_checkpoint_dir, state_path, ) entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) self.entropy_model = entropy_model else: self.entropy_model = None self.threshold = patcher_args.threshold self.threshold_add = patcher_args.threshold_add self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time if self.log_time: self.log = defaultdict(float) def patch( self, tokens: torch.Tensor, include_next_token: bool = False, preds: torch.Tensor | None = None, entropies: torch.Tensor | None = None, threshold: float = None, ) -> torch.Tensor: """ tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) -> output tensor: [batch_size, max_num_patches] each tensor is processed independently and gets right padded with zeros. Patching with the following modes: 1. patching_mode = None: static patch size 2. patching_mode = "entropy": calculate entropy of each token, allocate patches so that the total number of patches is the same as static patching but choose to begin patches on tokens where the model is most uncertain (highest entropy). When threshold is provided, it uses the threshold to decide when to start a new patch. 3. patching_mode = "space": use space like tokens to define the patches. 4. patching_mode = "bpe": use bpe delim tokens to define the patches. To correctly patch the last token, it may be necessary to include the next token in the patch lengths calculations. This is controlled by the include_next_token argument. """ bs, seq_len = tokens.shape seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC if self.patching_mode == PatchingModeEnum.byte: patch_lengths = torch.ones( (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device ) else: raise NotImplementedError(f"self.patching_mode {self.patching_mode}") # Apply any processing to patch lengths if self.max_patch_length is not None: # TODO: avoid going back to a list here. patch_lengths = [ split_large_numbers(pl, self.max_patch_length) for pl in patch_lengths.tolist() ] max_len = max([len(pl) for pl in patch_lengths]) patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] patch_lengths = torch.tensor( patch_lengths, dtype=tokens.dtype, device=tokens.device ) assert not check_non_zero_after_zero(patch_lengths) # Find the last non-zero column index using argmax on a reversed version of the tensor last_non_zero_col_reversed = ( (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() ) # Slice the tensor up to the last non-zero column patch_lengths = patch_lengths[ :, : patch_lengths.shape[1] - last_non_zero_col_reversed ] assert ( torch.sum(patch_lengths) == tokens.numel() + include_next_token * tokens.shape[0] ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" if self.log_time: self.log["postprocessing_patch_lengths"] += time.time() - s self.log["tokens"] += patch_lengths.sum().item() return patch_lengths, scores def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) torch.set_default_dtype(torch.bfloat16) model_params = reloaded["entropy_model"] logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" ) entropy_model_args = LMTransformerArgs( dim=model_params["dim"], n_layers=model_params["n_layers"], n_heads=model_params["n_heads"], max_seqlen=model_params["max_seqlen"], ffn_dim_multiplier=model_params["ffn_dim_multiplier"], vocab_size=model_params["vocab_size"], attn_bias_type="local_block_causal", attn_impl="xformers", sliding_window=512, ) entropy_model = LMTransformer(entropy_model_args) entropy_model.load_state_dict( torch.load(state_dict_path, map_location=device)["model"], strict=False ) entropy_model.to(device) entropy_model = entropy_model.eval() # no grads for the model: for param in entropy_model.parameters(): param.requires_grad = False return entropy_model, entropy_model_args def get_encoder_dim_token_emb(args): if args.dim_token is not None: dim_token_emb = args.dim_token elif args.use_local_encoder_transformer: dim_token_emb = args.dim_local_encoder else: dim_token_emb = args.dim_global // args.patch_size return dim_token_emb def get_encoder_dim_patch_emb(args): dim_patch_emb = None if args.cross_attn_encoder: if args.cross_attn_init_by_pooling: dim_patch_emb = args.dim_local_encoder else: dim_patch_emb = args.dim_global return dim_patch_emb def get_global_dim_patch_emb(args): dim_token_emb = get_encoder_dim_token_emb(args) if args.cross_attn_encoder: dim_patch_emb = dim_token_emb * args.cross_attn_k elif ( args.downsampling_by_pooling is None or not args.downsampling_by_pooling or len(args.downsampling_by_pooling) == 0 ): dim_patch_emb = dim_token_emb * args.patch_size else: dim_patch_emb = dim_token_emb * sum( [ pooling in args.downsampling_by_pooling for pooling in ["avg", "min", "max"] ] ) return dim_patch_emb def get_decoder_dim_token_emb(args): if args.share_encoder_decoder_emb: dim_token_emb = get_encoder_dim_token_emb(args) elif args.dim_token is not None: dim_token_emb = args.dim_token else: dim_token_emb = args.dim_local_decoder return dim_token_emb def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: if ngram_to_size_str is None: return None ngram_to_size = {} for entry in ngram_to_size_str.split(","): ngram, size = entry.split(":") ngram = int(ngram) size = int(size) ngram_to_size[ngram] = size return ngram_to_size def fill_tokens(tokens, patch_size, fill_id): batch_size, seq_len = tokens.shape if seq_len % patch_size == 0: return tokens else: remaining = patch_size - seq_len % patch_size final_padding = tokens.new(batch_size, remaining).fill_(fill_id) return torch.cat((tokens, final_padding), dim=1) def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): first_patch_length = patch_lengths[0, 0] assert torch.all( first_patch_length == patch_lengths[:, 0] ), "first patch should always be the same size (1 for dynamic, patch_size for static)." assert ( first_patch_length - nb_boe == 1 ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" # Remove first patch from patch_ids for local decoder inputs and shift the last patch. # decoder_patch_lengths = patch_lengths[:, 1:].clone() # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) decoder_patch_lengths = patch_lengths[:, 1:] assert ( decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] == patch_lengths.sum() ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" decoder_patch_ids = patch_ids_from_lengths( patch_lengths=decoder_patch_lengths, seq_len=seq_len ) return decoder_patch_ids primes = [ 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, ] def rolling_polynomial_hash(t, hash_func_nb: int = 0): prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) return torch.sum(t * prime_powers, dim=-1) def byte_group_hash_function( x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 ): """ Returns a hash of the input x and maps it to a value in the range [0, max_hash]. expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. Note: max hash can make a big difference on the number of collisions. """ with torch.no_grad(): bs, seq_len = x.shape # x_numpy = x.numpy() # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) # for i in range(bs): # for j in range(seq_len): # start = max(j, j-group_size+1) # end = j+1 # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) x = torch.cat([prefix, x], dim=1) windows = x.unfold(1, group_size, 1) # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) hashes = rolling_polynomial_hash(windows, hash_func_nb) hash_values_range = hashes % max_hash hash_values_range.requires_grad = False return hash_values_range def create_patch_mask_from_ids( patch_ids, num_patches, window=None, patches_as_queries=False ): """ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) is True if the patch id at position (i, j) is less than or equal to k. Args: patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. num_patches (int): Total number of patches. window (int): If not None, only considers patches within a window of size window. patches_as_queries (bool): If True, the patches are used as queries Returns: torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. """ bs, seq_len = patch_ids.shape if not patches_as_queries: q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) kv_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(0) .expand(bs, seq_len, num_patches) ) else: kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) q_ids = ( torch.arange(num_patches, device=patch_ids.device) .unsqueeze(0) .unsqueeze(-1) .expand(bs, num_patches, seq_len) ) if window is None: mask = q_ids == kv_ids else: mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) return mask def cross_attn_mask( patch_ids, patch_lengths, N, patches_as_queries=False, cross_attn_k=1, window=None, block_mask=True, ): bs = patch_ids.shape[0] with torch.no_grad(): # Create the patch mask cross_mask = create_patch_mask_from_ids( patch_ids, patch_lengths.shape[1], window=window, patches_as_queries=patches_as_queries, ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k assert cross_mask.shape == ( bs, q_len, kv_len, ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" if block_mask: def patch_mask(b, h, q_idx, kv_idx): return cross_mask[b, q_idx, kv_idx] block_mask = create_block_mask( patch_mask, B=bs, H=None, Q_LEN=q_len, KV_LEN=kv_len, _compile=True, ) return block_mask else: return torch.where( cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) ).unsqueeze( 1 ) # [bs, 1, q_len, kv_len] def get_blt_input( tokens: torch.Tensor, enforce_patch_size_multiple: bool, nb_boe: torch.Tensor, patch_size: int, boe_id: int, ): """ This function returns X_et, X_gt and X_dt, the encoder, global, and decoder tokens respectively. Consider the input and target sequences: X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] with patch_size=4 Note 1: that there will be no special tokens introduced at the patch level. Note 2: X_e needs to be trimmed to be passed to Global Current without boe: X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] --> lag fix: X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] Dynamic (current): X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] entropy patching: input: 7, bos, 9, 10 pred (high entropy): eos, 8, 10, eos X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] --> lag fix no boe (force single byte first patch): X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] input: 4, 7, bos, 9, 10 pred (high entropy): 5, eos, 8, 10, eos X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] Handle the last byte properly. patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] bpe delim X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] X_g = [[3], [4,5,6,7,], [eos,bos,], .. X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. Y = [4,5,6,7,, eos,bos, 8,9,, .. Note 1: that there will be no special tokens introduced at the patch level. Note 2: X_e needs to be trimmed to be passed to Global """ batch_size, seq_len = tokens.shape local_encoder_tokens = tokens local_decoder_tokens = tokens if nb_boe > 0: padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) # create global tokens, contains boe tokens and eos # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] # global_tokens += global_tokens.eq(0).int() * boe_id # TODO: fix this when we want to use block causal in the global. if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) return local_encoder_tokens, None, local_decoder_tokens def patch_ids_from_lengths(patch_lengths, seq_len): bs, num_patches = patch_lengths.shape # Create a tensor of cumulative sums of the patch lengths cum_d = torch.cat( [ torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), patch_lengths.cumsum(dim=-1), ], dim=-1, ) patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( dim=-2 ) - 1 assert not ( torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" return patch_ids class ByteLatentTransformerArgs(BaseTransformerArgs): # Basic model configuration seed: int = 42 vocab_size: int = -1 dim: int = 512 n_layers: int = 8 n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False patch_in_forward: bool = False # Architecture and dimensions dim_token: int | None = None dim_global: int = 512 dim_local_decoder: int = 512 dim_local_encoder: int = 512 n_layers_global: int = 8 n_layers_local_decoder: int = 8 n_layers_local_encoder: int = 8 # Tokenization and patching patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None patching_threshold_add: float | None = None monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" max_patch_length: int | None = None # Encoder/Decoder configuration tie_local_encoder_decoder_logits: bool = False use_local_encoder_transformer: bool = False encoder_lm_loss: bool = False max_encoder_seq_length: int | None = None pad_to_max_length: bool = False encoder_enable_byte_ngrams: bool = False encoder_enable_byte_group_hash: bool = False ngram_vocab_sizes: int | None = None # Cross attention configurations cross_attn_encoder: bool = False cross_attn_decoder: bool = False cross_attn_window_encoder: int | None = None cross_attn_window_decoder: int | None = None cross_attn_k: int | None = None cross_attn_nheads: int | None = None cross_attn_all_layers_decoder: bool = False cross_attn_all_layers_encoder: bool = False cross_attn_use_flex_attention: bool = True cross_attn_init_by_pooling: bool = False # Encoder hash configurations encoder_hash_byte_group_size: Any | None = None encoder_hash_byte_group_vocab: int = 30000 encoder_hash_byte_group_nb_functions: int = 3 # Model behavior and optimization log_patch_lengths: bool = False non_linearity: str = "swiglu" use_rope: bool = True recompute_fc1_out: bool = False recompute_fc3_out: bool = False recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" # Initialization and attention init_use_gaussian: bool = True init_use_depth: str = "current" attn_bias_type: str = "causal" alpha_depth: str = "disabled" max_length: int = 2048 # Norm configuration norm_eps: float = 1e-5 norm_affine: bool = True pre_norm: bool = True norm_type: str = "rmsnorm" # Additional configurations multiple_of: int = 256 ffn_dim_multiplier: float = 1.0 dropout: float = 0 output_size: int = -1 # Additional parameters from ModelArgs architecture: str = "vanilla" share_encoder_decoder_emb: bool = True global_local_decoder_residual_layer: str | None = None tokenize_with_bpe_delimiter: bool = False patching_thresholds_str: str | None = None tie_local_encoder_decoder: bool = False encoder_preds_low_entropy_toks: float | None = None encoder_preds_random_toks: float | None = None dim_token_emb: int | None = None dim_patch_emb: int | None = None encoder_ngram_table_dir: str | None = None encoder_ngram_to_size_str: str | None = None # Model architecture params entropy_model_checkpoint_dir: str | None = None entropy_model_is_ngram_model: bool = False downsampling_by_pooling: str | None = None n_heads_global: int = 8 n_heads_local_decoder: int = 8 n_heads_local_encoder: int = 8 n_kv_heads: int | None = None n_kv_heads_global: int | None = None conv_kernel_size: int | None = None local_attention_window_len: int | None = None # Performance optimization sequence_parallel: bool = False loss_parallel: bool = False fuse_sequence_parallel: bool = False use_fsdp: bool = True attn_to_keep: str = "all" # Parameter mixing pm_size: int = 0 # Logging full_logging_n_layers: int = 4 @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str ): self.encoder_hash_byte_group_size = [ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 ] return self class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None dim_patch_emb: int | None = None def __post_init__(self): # Override base args with global encoder specific values self.dim = self.dim_global self.n_layers = self.n_layers_global self.n_heads = self.n_heads_global self.n_kv_heads = self.n_kv_heads_global self.local_attention_window_len = None self.cross_attn_encoder = False self.cross_attn_decoder = False class LocalDecoderArgs(ByteLatentTransformerArgs): # Local decoder specific dimensions dim_token_emb: int | None = None dim_patch_emb: int | None = None def __post_init__(self): # Override base args with local decoder specific values self.dim = self.dim_local_decoder self.n_layers = self.n_layers_local_decoder self.n_heads = self.n_heads_local_decoder self.cross_attn_encoder = False self.cross_attn_init_by_pooling = False self.attn_bias_type = "local_block_causal" def create_global_transformer(args: ByteLatentTransformerArgs): global_args = args.model_copy( deep=True, update=dict( dim=args.dim_global, n_layers=args.n_layers_global, n_heads=args.n_heads_global, n_kv_heads=args.n_kv_heads_global, local_attention_window_len=None, dim_token_emb=get_global_dim_patch_emb(args), dim_patch_emb=None, cross_attn_encoder=False, cross_attn_decoder=False, ), ) return GlobalTransformer(global_args) class LocalModelArgs(BaseTransformerArgs): model_config = ConfigDict(extra="forbid") # Override defaults attn_impl: str | None = "xformers" attn_bias_type: str | None = "local_block_causal" # Local encoder specific dimensions dropout: float vocab_size: int patch_size: float sliding_window: int | None use_rope: bool cross_attn_encoder: bool | None cross_attn_decoder: bool | None cross_attn_k: int | None cross_attn_init_by_pooling: bool patching_mode: str use_local_encoder_transformer: bool downsampling_by_pooling: str | None encoder_hash_byte_group_size: Any | None = None cross_attn_all_layers_encoder: bool = False cross_attn_all_layers_decoder: bool = False cross_attn_nheads: int | None dim_token_emb: int dim_patch_emb: int | None class LocalModelBase(nn.Module): def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout self.vocab_size = args.vocab_size self.patch_size = args.patch_size self.dim_patch_emb = args.dim_patch_emb self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) self.eos_id = args.eos_id self.boe_id = BOE_ID self.layers = nn.ModuleList( [TransformerBlock(args) for _ in range(args.n_layers)] ) if not self.use_rope: self.pos_embeddings = nn.Embedding(args.max_length, args.dim) else: self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.pos_embeddings = None self.token_embedding_projection = ( nn.Linear(args.dim_token_emb, args.dim, bias=False) if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim else None ) self.patch_embedding_projection = self._create_patch_projection(args) def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( args.cross_attn_encoder and args.cross_attn_init_by_pooling ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions def _create_patch_projection(self, args): if not self._should_create_patch_projection(args): return None output_dim = args.dim_token_emb * (self.cross_attn_k or 1) return nn.Linear( in_features=args.dim_patch_emb, out_features=output_dim, bias=False, ) def apply_embedding(self, tokens, embeds): if embeds is not None: return embeds else: return self.tok_embeddings(tokens) def init_weights(self, init_std=None): self.rope.reset_parameters() if hasattr(self, "norm"): self.norm.reset_parameters() init_std = init_std or (self.dim ** (-0.5)) if hasattr(self, "tok_embeddings"): nn.init.trunc_normal_( self.tok_embeddings.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) if self.pos_embeddings is not None: nn.init.trunc_normal_( self.pos_embeddings.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) for depth, layer in enumerate(self.layers): factor = { InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, InitStdFactor.DIM_RATIO: self.dim / 4096, InitStdFactor.DISABLED: 1.0, }[self.init_std_factor] layer.init_weights(None, factor) if hasattr(self, "output"): nn.init.trunc_normal_( self.output.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) if self.token_embedding_projection is not None: nn.init.trunc_normal_( self.token_embedding_projection.weight, mean=0.0, std=init_std, a=-3 * init_std, b=3 * init_std, ) if self.patch_embedding_projection is not None: patch_emb_std = self.dim_patch_emb ** (-0.5) nn.init.trunc_normal_( self.patch_embedding_projection.weight, mean=0.0, std=patch_emb_std, a=-3 * patch_emb_std, b=3 * patch_emb_std, ) if self.cross_attn_layers is not None: for depth, layer in enumerate(self.cross_attn_layers): factor = { InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, InitStdFactor.DIM_RATIO: self.dim / 4096, InitStdFactor.DISABLED: 1.0, }[self.init_std_factor] layer.init_weights(None, factor) class LocalEncoder(LocalModelBase): def __init__(self, args: LocalModelArgs): super().__init__(args) self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( CrossAttention( dim=self.dim, head_dim=self.dim // self.cross_attn_nheads, n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, ) ) def apply_embedding(self, tokens, embeds): if embeds is not None: assert ( self.expects_hash_embeddings ), "Not expecting embeddings to be passed." return embeds else: return self.tok_embeddings(tokens) def forward( self, tokens: torch.Tensor, embeds: Optional[torch.Tensor] = None, patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, num_patches: Optional[int] = None, patch_ids: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ """ bs, seqlen = tokens.shape if mask is None: mask = create_causal_mask( seqlen, self.attn_impl, "local_block_causal", sliding_window=self.sliding_window, tokens=tokens, eos_id=self.eos_id, ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder ): patch_embeds = self.apply_cross_attention( h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask ) h_residual = patch_embeds if self.cross_attn_encoder else None return (h, h_residual), cache def apply_cross_attention( self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask ): # apply pooling and project if self.cross_attn_init_by_pooling and patch_embeds is None: patch_embeds = downsample( h, num_patches, patch_ids=patch_ids, downsampling_by_pooling=self.downsampling_by_pooling, patch_size=self.patch_size, ) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim ) layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0 patch_embeds_cross = self.cross_attn_layers[layer_idx]( x=patch_embeds, kv=h, mask=cross_mask, ) return patch_embeds + patch_embeds_cross class LocalDecoder(LocalModelBase): def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads self.norm = RMSNorm(args.dim, eps=args.norm_eps) if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 for _ in range(layers_to_add): self.cross_attn_layers.append( CrossAttention( dim=self.dim, head_dim=self.dim // self.cross_attn_nheads, n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, ) ) self.output = nn.Linear( self.dim, args.vocab_size, bias=False, ) def forward( self, tokens: torch.Tensor, embeds: Optional[torch.Tensor], patch_embeds: Optional[torch.Tensor] = None, mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, cross_mask: Optional[torch.Tensor] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): bs, seqlen = tokens.shape assert embeds is not None, "Embeddings must be provided" if mask is None: mask = create_causal_mask( seqlen, self.attn_impl, "local_block_causal", sliding_window=self.sliding_window, tokens=tokens, eos_id=self.eos_id, ) h = embeds if self.patch_embedding_projection is not None: assert patch_embeds is not None, "Patch embeddings must be passed." patch_embeds = self.patch_embedding_projection(patch_embeds) if self.cross_attn_k is not None: patch_embeds = patch_embeds.reshape( bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim ) if patch_embeds is not None and not self.cross_attn_decoder: h = h + patch_embeds freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): if self.cross_attn_decoder and ( i == 0 or self.cross_attn_all_layers_decoder ): # Use cross attention to extract info from patch_embeds into h h_cross = self.cross_attn_layers[i]( x=h, kv=patch_embeds, mask=cross_mask, ) h = h + h_cross h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) h_preds = self.output(h_preds) h_preds = h_preds.float() return h_preds, cache class CrossAttention(nn.Module): """ CrossAttention block to attend to the encoder states from the decoder. Rope is not supported. """ def __init__( self, dim: int, head_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, ): super().__init__() self.dim = dim self.head_dim = head_dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) self.wq = nn.Linear( dim, n_heads * head_dim, bias=False, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, ) def forward( self, x: torch.Tensor, kv: torch.Tensor, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, ) -> torch.Tensor: # B S D bsz, seq_len, _ = x.shape _, slen_kv, _ = kv.shape x_norm = self.cross_attn_norm_q(x) kv = self.cross_attn_norm_kv(kv) xq = self.wq(x_norm) xk = self.wk(kv) xv = self.wv(kv) output_shape = xq.shape # B S D -> B S H D xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) xk = repeat_kv(xk, self.heads_per_group, dim=2) xv = repeat_kv(xv, self.heads_per_group, dim=2) assert mask is None or isinstance(mask, BlockMask) xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D output = self.wo(output.reshape(output_shape)) return x + output def init_weights(self, base_std: float, factor: float = 1.0): std = base_std or (self.dim ** (-0.5)) / factor nn.init.trunc_normal_( self.wq.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) nn.init.trunc_normal_( self.wk.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) nn.init.trunc_normal_( self.wv.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) nn.init.trunc_normal_( self.wo.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) self.cross_attn_norm_q.reset_parameters() self.cross_attn_norm_kv.reset_parameters() class GlobalTransformer(BaseTransformer): def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout self.eos_id = args.eos_id self.dim_token_emb = args.dim_token_emb self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: self.token_embedding_projection = nn.Linear( args.dim_token_emb, args.dim, bias=False, ) def forward( self, tokens: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, embeds: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, ): """ Similar to BaseTransformer.forward, but with an additional embeds argument and projection to the token space. """ bs, seqlen = tokens.shape h = embeds mask = ( mask if mask is not None else create_causal_mask( seqlen, self.attn_impl, self.attn_bias_type, tokens=tokens, eos_id=self.eos_id, ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: h = self.token_embedding_projection(h) h = F.dropout(h, p=self.dropout, training=self.training) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self): super().init_weights() std = self.dim_token_emb ** (-0.5) if self.token_embedding_projection is not None: nn.init.trunc_normal_( self.token_embedding_projection.weight, mean=0.0, std=std, a=-3 * std, b=3 * std, ) def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: local_encoder_args = LocalModelArgs( # Updated args dim=args.dim_local_encoder, n_layers=args.n_layers_local_encoder, n_heads=args.n_heads_local_encoder, dim_token_emb=get_encoder_dim_token_emb(args), dim_patch_emb=get_encoder_dim_patch_emb(args), cross_attn_encoder=args.cross_attn_encoder, cross_attn_decoder=False, cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, # Defaults head_dim=args.head_dim, max_seqlen=args.max_encoder_seq_length, dropout=args.dropout, vocab_size=args.vocab_size + args.pm_size, norm_eps=args.norm_eps, patch_size=args.patch_size, sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, attn_impl=args.attn_impl, attn_bias_type="local_block_causal", multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, patching_mode=args.patching_mode, use_local_encoder_transformer=args.use_local_encoder_transformer, downsampling_by_pooling=args.downsampling_by_pooling, encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, cross_attn_nheads=args.cross_attn_nheads, eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args local_decoder_args = LocalModelArgs( dim=args.dim_local_decoder, n_layers=args.n_layers_local_decoder, n_heads=args.n_heads_local_decoder, dim_token_emb=get_decoder_dim_token_emb(args), dim_patch_emb=args.dim_global, cross_attn_encoder=False, cross_attn_decoder=args.cross_attn_decoder, cross_attn_init_by_pooling=False, # states are already defined cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, # Defaults head_dim=args.head_dim, max_seqlen=args.max_encoder_seq_length, dropout=args.dropout, vocab_size=args.vocab_size + args.pm_size, norm_eps=args.norm_eps, patch_size=args.patch_size, sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, attn_impl=args.attn_impl, attn_bias_type="local_block_causal", multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, patching_mode=args.patching_mode, use_local_encoder_transformer=args.use_local_encoder_transformer, downsampling_by_pooling=args.downsampling_by_pooling, encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, cross_attn_nheads=args.cross_attn_nheads, eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) class EmbeddingType(Enum): HASH_TOK = auto() NGRAM = auto() def init_embeddings( args, embedding_type: EmbeddingType, local_encoder_dim: int, encoder_hash_byte_group_size: list = None, ): if ( embedding_type == EmbeddingType.HASH_TOK and args.encoder_hash_byte_group_size is None ): return None if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: return None embeddings = [] if embedding_type == EmbeddingType.HASH_TOK: emb_dim = local_encoder_dim encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab for _ in range(args.encoder_hash_byte_group_nb_functions): for _ in encoder_hash_byte_group_size: embeddings.append( nn.Embedding( encoder_hash_byte_group_vocab, emb_dim, ) ) elif embedding_type == EmbeddingType.NGRAM: encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) emb_dim = local_encoder_dim OFFSET = 4 # This should be passed as parameter if it's variable for ngram_vocab_size in encoder_ngram_to_size.values(): embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) return nn.ModuleList(embeddings) def compute_hash_embeddings( local_encoder_tokens: torch.Tensor, local_encoder, encoder_hash_tok_embedding: nn.ModuleList, encoder_hash_byte_group_nb_functions: int, encoder_hash_byte_group_size: list, encoder_hash_byte_group_vocab: int, ) -> torch.Tensor: """ Compute embeddings using hash token embeddings. Args: local_encoder_tokens: Input tokens tensor local_encoder: Encoder object with tok_embeddings method encoder_hash_tok_embedding: ModuleList of hash token embeddings encoder_hash_byte_group_nb_functions: Number of hash functions encoder_hash_byte_group_size: List of byte group sizes encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings Returns: torch.Tensor: Combined embeddings """ if encoder_hash_tok_embedding is None: return None local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) i = 0 for func_nb in range(encoder_hash_byte_group_nb_functions): for byte_group_size in encoder_hash_byte_group_size: hash_ids = byte_group_hash_function( local_encoder_tokens, byte_group_size, hash_func_nb=func_nb, max_hash=encoder_hash_byte_group_vocab, ) hash_tok_embedding = encoder_hash_tok_embedding[i] local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) i += 1 assert i == len(encoder_hash_tok_embedding) return local_encoder_embeds class ByteLatentTransformer( nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin, repo_url="https://github.com/facebookresearch/blt", # paper_url="https://arxiv.org/abs/2412.09871", pipeline_tag="text-generation", license="other", license_name="fair-noncommercial-research-license", license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", coders={ ByteLatentTransformerArgs: ( lambda x: {"args": x.model_dump()}, lambda data: ByteLatentTransformerArgs(**data), ) }, ): """ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for improved performance and inference efficiency. """ def __init__(self, args: ByteLatentTransformerArgs): super().__init__() # General configuration self.weight_tying = args.weight_tying self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( BOE_ID, BOS_ID, PAD_ID, EOS_ID, ) self.downsampling_by_pooling = args.downsampling_by_pooling self.patching_threshold = args.patching_threshold self.dim = args.dim self.init_base_std = args.init_base_std self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen # Cross attention configuration self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_k = args.cross_attn_k self.cross_attn_window_encoder = args.cross_attn_window_encoder self.cross_attn_window_decoder = args.cross_attn_window_decoder self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention # Encoder hash configuration self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab self.encoder_hash_byte_group_nb_functions = ( args.encoder_hash_byte_group_nb_functions ) # ByteLatent modules self.local_encoder = create_local_encoder(args) self.global_transformer = create_global_transformer(args) self.local_decoder = create_local_decoder(args) self.encoder_hash_tok_embedding = init_embeddings( args, EmbeddingType.HASH_TOK, local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, ) self.encoder_ngram_embedding = init_embeddings( args, EmbeddingType.NGRAM, local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=None, ) # Encoder ngram embedding tables self.encoder_ngram_embedding = None if args.encoder_enable_byte_ngrams: self.encoder_ngram_embedding = nn.ModuleList() assert args.ngram_vocab_sizes is not None self.encoder_ngram_to_size = parse_ngram_to_size( args.encoder_ngram_to_size_str ) ngram_emb_dim = self.local_encoder.dim for ngram_vocab_size in self.encoder_ngram_to_size.values(): self.encoder_ngram_embedding.append( nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) ) # Output layer assert args.vocab_size > 0, "vocab_size must be greater than 0" # Patcher module if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, patching_mode=args.patching_mode, patching_threshold=args.patching_threshold, patching_threshold_add=args.patching_threshold_add, monotonicity=args.monotonicity, max_patch_length=args.max_patch_length, ) ) def push_to_hub(self, *args, **kwargs): raise ValueError( "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." ) def get_output_seq_len(self): return self.max_seqlen def forward( self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None, ngram_ids: Optional[torch.Tensor] = None, ): # Ensure ngram_ids is either a tensor or None assert ( isinstance(ngram_ids, torch.Tensor) or ngram_ids is None ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" bs, N = tokens.shape # Batch size and sequence length # Get megabyte inputs nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) local_encoder_tokens, _, local_decoder_tokens = get_blt_input( tokens=tokens, enforce_patch_size_multiple=False, nb_boe=nb_boe, patch_size=self.patch_size, boe_id=self.boe_id, ) # Patching if patch_lengths is None: assert ( getattr(self, "patcher", None) is not None ), "Patcher not defined and no patch_lengths passed." patch_lengths, tok_scores = self.patcher.patch( local_encoder_tokens, include_next_token=True, threshold=self.patcher.threshold, ) else: if nb_boe > 0: patch_lengths[:, 0] += nb_boe assert torch.min(patch_lengths) >= 0 # Generate patch IDs from patch_lengths patch_ids = patch_ids_from_lengths( patch_lengths, local_encoder_tokens.shape[-1] ) assert torch.max(patch_ids) + 1 <= torch.max( (patch_lengths != 0).sum(dim=-1) ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" cross_attn_mask_enc = None # Cross-attention encoder if self.cross_attn_encoder: cross_attn_mask_enc = cross_attn_mask( patch_ids, patch_lengths, N, patches_as_queries=True, cross_attn_k=self.cross_attn_k, window=self.cross_attn_window_encoder, block_mask=self.cross_attn_use_flex_attention, ) # Hashing and embedding local_encoder_embeds = compute_hash_embeddings( local_encoder_tokens=local_encoder_tokens, local_encoder=self.local_encoder, encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, ) # N-gram table embeddings if self.encoder_ngram_embedding is not None: assert ngram_ids is not None, "ngram_ids must be provided" if local_encoder_embeds is None: local_encoder_embeds = self.local_encoder.tok_embeddings( local_encoder_tokens ) assert len(ngram_ids) == len( self.encoder_ngram_embedding ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" for i in range(ngram_ids.shape[0]): ngram_embedding = self.encoder_ngram_embedding[i] ngram_embeds = ngram_embedding(ngram_ids[i]) assert ( local_encoder_embeds.shape == ngram_embeds.shape ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" local_encoder_embeds = local_encoder_embeds + ngram_embeds # Local encoder (h_encoder, h_cross), cache_encoder = self.local_encoder( tokens=local_encoder_tokens, embeds=local_encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, ) # Downsampling if not self.cross_attn_encoder: assert ( patch_ids.shape[1] == h_encoder.shape[1] ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}" h = downsample( h_encoder, patch_lengths.shape[1], patch_lengths, patch_ids, downsampling_by_pooling=self.downsampling_by_pooling, patch_size=self.patch_size, ) else: # Reshape h_cross h = h_cross.view(bs, patch_lengths.shape[1], -1) # Global transformer global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) rows, cols = torch.where(local_encoder_tokens == self.eos_id) eos_patch_ids = patch_ids[rows, cols] global_tokens[rows, eos_patch_ids] = self.eos_id h, _ = self.global_transformer( embeds=h, tokens=global_tokens, ) # Unpatching dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] # Generate decoder patch IDs decoder_patch_ids = decoder_patch_ids_from_lengths( patch_lengths, nb_boe, local_decoder_tokens.shape[-1] ) assert ( torch.max(decoder_patch_ids) + 1 <= h.shape[1] ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" assert ( decoder_patch_ids.shape[1] == dec_embeds.shape[1] ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" # Cross-attention decoder if not self.cross_attn_decoder: h = torch.gather( h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) ) cross_attn_mask_dec = None assert local_decoder_tokens.shape == h.shape[:-1] else: cross_attn_mask_dec = cross_attn_mask( decoder_patch_ids, patch_lengths, N, patches_as_queries=False, cross_attn_k=self.cross_attn_k, window=self.cross_attn_window_decoder, block_mask=self.cross_attn_use_flex_attention, ) # Local decoder output, _ = self.local_decoder( embeds=dec_embeds, patch_embeds=h, tokens=local_decoder_tokens, cross_mask=cross_attn_mask_dec, ) return output def init_weights(self): self.local_encoder.init_weights() self.global_transformer.init_weights() self.local_decoder.init_weights() emb_std = self.local_encoder.dim ** (-0.5) for emb in self.encoder_hash_tok_embedding: nn.init.trunc_normal_( emb.weight, mean=0.0, std=emb_std, a=-3 * emb_std, b=3 * emb_std, )