Add on-device initialization

This commit is contained in:
Gustaf Ahdritz 2025-06-05 15:19:43 -07:00
parent 4ae7a62594
commit 4c5e51e4de
8 changed files with 232 additions and 66 deletions

View file

@ -6,18 +6,18 @@ from enum import Enum
from typing import Optional, Tuple, Union
import torch
from bytelatent.tokenizers.constants import EOS_ID
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
BlockMask,
_mask_mod_signature,
BlockMask,
flex_attention,
)
from xformers.ops import AttentionBias, fmha
from bytelatent.tokenizers.constants import EOS_ID
logger = logging.getLogger()
try:
@ -42,7 +42,7 @@ class InitStdFactor(str, Enum):
class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
dim: int = 512
n_layers: int = 8
head_dim: int | None = None
@ -68,6 +68,9 @@ class BaseTransformerArgs(BaseModel):
# Special token config
eos_id: int | None = EOS_ID
init_device: str = "cpu"
init_dtype: torch.dtype = torch.float32
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
@ -95,6 +98,7 @@ def precompute_freqs_cis(
end: int,
theta: float = 10000.0,
rope_use_fp32_in_outer_product: bool = False,
device: str | torch.device = torch.device("cpu"),
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@ -111,7 +115,9 @@ def precompute_freqs_cis(
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device)
if rope_use_fp32_in_outer_product:
t = t.to(torch.float32)
@ -258,6 +264,8 @@ class RotaryEmbedding(torch.nn.Module):
head_dim: int,
max_seqlen: int = 1024,
rope_use_fp32_in_outer_product: bool = False,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()
@ -273,7 +281,8 @@ class RotaryEmbedding(torch.nn.Module):
end=max_seqlen,
theta=theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
),
device=device,
).to(dtype=dtype),
persistent=False,
)
@ -325,6 +334,8 @@ class Attention(nn.Module):
n_heads: int,
n_kv_heads: int,
rope_theta: float,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()
@ -340,22 +351,30 @@ class Attention(nn.Module):
dim,
n_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
device=device,
dtype=dtype,
)
def forward(
@ -368,6 +387,7 @@ class Attention(nn.Module):
) -> torch.Tensor:
# B S D
bsz, seq_len, dim = x.shape
xq = self.wq(x.view_as(x))
xk = self.wk(x.view_as(x))
xv = self.wv(x.view_as(x))
@ -453,6 +473,8 @@ class FeedForward(nn.Module):
multiple_of: int,
ffn_dim_multiplier: Optional[float],
mp_size: int = 1,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()
@ -469,16 +491,22 @@ class FeedForward(nn.Module):
dim,
hidden_dim,
bias=False,
device=device,
dtype=dtype,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
device=device,
dtype=dtype,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
device=device,
dtype=dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -535,15 +563,24 @@ class TransformerBlock(nn.Module):
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta,
device=args.init_device,
dtype=args.init_dtype,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
device=args.init_device,
dtype=args.init_dtype,
)
# Norms stay in full precision
self.attention_norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
)
self.ffn_norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
@ -593,6 +630,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device,
dtype=args.init_dtype,
)
self.eos_id = args.eos_id

View file

@ -10,11 +10,13 @@ from bytelatent.transformer import LMTransformer, LMTransformerArgs
logger = logging.getLogger()
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
def load_entropy_model(
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16
):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())
torch.set_default_dtype(torch.bfloat16)
# torch.set_default_dtype(dtype)
model_params = reloaded["entropy_model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
@ -29,6 +31,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
init_device=device,
init_dtype=dtype,
)
entropy_model = LMTransformer(entropy_model_args)
@ -38,6 +42,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
entropy_model.to(device)
entropy_model = entropy_model.eval()
# no grads for the model:
for param in entropy_model.parameters():
for n, param in entropy_model.named_parameters():
param.requires_grad = False
return entropy_model, entropy_model_args

View file

@ -4,11 +4,6 @@ import os
import time
import torch
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm
from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
from bytelatent.base_transformer import (
@ -19,9 +14,9 @@ from bytelatent.base_transformer import (
lengths_to_start_ids,
)
from bytelatent.checkpoint import (
consolidate_checkpoints,
CONSOLIDATE_FOLDER,
CONSOLIDATE_NAME,
consolidate_checkpoints,
)
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
@ -33,6 +28,11 @@ from bytelatent.distributed import (
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer
from omegaconf import OmegaConf
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
@ -400,25 +400,33 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
setup_torch_distributed(distributed_args)
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
if train_args.train_entropy_model:
model_args = train_args.entropy_model
model = LMTransformer(model_args)
else:
model_args = train_args.model
model = ByteLatentTransformer(model_args)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
train_args.distributed.model_dtype
]
if train_args.train_entropy_model:
model_args = train_args.entropy_model
model_args.init_device = "cuda"
model_args.init_dtype = param_dtype
model = LMTransformer(model_args)
else:
model_args = train_args.model
model_args.init_device = "cuda"
model_args.init_dtype = param_dtype
model = ByteLatentTransformer(args=model_args)
model = model.eval()
tokenizer = train_args.data.tokenizer_args.build()
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():
param.data = param.data.to(dtype=param_dtype)
return model, tokenizer, train_args

View file

@ -1,14 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum, auto
from enum import auto, Enum
from typing import Any, Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from pydantic import model_validator
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from typing_extensions import Self
from bytelatent.base_transformer import (
BaseTransformerArgs,
@ -18,8 +13,15 @@ from bytelatent.base_transformer import (
from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.latent_transformer import GlobalTransformer
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.utils import downsample
from bytelatent.model.utils import check_param_device, downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
from huggingface_hub import PyTorchModelHubMixin
from numpy.random import f
from pydantic import model_validator
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from typing_extensions import Self
def attention_flops_per_token(n_layers, seq_len, dim, causal):
@ -155,6 +157,9 @@ primes = [
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
if hash_func_nb >= len(primes):
print(f"len(primes): {len(primes)}, hash_func_nb: {hash_func_nb}")
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
return torch.sum(t * prime_powers, dim=-1)
@ -239,6 +244,9 @@ def create_patch_mask_from_ids(
return mask
GLOBAL = set()
def cross_attn_mask(
patch_ids,
patch_lengths,
@ -265,9 +273,36 @@ def cross_attn_mask(
kv_len,
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
if block_mask:
# This appears to resolve occasional nondeterministic RuntimeErrors
# in the create_block_mask call. I have no idea why.
cross_mask_copy = cross_mask.clone()
def patch_mask(b, h, q_idx, kv_idx):
return cross_mask[b, q_idx, kv_idx]
return cross_mask_copy[b, q_idx, kv_idx]
# print(f"cross_mask: {cross_mask.shape}")
# print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}")
# print(cross_mask[0, 0, 0])
# for i in range(bs):
# for j in range(q_len):
# for k in range(kv_len):
# y = cross_mask[i, j, k]
# import pickle
# with open("cross_mask.pkl", "wb") as f:
# pickle.dump(cross_mask, f)
# global GLOBAL
# s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}"
# if s not in GLOBAL:
# GLOBAL.add(s)
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}")
# else:
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)")
# if q_len >= 51 and kv_len >= 96:
# breakpoint()
block_mask = create_block_mask(
patch_mask,
@ -277,6 +312,9 @@ def cross_attn_mask(
KV_LEN=kv_len,
_compile=True,
)
# print(f"block_mask_shape: {block_mask.shape}")
return block_mask
else:
return torch.where(
@ -632,6 +670,8 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
init_device=args.init_device,
init_dtype=args.init_dtype,
)
return LocalEncoder(local_encoder_args)
@ -675,6 +715,8 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
cross_attn_nheads=args.cross_attn_nheads,
eos_id=args.eos_id,
init_device=args.init_device,
init_dtype=args.init_dtype,
)
return LocalDecoder(local_decoder_args)
@ -710,6 +752,8 @@ def init_embeddings(
nn.Embedding(
encoder_hash_byte_group_vocab,
emb_dim,
device=args.init_device,
dtype=args.init_dtype,
)
)
@ -718,7 +762,14 @@ def init_embeddings(
emb_dim = local_encoder_dim
OFFSET = 4 # This should be passed as parameter if it's variable
for ngram_vocab_size in encoder_ngram_to_size.values():
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
embeddings.append(
nn.Embedding(
ngram_vocab_size + OFFSET,
emb_dim,
device=args.init_device,
dtype=args.init_dtype,
)
)
return nn.ModuleList(embeddings)
@ -792,7 +843,7 @@ class ByteLatentTransformer(
"""
def __init__(self, args: ByteLatentTransformerArgs):
super().__init__()
super(ByteLatentTransformer, self).__init__()
# General configuration
self.weight_tying = args.weight_tying
@ -854,7 +905,12 @@ class ByteLatentTransformer(
ngram_emb_dim = self.local_encoder.dim
for ngram_vocab_size in self.encoder_ngram_to_size.values():
self.encoder_ngram_embedding.append(
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
nn.Embedding(
ngram_vocab_size + OFFSET,
ngram_emb_dim,
device=args.init_device,
dtype=args.init_dtype,
)
)
# Output layer
@ -873,6 +929,9 @@ class ByteLatentTransformer(
)
)
# Sanity check
check_param_device(self, args.init_device)
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."

View file

@ -5,9 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
@ -16,6 +13,9 @@ from bytelatent.base_transformer import (
repeat_kv,
)
from bytelatent.model.utils import create_causal_mask
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
logger = logging.getLogger()
try:
@ -40,6 +40,8 @@ class CrossAttention(nn.Module):
n_heads: int,
n_kv_heads: int,
norm_eps: float,
device: str | torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
super().__init__()
@ -50,29 +52,39 @@ class CrossAttention(nn.Module):
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_q = nn.RMSNorm(
dim, eps=norm_eps, device=device, dtype=dtype
)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype)
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
device=device,
dtype=dtype,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
device=device,
dtype=dtype,
)
def forward(
@ -160,6 +172,8 @@ class GlobalTransformer(BaseTransformer):
args.dim_token_emb,
args.dim,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
)
def forward(

View file

@ -6,10 +6,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from pydantic import ConfigDict
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformerArgs,
@ -20,6 +16,10 @@ from bytelatent.base_transformer import (
from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
from pydantic import ConfigDict
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
logger = logging.getLogger()
try:
@ -85,18 +85,31 @@ class LocalModelBase(nn.Module):
)
if not self.use_rope:
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
self.pos_embeddings = nn.Embedding(
args.max_length,
args.dim,
device=args.init_device,
dtype=args.init_dtype,
)
else:
self.rope = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device,
dtype=args.init_dtype,
)
self.pos_embeddings = None
self.token_embedding_projection = (
nn.Linear(args.dim_token_emb, args.dim, bias=False)
nn.Linear(
args.dim_token_emb,
args.dim,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
)
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
else None
)
@ -125,6 +138,8 @@ class LocalModelBase(nn.Module):
in_features=args.dim_patch_emb,
out_features=output_dim,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
)
def apply_embedding(self, tokens, embeds):
@ -218,7 +233,9 @@ class LocalEncoder(LocalModelBase):
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
self.cross_attn_nheads = args.cross_attn_nheads
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
self.tok_embeddings = nn.Embedding(
self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
)
if self.cross_attn_encoder:
self.cross_attn_layers = torch.nn.ModuleList()
@ -231,6 +248,8 @@ class LocalEncoder(LocalModelBase):
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
device=args.init_device,
dtype=args.init_dtype,
)
)
@ -321,7 +340,9 @@ class LocalDecoder(LocalModelBase):
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
self.cross_attn_nheads = args.cross_attn_nheads
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
)
if self.cross_attn_decoder:
self.cross_attn_layers = torch.nn.ModuleList()
@ -334,6 +355,8 @@ class LocalDecoder(LocalModelBase):
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
device=args.init_device,
dtype=args.init_dtype,
)
)
@ -341,6 +364,8 @@ class LocalDecoder(LocalModelBase):
self.dim,
args.vocab_size,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
)
def forward(

View file

@ -175,3 +175,10 @@ def create_causal_mask(
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
def check_param_device(model, device_type: str = "cpu"):
for name, param in model.named_parameters():
assert (
param.device.type == device_type
), f"Parameter {name} is on {param.device.type}, not on {device_type}"

View file

@ -4,25 +4,25 @@ import logging
from typing import Optional, Tuple, Union
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
BaseTransformerArgs,
cross_entropy,
)
from bytelatent.model.utils import create_causal_mask
from bytelatent.model.utils import check_param_device, create_causal_mask
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from xformers.ops import AttentionBias
logger = logging.getLogger()
@ -84,19 +84,28 @@ class LMTransformer(
assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
self.tok_embeddings = torch.nn.Embedding(
args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
)
if args.weight_tying:
self.output.weight = self.embeddings.tok_embeddings.weight
# Sanity check
check_param_device(self, args.init_device)
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."