Remove non-serializable type from model config

This commit is contained in:
Gustaf Ahdritz 2025-06-06 11:14:55 -07:00
parent 4c5e51e4de
commit 48256248b5
8 changed files with 57 additions and 35 deletions

View file

@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
import torch import torch
from bytelatent.model.utils import DTYPE_MAP
from bytelatent.tokenizers.constants import EOS_ID from bytelatent.tokenizers.constants import EOS_ID
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from torch import nn from torch import nn
@ -69,7 +70,7 @@ class BaseTransformerArgs(BaseModel):
eos_id: int | None = EOS_ID eos_id: int | None = EOS_ID
init_device: str = "cpu" init_device: str = "cpu"
init_dtype: torch.dtype = torch.float32 init_dtype: str = "float32"
def cross_entropy(pred, target, **kwargs): def cross_entropy(pred, target, **kwargs):
@ -564,7 +565,7 @@ class TransformerBlock(nn.Module):
n_kv_heads=self.n_kv_heads, n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta, rope_theta=args.rope_theta,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
self.feed_forward = FeedForward( self.feed_forward = FeedForward(
dim=args.dim, dim=args.dim,
@ -572,14 +573,20 @@ class TransformerBlock(nn.Module):
multiple_of=args.multiple_of, multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier, ffn_dim_multiplier=args.ffn_dim_multiplier,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
# Norms stay in full precision # Norms stay in full precision
self.attention_norm = RMSNorm( self.attention_norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
) )
self.ffn_norm = RMSNorm( self.ffn_norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
) )
def forward( def forward(
@ -631,7 +638,7 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
max_seqlen=args.max_seqlen, max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
self.eos_id = args.eos_id self.eos_id = args.eos_id

View file

@ -11,7 +11,7 @@ logger = logging.getLogger()
def load_entropy_model( def load_entropy_model(
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16 entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16"
): ):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read()) reloaded = json.loads(fr.read())

View file

@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
train_args.distributed.model_dtype
]
if train_args.train_entropy_model: if train_args.train_entropy_model:
model_args = train_args.entropy_model model_args = train_args.entropy_model
model_args.init_device = "cuda" model_args.init_device = "cuda"
model_args.init_dtype = param_dtype model_args.init_dtype = train_args.distributed.model_dtype
model = LMTransformer(model_args) model = LMTransformer(model_args)
else: else:
model_args = train_args.model model_args = train_args.model
model_args.init_device = "cuda" model_args.init_device = "cuda"
model_args.init_dtype = param_dtype model_args.init_dtype = train_args.distributed.model_dtype
model = ByteLatentTransformer(args=model_args) model = ByteLatentTransformer(args=model_args)
model = model.eval() model = model.eval()
tokenizer = train_args.data.tokenizer_args.build() tokenizer = train_args.data.tokenizer_args.build()
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
st_dict = torch.load(f, weights_only=True) st_dict = torch.load(fp, weights_only=True)
model.load_state_dict(st_dict["model"]) model.load_state_dict(st_dict["model"])

View file

@ -13,7 +13,7 @@ from bytelatent.base_transformer import (
from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.latent_transformer import GlobalTransformer from bytelatent.model.latent_transformer import GlobalTransformer
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.utils import check_param_device, downsample from bytelatent.model.utils import check_param_device, downsample, DTYPE_MAP
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
@ -753,7 +753,7 @@ def init_embeddings(
encoder_hash_byte_group_vocab, encoder_hash_byte_group_vocab,
emb_dim, emb_dim,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
) )
@ -767,7 +767,7 @@ def init_embeddings(
ngram_vocab_size + OFFSET, ngram_vocab_size + OFFSET,
emb_dim, emb_dim,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
) )
@ -909,7 +909,7 @@ class ByteLatentTransformer(
ngram_vocab_size + OFFSET, ngram_vocab_size + OFFSET,
ngram_emb_dim, ngram_emb_dim,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=dtype_map[args.init_dtype],
) )
) )

View file

@ -12,7 +12,7 @@ from bytelatent.base_transformer import (
flex_attention_comp, flex_attention_comp,
repeat_kv, repeat_kv,
) )
from bytelatent.model.utils import create_causal_mask from bytelatent.model.utils import create_causal_mask, DTYPE_MAP
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias from xformers.ops import AttentionBias
@ -173,7 +173,7 @@ class GlobalTransformer(BaseTransformer):
args.dim, args.dim,
bias=False, bias=False,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
def forward( def forward(

View file

@ -14,7 +14,7 @@ from bytelatent.base_transformer import (
TransformerBlock, TransformerBlock,
) )
from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.model.utils import create_causal_mask, downsample, DTYPE_MAP
from bytelatent.tokenizers.blt_tokenizer import BOE_ID from bytelatent.tokenizers.blt_tokenizer import BOE_ID
from pydantic import ConfigDict from pydantic import ConfigDict
from torch.nn import functional as F from torch.nn import functional as F
@ -89,7 +89,7 @@ class LocalModelBase(nn.Module):
args.max_length, args.max_length,
args.dim, args.dim,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
else: else:
self.rope = RotaryEmbedding( self.rope = RotaryEmbedding(
@ -98,7 +98,7 @@ class LocalModelBase(nn.Module):
max_seqlen=args.max_seqlen, max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
self.pos_embeddings = None self.pos_embeddings = None
@ -108,7 +108,7 @@ class LocalModelBase(nn.Module):
args.dim, args.dim,
bias=False, bias=False,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
else None else None
@ -139,7 +139,7 @@ class LocalModelBase(nn.Module):
out_features=output_dim, out_features=output_dim,
bias=False, bias=False,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
def apply_embedding(self, tokens, embeds): def apply_embedding(self, tokens, embeds):
@ -234,7 +234,10 @@ class LocalEncoder(LocalModelBase):
self.cross_attn_nheads = args.cross_attn_nheads self.cross_attn_nheads = args.cross_attn_nheads
self.tok_embeddings = nn.Embedding( self.tok_embeddings = nn.Embedding(
self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype self.vocab_size,
args.dim,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
) )
if self.cross_attn_encoder: if self.cross_attn_encoder:
@ -249,7 +252,7 @@ class LocalEncoder(LocalModelBase):
n_kv_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps, norm_eps=args.norm_eps,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
) )
@ -341,7 +344,10 @@ class LocalDecoder(LocalModelBase):
self.cross_attn_nheads = args.cross_attn_nheads self.cross_attn_nheads = args.cross_attn_nheads
self.norm = RMSNorm( self.norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
) )
if self.cross_attn_decoder: if self.cross_attn_decoder:
@ -356,7 +362,7 @@ class LocalDecoder(LocalModelBase):
n_kv_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps, norm_eps=args.norm_eps,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
) )
@ -365,7 +371,7 @@ class LocalDecoder(LocalModelBase):
args.vocab_size, args.vocab_size,
bias=False, bias=False,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
def forward( def forward(

View file

@ -8,6 +8,13 @@ from xformers.ops import fmha
logger = logging.getLogger() logger = logging.getLogger()
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"fp64": torch.float64,
}
def patch_reduce(h, max_num_patches, reduction, patch_ids): def patch_reduce(h, max_num_patches, reduction, patch_ids):
""" """

View file

@ -10,7 +10,7 @@ from bytelatent.base_transformer import (
BaseTransformerArgs, BaseTransformerArgs,
cross_entropy, cross_entropy,
) )
from bytelatent.model.utils import check_param_device, create_causal_mask from bytelatent.model.utils import check_param_device, create_causal_mask, DTYPE_MAP
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from torch import nn from torch import nn
from torch.distributed._tensor import Replicate, Shard from torch.distributed._tensor import Replicate, Shard
@ -85,11 +85,17 @@ class LMTransformer(
assert args.vocab_size > 0 assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding( self.tok_embeddings = torch.nn.Embedding(
args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype args.vocab_size,
args.dim,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
) )
self.norm = RMSNorm( self.norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
) )
self.output = nn.Linear( self.output = nn.Linear(
@ -97,7 +103,7 @@ class LMTransformer(
args.vocab_size, args.vocab_size,
bias=False, bias=False,
device=args.init_device, device=args.init_device,
dtype=args.init_dtype, dtype=DTYPE_MAP[args.init_dtype],
) )
if args.weight_tying: if args.weight_tying: