mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-04 03:29:05 +00:00
Remove non-serializable type from model config
This commit is contained in:
parent
4c5e51e4de
commit
48256248b5
8 changed files with 57 additions and 35 deletions
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from bytelatent.model.utils import DTYPE_MAP
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from torch import nn
|
||||
|
@ -69,7 +70,7 @@ class BaseTransformerArgs(BaseModel):
|
|||
eos_id: int | None = EOS_ID
|
||||
|
||||
init_device: str = "cpu"
|
||||
init_dtype: torch.dtype = torch.float32
|
||||
init_dtype: str = "float32"
|
||||
|
||||
|
||||
def cross_entropy(pred, target, **kwargs):
|
||||
|
@ -564,7 +565,7 @@ class TransformerBlock(nn.Module):
|
|||
n_kv_heads=self.n_kv_heads,
|
||||
rope_theta=args.rope_theta,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
|
@ -572,14 +573,20 @@ class TransformerBlock(nn.Module):
|
|||
multiple_of=args.multiple_of,
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
# Norms stay in full precision
|
||||
self.attention_norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
args.dim,
|
||||
eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
self.ffn_norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
args.dim,
|
||||
eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -631,7 +638,7 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
|||
max_seqlen=args.max_seqlen,
|
||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
self.eos_id = args.eos_id
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ logger = logging.getLogger()
|
|||
|
||||
|
||||
def load_entropy_model(
|
||||
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16
|
||||
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16"
|
||||
):
|
||||
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||
reloaded = json.loads(fr.read())
|
||||
|
|
|
@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
|
|||
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||
train_args.distributed.model_dtype
|
||||
]
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
model_args.init_device = "cuda"
|
||||
model_args.init_dtype = param_dtype
|
||||
model_args.init_dtype = train_args.distributed.model_dtype
|
||||
model = LMTransformer(model_args)
|
||||
else:
|
||||
model_args = train_args.model
|
||||
model_args.init_device = "cuda"
|
||||
model_args.init_dtype = param_dtype
|
||||
model_args.init_dtype = train_args.distributed.model_dtype
|
||||
model = ByteLatentTransformer(args=model_args)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||
st_dict = torch.load(f, weights_only=True)
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
|
||||
st_dict = torch.load(fp, weights_only=True)
|
||||
|
||||
model.load_state_dict(st_dict["model"])
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from bytelatent.base_transformer import (
|
|||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||
from bytelatent.model.latent_transformer import GlobalTransformer
|
||||
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
|
||||
from bytelatent.model.utils import check_param_device, downsample
|
||||
from bytelatent.model.utils import check_param_device, downsample, DTYPE_MAP
|
||||
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
@ -753,7 +753,7 @@ def init_embeddings(
|
|||
encoder_hash_byte_group_vocab,
|
||||
emb_dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -767,7 +767,7 @@ def init_embeddings(
|
|||
ngram_vocab_size + OFFSET,
|
||||
emb_dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -909,7 +909,7 @@ class ByteLatentTransformer(
|
|||
ngram_vocab_size + OFFSET,
|
||||
ngram_emb_dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=dtype_map[args.init_dtype],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from bytelatent.base_transformer import (
|
|||
flex_attention_comp,
|
||||
repeat_kv,
|
||||
)
|
||||
from bytelatent.model.utils import create_causal_mask
|
||||
from bytelatent.model.utils import create_causal_mask, DTYPE_MAP
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
@ -173,7 +173,7 @@ class GlobalTransformer(BaseTransformer):
|
|||
args.dim,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -14,7 +14,7 @@ from bytelatent.base_transformer import (
|
|||
TransformerBlock,
|
||||
)
|
||||
from bytelatent.model.latent_transformer import CrossAttention
|
||||
from bytelatent.model.utils import create_causal_mask, downsample
|
||||
from bytelatent.model.utils import create_causal_mask, downsample, DTYPE_MAP
|
||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||
from pydantic import ConfigDict
|
||||
from torch.nn import functional as F
|
||||
|
@ -89,7 +89,7 @@ class LocalModelBase(nn.Module):
|
|||
args.max_length,
|
||||
args.dim,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
else:
|
||||
self.rope = RotaryEmbedding(
|
||||
|
@ -98,7 +98,7 @@ class LocalModelBase(nn.Module):
|
|||
max_seqlen=args.max_seqlen,
|
||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
self.pos_embeddings = None
|
||||
|
||||
|
@ -108,7 +108,7 @@ class LocalModelBase(nn.Module):
|
|||
args.dim,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
||||
else None
|
||||
|
@ -139,7 +139,7 @@ class LocalModelBase(nn.Module):
|
|||
out_features=output_dim,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
def apply_embedding(self, tokens, embeds):
|
||||
|
@ -234,7 +234,10 @@ class LocalEncoder(LocalModelBase):
|
|||
self.cross_attn_nheads = args.cross_attn_nheads
|
||||
|
||||
self.tok_embeddings = nn.Embedding(
|
||||
self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
|
||||
self.vocab_size,
|
||||
args.dim,
|
||||
device=args.init_device,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
if self.cross_attn_encoder:
|
||||
|
@ -249,7 +252,7 @@ class LocalEncoder(LocalModelBase):
|
|||
n_kv_heads=self.cross_attn_nheads,
|
||||
norm_eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -341,7 +344,10 @@ class LocalDecoder(LocalModelBase):
|
|||
self.cross_attn_nheads = args.cross_attn_nheads
|
||||
|
||||
self.norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
args.dim,
|
||||
eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
if self.cross_attn_decoder:
|
||||
|
@ -356,7 +362,7 @@ class LocalDecoder(LocalModelBase):
|
|||
n_kv_heads=self.cross_attn_nheads,
|
||||
norm_eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -365,7 +371,7 @@ class LocalDecoder(LocalModelBase):
|
|||
args.vocab_size,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -8,6 +8,13 @@ from xformers.ops import fmha
|
|||
|
||||
logger = logging.getLogger()
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
"fp64": torch.float64,
|
||||
}
|
||||
|
||||
|
||||
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
||||
"""
|
||||
|
|
|
@ -10,7 +10,7 @@ from bytelatent.base_transformer import (
|
|||
BaseTransformerArgs,
|
||||
cross_entropy,
|
||||
)
|
||||
from bytelatent.model.utils import check_param_device, create_causal_mask
|
||||
from bytelatent.model.utils import check_param_device, create_causal_mask, DTYPE_MAP
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import nn
|
||||
from torch.distributed._tensor import Replicate, Shard
|
||||
|
@ -85,11 +85,17 @@ class LMTransformer(
|
|||
assert args.vocab_size > 0
|
||||
|
||||
self.tok_embeddings = torch.nn.Embedding(
|
||||
args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
|
||||
args.vocab_size,
|
||||
args.dim,
|
||||
device=args.init_device,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(
|
||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
||||
args.dim,
|
||||
eps=args.norm_eps,
|
||||
device=args.init_device,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
self.output = nn.Linear(
|
||||
|
@ -97,7 +103,7 @@ class LMTransformer(
|
|||
args.vocab_size,
|
||||
bias=False,
|
||||
device=args.init_device,
|
||||
dtype=args.init_dtype,
|
||||
dtype=DTYPE_MAP[args.init_dtype],
|
||||
)
|
||||
|
||||
if args.weight_tying:
|
||||
|
|
Loading…
Add table
Reference in a new issue