mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-06 12:39:04 +00:00
Remove non-serializable type from model config
This commit is contained in:
parent
4c5e51e4de
commit
48256248b5
8 changed files with 57 additions and 35 deletions
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from bytelatent.model.utils import DTYPE_MAP
|
||||||
from bytelatent.tokenizers.constants import EOS_ID
|
from bytelatent.tokenizers.constants import EOS_ID
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -69,7 +70,7 @@ class BaseTransformerArgs(BaseModel):
|
||||||
eos_id: int | None = EOS_ID
|
eos_id: int | None = EOS_ID
|
||||||
|
|
||||||
init_device: str = "cpu"
|
init_device: str = "cpu"
|
||||||
init_dtype: torch.dtype = torch.float32
|
init_dtype: str = "float32"
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(pred, target, **kwargs):
|
def cross_entropy(pred, target, **kwargs):
|
||||||
|
@ -564,7 +565,7 @@ class TransformerBlock(nn.Module):
|
||||||
n_kv_heads=self.n_kv_heads,
|
n_kv_heads=self.n_kv_heads,
|
||||||
rope_theta=args.rope_theta,
|
rope_theta=args.rope_theta,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
self.feed_forward = FeedForward(
|
self.feed_forward = FeedForward(
|
||||||
dim=args.dim,
|
dim=args.dim,
|
||||||
|
@ -572,14 +573,20 @@ class TransformerBlock(nn.Module):
|
||||||
multiple_of=args.multiple_of,
|
multiple_of=args.multiple_of,
|
||||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
# Norms stay in full precision
|
# Norms stay in full precision
|
||||||
self.attention_norm = RMSNorm(
|
self.attention_norm = RMSNorm(
|
||||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
args.dim,
|
||||||
|
eps=args.norm_eps,
|
||||||
|
device=args.init_device,
|
||||||
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
self.ffn_norm = RMSNorm(
|
self.ffn_norm = RMSNorm(
|
||||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
args.dim,
|
||||||
|
eps=args.norm_eps,
|
||||||
|
device=args.init_device,
|
||||||
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -631,7 +638,7 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
||||||
max_seqlen=args.max_seqlen,
|
max_seqlen=args.max_seqlen,
|
||||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
self.eos_id = args.eos_id
|
self.eos_id = args.eos_id
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def load_entropy_model(
|
def load_entropy_model(
|
||||||
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype=torch.bfloat16
|
entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16"
|
||||||
):
|
):
|
||||||
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||||
reloaded = json.loads(fr.read())
|
reloaded = json.loads(fr.read())
|
||||||
|
|
|
@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
|
||||||
|
|
||||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||||
|
|
||||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
|
||||||
train_args.distributed.model_dtype
|
|
||||||
]
|
|
||||||
|
|
||||||
if train_args.train_entropy_model:
|
if train_args.train_entropy_model:
|
||||||
model_args = train_args.entropy_model
|
model_args = train_args.entropy_model
|
||||||
model_args.init_device = "cuda"
|
model_args.init_device = "cuda"
|
||||||
model_args.init_dtype = param_dtype
|
model_args.init_dtype = train_args.distributed.model_dtype
|
||||||
model = LMTransformer(model_args)
|
model = LMTransformer(model_args)
|
||||||
else:
|
else:
|
||||||
model_args = train_args.model
|
model_args = train_args.model
|
||||||
model_args.init_device = "cuda"
|
model_args.init_device = "cuda"
|
||||||
model_args.init_dtype = param_dtype
|
model_args.init_dtype = train_args.distributed.model_dtype
|
||||||
model = ByteLatentTransformer(args=model_args)
|
model = ByteLatentTransformer(args=model_args)
|
||||||
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
tokenizer = train_args.data.tokenizer_args.build()
|
tokenizer = train_args.data.tokenizer_args.build()
|
||||||
|
|
||||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
|
||||||
st_dict = torch.load(f, weights_only=True)
|
st_dict = torch.load(fp, weights_only=True)
|
||||||
|
|
||||||
model.load_state_dict(st_dict["model"])
|
model.load_state_dict(st_dict["model"])
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from bytelatent.base_transformer import (
|
||||||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||||
from bytelatent.model.latent_transformer import GlobalTransformer
|
from bytelatent.model.latent_transformer import GlobalTransformer
|
||||||
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
|
from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
|
||||||
from bytelatent.model.utils import check_param_device, downsample
|
from bytelatent.model.utils import check_param_device, downsample, DTYPE_MAP
|
||||||
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
@ -753,7 +753,7 @@ def init_embeddings(
|
||||||
encoder_hash_byte_group_vocab,
|
encoder_hash_byte_group_vocab,
|
||||||
emb_dim,
|
emb_dim,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -767,7 +767,7 @@ def init_embeddings(
|
||||||
ngram_vocab_size + OFFSET,
|
ngram_vocab_size + OFFSET,
|
||||||
emb_dim,
|
emb_dim,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -909,7 +909,7 @@ class ByteLatentTransformer(
|
||||||
ngram_vocab_size + OFFSET,
|
ngram_vocab_size + OFFSET,
|
||||||
ngram_emb_dim,
|
ngram_emb_dim,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=dtype_map[args.init_dtype],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from bytelatent.base_transformer import (
|
||||||
flex_attention_comp,
|
flex_attention_comp,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from bytelatent.model.utils import create_causal_mask
|
from bytelatent.model.utils import create_causal_mask, DTYPE_MAP
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
|
@ -173,7 +173,7 @@ class GlobalTransformer(BaseTransformer):
|
||||||
args.dim,
|
args.dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -14,7 +14,7 @@ from bytelatent.base_transformer import (
|
||||||
TransformerBlock,
|
TransformerBlock,
|
||||||
)
|
)
|
||||||
from bytelatent.model.latent_transformer import CrossAttention
|
from bytelatent.model.latent_transformer import CrossAttention
|
||||||
from bytelatent.model.utils import create_causal_mask, downsample
|
from bytelatent.model.utils import create_causal_mask, downsample, DTYPE_MAP
|
||||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
@ -89,7 +89,7 @@ class LocalModelBase(nn.Module):
|
||||||
args.max_length,
|
args.max_length,
|
||||||
args.dim,
|
args.dim,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.rope = RotaryEmbedding(
|
self.rope = RotaryEmbedding(
|
||||||
|
@ -98,7 +98,7 @@ class LocalModelBase(nn.Module):
|
||||||
max_seqlen=args.max_seqlen,
|
max_seqlen=args.max_seqlen,
|
||||||
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
self.pos_embeddings = None
|
self.pos_embeddings = None
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ class LocalModelBase(nn.Module):
|
||||||
args.dim,
|
args.dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
||||||
else None
|
else None
|
||||||
|
@ -139,7 +139,7 @@ class LocalModelBase(nn.Module):
|
||||||
out_features=output_dim,
|
out_features=output_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_embedding(self, tokens, embeds):
|
def apply_embedding(self, tokens, embeds):
|
||||||
|
@ -234,7 +234,10 @@ class LocalEncoder(LocalModelBase):
|
||||||
self.cross_attn_nheads = args.cross_attn_nheads
|
self.cross_attn_nheads = args.cross_attn_nheads
|
||||||
|
|
||||||
self.tok_embeddings = nn.Embedding(
|
self.tok_embeddings = nn.Embedding(
|
||||||
self.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
|
self.vocab_size,
|
||||||
|
args.dim,
|
||||||
|
device=args.init_device,
|
||||||
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cross_attn_encoder:
|
if self.cross_attn_encoder:
|
||||||
|
@ -249,7 +252,7 @@ class LocalEncoder(LocalModelBase):
|
||||||
n_kv_heads=self.cross_attn_nheads,
|
n_kv_heads=self.cross_attn_nheads,
|
||||||
norm_eps=args.norm_eps,
|
norm_eps=args.norm_eps,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -341,7 +344,10 @@ class LocalDecoder(LocalModelBase):
|
||||||
self.cross_attn_nheads = args.cross_attn_nheads
|
self.cross_attn_nheads = args.cross_attn_nheads
|
||||||
|
|
||||||
self.norm = RMSNorm(
|
self.norm = RMSNorm(
|
||||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
args.dim,
|
||||||
|
eps=args.norm_eps,
|
||||||
|
device=args.init_device,
|
||||||
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cross_attn_decoder:
|
if self.cross_attn_decoder:
|
||||||
|
@ -356,7 +362,7 @@ class LocalDecoder(LocalModelBase):
|
||||||
n_kv_heads=self.cross_attn_nheads,
|
n_kv_heads=self.cross_attn_nheads,
|
||||||
norm_eps=args.norm_eps,
|
norm_eps=args.norm_eps,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -365,7 +371,7 @@ class LocalDecoder(LocalModelBase):
|
||||||
args.vocab_size,
|
args.vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -8,6 +8,13 @@ from xformers.ops import fmha
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"bf16": torch.bfloat16,
|
||||||
|
"fp16": torch.float16,
|
||||||
|
"fp32": torch.float32,
|
||||||
|
"fp64": torch.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -10,7 +10,7 @@ from bytelatent.base_transformer import (
|
||||||
BaseTransformerArgs,
|
BaseTransformerArgs,
|
||||||
cross_entropy,
|
cross_entropy,
|
||||||
)
|
)
|
||||||
from bytelatent.model.utils import check_param_device, create_causal_mask
|
from bytelatent.model.utils import check_param_device, create_causal_mask, DTYPE_MAP
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed._tensor import Replicate, Shard
|
from torch.distributed._tensor import Replicate, Shard
|
||||||
|
@ -85,11 +85,17 @@ class LMTransformer(
|
||||||
assert args.vocab_size > 0
|
assert args.vocab_size > 0
|
||||||
|
|
||||||
self.tok_embeddings = torch.nn.Embedding(
|
self.tok_embeddings = torch.nn.Embedding(
|
||||||
args.vocab_size, args.dim, device=args.init_device, dtype=args.init_dtype
|
args.vocab_size,
|
||||||
|
args.dim,
|
||||||
|
device=args.init_device,
|
||||||
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(
|
self.norm = RMSNorm(
|
||||||
args.dim, eps=args.norm_eps, device=args.init_device, dtype=args.init_dtype
|
args.dim,
|
||||||
|
eps=args.norm_eps,
|
||||||
|
device=args.init_device,
|
||||||
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.output = nn.Linear(
|
self.output = nn.Linear(
|
||||||
|
@ -97,7 +103,7 @@ class LMTransformer(
|
||||||
args.vocab_size,
|
args.vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
device=args.init_device,
|
device=args.init_device,
|
||||||
dtype=args.init_dtype,
|
dtype=DTYPE_MAP[args.init_dtype],
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.weight_tying:
|
if args.weight_tying:
|
||||||
|
|
Loading…
Add table
Reference in a new issue