diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index d947040..0b78c92 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -6,18 +6,18 @@ from enum import Enum from typing import Optional, Tuple, Union import torch + +from bytelatent.tokenizers.constants import EOS_ID from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( - BlockMask, _mask_mod_signature, + BlockMask, flex_attention, ) from xformers.ops import AttentionBias, fmha -from bytelatent.tokenizers.constants import EOS_ID - logger = logging.getLogger() try: @@ -42,7 +42,7 @@ class InitStdFactor(str, Enum): class BaseTransformerArgs(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) dim: int = 512 n_layers: int = 8 head_dim: int | None = None @@ -68,6 +68,9 @@ class BaseTransformerArgs(BaseModel): # Special token config eos_id: int | None = EOS_ID + init_device: str = "cpu" + init_dtype: torch.dtype = torch.float32 + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -95,6 +98,7 @@ def precompute_freqs_cis( end: int, theta: float = 10000.0, rope_use_fp32_in_outer_product: bool = False, + device: str | torch.device = torch.device("cpu"), ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -111,7 +115,9 @@ def precompute_freqs_cis( Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + ) t = torch.arange(end, device=freqs.device) if rope_use_fp32_in_outer_product: t = t.to(torch.float32) @@ -258,6 +264,8 @@ class RotaryEmbedding(torch.nn.Module): head_dim: int, max_seqlen: int = 1024, rope_use_fp32_in_outer_product: bool = False, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -273,7 +281,8 @@ class RotaryEmbedding(torch.nn.Module): end=max_seqlen, theta=theta, rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ), + device=device, + ).to(dtype=dtype), persistent=False, ) @@ -325,6 +334,8 @@ class Attention(nn.Module): n_heads: int, n_kv_heads: int, rope_theta: float, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -340,22 +351,30 @@ class Attention(nn.Module): dim, n_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, + device=device, + dtype=dtype, ) def forward( @@ -368,6 +387,7 @@ class Attention(nn.Module): ) -> torch.Tensor: # B S D bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) xk = self.wk(x.view_as(x)) xv = self.wv(x.view_as(x)) @@ -453,6 +473,8 @@ class FeedForward(nn.Module): multiple_of: int, ffn_dim_multiplier: Optional[float], mp_size: int = 1, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -469,16 +491,22 @@ class FeedForward(nn.Module): dim, hidden_dim, bias=False, + device=device, + dtype=dtype, ) self.w3 = nn.Linear( dim, hidden_dim, bias=False, + device=device, + dtype=dtype, ) self.w2 = nn.Linear( hidden_dim, dim, bias=False, + device=device, + dtype=dtype, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -535,15 +563,24 @@ class TransformerBlock(nn.Module): n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, rope_theta=args.rope_theta, + device=args.init_device, + dtype=args.init_dtype, ) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, + device=args.init_device, + dtype=args.init_dtype, + ) + # Norms stay in full precision + self.attention_norm = RMSNorm( + args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + ) + self.ffn_norm = RMSNorm( + args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype ) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward( self, @@ -593,6 +630,8 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + device=args.init_device, + dtype=args.init_dtype, ) self.eos_id = args.eos_id diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 51973e2..e90e4f0 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -10,11 +10,13 @@ from bytelatent.transformer import LMTransformer, LMTransformerArgs logger = logging.getLogger() -def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): +def load_entropy_model( + entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16 +): with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) - torch.set_default_dtype(torch.bfloat16) + # torch.set_default_dtype(dtype) model_params = reloaded["entropy_model"] logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" @@ -29,6 +31,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp attn_bias_type="local_block_causal", attn_impl="xformers", sliding_window=512, + init_device=device, + init_dtype=dtype, ) entropy_model = LMTransformer(entropy_model_args) @@ -38,6 +42,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp entropy_model.to(device) entropy_model = entropy_model.eval() # no grads for the model: - for param in entropy_model.parameters(): + for n, param in entropy_model.named_parameters(): param.requires_grad = False + return entropy_model, entropy_model_args diff --git a/bytelatent/generate.py b/bytelatent/generate.py index 97434dc..db73ba3 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -4,11 +4,6 @@ import os import time import torch -from omegaconf import OmegaConf -from torch import nn -from torch.nn import functional as F -from torch.nn.attention.flex_attention import create_block_mask -from tqdm import tqdm from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( @@ -19,9 +14,9 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import ( + consolidate_checkpoints, CONSOLIDATE_FOLDER, CONSOLIDATE_NAME, - consolidate_checkpoints, ) from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs @@ -33,6 +28,11 @@ from bytelatent.distributed import ( from bytelatent.model.blt import ByteLatentTransformer from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.transformer import LMTransformer +from omegaconf import OmegaConf +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import create_block_mask +from tqdm import tqdm def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -400,25 +400,33 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa setup_torch_distributed(distributed_args) train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) - if train_args.train_entropy_model: - model_args = train_args.entropy_model - model = LMTransformer(model_args) - else: - model_args = train_args.model - model = ByteLatentTransformer(model_args) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ train_args.distributed.model_dtype ] + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model_args.init_device = "cuda" + model_args.init_dtype = param_dtype + model = LMTransformer(model_args) + else: + model_args = train_args.model + model_args.init_device = "cuda" + model_args.init_dtype = param_dtype + model = ByteLatentTransformer(args=model_args) + + model = model.eval() + tokenizer = train_args.data.tokenizer_args.build() + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: st_dict = torch.load(f, weights_only=True) + model.load_state_dict(st_dict["model"]) - model = model.cuda().eval() - for param in model.parameters(): - param.data = param.data.to(dtype=param_dtype) + return model, tokenizer, train_args diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 26934bb..f649433 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -1,14 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from enum import Enum, auto +from enum import auto, Enum from typing import Any, Optional import torch -from huggingface_hub import PyTorchModelHubMixin -from pydantic import model_validator -from torch import nn -from torch.nn.attention.flex_attention import create_block_mask -from typing_extensions import Self from bytelatent.base_transformer import ( BaseTransformerArgs, @@ -18,8 +13,15 @@ from bytelatent.base_transformer import ( from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.model.latent_transformer import GlobalTransformer from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs -from bytelatent.model.utils import downsample +from bytelatent.model.utils import check_param_device, downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID +from huggingface_hub import PyTorchModelHubMixin + +from numpy.random import f +from pydantic import model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask +from typing_extensions import Self def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -155,6 +157,9 @@ primes = [ def rolling_polynomial_hash(t, hash_func_nb: int = 0): + if hash_func_nb >= len(primes): + print(f"len(primes): {len(primes)}, hash_func_nb: {hash_func_nb}") + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) return torch.sum(t * prime_powers, dim=-1) @@ -239,6 +244,9 @@ def create_patch_mask_from_ids( return mask +GLOBAL = set() + + def cross_attn_mask( patch_ids, patch_lengths, @@ -265,9 +273,36 @@ def cross_attn_mask( kv_len, ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" if block_mask: + # This appears to resolve occasional nondeterministic RuntimeErrors + # in the create_block_mask call. I have no idea why. + cross_mask_copy = cross_mask.clone() def patch_mask(b, h, q_idx, kv_idx): - return cross_mask[b, q_idx, kv_idx] + return cross_mask_copy[b, q_idx, kv_idx] + + # print(f"cross_mask: {cross_mask.shape}") + # print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}") + # print(cross_mask[0, 0, 0]) + # for i in range(bs): + # for j in range(q_len): + # for k in range(kv_len): + # y = cross_mask[i, j, k] + + # import pickle + + # with open("cross_mask.pkl", "wb") as f: + # pickle.dump(cross_mask, f) + + # global GLOBAL + # s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}" + # if s not in GLOBAL: + # GLOBAL.add(s) + # print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}") + # else: + # print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)") + + # if q_len >= 51 and kv_len >= 96: + # breakpoint() block_mask = create_block_mask( patch_mask, @@ -277,6 +312,9 @@ def cross_attn_mask( KV_LEN=kv_len, _compile=True, ) + + # print(f"block_mask_shape: {block_mask.shape}") + return block_mask else: return torch.where( @@ -632,6 +670,8 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, cross_attn_nheads=args.cross_attn_nheads, eos_id=args.eos_id, + init_device=args.init_device, + init_dtype=args.init_dtype, ) return LocalEncoder(local_encoder_args) @@ -675,6 +715,8 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, cross_attn_nheads=args.cross_attn_nheads, eos_id=args.eos_id, + init_device=args.init_device, + init_dtype=args.init_dtype, ) return LocalDecoder(local_decoder_args) @@ -710,6 +752,8 @@ def init_embeddings( nn.Embedding( encoder_hash_byte_group_vocab, emb_dim, + device=args.init_device, + dtype=args.init_dtype, ) ) @@ -718,7 +762,14 @@ def init_embeddings( emb_dim = local_encoder_dim OFFSET = 4 # This should be passed as parameter if it's variable for ngram_vocab_size in encoder_ngram_to_size.values(): - embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) + embeddings.append( + nn.Embedding( + ngram_vocab_size + OFFSET, + emb_dim, + device=args.init_device, + dtype=args.init_dtype, + ) + ) return nn.ModuleList(embeddings) @@ -792,7 +843,7 @@ class ByteLatentTransformer( """ def __init__(self, args: ByteLatentTransformerArgs): - super().__init__() + super(ByteLatentTransformer, self).__init__() # General configuration self.weight_tying = args.weight_tying @@ -854,7 +905,12 @@ class ByteLatentTransformer( ngram_emb_dim = self.local_encoder.dim for ngram_vocab_size in self.encoder_ngram_to_size.values(): self.encoder_ngram_embedding.append( - nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) + nn.Embedding( + ngram_vocab_size + OFFSET, + ngram_emb_dim, + device=args.init_device, + dtype=args.init_dtype, + ) ) # Output layer @@ -873,6 +929,9 @@ class ByteLatentTransformer( ) ) + # Sanity check + check_param_device(self, args.init_device) + def push_to_hub(self, *args, **kwargs): raise ValueError( "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index a6cabdc..5a975cd 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -5,9 +5,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn -from torch.nn import functional as F -from torch.nn.attention.flex_attention import BlockMask -from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, @@ -16,6 +13,9 @@ from bytelatent.base_transformer import ( repeat_kv, ) from bytelatent.model.utils import create_causal_mask +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias logger = logging.getLogger() try: @@ -40,6 +40,8 @@ class CrossAttention(nn.Module): n_heads: int, n_kv_heads: int, norm_eps: float, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -50,29 +52,39 @@ class CrossAttention(nn.Module): self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads - self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) - self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_q = nn.RMSNorm( + dim, eps=norm_eps, device=device, dtype=dtype + ) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype) self.wq = nn.Linear( dim, n_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, + device=device, + dtype=dtype, ) def forward( @@ -160,6 +172,8 @@ class GlobalTransformer(BaseTransformer): args.dim_token_emb, args.dim, bias=False, + device=args.init_device, + dtype=args.init_dtype, ) def forward( diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 7083ac4..a9d2272 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -6,10 +6,6 @@ from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn -from pydantic import ConfigDict -from torch.nn import functional as F -from torch.nn.attention.flex_attention import BlockMask -from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformerArgs, @@ -20,6 +16,10 @@ from bytelatent.base_transformer import ( from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +from pydantic import ConfigDict +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias logger = logging.getLogger() try: @@ -85,18 +85,31 @@ class LocalModelBase(nn.Module): ) if not self.use_rope: - self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + self.pos_embeddings = nn.Embedding( + args.max_length, + args.dim, + device=args.init_device, + dtype=args.init_dtype, + ) else: self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + device=args.init_device, + dtype=args.init_dtype, ) self.pos_embeddings = None self.token_embedding_projection = ( - nn.Linear(args.dim_token_emb, args.dim, bias=False) + nn.Linear( + args.dim_token_emb, + args.dim, + bias=False, + device=args.init_device, + dtype=args.init_dtype, + ) if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim else None ) @@ -125,6 +138,8 @@ class LocalModelBase(nn.Module): in_features=args.dim_patch_emb, out_features=output_dim, bias=False, + device=args.init_device, + dtype=args.init_dtype, ) def apply_embedding(self, tokens, embeds): @@ -218,7 +233,9 @@ class LocalEncoder(LocalModelBase): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads - self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + self.tok_embeddings = nn.Embedding( + self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype + ) if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() @@ -231,6 +248,8 @@ class LocalEncoder(LocalModelBase): n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, + device=args.init_device, + dtype=args.init_dtype, ) ) @@ -321,7 +340,9 @@ class LocalDecoder(LocalModelBase): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = RMSNorm( + args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + ) if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() @@ -334,6 +355,8 @@ class LocalDecoder(LocalModelBase): n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, + device=args.init_device, + dtype=args.init_dtype, ) ) @@ -341,6 +364,8 @@ class LocalDecoder(LocalModelBase): self.dim, args.vocab_size, bias=False, + device=args.init_device, + dtype=args.init_dtype, ) def forward( diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index e01672e..0d4bcfe 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -175,3 +175,10 @@ def create_causal_mask( raise NotImplementedError( f"Attention {attn_impl} with {sliding_window} sliding window not implemented" ) + + +def check_param_device(model, device_type: str = "cpu"): + for name, param in model.named_parameters(): + assert ( + param.device.type == device_type + ), f"Parameter {name} is on {param.device.type}, not on {device_type}" diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 32d63be..0e61da1 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -4,25 +4,25 @@ import logging from typing import Optional, Tuple, Union import torch -from huggingface_hub import PyTorchModelHubMixin -from torch import nn -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, - parallelize_module, -) -from torch.nn.attention.flex_attention import BlockMask, create_block_mask -from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, BaseTransformerArgs, cross_entropy, ) -from bytelatent.model.utils import create_causal_mask +from bytelatent.model.utils import check_param_device, create_causal_mask +from huggingface_hub import PyTorchModelHubMixin +from torch import nn +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torch.nn.attention.flex_attention import BlockMask, create_block_mask +from xformers.ops import AttentionBias logger = logging.getLogger() @@ -84,19 +84,28 @@ class LMTransformer( assert args.vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + self.tok_embeddings = torch.nn.Embedding( + args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype + ) - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = RMSNorm( + args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + ) self.output = nn.Linear( args.dim, args.vocab_size, bias=False, + device=args.init_device, + dtype=args.init_dtype, ) if args.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight + # Sanity check + check_param_device(self, args.init_device) + def push_to_hub(self, *args, **kwargs): raise ValueError( "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."