mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-17 09:39:44 +00:00
Add on-device initialization
This commit is contained in:
parent
4ae7a62594
commit
4c5e51e4de
8 changed files with 232 additions and 66 deletions
|
@ -6,18 +6,18 @@ from enum import Enum
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
_mask_mod_signature,
|
||||
BlockMask,
|
||||
flex_attention,
|
||||
)
|
||||
from xformers.ops import AttentionBias, fmha
|
||||
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
try:
|
||||
|
@ -42,7 +42,7 @@ class InitStdFactor(str, Enum):
|
|||
|
||||
|
||||
class BaseTransformerArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
dim: int = 512
|
||||
n_layers: int = 8
|
||||
head_dim: int | None = None
|
||||
|
@ -68,6 +68,9 @@ class BaseTransformerArgs(BaseModel):
|
|||
# Special token config
|
||||
eos_id: int | None = EOS_ID
|
||||
|
||||
init_device: str = "cpu"
|
||||
init_dtype: torch.dtype = torch.float32
|
||||
|
||||
|
||||
def cross_entropy(pred, target, **kwargs):
|
||||
return F.nll_loss(
|
||||
|
@ -95,6 +98,7 @@ def precompute_freqs_cis(
|
|||
end: int,
|
||||
theta: float = 10000.0,
|
||||
rope_use_fp32_in_outer_product: bool = False,
|
||||
device: str | torch.device = torch.device("cpu"),
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
@ -111,7 +115,9 @@ def precompute_freqs_cis(
|
|||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
||||
"""
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
|
||||
)
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
if rope_use_fp32_in_outer_product:
|
||||
t = t.to(torch.float32)
|
||||
|
@ -258,6 +264,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
head_dim: int,
|
||||
max_seqlen: int = 1024,
|
||||
rope_use_fp32_in_outer_product: bool = False,
|
||||
device: str | torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -273,7 +281,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
end=max_seqlen,
|
||||
theta=theta,
|
||||
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
||||
),
|
||||
device=device,
|
||||
).to(dtype=dtype),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -325,6 +334,8 @@ class Attention(nn.Module):
|
|||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
rope_theta: float,
|
||||
device: str | torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -340,22 +351,30 @@ class Attention(nn.Module):
|
|||
dim,
|
||||
n_heads * head_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.wk = nn.Linear(
|
||||
dim,
|
||||
n_kv_heads * head_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.wv = nn.Linear(
|
||||
dim,
|
||||
n_kv_heads * head_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.wo = nn.Linear(
|
||||
n_heads * head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -368,6 +387,7 @@ class Attention(nn.Module):
|
|||
) -> torch.Tensor:
|
||||
# B S D
|
||||
bsz, seq_len, dim = x.shape
|
||||
|
||||
xq = self.wq(x.view_as(x))
|
||||
xk = self.wk(x.view_as(x))
|
||||
xv = self.wv(x.view_as(x))
|
||||
|
@ -453,6 +473,8 @@ class FeedForward(nn.Module):
|
|||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
mp_size: int = 1,
|
||||
device: str | torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -469,16 +491,22 @@ class FeedForward(nn.Module):
|
|||
dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = nn.Linear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -535,15 +563,24 @@ class TransformerBlock(nn.Module):
|
|||
n_heads=self.n_heads,
|
||||
n_kv_heads=self.n_kv_heads,
|
||||
rope_theta=args.rope_theta,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
multiple_of=args.multiple_of,
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
# Norms stay in full precision
|
||||
self.attention_norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
)
|
||||
self.ffn_norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -593,6 +630,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
|||
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||
max_seqlen=args.max_seqlen,
|
||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
self.eos_id = args.eos_id
|
||||
|
||||
|
|
|
@ -10,11 +10,13 @@ from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
|||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
|
||||
def load_entropy_model(
|
||||
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16
|
||||
):
|
||||
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||
reloaded = json.loads(fr.read())
|
||||
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
# torch.set_default_dtype(dtype)
|
||||
model_params = reloaded["entropy_model"]
|
||||
logger.warning(
|
||||
"Update checkpoint to load attn and sliding window args from checkpoint"
|
||||
|
@ -29,6 +31,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
|||
attn_bias_type="local_block_causal",
|
||||
attn_impl="xformers",
|
||||
sliding_window=512,
|
||||
init_device=device,
|
||||
init_dtype=dtype,
|
||||
)
|
||||
entropy_model = LMTransformer(entropy_model_args)
|
||||
|
||||
|
@ -38,6 +42,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
|||
entropy_model.to(device)
|
||||
entropy_model = entropy_model.eval()
|
||||
# no grads for the model:
|
||||
for param in entropy_model.parameters():
|
||||
for n, param in entropy_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return entropy_model, entropy_model_args
|
||||
|
|
|
@ -4,11 +4,6 @@ import os
|
|||
import time
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
from tqdm import tqdm
|
||||
|
||||
from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
|
||||
from bytelatent.base_transformer import (
|
||||
|
@ -19,9 +14,9 @@ from bytelatent.base_transformer import (
|
|||
lengths_to_start_ids,
|
||||
)
|
||||
from bytelatent.checkpoint import (
|
||||
consolidate_checkpoints,
|
||||
CONSOLIDATE_FOLDER,
|
||||
CONSOLIDATE_NAME,
|
||||
consolidate_checkpoints,
|
||||
)
|
||||
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||
from bytelatent.data.file_util import get_fs
|
||||
|
@ -33,6 +28,11 @@ from bytelatent.distributed import (
|
|||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||
from bytelatent.transformer import LMTransformer
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
|
@ -400,25 +400,33 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
|
|||
setup_torch_distributed(distributed_args)
|
||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||
fs = get_fs(train_args_path)
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
model = LMTransformer(model_args)
|
||||
else:
|
||||
model_args = train_args.model
|
||||
model = ByteLatentTransformer(model_args)
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||
train_args.distributed.model_dtype
|
||||
]
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
model_args.init_device = "cuda"
|
||||
model_args.init_dtype = param_dtype
|
||||
model = LMTransformer(model_args)
|
||||
else:
|
||||
model_args = train_args.model
|
||||
model_args.init_device = "cuda"
|
||||
model_args.init_dtype = param_dtype
|
||||
model = ByteLatentTransformer(args=model_args)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||
st_dict = torch.load(f, weights_only=True)
|
||||
|
||||
model.load_state_dict(st_dict["model"])
|
||||
model = model.cuda().eval()
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(dtype=param_dtype)
|
||||
|
||||
return model, tokenizer, train_args
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
from enum import Enum, auto
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from pydantic import model_validator
|
||||
from torch import nn
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
from typing_extensions import Self
|
||||
|
||||
from bytelatent.base_transformer import (
|
||||
BaseTransformerArgs,
|
||||
|
@ -18,8 +13,15 @@ from bytelatent.base_transformer import (
|
|||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||
from bytelatent.model.latent_transformer import GlobalTransformer
|
||||
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
|
||||
from bytelatent.model.utils import downsample
|
||||
from bytelatent.model.utils import check_param_device, downsample
|
||||
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from numpy.random import f
|
||||
from pydantic import model_validator
|
||||
from torch import nn
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
||||
|
@ -155,6 +157,9 @@ primes = [
|
|||
|
||||
|
||||
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
|
||||
if hash_func_nb >= len(primes):
|
||||
print(f"len(primes): {len(primes)}, hash_func_nb: {hash_func_nb}")
|
||||
|
||||
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
|
||||
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
|
||||
return torch.sum(t * prime_powers, dim=-1)
|
||||
|
@ -239,6 +244,9 @@ def create_patch_mask_from_ids(
|
|||
return mask
|
||||
|
||||
|
||||
GLOBAL = set()
|
||||
|
||||
|
||||
def cross_attn_mask(
|
||||
patch_ids,
|
||||
patch_lengths,
|
||||
|
@ -265,9 +273,36 @@ def cross_attn_mask(
|
|||
kv_len,
|
||||
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
|
||||
if block_mask:
|
||||
# This appears to resolve occasional nondeterministic RuntimeErrors
|
||||
# in the create_block_mask call. I have no idea why.
|
||||
cross_mask_copy = cross_mask.clone()
|
||||
|
||||
def patch_mask(b, h, q_idx, kv_idx):
|
||||
return cross_mask[b, q_idx, kv_idx]
|
||||
return cross_mask_copy[b, q_idx, kv_idx]
|
||||
|
||||
# print(f"cross_mask: {cross_mask.shape}")
|
||||
# print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}")
|
||||
# print(cross_mask[0, 0, 0])
|
||||
# for i in range(bs):
|
||||
# for j in range(q_len):
|
||||
# for k in range(kv_len):
|
||||
# y = cross_mask[i, j, k]
|
||||
|
||||
# import pickle
|
||||
|
||||
# with open("cross_mask.pkl", "wb") as f:
|
||||
# pickle.dump(cross_mask, f)
|
||||
|
||||
# global GLOBAL
|
||||
# s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}"
|
||||
# if s not in GLOBAL:
|
||||
# GLOBAL.add(s)
|
||||
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}")
|
||||
# else:
|
||||
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)")
|
||||
|
||||
# if q_len >= 51 and kv_len >= 96:
|
||||
# breakpoint()
|
||||
|
||||
block_mask = create_block_mask(
|
||||
patch_mask,
|
||||
|
@ -277,6 +312,9 @@ def cross_attn_mask(
|
|||
KV_LEN=kv_len,
|
||||
_compile=True,
|
||||
)
|
||||
|
||||
# print(f"block_mask_shape: {block_mask.shape}")
|
||||
|
||||
return block_mask
|
||||
else:
|
||||
return torch.where(
|
||||
|
@ -632,6 +670,8 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
|||
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
||||
cross_attn_nheads=args.cross_attn_nheads,
|
||||
eos_id=args.eos_id,
|
||||
init_device=args.init_device,
|
||||
init_dtype=args.init_dtype,
|
||||
)
|
||||
|
||||
return LocalEncoder(local_encoder_args)
|
||||
|
@ -675,6 +715,8 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
|||
cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
|
||||
cross_attn_nheads=args.cross_attn_nheads,
|
||||
eos_id=args.eos_id,
|
||||
init_device=args.init_device,
|
||||
init_dtype=args.init_dtype,
|
||||
)
|
||||
|
||||
return LocalDecoder(local_decoder_args)
|
||||
|
@ -710,6 +752,8 @@ def init_embeddings(
|
|||
nn.Embedding(
|
||||
encoder_hash_byte_group_vocab,
|
||||
emb_dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -718,7 +762,14 @@ def init_embeddings(
|
|||
emb_dim = local_encoder_dim
|
||||
OFFSET = 4 # This should be passed as parameter if it's variable
|
||||
for ngram_vocab_size in encoder_ngram_to_size.values():
|
||||
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
|
||||
embeddings.append(
|
||||
nn.Embedding(
|
||||
ngram_vocab_size + OFFSET,
|
||||
emb_dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return nn.ModuleList(embeddings)
|
||||
|
||||
|
@ -792,7 +843,7 @@ class ByteLatentTransformer(
|
|||
"""
|
||||
|
||||
def __init__(self, args: ByteLatentTransformerArgs):
|
||||
super().__init__()
|
||||
super(ByteLatentTransformer, self).__init__()
|
||||
|
||||
# General configuration
|
||||
self.weight_tying = args.weight_tying
|
||||
|
@ -854,7 +905,12 @@ class ByteLatentTransformer(
|
|||
ngram_emb_dim = self.local_encoder.dim
|
||||
for ngram_vocab_size in self.encoder_ngram_to_size.values():
|
||||
self.encoder_ngram_embedding.append(
|
||||
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
|
||||
nn.Embedding(
|
||||
ngram_vocab_size + OFFSET,
|
||||
ngram_emb_dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
# Output layer
|
||||
|
@ -873,6 +929,9 @@ class ByteLatentTransformer(
|
|||
)
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
check_param_device(self, args.init_device)
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
|
|
|
@ -5,9 +5,6 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from bytelatent.base_transformer import (
|
||||
BaseTransformer,
|
||||
|
@ -16,6 +13,9 @@ from bytelatent.base_transformer import (
|
|||
repeat_kv,
|
||||
)
|
||||
from bytelatent.model.utils import create_causal_mask
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
logger = logging.getLogger()
|
||||
try:
|
||||
|
@ -40,6 +40,8 @@ class CrossAttention(nn.Module):
|
|||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
device: str | torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -50,29 +52,39 @@ class CrossAttention(nn.Module):
|
|||
self.n_kv_heads = n_kv_heads
|
||||
self.heads_per_group = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
|
||||
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
||||
self.cross_attn_norm_q = nn.RMSNorm(
|
||||
dim, eps=norm_eps, device=device, dtype=dtype
|
||||
)
|
||||
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype)
|
||||
|
||||
self.wq = nn.Linear(
|
||||
dim,
|
||||
n_heads * head_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.wk = nn.Linear(
|
||||
dim,
|
||||
n_kv_heads * head_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.wv = nn.Linear(
|
||||
dim,
|
||||
n_kv_heads * head_dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.wo = nn.Linear(
|
||||
n_heads * head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -160,6 +172,8 @@ class GlobalTransformer(BaseTransformer):
|
|||
args.dim_token_emb,
|
||||
args.dim,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -6,10 +6,6 @@ from typing import Any, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn
|
||||
import torch.nn as nn
|
||||
from pydantic import ConfigDict
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from bytelatent.base_transformer import (
|
||||
BaseTransformerArgs,
|
||||
|
@ -20,6 +16,10 @@ from bytelatent.base_transformer import (
|
|||
from bytelatent.model.latent_transformer import CrossAttention
|
||||
from bytelatent.model.utils import create_causal_mask, downsample
|
||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||
from pydantic import ConfigDict
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
logger = logging.getLogger()
|
||||
try:
|
||||
|
@ -85,18 +85,31 @@ class LocalModelBase(nn.Module):
|
|||
)
|
||||
|
||||
if not self.use_rope:
|
||||
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
|
||||
self.pos_embeddings = nn.Embedding(
|
||||
args.max_length,
|
||||
args.dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
else:
|
||||
self.rope = RotaryEmbedding(
|
||||
theta=args.rope_theta,
|
||||
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||
max_seqlen=args.max_seqlen,
|
||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
self.pos_embeddings = None
|
||||
|
||||
self.token_embedding_projection = (
|
||||
nn.Linear(args.dim_token_emb, args.dim, bias=False)
|
||||
nn.Linear(
|
||||
args.dim_token_emb,
|
||||
args.dim,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
||||
else None
|
||||
)
|
||||
|
@ -125,6 +138,8 @@ class LocalModelBase(nn.Module):
|
|||
in_features=args.dim_patch_emb,
|
||||
out_features=output_dim,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
|
||||
def apply_embedding(self, tokens, embeds):
|
||||
|
@ -218,7 +233,9 @@ class LocalEncoder(LocalModelBase):
|
|||
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||
self.cross_attn_nheads = args.cross_attn_nheads
|
||||
|
||||
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
||||
self.tok_embeddings = nn.Embedding(
|
||||
self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
|
||||
)
|
||||
|
||||
if self.cross_attn_encoder:
|
||||
self.cross_attn_layers = torch.nn.ModuleList()
|
||||
|
@ -231,6 +248,8 @@ class LocalEncoder(LocalModelBase):
|
|||
n_heads=self.cross_attn_nheads,
|
||||
n_kv_heads=self.cross_attn_nheads,
|
||||
norm_eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -321,7 +340,9 @@ class LocalDecoder(LocalModelBase):
|
|||
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
||||
self.cross_attn_nheads = args.cross_attn_nheads
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
)
|
||||
|
||||
if self.cross_attn_decoder:
|
||||
self.cross_attn_layers = torch.nn.ModuleList()
|
||||
|
@ -334,6 +355,8 @@ class LocalDecoder(LocalModelBase):
|
|||
n_heads=self.cross_attn_nheads,
|
||||
n_kv_heads=self.cross_attn_nheads,
|
||||
norm_eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -341,6 +364,8 @@ class LocalDecoder(LocalModelBase):
|
|||
self.dim,
|
||||
args.vocab_size,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -175,3 +175,10 @@ def create_causal_mask(
|
|||
raise NotImplementedError(
|
||||
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
|
||||
)
|
||||
|
||||
|
||||
def check_param_device(model, device_type: str = "cpu"):
|
||||
for name, param in model.named_parameters():
|
||||
assert (
|
||||
param.device.type == device_type
|
||||
), f"Parameter {name} is on {param.device.type}, not on {device_type}"
|
||||
|
|
|
@ -4,25 +4,25 @@ import logging
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import nn
|
||||
from torch.distributed._tensor import Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
PrepareModuleInput,
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
parallelize_module,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from bytelatent.base_transformer import (
|
||||
BaseTransformer,
|
||||
BaseTransformerArgs,
|
||||
cross_entropy,
|
||||
)
|
||||
from bytelatent.model.utils import create_causal_mask
|
||||
from bytelatent.model.utils import check_param_device, create_causal_mask
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import nn
|
||||
from torch.distributed._tensor import Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
PrepareModuleInput,
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -84,19 +84,28 @@ class LMTransformer(
|
|||
|
||||
assert args.vocab_size > 0
|
||||
|
||||
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
||||
self.tok_embeddings = torch.nn.Embedding(
|
||||
args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
)
|
||||
|
||||
self.output = nn.Linear(
|
||||
args.dim,
|
||||
args.vocab_size,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
)
|
||||
|
||||
if args.weight_tying:
|
||||
self.output.weight = self.embeddings.tok_embeddings.weight
|
||||
|
||||
# Sanity check
|
||||
check_param_device(self, args.init_device)
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue