2024-12-12 23:32:30 +00:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
|
|
|
|
import logging
|
2025-01-17 22:23:01 +00:00
|
|
|
from typing import Any, List, Optional, Tuple, Union
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn
|
|
|
|
import torch.nn as nn
|
2025-01-17 22:23:01 +00:00
|
|
|
from pydantic import BaseModel, ConfigDict
|
2024-12-12 23:32:30 +00:00
|
|
|
from torch.nn import functional as F
|
|
|
|
from torch.nn.attention.flex_attention import BlockMask
|
|
|
|
from xformers.ops import AttentionBias
|
|
|
|
|
|
|
|
from bytelatent.base_transformer import (
|
2025-01-17 22:23:01 +00:00
|
|
|
BaseTransformerArgs,
|
2024-12-12 23:32:30 +00:00
|
|
|
InitStdFactor,
|
|
|
|
RMSNorm,
|
|
|
|
RotaryEmbedding,
|
|
|
|
TransformerBlock,
|
|
|
|
)
|
2025-01-17 22:23:01 +00:00
|
|
|
from bytelatent.model.latent_transformer import CrossAttention
|
2024-12-12 23:32:30 +00:00
|
|
|
from bytelatent.model.utils import create_causal_mask, downsample
|
|
|
|
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
|
|
|
2025-01-17 22:23:01 +00:00
|
|
|
class LocalModelArgs(BaseTransformerArgs):
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
# Override defaults
|
|
|
|
attn_impl: str | None = "xformers"
|
|
|
|
attn_bias_type: str | None = "local_block_causal"
|
|
|
|
|
|
|
|
# Local encoder specific dimensions
|
|
|
|
dropout: float
|
|
|
|
vocab_size: int
|
|
|
|
patch_size: int
|
|
|
|
sliding_window: int | None
|
|
|
|
use_rope: bool
|
|
|
|
cross_attn_encoder: bool | None
|
|
|
|
cross_attn_decoder: bool | None
|
|
|
|
cross_attn_k: int | None
|
|
|
|
cross_attn_init_by_pooling: bool
|
|
|
|
patching_mode: str
|
|
|
|
use_local_encoder_transformer: bool
|
|
|
|
downsampling_by_pooling: str | None
|
|
|
|
encoder_hash_byte_group_size: Any | None = None
|
|
|
|
cross_attn_all_layers_encoder: bool = False
|
|
|
|
cross_attn_all_layers_decoder: bool = False
|
|
|
|
cross_attn_nheads: int | None
|
|
|
|
|
|
|
|
dim_token_emb: int
|
|
|
|
dim_patch_emb: int | None
|
|
|
|
|
|
|
|
|
2024-12-12 23:32:30 +00:00
|
|
|
class LocalModelBase(nn.Module):
|
2025-01-17 22:23:01 +00:00
|
|
|
def __init__(self, args: LocalModelArgs):
|
2024-12-12 23:32:30 +00:00
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.dim = args.dim
|
|
|
|
self.dropout = args.dropout
|
2025-01-17 22:23:01 +00:00
|
|
|
self.vocab_size = args.vocab_size
|
2024-12-12 23:32:30 +00:00
|
|
|
self.patch_size = args.patch_size
|
|
|
|
|
2025-01-17 22:23:01 +00:00
|
|
|
self.attn_impl = args.attn_impl
|
2024-12-12 23:32:30 +00:00
|
|
|
self.sliding_window = args.sliding_window
|
|
|
|
self.use_rope = args.use_rope
|
|
|
|
self.init_std_factor = args.init_std_factor
|
|
|
|
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
|
|
|
|
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
|
|
|
|
self.cross_attn_k = getattr(args, "cross_attn_k", None)
|
2025-01-17 22:23:01 +00:00
|
|
|
self.eos_id = args.eos_id
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
self.boe_id = BOE_ID
|
|
|
|
|
|
|
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
|
|
self.layers = nn.ModuleList(
|
|
|
|
[TransformerBlock(args) for _ in range(args.n_layers)]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
|
|
|
if not self.use_rope:
|
|
|
|
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
|
|
|
|
else:
|
|
|
|
self.rope = RotaryEmbedding(
|
|
|
|
theta=args.rope_theta,
|
|
|
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
2025-01-17 22:23:01 +00:00
|
|
|
max_seqlen=args.max_seqlen,
|
2024-12-12 23:32:30 +00:00
|
|
|
)
|
|
|
|
self.pos_embeddings = None
|
|
|
|
|
|
|
|
self.token_embedding_projection = (
|
|
|
|
nn.Linear(args.dim_token_emb, args.dim, bias=False)
|
|
|
|
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
|
|
|
|
self.patch_embedding_projection = self._create_patch_projection(args)
|
|
|
|
|
2025-01-17 22:23:01 +00:00
|
|
|
def _should_create_patch_projection(self, args: LocalModelArgs):
|
2024-12-12 23:32:30 +00:00
|
|
|
dimension_mismatch = (
|
|
|
|
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
|
|
|
|
)
|
|
|
|
|
|
|
|
# Check cross attention conditions
|
|
|
|
cross_attn_conditions = (
|
2025-01-17 22:23:01 +00:00
|
|
|
args.cross_attn_encoder and args.cross_attn_init_by_pooling
|
|
|
|
) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
return dimension_mismatch or cross_attn_conditions
|
|
|
|
|
|
|
|
def _create_patch_projection(self, args):
|
|
|
|
if not self._should_create_patch_projection(args):
|
|
|
|
return None
|
|
|
|
|
|
|
|
output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
|
|
|
|
|
|
|
|
return nn.Linear(
|
|
|
|
in_features=args.dim_patch_emb,
|
|
|
|
out_features=output_dim,
|
|
|
|
bias=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
def apply_embedding(self, tokens, embeds):
|
|
|
|
if embeds is not None:
|
|
|
|
return embeds
|
|
|
|
else:
|
|
|
|
return self.tok_embeddings(tokens)
|
|
|
|
|
|
|
|
def init_weights(self, init_std=None):
|
|
|
|
self.rope.reset_parameters()
|
|
|
|
|
|
|
|
init_std = init_std or (self.dim ** (-0.5))
|
|
|
|
nn.init.trunc_normal_(
|
|
|
|
self.tok_embeddings.weight,
|
|
|
|
mean=0.0,
|
|
|
|
std=init_std,
|
|
|
|
a=-3 * init_std,
|
|
|
|
b=3 * init_std,
|
|
|
|
)
|
|
|
|
if self.pos_embeddings is not None:
|
|
|
|
nn.init.trunc_normal_(
|
|
|
|
self.pos_embeddings.weight,
|
|
|
|
mean=0.0,
|
|
|
|
std=init_std,
|
|
|
|
a=-3 * init_std,
|
|
|
|
b=3 * init_std,
|
|
|
|
)
|
|
|
|
|
|
|
|
for depth, layer in enumerate(self.layers):
|
|
|
|
factor = {
|
|
|
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
|
|
|
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
|
|
|
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
|
|
|
InitStdFactor.DISABLED: 1.0,
|
|
|
|
}[self.init_std_factor]
|
|
|
|
|
|
|
|
layer.init_weights(init_std, factor)
|
|
|
|
|
|
|
|
if self.token_embedding_projection is not None:
|
|
|
|
nn.init.trunc_normal_(
|
|
|
|
self.token_embedding_projection.weight,
|
|
|
|
mean=0.0,
|
|
|
|
std=init_std,
|
|
|
|
a=-3 * init_std,
|
|
|
|
b=3 * init_std,
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.patch_embedding_projection is not None:
|
|
|
|
nn.init.trunc_normal_(
|
|
|
|
self.patch_embedding_projection.weight,
|
|
|
|
mean=0.0,
|
|
|
|
std=init_std,
|
|
|
|
a=-3 * init_std,
|
|
|
|
b=3 * init_std,
|
|
|
|
)
|
|
|
|
|
|
|
|
if hasattr(self, "output"):
|
|
|
|
nn.init.trunc_normal_(
|
|
|
|
self.output.weight,
|
|
|
|
mean=0.0,
|
|
|
|
std=init_std,
|
|
|
|
a=-3 * init_std,
|
|
|
|
b=3 * init_std,
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.cross_attn_layers is not None:
|
|
|
|
for depth, layer in enumerate(self.cross_attn_layers):
|
|
|
|
factor = {
|
|
|
|
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
|
|
|
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
|
|
|
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
|
|
|
InitStdFactor.DISABLED: 1.0,
|
|
|
|
}[self.init_std_factor]
|
|
|
|
|
|
|
|
layer.init_weights(init_std, factor)
|
|
|
|
|
|
|
|
|
|
|
|
class LocalEncoder(LocalModelBase):
|
2025-01-17 22:23:01 +00:00
|
|
|
def __init__(self, args: LocalModelArgs):
|
2024-12-12 23:32:30 +00:00
|
|
|
super().__init__(args)
|
|
|
|
self.output_proj = (
|
|
|
|
args.patching_mode in ["entropy", "probmax"]
|
|
|
|
) and args.entropy_model_checkpoint_dir is None
|
|
|
|
|
|
|
|
self.apply_transformer = args.use_local_encoder_transformer
|
|
|
|
self.downsampling_by_pooling = args.downsampling_by_pooling
|
|
|
|
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
|
|
|
|
self.cross_attn_encoder = args.cross_attn_encoder
|
|
|
|
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
|
|
|
|
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
|
|
|
self.cross_attn_nheads = args.cross_attn_nheads
|
|
|
|
|
|
|
|
if self.cross_attn_encoder:
|
|
|
|
self.cross_attn_layers = torch.nn.ModuleList()
|
|
|
|
layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
|
|
|
|
for _ in range(layers_to_add):
|
|
|
|
self.cross_attn_layers.append(
|
|
|
|
CrossAttention(
|
|
|
|
dim=self.dim,
|
|
|
|
head_dim=self.dim // self.cross_attn_nheads,
|
|
|
|
n_heads=self.cross_attn_nheads,
|
|
|
|
n_kv_heads=self.cross_attn_nheads,
|
|
|
|
norm_eps=args.norm_eps,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
def apply_embedding(self, tokens, embeds):
|
|
|
|
if embeds is not None:
|
|
|
|
assert (
|
|
|
|
self.expects_hash_embeddings
|
|
|
|
), "Not expecting embeddings to be passed."
|
|
|
|
return embeds
|
|
|
|
else:
|
|
|
|
return self.tok_embeddings(tokens)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
tokens: torch.Tensor,
|
|
|
|
embeds: Optional[torch.Tensor] = None,
|
|
|
|
patch_embeds: Optional[torch.Tensor] = None,
|
|
|
|
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
|
|
|
|
cross_mask: Optional[torch.Tensor] = None,
|
|
|
|
num_patches: Optional[int] = None,
|
|
|
|
patch_ids: Optional[torch.Tensor] = None,
|
|
|
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
|
|
|
):
|
|
|
|
""" """
|
|
|
|
bs, seqlen = tokens.shape
|
|
|
|
if mask is None:
|
2025-01-17 22:23:01 +00:00
|
|
|
mask = create_causal_mask(
|
|
|
|
seqlen,
|
|
|
|
self.attn_impl,
|
|
|
|
"local_block_causal",
|
|
|
|
sliding_window=self.sliding_window,
|
|
|
|
tokens=tokens,
|
|
|
|
eos_id=self.eos_id,
|
|
|
|
)
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
h = self.apply_embedding(tokens, embeds)
|
|
|
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
|
|
|
|
|
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
|
|
|
|
|
|
for i, layer in enumerate(self.layers):
|
2025-01-17 22:23:01 +00:00
|
|
|
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
|
2024-12-12 23:32:30 +00:00
|
|
|
# check if cross attention should be applied to either all layer or only the last layer
|
|
|
|
if self.cross_attn_encoder and (
|
|
|
|
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
|
|
|
|
):
|
|
|
|
patch_embeds = self.apply_cross_attention(
|
|
|
|
h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
|
|
|
|
)
|
|
|
|
|
|
|
|
h_residual = patch_embeds if self.cross_attn_encoder else None
|
|
|
|
return (h, h_residual), cache
|
|
|
|
|
|
|
|
def apply_cross_attention(
|
|
|
|
self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
|
|
|
|
):
|
|
|
|
# apply pooling and project
|
|
|
|
if self.cross_attn_init_by_pooling and patch_embeds is None:
|
|
|
|
patch_embeds = downsample(
|
|
|
|
h,
|
|
|
|
num_patches,
|
|
|
|
patch_ids=patch_ids,
|
|
|
|
downsampling_by_pooling=self.downsampling_by_pooling,
|
|
|
|
patch_size=self.patch_size,
|
|
|
|
)
|
|
|
|
if self.patch_embedding_projection is not None:
|
|
|
|
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
|
|
|
patch_embeds = patch_embeds.reshape(
|
|
|
|
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
|
|
|
|
)
|
|
|
|
|
|
|
|
layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
|
|
|
|
patch_embeds_cross = self.cross_attn_layers[layer_idx](
|
|
|
|
x=patch_embeds,
|
|
|
|
kv=h,
|
|
|
|
mask=cross_mask,
|
|
|
|
)
|
|
|
|
patch_embeds += patch_embeds_cross
|
|
|
|
return patch_embeds
|
|
|
|
|
|
|
|
|
|
|
|
class LocalDecoder(LocalModelBase):
|
2025-01-17 22:23:01 +00:00
|
|
|
def __init__(self, args: LocalModelArgs):
|
2024-12-12 23:32:30 +00:00
|
|
|
super().__init__(args)
|
|
|
|
|
|
|
|
# Model configuration flags
|
|
|
|
self.cross_attn_decoder = args.cross_attn_decoder
|
|
|
|
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
|
|
|
|
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
|
|
|
self.cross_attn_nheads = args.cross_attn_nheads
|
|
|
|
|
|
|
|
if self.cross_attn_decoder:
|
|
|
|
self.cross_attn_layers = torch.nn.ModuleList()
|
|
|
|
layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
|
|
|
|
for _ in range(layers_to_add):
|
|
|
|
self.cross_attn_layers.append(
|
|
|
|
CrossAttention(
|
|
|
|
dim=self.dim,
|
|
|
|
head_dim=self.dim // self.cross_attn_nheads,
|
|
|
|
n_heads=self.cross_attn_nheads,
|
|
|
|
n_kv_heads=self.cross_attn_nheads,
|
|
|
|
norm_eps=args.norm_eps,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.output = nn.Linear(
|
|
|
|
self.dim,
|
|
|
|
args.vocab_size,
|
|
|
|
bias=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
tokens: torch.Tensor,
|
|
|
|
embeds: Optional[torch.Tensor],
|
|
|
|
patch_embeds: Optional[torch.Tensor] = None,
|
|
|
|
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
|
|
|
|
cross_mask: Optional[torch.Tensor] = None,
|
|
|
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
|
|
|
):
|
|
|
|
bs, seqlen = tokens.shape
|
|
|
|
assert embeds is not None, "Embeddings must be provided"
|
|
|
|
|
|
|
|
if mask is None:
|
2025-01-17 22:23:01 +00:00
|
|
|
mask = create_causal_mask(
|
|
|
|
seqlen,
|
|
|
|
self.attn_impl,
|
|
|
|
"local_block_causal",
|
|
|
|
sliding_window=self.sliding_window,
|
|
|
|
tokens=tokens,
|
|
|
|
eos_id=self.eos_id,
|
|
|
|
)
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
h = embeds
|
|
|
|
|
|
|
|
if self.patch_embedding_projection is not None:
|
|
|
|
assert patch_embeds is not None, "Patch embeddings must be passed."
|
|
|
|
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
|
|
|
if self.cross_attn_k is not None:
|
|
|
|
patch_embeds = patch_embeds.reshape(
|
|
|
|
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
|
|
|
|
)
|
|
|
|
|
|
|
|
if patch_embeds is not None and not self.cross_attn_decoder:
|
|
|
|
h = h + patch_embeds
|
|
|
|
|
|
|
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
|
|
|
|
|
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
|
if self.cross_attn_decoder and (
|
|
|
|
i == 0 or self.cross_attn_all_layers_decoder
|
|
|
|
):
|
|
|
|
# Use cross attention to extract info from patch_embeds into h
|
|
|
|
h_cross = self.cross_attn_layers[i](
|
|
|
|
x=h,
|
|
|
|
kv=patch_embeds,
|
|
|
|
mask=cross_mask,
|
|
|
|
)
|
|
|
|
h = h + h_cross
|
|
|
|
|
2025-01-17 22:23:01 +00:00
|
|
|
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
h_preds = self.norm(h)
|
|
|
|
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
|
|
|
|
h_preds = self.output(h_preds)
|
|
|
|
h_preds = h_preds.float()
|
|
|
|
return h_preds, cache
|