diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 0b78c92..19953e9 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union import torch +from bytelatent.model.utils import DTYPE_MAP from bytelatent.tokenizers.constants import EOS_ID from pydantic import BaseModel, ConfigDict from torch import nn @@ -69,7 +70,7 @@ class BaseTransformerArgs(BaseModel): eos_id: int | None = EOS_ID init_device: str = "cpu" - init_dtype: torch.dtype = torch.float32 + init_dtype: str = "float32" def cross_entropy(pred, target, **kwargs): @@ -564,7 +565,7 @@ class TransformerBlock(nn.Module): n_kv_heads=self.n_kv_heads, rope_theta=args.rope_theta, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) self.feed_forward = FeedForward( dim=args.dim, @@ -572,14 +573,20 @@ class TransformerBlock(nn.Module): multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) # Norms stay in full precision self.attention_norm = RMSNorm( - args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) self.ffn_norm = RMSNorm( - args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) def forward( @@ -631,7 +638,7 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput): max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) self.eos_id = args.eos_id diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index e90e4f0..36726c9 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -11,7 +11,7 @@ logger = logging.getLogger() def load_entropy_model( - entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16 + entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16" ): with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) diff --git a/bytelatent/generate.py b/bytelatent/generate.py index db73ba3..71ddc52 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) - param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - train_args.distributed.model_dtype - ] - if train_args.train_entropy_model: model_args = train_args.entropy_model model_args.init_device = "cuda" - model_args.init_dtype = param_dtype + model_args.init_dtype = train_args.distributed.model_dtype model = LMTransformer(model_args) else: model_args = train_args.model model_args.init_device = "cuda" - model_args.init_dtype = param_dtype + model_args.init_dtype = train_args.distributed.model_dtype model = ByteLatentTransformer(args=model_args) model = model.eval() tokenizer = train_args.data.tokenizer_args.build() - with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: - st_dict = torch.load(f, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp: + st_dict = torch.load(fp, weights_only=True) model.load_state_dict(st_dict["model"]) diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index f649433..199c88b 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -13,7 +13,7 @@ from bytelatent.base_transformer import ( from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.model.latent_transformer import GlobalTransformer from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs -from bytelatent.model.utils import check_param_device, downsample +from bytelatent.model.utils import check_param_device, downsample, DTYPE_MAP from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID from huggingface_hub import PyTorchModelHubMixin @@ -753,7 +753,7 @@ def init_embeddings( encoder_hash_byte_group_vocab, emb_dim, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -767,7 +767,7 @@ def init_embeddings( ngram_vocab_size + OFFSET, emb_dim, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -909,7 +909,7 @@ class ByteLatentTransformer( ngram_vocab_size + OFFSET, ngram_emb_dim, device=args.init_device, - dtype=args.init_dtype, + dtype=dtype_map[args.init_dtype], ) ) diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index 5a975cd..55507c1 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -12,7 +12,7 @@ from bytelatent.base_transformer import ( flex_attention_comp, repeat_kv, ) -from bytelatent.model.utils import create_causal_mask +from bytelatent.model.utils import create_causal_mask, DTYPE_MAP from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -173,7 +173,7 @@ class GlobalTransformer(BaseTransformer): args.dim, bias=False, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) def forward( diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index a9d2272..5b81bdc 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -14,7 +14,7 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.model.latent_transformer import CrossAttention -from bytelatent.model.utils import create_causal_mask, downsample +from bytelatent.model.utils import create_causal_mask, downsample, DTYPE_MAP from bytelatent.tokenizers.blt_tokenizer import BOE_ID from pydantic import ConfigDict from torch.nn import functional as F @@ -89,7 +89,7 @@ class LocalModelBase(nn.Module): args.max_length, args.dim, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) else: self.rope = RotaryEmbedding( @@ -98,7 +98,7 @@ class LocalModelBase(nn.Module): max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) self.pos_embeddings = None @@ -108,7 +108,7 @@ class LocalModelBase(nn.Module): args.dim, bias=False, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim else None @@ -139,7 +139,7 @@ class LocalModelBase(nn.Module): out_features=output_dim, bias=False, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) def apply_embedding(self, tokens, embeds): @@ -234,7 +234,10 @@ class LocalEncoder(LocalModelBase): self.cross_attn_nheads = args.cross_attn_nheads self.tok_embeddings = nn.Embedding( - self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype + self.vocab_size, + args.dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) if self.cross_attn_encoder: @@ -249,7 +252,7 @@ class LocalEncoder(LocalModelBase): n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -341,7 +344,10 @@ class LocalDecoder(LocalModelBase): self.cross_attn_nheads = args.cross_attn_nheads self.norm = RMSNorm( - args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) if self.cross_attn_decoder: @@ -356,7 +362,7 @@ class LocalDecoder(LocalModelBase): n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -365,7 +371,7 @@ class LocalDecoder(LocalModelBase): args.vocab_size, bias=False, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) def forward( diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index 0d4bcfe..ebb4edf 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -8,6 +8,13 @@ from xformers.ops import fmha logger = logging.getLogger() +DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "fp64": torch.float64, +} + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 0e61da1..86ec17c 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -10,7 +10,7 @@ from bytelatent.base_transformer import ( BaseTransformerArgs, cross_entropy, ) -from bytelatent.model.utils import check_param_device, create_causal_mask +from bytelatent.model.utils import check_param_device, create_causal_mask, DTYPE_MAP from huggingface_hub import PyTorchModelHubMixin from torch import nn from torch.distributed._tensor import Replicate, Shard @@ -85,11 +85,17 @@ class LMTransformer( assert args.vocab_size > 0 self.tok_embeddings = torch.nn.Embedding( - args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype + args.vocab_size, + args.dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) self.norm = RMSNorm( - args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) self.output = nn.Linear( @@ -97,7 +103,7 @@ class LMTransformer( args.vocab_size, bias=False, device=args.init_device, - dtype=args.init_dtype, + dtype=DTYPE_MAP[args.init_dtype], ) if args.weight_tying: