Remove non-serializable type from model config

This commit is contained in:
Gustaf Ahdritz 2025-06-06 11:14:55 -07:00
parent 4c5e51e4de
commit 48256248b5
8 changed files with 57 additions and 35 deletions

View file

@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
import torch
from bytelatent.model.utils import DTYPE_MAP
from bytelatent.tokenizers.constants import EOS_ID
from pydantic import BaseModel, ConfigDict
from torch import nn
@ -69,7 +70,7 @@ class BaseTransformerArgs(BaseModel):
eos_id: int | None = EOS_ID
init_device: str = "cpu"
init_dtype: torch.dtype = torch.float32
init_dtype: str = "float32"
def cross_entropy(pred, target, **kwargs):
@ -564,7 +565,7 @@ class TransformerBlock(nn.Module):
n_kv_heads=self.n_kv_heads,
rope_theta=args.rope_theta,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
self.feed_forward = FeedForward(
dim=args.dim,
@ -572,14 +573,20 @@ class TransformerBlock(nn.Module):
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
# Norms stay in full precision
self.attention_norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.ffn_norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
def forward(
@ -631,7 +638,7 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
self.eos_id = args.eos_id

View file

@ -11,7 +11,7 @@ logger = logging.getLogger()
def load_entropy_model(
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16"
):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())

View file

@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
train_args.distributed.model_dtype
]
if train_args.train_entropy_model:
model_args = train_args.entropy_model
model_args.init_device = "cuda"
model_args.init_dtype = param_dtype
model_args.init_dtype = train_args.distributed.model_dtype
model = LMTransformer(model_args)
else:
model_args = train_args.model
model_args.init_device = "cuda"
model_args.init_dtype = param_dtype
model_args.init_dtype = train_args.distributed.model_dtype
model = ByteLatentTransformer(args=model_args)
model = model.eval()
tokenizer = train_args.data.tokenizer_args.build()
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
st_dict = torch.load(fp, weights_only=True)
model.load_state_dict(st_dict["model"])

View file

@ -13,7 +13,7 @@ from bytelatent.base_transformer import (
from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.latent_transformer import GlobalTransformer
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
from bytelatent.model.utils import check_param_device, downsample
from bytelatent.model.utils import check_param_device, downsample, DTYPE_MAP
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
from huggingface_hub import PyTorchModelHubMixin
@ -753,7 +753,7 @@ def init_embeddings(
encoder_hash_byte_group_vocab,
emb_dim,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
)
@ -767,7 +767,7 @@ def init_embeddings(
ngram_vocab_size + OFFSET,
emb_dim,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
)
@ -909,7 +909,7 @@ class ByteLatentTransformer(
ngram_vocab_size + OFFSET,
ngram_emb_dim,
device=args.init_device,
dtype=args.init_dtype,
dtype=dtype_map[args.init_dtype],
)
)

View file

@ -12,7 +12,7 @@ from bytelatent.base_transformer import (
flex_attention_comp,
repeat_kv,
)
from bytelatent.model.utils import create_causal_mask
from bytelatent.model.utils import create_causal_mask, DTYPE_MAP
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
@ -173,7 +173,7 @@ class GlobalTransformer(BaseTransformer):
args.dim,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
def forward(

View file

@ -14,7 +14,7 @@ from bytelatent.base_transformer import (
TransformerBlock,
)
from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.model.utils import create_causal_mask, downsample, DTYPE_MAP
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
from pydantic import ConfigDict
from torch.nn import functional as F
@ -89,7 +89,7 @@ class LocalModelBase(nn.Module):
args.max_length,
args.dim,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
else:
self.rope = RotaryEmbedding(
@ -98,7 +98,7 @@ class LocalModelBase(nn.Module):
max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
self.pos_embeddings = None
@ -108,7 +108,7 @@ class LocalModelBase(nn.Module):
args.dim,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
else None
@ -139,7 +139,7 @@ class LocalModelBase(nn.Module):
out_features=output_dim,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
def apply_embedding(self, tokens, embeds):
@ -234,7 +234,10 @@ class LocalEncoder(LocalModelBase):
self.cross_attn_nheads = args.cross_attn_nheads
self.tok_embeddings = nn.Embedding(
self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
self.vocab_size,
args.dim,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
if self.cross_attn_encoder:
@ -249,7 +252,7 @@ class LocalEncoder(LocalModelBase):
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
)
@ -341,7 +344,10 @@ class LocalDecoder(LocalModelBase):
self.cross_attn_nheads = args.cross_attn_nheads
self.norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
if self.cross_attn_decoder:
@ -356,7 +362,7 @@ class LocalDecoder(LocalModelBase):
n_kv_heads=self.cross_attn_nheads,
norm_eps=args.norm_eps,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
)
@ -365,7 +371,7 @@ class LocalDecoder(LocalModelBase):
args.vocab_size,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
def forward(

View file

@ -8,6 +8,13 @@ from xformers.ops import fmha
logger = logging.getLogger()
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"fp64": torch.float64,
}
def patch_reduce(h, max_num_patches, reduction, patch_ids):
"""

View file

@ -10,7 +10,7 @@ from bytelatent.base_transformer import (
BaseTransformerArgs,
cross_entropy,
)
from bytelatent.model.utils import check_param_device, create_causal_mask
from bytelatent.model.utils import check_param_device, create_causal_mask, DTYPE_MAP
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.distributed._tensor import Replicate, Shard
@ -85,11 +85,17 @@ class LMTransformer(
assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(
args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
args.vocab_size,
args.dim,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.norm = RMSNorm(
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
args.dim,
eps=args.norm_eps,
device=args.init_device,
dtype=DTYPE_MAP[args.init_dtype],
)
self.output = nn.Linear(
@ -97,7 +103,7 @@ class LMTransformer(
args.vocab_size,
bias=False,
device=args.init_device,
dtype=args.init_dtype,
dtype=DTYPE_MAP[args.init_dtype],
)
if args.weight_tying: