diff --git a/blt_one_file.py b/blt_one_file.py new file mode 100644 index 0000000..38219c7 --- /dev/null +++ b/blt_one_file.py @@ -0,0 +1,2618 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum, auto +from typing import Any, List, Optional, Tuple, Union + +import torch +from huggingface_hub import PyTorchModelHubMixin +from pydantic import model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention +from typing_extensions import Self +import json +import logging + +import torch +import torch.nn +import torch.nn as nn +from pydantic import ConfigDict +from torch.nn import functional as F +from xformers.ops import AttentionBias, fmha + +import abc + +import os +import time +from collections import defaultdict + +from pydantic import BaseModel +from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID + +RMSNorm = nn.RMSNorm + +from bytelatent.distributed import get_local_rank + +logger = logging.getLogger() + + +if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: + flex_attention_comp = torch.compile(flex_attention) +else: + flex_attention_comp = None + + +def patch_reduce(h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +def concat_downsample(h, patch_lengths, patch_size): + # The assumption in this function is that seq_len = patch_size * num_patches. + bs, seq_len, emb_dim = h.shape + patch_end_ids = torch.cumsum(patch_lengths, dim=1) + patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( + patch_end_ids.device + ) + # Is clamp ok here? + patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) + patch_ids = patch_ids.view(bs, -1, emb_dim) + # after gather h.shape = [batch_size, seq_len, dim] + h = torch.gather(h, 1, patch_ids) + h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) + return h + + +def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): + cat = [] + if "avg" in pooling_mode or "mean" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) + if "min" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) + if "max" in pooling_mode: + cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) + assert len(cat) > 0 + h = torch.cat(cat, dim=-1) + return h + + +def downsample( + h, + num_patches, + patch_lengths=None, + patch_ids=None, + downsampling_by_pooling=None, + patch_size=4, +): + """ + Downsampling: + a. concatenating embeddings in the patch + Note: with dynamic patching, patch the last patch_size tokens. + b. pooling embeddings in the patch + """ + # input: h.shape = [batch_size, seq_len, dim] + # input: pool h.shape = [batch_size, seq_len / patch_size, dim] + # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep + if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: + # By pooling + max_num_patches = num_patches + assert patch_ids is not None + h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) + else: + # TODO: remove this condition + # By concatenating (fixed lengths patching) + assert patch_lengths is not None + h = concat_downsample(h, patch_lengths, patch_size) + return h + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) + elif attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + elif attn_impl == "fmha": + return None + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dim: int = 512 + n_layers: int = 8 + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None + + ffn_dim_multiplier: float | None = None + + multiple_of: int = 256 + + norm_eps: float = 1e-5 + + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + init_base_std: float | None = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + max_seqlen: int = 1024 + + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [ + d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) + ] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast( + freqs_cis, xq_, seq_dim + ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + def reset_parameters(self): + self.freqs_cis[...] = precompute_freqs_cis( + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ) + + def forward( + self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None + ): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "xformers": + assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) + output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) + # This uses B S H D instead of B H S D of pytorch + + elif attn_impl == "sdpa": + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError( + f"Attention implementation {attn_impl} not supported" + ) + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + def reset_parameters(self, init_std=None, factor=1.0): + init_std = init_std or (self.dim ** (-0.5)) / factor + + for w in [self.wq, self.wk, self.wv]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + def reset_parameters(self, init_std=None, factor=1.0): + in_init_std = init_std or (self.dim ** (-0.5)) / factor + out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.w1.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + nn.init.trunc_normal_( + self.w2.weight, + mean=0.0, + std=out_init_std, + a=-3 * out_init_std, + b=3 * out_init_std, + ) + nn.init.trunc_normal_( + self.w3.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + + +class TransformerBlock(nn.Module): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + + assert (args.head_dim is not None) or ( + args.n_heads is not None + ), "Should specify at least head_dim or n_heads" + self.head_dim = args.head_dim or args.dim // args.n_heads + self.n_heads = args.n_heads or args.dim // args.head_dim + self.n_kv_heads = args.n_kv_heads or self.n_heads + + assert args.n_heads % self.n_kv_heads == 0 + assert args.dim % args.n_heads == 0 + + self.attention = Attention( + dim=args.dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=args.rope_theta, + ) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out + + def init_weights(self, init_std=None, factor=1.0): + self.attention.reset_parameters(init_std, factor) + self.attention_norm.reset_parameters() + + self.feed_forward.reset_parameters(init_std, factor) + self.ffn_norm.reset_parameters() + + +class SequenceModelWithOutput(abc.ABC): + @abc.abstractmethod + def get_output_seq_len(self) -> int: + pass + + +class BaseTransformer(nn.Module, SequenceModelWithOutput): + def __init__(self, args: BaseTransformerArgs): + super().__init__() + self.dim = args.dim + self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + ) + self.eos_id = args.eos_id + + self.layers = nn.ModuleList() + for _ in range(args.n_layers): + self.layers.append(TransformerBlock(args)) + + def get_output_seq_len(self): + return self.max_seqlen + + def forward( + self, + h, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + attn_impl: str = "sdpa", + ): + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + return h + + def init_weights(self): + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(self.init_base_std, factor) + + +class LMTransformerArgs(BaseTransformerArgs): + seed: int = 42 + + vocab_size: int = -1 + weight_tying: bool = False + + sliding_window: int | None = None + + +class LMTransformer( + BaseTransformer, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + # paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + LMTransformerArgs: ( + lambda x: {"args": x.model_dump()}, + lambda data: LMTransformerArgs(**data), + ) + }, +): + def __init__(self, args: LMTransformerArgs): + super().__init__(args) + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + + assert args.vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = nn.Linear( + args.dim, + args.vocab_size, + bias=False, + ) + + if args.weight_tying: + self.output.weight = self.embeddings.tok_embeddings.weight + + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + attn_impl: str | None = None, + ): + if attn_impl is None: + attn_impl = self.attn_impl + bsz, seqlen = token_values.shape + + h = self.tok_embeddings(token_values) + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) + ) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + + logits = self.output(self.norm(h)) + if target is not None: + return cross_entropy(logits, target) + else: + return logits + + def reset_parameters(self, init_std=None): + self.norm.reset_parameters() + + def init_weights(self): + self.reset_parameters() + init_std = self.dim ** (-0.5) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + super().init_weights() + + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + + + + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class PatcherArgs(BaseModel): + patching_mode: PatchingModeEnum = PatchingModeEnum.entropy + patching_device: str = "cuda" + entropy_model_checkpoint_dir: str | None = None + realtime_patching: bool = False + threshold: float = 1.335442066192627 + threshold_add: float | None = None + max_patch_length: int | None = None + patch_size: float = 4.5 + patching_batch_size: int = 1 + device: str = "cuda" + monotonicity: bool = False + log_time: bool = False + + def build(self) -> "Patcher": + return Patcher(self) + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def to_device(entropy_model, device=None): + if device == "cuda": + rank = get_local_rank() + device = f"cuda:{rank}" + entropy_model = entropy_model.to(device) + return entropy_model, device + + +def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +class Patcher: + def __init__(self, patcher_args: PatcherArgs): + self.patcher_args = patcher_args + self.patching_mode = patcher_args.patching_mode + self.realtime_patching = patcher_args.realtime_patching + if self.realtime_patching: + assert ( + patcher_args.entropy_model_checkpoint_dir is not None + ), "Cannot require realtime patching without an entropy model checkpoint" + maybe_consolidated = os.path.join( + patcher_args.entropy_model_checkpoint_dir, + "consolidated/consolidated.pth", + ) + if os.path.exists(maybe_consolidated): + state_path = maybe_consolidated + else: + state_path = os.path.join( + patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" + ) + entropy_model, _ = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir, + state_path, + ) + entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) + self.entropy_model = entropy_model + else: + self.entropy_model = None + self.threshold = patcher_args.threshold + self.threshold_add = patcher_args.threshold_add + self.max_patch_length = patcher_args.max_patch_length + self.patch_size = patcher_args.patch_size + self.patching_batch_size = patcher_args.patching_batch_size + self.device = patcher_args.device + self.monotonicity = patcher_args.monotonicity + self.log_time = patcher_args.log_time + if self.log_time: + self.log = defaultdict(float) + + def patch( + self, + tokens: torch.Tensor, + include_next_token: bool = False, + preds: torch.Tensor | None = None, + entropies: torch.Tensor | None = None, + threshold: float = None, + ) -> torch.Tensor: + """ + tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched + Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.) + -> output tensor: [batch_size, max_num_patches] + each tensor is processed independently and gets right padded with zeros. + + Patching with the following modes: + 1. patching_mode = None: static patch size + 2. patching_mode = "entropy": + calculate entropy of each token, allocate patches so that the total + number of patches is the same as static patching but choose to begin + patches on tokens where the model is most uncertain (highest entropy). + + When threshold is provided, it uses the threshold to decide when to + start a new patch. + 3. patching_mode = "space": + use space like tokens to define the patches. + 4. patching_mode = "bpe": + use bpe delim tokens to define the patches. + + To correctly patch the last token, it may be necessary to include the next token in the patch + lengths calculations. This is controlled by the include_next_token argument. + """ + bs, seq_len = tokens.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + scores = None + # STATIC + if self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) + else: + raise NotImplementedError(f"self.patching_mode {self.patching_mode}") + + # Apply any processing to patch lengths + if self.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + split_large_numbers(pl, self.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=tokens.dtype, device=tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + assert ( + torch.sum(patch_lengths) + == tokens.numel() + include_next_token * tokens.shape[0] + ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}" + if self.log_time: + self.log["postprocessing_patch_lengths"] += time.time() - s + self.log["tokens"] += patch_lengths.sum().item() + return patch_lengths, scores + + + +def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): + with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["entropy_model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) + entropy_model_args = LMTransformerArgs( + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + max_seqlen=model_params["max_seqlen"], + ffn_dim_multiplier=model_params["ffn_dim_multiplier"], + vocab_size=model_params["vocab_size"], + attn_bias_type="local_block_causal", + attn_impl="xformers", + sliding_window=512, + ) + entropy_model = LMTransformer(entropy_model_args) + + entropy_model.load_state_dict( + torch.load(state_dict_path, map_location=device)["model"], strict=False + ) + entropy_model.to(device) + entropy_model = entropy_model.eval() + # no grads for the model: + for param in entropy_model.parameters(): + param.requires_grad = False + return entropy_model, entropy_model_args + + +def get_encoder_dim_token_emb(args): + if args.dim_token is not None: + dim_token_emb = args.dim_token + elif args.use_local_encoder_transformer: + dim_token_emb = args.dim_local_encoder + else: + dim_token_emb = args.dim_global // args.patch_size + return dim_token_emb + + +def get_encoder_dim_patch_emb(args): + dim_patch_emb = None + if args.cross_attn_encoder: + if args.cross_attn_init_by_pooling: + dim_patch_emb = args.dim_local_encoder + else: + dim_patch_emb = args.dim_global + return dim_patch_emb + + +def get_global_dim_patch_emb(args): + dim_token_emb = get_encoder_dim_token_emb(args) + if args.cross_attn_encoder: + dim_patch_emb = dim_token_emb * args.cross_attn_k + elif ( + args.downsampling_by_pooling is None + or not args.downsampling_by_pooling + or len(args.downsampling_by_pooling) == 0 + ): + dim_patch_emb = dim_token_emb * args.patch_size + else: + dim_patch_emb = dim_token_emb * sum( + [ + pooling in args.downsampling_by_pooling + for pooling in ["avg", "min", "max"] + ] + ) + return dim_patch_emb + + +def get_decoder_dim_token_emb(args): + if args.share_encoder_decoder_emb: + dim_token_emb = get_encoder_dim_token_emb(args) + elif args.dim_token is not None: + dim_token_emb = args.dim_token + else: + dim_token_emb = args.dim_local_decoder + return dim_token_emb + + +def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]: + if ngram_to_size_str is None: + return None + ngram_to_size = {} + for entry in ngram_to_size_str.split(","): + ngram, size = entry.split(":") + ngram = int(ngram) + size = int(size) + ngram_to_size[ngram] = size + return ngram_to_size + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): + first_patch_length = patch_lengths[0, 0] + assert torch.all( + first_patch_length == patch_lengths[:, 0] + ), "first patch should always be the same size (1 for dynamic, patch_size for static)." + assert ( + first_patch_length - nb_boe == 1 + ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})" + # Remove first patch from patch_ids for local decoder inputs and shift the last patch. + # decoder_patch_lengths = patch_lengths[:, 1:].clone() + # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1) + decoder_patch_lengths = patch_lengths[:, 1:] + assert ( + decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0] + == patch_lengths.sum() + ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}" + assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}" + decoder_patch_ids = patch_ids_from_lengths( + patch_lengths=decoder_patch_lengths, seq_len=seq_len + ) + return decoder_patch_ids + + +primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, +] + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + +def byte_group_hash_function( + x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + # x_numpy = x.numpy() + # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False) + # for i in range(bs): + # for j in range(seq_len): + # start = max(j, j-group_size+1) + # end = j+1 + # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash) + + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids( + patch_ids, num_patches, window=None, patches_as_queries=False +): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where( + cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) + ).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +def patch_ids_from_lengths(patch_lengths, seq_len): + bs, num_patches = patch_lengths.shape + # Create a tensor of cumulative sums of the patch lengths + cum_d = torch.cat( + [ + torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1), + ], + dim=-1, + ) + patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum( + dim=-2 + ) - 1 + assert not ( + torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0 + ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0" + return patch_ids + + +class ByteLatentTransformerArgs(BaseTransformerArgs): + # Basic model configuration + seed: int = 42 + vocab_size: int = -1 + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + # TODO: What is the purpose of this parameter? + weight_tying: bool = False + patch_in_forward: bool = False + + # Architecture and dimensions + dim_token: int | None = None + dim_global: int = 512 + dim_local_decoder: int = 512 + dim_local_encoder: int = 512 + n_layers_global: int = 8 + n_layers_local_decoder: int = 8 + n_layers_local_encoder: int = 8 + + # Tokenization and patching + patch_size: float | None = None + patching_mode: str | None = None + patching_threshold: float | None = None + patching_threshold_add: float | None = None + monotonicity: bool = False + patching_batch_size: int = 1 + patching_device: str = "cuda" + max_patch_length: int | None = None + + # Encoder/Decoder configuration + tie_local_encoder_decoder_logits: bool = False + use_local_encoder_transformer: bool = False + encoder_lm_loss: bool = False + max_encoder_seq_length: int | None = None + pad_to_max_length: bool = False + encoder_enable_byte_ngrams: bool = False + encoder_enable_byte_group_hash: bool = False + ngram_vocab_sizes: int | None = None + + # Cross attention configurations + cross_attn_encoder: bool = False + cross_attn_decoder: bool = False + cross_attn_window_encoder: int | None = None + cross_attn_window_decoder: int | None = None + cross_attn_k: int | None = None + cross_attn_nheads: int | None = None + cross_attn_all_layers_decoder: bool = False + cross_attn_all_layers_encoder: bool = False + cross_attn_use_flex_attention: bool = True + cross_attn_init_by_pooling: bool = False + + # Encoder hash configurations + encoder_hash_byte_group_size: Any | None = None + encoder_hash_byte_group_vocab: int = 30000 + encoder_hash_byte_group_nb_functions: int = 3 + + # Model behavior and optimization + log_patch_lengths: bool = False + non_linearity: str = "swiglu" + use_rope: bool = True + recompute_fc1_out: bool = False + recompute_fc3_out: bool = False + recompute_attn: bool = True + custom_bwd: bool = False + layer_ckpt: str = "all" + + # Initialization and attention + init_use_gaussian: bool = True + init_use_depth: str = "current" + attn_bias_type: str = "causal" + alpha_depth: str = "disabled" + max_length: int = 2048 + + # Norm configuration + norm_eps: float = 1e-5 + norm_affine: bool = True + pre_norm: bool = True + norm_type: str = "rmsnorm" + + # Additional configurations + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.0 + dropout: float = 0 + output_size: int = -1 + + # Additional parameters from ModelArgs + architecture: str = "vanilla" + share_encoder_decoder_emb: bool = True + global_local_decoder_residual_layer: str | None = None + + tokenize_with_bpe_delimiter: bool = False + patching_thresholds_str: str | None = None + tie_local_encoder_decoder: bool = False + encoder_preds_low_entropy_toks: float | None = None + encoder_preds_random_toks: float | None = None + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + encoder_ngram_table_dir: str | None = None + encoder_ngram_to_size_str: str | None = None + + # Model architecture params + entropy_model_checkpoint_dir: str | None = None + entropy_model_is_ngram_model: bool = False + downsampling_by_pooling: str | None = None + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads: int | None = None + n_kv_heads_global: int | None = None + conv_kernel_size: int | None = None + local_attention_window_len: int | None = None + + # Performance optimization + sequence_parallel: bool = False + loss_parallel: bool = False + fuse_sequence_parallel: bool = False + use_fsdp: bool = True + attn_to_keep: str = "all" + + # Parameter mixing + pm_size: int = 0 + + # Logging + full_logging_n_layers: int = 4 + + @model_validator(mode="after") + def check_hash_byte_sizes(self) -> Self: + if ( + self.encoder_hash_byte_group_size is not None + and type(self.encoder_hash_byte_group_size) == str + ): + self.encoder_hash_byte_group_size = [ + int(x) + for x in self.encoder_hash_byte_group_size.split(",") + if len(x) > 0 + ] + return self + + +class GlobalTransformerArgs(ByteLatentTransformerArgs): + # Global encoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with global encoder specific values + self.dim = self.dim_global + self.n_layers = self.n_layers_global + self.n_heads = self.n_heads_global + self.n_kv_heads = self.n_kv_heads_global + self.local_attention_window_len = None + self.cross_attn_encoder = False + self.cross_attn_decoder = False + + +class LocalDecoderArgs(ByteLatentTransformerArgs): + # Local decoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local decoder specific values + self.dim = self.dim_local_decoder + self.n_layers = self.n_layers_local_decoder + self.n_heads = self.n_heads_local_decoder + self.cross_attn_encoder = False + self.cross_attn_init_by_pooling = False + self.attn_bias_type = "local_block_causal" + + +def create_global_transformer(args: ByteLatentTransformerArgs): + global_args = args.model_copy( + deep=True, + update=dict( + dim=args.dim_global, + n_layers=args.n_layers_global, + n_heads=args.n_heads_global, + n_kv_heads=args.n_kv_heads_global, + local_attention_window_len=None, + dim_token_emb=get_global_dim_patch_emb(args), + dim_patch_emb=None, + cross_attn_encoder=False, + cross_attn_decoder=False, + ), + ) + + return GlobalTransformer(global_args) + + +class LocalModelArgs(BaseTransformerArgs): + model_config = ConfigDict(extra="forbid") + # Override defaults + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + + # Local encoder specific dimensions + dropout: float + vocab_size: int + patch_size: float + sliding_window: int | None + use_rope: bool + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + dim_token_emb: int + dim_patch_emb: int | None + + +class LocalModelBase(nn.Module): + def __init__(self, args: LocalModelArgs): + super().__init__() + + self.dim = args.dim + self.dropout = args.dropout + self.vocab_size = args.vocab_size + self.patch_size = args.patch_size + self.dim_patch_emb = args.dim_patch_emb + + self.attn_impl = args.attn_impl + self.sliding_window = args.sliding_window + self.use_rope = args.use_rope + self.init_std_factor = args.init_std_factor + self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) + self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id + + self.boe_id = BOE_ID + + self.layers = nn.ModuleList( + [TransformerBlock(args) for _ in range(args.n_layers)] + ) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + else: + self.rope = RotaryEmbedding( + theta=args.rope_theta, + head_dim=args.head_dim or args.dim // args.n_heads, + max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None + + self.token_embedding_projection = ( + nn.Linear(args.dim_token_emb, args.dim, bias=False) + if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(args) + + def _should_create_patch_projection(self, args: LocalModelArgs): + dimension_mismatch = ( + getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim + ) + + # Check cross attention conditions + cross_attn_conditions = ( + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, args): + if not self._should_create_patch_projection(args): + return None + + output_dim = args.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=args.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + def init_weights(self, init_std=None): + self.rope.reset_parameters() + if hasattr(self, "norm"): + self.norm.reset_parameters() + + init_std = init_std or (self.dim ** (-0.5)) + if hasattr(self, "tok_embeddings"): + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if self.pos_embeddings is not None: + nn.init.trunc_normal_( + self.pos_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + for depth, layer in enumerate(self.layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(None, factor) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.patch_embedding_projection is not None: + patch_emb_std = self.dim_patch_emb ** (-0.5) + nn.init.trunc_normal_( + self.patch_embedding_projection.weight, + mean=0.0, + std=patch_emb_std, + a=-3 * patch_emb_std, + b=3 * patch_emb_std, + ) + + if self.cross_attn_layers is not None: + for depth, layer in enumerate(self.cross_attn_layers): + factor = { + InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, + InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, + InitStdFactor.DIM_RATIO: self.dim / 4096, + InitStdFactor.DISABLED: 1.0, + }[self.init_std_factor] + + layer.init_weights(None, factor) + + +class LocalEncoder(LocalModelBase): + def __init__(self, args: LocalModelArgs): + super().__init__(args) + + self.apply_transformer = args.use_local_encoder_transformer + self.downsampling_by_pooling = args.downsampling_by_pooling + self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert ( + self.expects_hash_embeddings + ), "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and ( + i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder + ): + patch_embeds = self.apply_cross_attention( + h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask + ) + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + def apply_cross_attention( + self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask + ): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = downsample( + h, + num_patches, + patch_ids=patch_ids, + downsampling_by_pooling=self.downsampling_by_pooling, + patch_size=self.patch_size, + ) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + return patch_embeds + patch_embeds_cross + + +class LocalDecoder(LocalModelBase): + def __init__(self, args: LocalModelArgs): + super().__init__(args) + + # Model configuration flags + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling + self.cross_attn_nheads = args.cross_attn_nheads + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + CrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=args.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + args.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and ( + i == 0 or self.cross_attn_all_layers_decoder + ): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache + + +class CrossAttention(nn.Module): + """ + CrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, AttentionBias, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + def init_weights(self, base_std: float, factor: float = 1.0): + std = base_std or (self.dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.wq.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wk.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wv.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + self.cross_attn_norm_q.reset_parameters() + self.cross_attn_norm_kv.reset_parameters() + + +class GlobalTransformer(BaseTransformer): + def __init__(self, args: BaseTransformerArgs): + super().__init__(args) + self.dropout = args.dropout + self.eos_id = args.eos_id + self.dim_token_emb = args.dim_token_emb + + self.token_embedding_projection = None + if args.dim_token_emb is not None and args.dim_token_emb != self.dim: + self.token_embedding_projection = nn.Linear( + args.dim_token_emb, + args.dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ + Similar to BaseTransformer.forward, but with an additional embeds argument + and projection to the token space. + """ + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + return h, cache + + def init_weights(self): + super().init_weights() + std = self.dim_token_emb ** (-0.5) + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + + +def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, + ) + + return LocalEncoder(local_encoder_args) + + +def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: + # First deep copy the original args + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, + ) + + return LocalDecoder(local_decoder_args) + + +class EmbeddingType(Enum): + HASH_TOK = auto() + NGRAM = auto() + + +def init_embeddings( + args, + embedding_type: EmbeddingType, + local_encoder_dim: int, + encoder_hash_byte_group_size: list = None, +): + if ( + embedding_type == EmbeddingType.HASH_TOK + and args.encoder_hash_byte_group_size is None + ): + return None + if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None: + return None + + embeddings = [] + + if embedding_type == EmbeddingType.HASH_TOK: + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + for _ in range(args.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + elif embedding_type == EmbeddingType.NGRAM: + encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str) + emb_dim = local_encoder_dim + OFFSET = 4 # This should be passed as parameter if it's variable + for ngram_vocab_size in encoder_ngram_to_size.values(): + embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) + + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class ByteLatentTransformer( + nn.Module, + SequenceModelWithOutput, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + # paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + ByteLatentTransformerArgs: ( + lambda x: {"args": x.model_dump()}, + lambda data: ByteLatentTransformerArgs(**data), + ) + }, +): + """ + The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + improved performance and inference efficiency. + """ + + def __init__(self, args: ByteLatentTransformerArgs): + super().__init__() + + # General configuration + self.weight_tying = args.weight_tying + self.patch_size = args.patch_size + self.patching_mode = args.patching_mode + self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( + BOE_ID, + BOS_ID, + PAD_ID, + EOS_ID, + ) + self.downsampling_by_pooling = args.downsampling_by_pooling + self.patching_threshold = args.patching_threshold + self.dim = args.dim + self.init_base_std = args.init_base_std + self.init_std_factor = InitStdFactor(args.init_std_factor) + self.max_seqlen = args.max_seqlen + + # Cross attention configuration + self.cross_attn_encoder = args.cross_attn_encoder + self.cross_attn_decoder = args.cross_attn_decoder + self.cross_attn_k = args.cross_attn_k + self.cross_attn_window_encoder = args.cross_attn_window_encoder + self.cross_attn_window_decoder = args.cross_attn_window_decoder + self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention + + # Encoder hash configuration + self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = ( + args.encoder_hash_byte_group_nb_functions + ) + + # ByteLatent modules + self.local_encoder = create_local_encoder(args) + self.global_transformer = create_global_transformer(args) + self.local_decoder = create_local_decoder(args) + self.encoder_hash_tok_embedding = init_embeddings( + args, + EmbeddingType.HASH_TOK, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + ) + self.encoder_ngram_embedding = init_embeddings( + args, + EmbeddingType.NGRAM, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=None, + ) + + # Encoder ngram embedding tables + self.encoder_ngram_embedding = None + if args.encoder_enable_byte_ngrams: + self.encoder_ngram_embedding = nn.ModuleList() + assert args.ngram_vocab_sizes is not None + self.encoder_ngram_to_size = parse_ngram_to_size( + args.encoder_ngram_to_size_str + ) + ngram_emb_dim = self.local_encoder.dim + for ngram_vocab_size in self.encoder_ngram_to_size.values(): + self.encoder_ngram_embedding.append( + nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) + ) + + # Output layer + assert args.vocab_size > 0, "vocab_size must be greater than 0" + + # Patcher module + if args.patch_in_forward: + self.patcher = Patcher( + PatcherArgs( + patch_size=args.patch_size, + patching_mode=args.patching_mode, + patching_threshold=args.patching_threshold, + patching_threshold_add=args.patching_threshold_add, + monotonicity=args.monotonicity, + max_patch_length=args.max_patch_length, + ) + ) + + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + + def get_output_seq_len(self): + return self.max_seqlen + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ngram_ids: Optional[torch.Tensor] = None, + ): + # Ensure ngram_ids is either a tensor or None + assert ( + isinstance(ngram_ids, torch.Tensor) or ngram_ids is None + ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}" + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.patch_size, + boe_id=self.boe_id, + ) + + # Patching + if patch_lengths is None: + assert ( + getattr(self, "patcher", None) is not None + ), "Patcher not defined and no patch_lengths passed." + patch_lengths, tok_scores = self.patcher.patch( + local_encoder_tokens, + include_next_token=True, + threshold=self.patcher.threshold, + ) + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + assert torch.max(patch_ids) + 1 <= torch.max( + (patch_lengths != 0).sum(dim=-1) + ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_encoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab, + ) + + # N-gram table embeddings + if self.encoder_ngram_embedding is not None: + assert ngram_ids is not None, "ngram_ids must be provided" + if local_encoder_embeds is None: + local_encoder_embeds = self.local_encoder.tok_embeddings( + local_encoder_tokens + ) + assert len(ngram_ids) == len( + self.encoder_ngram_embedding + ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}" + for i in range(ngram_ids.shape[0]): + ngram_embedding = self.encoder_ngram_embedding[i] + ngram_embeds = ngram_embedding(ngram_ids[i]) + assert ( + local_encoder_embeds.shape == ngram_embeds.shape + ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}" + local_encoder_embeds = local_encoder_embeds + ngram_embeds + + # Local encoder + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + if not self.cross_attn_encoder: + assert ( + patch_ids.shape[1] == h_encoder.shape[1] + ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}" + h = downsample( + h_encoder, + patch_lengths.shape[1], + patch_lengths, + patch_ids, + downsampling_by_pooling=self.downsampling_by_pooling, + patch_size=self.patch_size, + ) + else: + # Reshape h_cross + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.eos_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.eos_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = decoder_patch_ids_from_lengths( + patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + ) + assert ( + torch.max(decoder_patch_ids) + 1 <= h.shape[1] + ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert ( + decoder_patch_ids.shape[1] == dec_embeds.shape[1] + ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + + # Cross-attention decoder + if not self.cross_attn_decoder: + h = torch.gather( + h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + ) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.cross_attn_k, + window=self.cross_attn_window_decoder, + block_mask=self.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def init_weights(self): + self.local_encoder.init_weights() + self.global_transformer.init_weights() + self.local_decoder.init_weights() + + emb_std = self.local_encoder.dim ** (-0.5) + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, + ) diff --git a/demo_hf.py b/demo_hf.py new file mode 100644 index 0000000..40d5c7d --- /dev/null +++ b/demo_hf.py @@ -0,0 +1,211 @@ +#demo_hf.py + +import os + +import torch +import typer + +from blt_one_file import ByteLatentTransformer, ByteLatentTransformerArgs +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer + +from huggingface_hub import hf_hub_download +import json +#generatel_blt_consolidated.py + +import logging +import os + +import torch + +from blt_one_file import Patcher +from bytelatent.distributed import ( + dist_max, + dist_min, +) +from blt_one_file import ByteLatentTransformer +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer + +logger = logging.getLogger() + +def get_generation_range( + prompt_tokens: list[list[int]] | None, max_gen_len: int +) -> tuple[int, int]: + batch_min_prompt_length = min([len(t) for t in prompt_tokens]) + batch_max_prompt_length = max([len(t) for t in prompt_tokens]) + return batch_min_prompt_length, batch_max_prompt_length + max_gen_len + + +def sample_top_k(probs, k): + topk_value, _ = torch.topk(probs, k) # batch_sz x topk + min_value_top_k = topk_value[:, [-1]] + probs[probs < min_value_top_k] = 0.0 + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +@torch.inference_mode() +def generate_nocache( + prompts: list[str] | None, + *, + model: ByteLatentTransformer, + tokenizer: BltTokenizer, + patcher: Patcher, + max_prompt_len: int = 256, + max_gen_len: int = 256, + use_sampling: bool = False, + temp: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + remove_prompts: bool = True, +) -> list[list[int]]: + assert ( + patcher.realtime_patching + ), "generate_nocache requires patcher.realtime_patching=True" + model.eval() + prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] + # Truncation + prompt_tokens = [ + t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] + for t in prompt_tokens + ] + start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) + batch_size = len(prompt_tokens) + tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long() + + # Copy inputs to tensor for generated tokens + for i, row_tokens in enumerate(prompt_tokens): + tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long() + input_text_mask = tokens != tokenizer.pad_id + + for i, curr_pos in enumerate(range(start_pos, end_pos)): + current_tokens = tokens[:, :curr_pos] + patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True) + logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] + + if use_sampling: + probs = torch.softmax(logits / temp, dim=-1) + if top_p > 0.0: + next_token = sample_top_p(probs, top_p) + elif top_k > 0: + next_token = sample_top_k(probs, top_k) + else: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1) + + next_token = torch.where( + input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token + ) + tokens[:, curr_pos] = next_token + + if remove_prompts: + generated_tokens = [ + t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() + for i, t in enumerate(tokens) + ] + else: + generated_tokens = [ + t[: len(prompt_tokens[i]) + max_gen_len].tolist() + for i, t in enumerate(tokens) + ] + return generated_tokens + + + +def main(prompt: str = "my name is", model_name: str = "blt-1b"): + # distributed_args = DistributedArgs() + # distributed_args.configure_world() + # if not torch.distributed.is_initialized(): + # setup_torch_distributed(distributed_args) + + # Set device and ensure CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required but not available") + device = torch.device("cuda") + torch.cuda.empty_cache() # Clear any existing CUDA memory + + assert model_name in ["blt-1b", "blt-7b"] + model_name = model_name.replace("-", "_") + + #HF + blt_repo = "facebook/blt-1b" + + # Get the model's default configuration and entropy model params + print("Loading model configuration...") + config_path = hf_hub_download(repo_id=blt_repo, filename="config.json") + entropy_params_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/params.json") + + with open(config_path, 'r') as f: + config = json.load(f) + with open(entropy_params_path, 'r') as f: + entropy_params = json.load(f) + + # Create model args from config + model_args = ByteLatentTransformerArgs(**config["args"]) + + # Update patch parameters from entropy model params + patcher_args = entropy_params["data"]["patcher_args"] + model_args.patch_in_forward = True + model_args.patch_size = patcher_args["patch_size"] + model_args.patching_mode = patcher_args["patching_mode"] + model_args.patching_threshold = patcher_args["threshold"] + model_args.patching_threshold_add = patcher_args["threshold_add"] + model_args.max_patch_length = patcher_args["max_patch_length"] + model_args.patching_batch_size = patcher_args["patching_batch_size"] + model_args.patching_device = patcher_args["patching_device"] + model_args.monotonicity = patcher_args["monotonicity"] + + # Load the model with updated arguments + print("Loading model with updated arguments...") + model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device) + + # Configure model's patcher + model.patcher.realtime_patching = True + model.patcher.entropy_model_checkpoint_dir = os.path.join( + "hf-weights", "entropy_model" + ) + + # Create tokenizer + tokenizer = BltTokenizer( + vocab_size_unit_1=model_args.vocab_size, + add_bos=True, + add_eos=True + ) + + # Generate text + print("Generating text...") + prompts = [prompt] + outputs = generate_nocache( + prompts, + model=model, + tokenizer=tokenizer, + patcher=model.patcher, # Use the model's patcher + max_gen_len=100 + ) + + # Decode and print results + text_outputs = [tokenizer.decode(t) for t in outputs] + for p, t in zip(prompts, text_outputs): + print(f'Prompt: "{p}"') + print(f'Completion: "{t}"') + print() + + # Clean up + torch.cuda.empty_cache() + + +if __name__ == "__main__": + typer.run(main) +