Improve HF compatibility (#99)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-05-01 11:44:50 -07:00 committed by GitHub
parent 1b67cbe022
commit bbc205c2b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 291 additions and 26 deletions

View file

@ -55,8 +55,37 @@ These instructions have been tested on H100 GPUs, but we can only offer suggesti
1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval.
2. On the huggingface cli, login: `huggingface-cli login`
3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
4. Run the generate demo: `python demo.py "A BLT has"`.
From here there are two options: (1) load weights in our train script and (2) loading weights via HF hub to use for anything else.
## Load Weights via HF Hub
In your terminal:
```bash
python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b hub --prompt "My test prompt"
```
In your own code:
```python
from bytelatent.transformer import LMTransformer
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.hf import BltTokenizerAndPatcher
entropy_repo = "facebook/blt-entropy"
blt_repo = "facebook/blt-1b"
entropy_model = LMTransformer.from_pretrained(entropy_repo)
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo)
tokenizer = tok_and_patcher.tokenizer_args.build()
patcher = tok_and_patcher.patcher_args.build()
```
## Load Weights for Running BLT Train Script
1. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
2. Run the generate demo: `python demo.py "A BLT has"`.
The demo generates text, but is also a good starting point for loading BLT in your own code.

View file

@ -34,7 +34,7 @@ else:
flex_attention_comp = None
class InitStdFactor(Enum):
class InitStdFactor(str, Enum):
DISABLED = "disabled" # Init std is divided by 1.0
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)

View file

@ -486,7 +486,7 @@ class Patcher:
state_path = os.path.join(
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
)
entropy_model = load_entropy_model(
entropy_model, _ = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir,
state_path,
)

View file

@ -19,19 +19,18 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
)
entropy_model = LMTransformer(
LMTransformerArgs(
dim=model_params["dim"],
n_layers=model_params["n_layers"],
n_heads=model_params["n_heads"],
max_seqlen=model_params["max_seqlen"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
)
entropy_model_args = LMTransformerArgs(
dim=model_params["dim"],
n_layers=model_params["n_layers"],
n_heads=model_params["n_heads"],
max_seqlen=model_params["max_seqlen"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
)
entropy_model = LMTransformer(entropy_model_args)
entropy_model.load_state_dict(
torch.load(state_dict_path, map_location=device)["model"], strict=False
@ -41,4 +40,4 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
# no grads for the model:
for param in entropy_model.parameters():
param.requires_grad = False
return entropy_model
return entropy_model, entropy_model_args

199
bytelatent/hf.py Normal file
View file

@ -0,0 +1,199 @@
import json
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Union
import torch
import typer
from huggingface_hub import hf_hub_download
from huggingface_hub.hub_mixin import ModelHubMixin
from bytelatent.args import TrainArgs
from bytelatent.data.patcher import PatcherArgs, to_device
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
from bytelatent.entropy_model import load_entropy_model
from bytelatent.generate import load_consolidated_model_and_tokenizer
from bytelatent.generate_blt import generate_nocache
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformer
app = typer.Typer()
class BltTokenizerAndPatcher(ModelHubMixin):
def __init__(
self,
*,
patcher_args: PatcherArgs,
tokenizer_args: TokenizerArgs,
distributed_args: DistributedArgs,
):
self.patcher_args = patcher_args
self.tokenizer_args = tokenizer_args
self.distributed_args = distributed_args
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def save_pretrained(self, *args, **kwargs):
raise ValueError(
"Tokenizer and Patcher are saved by BLT, this class is just for loading"
)
def _save_pretrained(self, *args, **kwargs):
raise ValueError(
"Tokenizer and Patcher are saved by BLT, this class is just for loading"
)
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
**model_kwargs,
):
if os.path.isdir(model_id):
train_args_file = os.path.join(model_id, "train_args.json")
else:
train_args_file = hf_hub_download(
repo_id=model_id,
filename="train_args.json",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
)
with open(train_args_file) as f:
train_args = TrainArgs(**json.load(f))
return cls(
patcher_args=train_args.data.patcher_args,
tokenizer_args=train_args.data.tokenizer_args,
distributed_args=train_args.distributed,
)
@app.command()
def convert_to_transformers(blt_weights_dir: str, output_dir: str):
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir)
blt_dir = os.path.join(output_dir, "blt")
entropy_dir = os.path.join(output_dir, "entropy")
model.save_pretrained(blt_dir, config={"args": train_cfg.model.model_dump()})
shutil.copyfile(
os.path.join(blt_weights_dir, "params.json"),
os.path.join(blt_dir, "train_args.json"),
)
blt_readme_file = os.path.join(blt_dir, "README.md")
if os.path.exists(blt_readme_file):
os.remove(blt_readme_file)
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = False
print("Loading entropy model and patcher")
patcher_args.entropy_model_checkpoint_dir = os.path.join(
blt_weights_dir, "entropy_model"
)
state_path = os.path.join(
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
)
entropy_model, entropy_model_args = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir, state_path
)
entropy_model.save_pretrained(
entropy_dir, config={"args": entropy_model_args.model_dump()}
)
entropy_readme_file = os.path.join(entropy_dir, "README.md")
if os.path.exists(entropy_readme_file):
os.remove(entropy_readme_file)
@app.command()
def load_transformers(
source: str,
entropy_repo: str = "facebook/blt-entropy",
blt_repo: str = "facebook/blt-1b",
entropy_dir: str | None = None,
blt_dir: str | None = None,
prompt: str | None = None,
):
if source == "local":
assert entropy_dir is not None
assert blt_dir is not None
entropy_model = LMTransformer.from_pretrained(
entropy_dir, local_files_only=True
)
blt_model = ByteLatentTransformer.from_pretrained(
blt_dir, local_files_only=True
)
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(
blt_dir, local_files_only=True
)
tokenizer = tok_and_patcher.tokenizer_args.build()
patcher = tok_and_patcher.patcher_args.build()
print("Loaded all local")
print(entropy_model)
print(blt_model)
print(tok_and_patcher)
elif source == "hub":
entropy_model = LMTransformer.from_pretrained(entropy_repo)
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo)
tokenizer = tok_and_patcher.tokenizer_args.build()
patcher = tok_and_patcher.patcher_args.build()
print("Loaded all remote")
print(entropy_model)
print(blt_model)
print(tok_and_patcher)
else:
raise ValueError(f"Unknown source: {source}")
if prompt is not None:
assert isinstance(tokenizer, BltTokenizer)
# Move args to correct GPU
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
tok_and_patcher.distributed_args.model_dtype
]
blt_model = blt_model.cuda().eval()
for param in blt_model.parameters():
param.data = param.data.to(dtype=param_dtype)
# Enable realtime patching
patcher.realtime_patching = True
patcher.entropy_model, _ = to_device(
entropy_model, tok_and_patcher.patcher_args.patching_device
)
# Setup distributed
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized():
setup_torch_distributed(distributed_args)
prompts = [prompt]
outputs = generate_nocache(
prompts, model=blt_model, tokenizer=tokenizer, patcher=patcher
)
text_outputs = [tokenizer.decode(t) for t in outputs]
for p, t in zip(prompts, text_outputs):
print(f'Prompt: "{p}"\nCompletion: "{t}"')
print()
if __name__ == "__main__":
app()

View file

@ -4,6 +4,7 @@ from enum import Enum, auto
from typing import Any, Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from pydantic import model_validator
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
@ -20,8 +21,6 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
from huggingface_hub import PyTorchModelHubMixin
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
@ -768,10 +767,23 @@ def compute_hash_embeddings(
return local_encoder_embeds
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
pipeline_tag="text-generation",
license="other"):
class ByteLatentTransformer(
nn.Module,
SequenceModelWithOutput,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
paper_url="https://arxiv.org/abs/2412.09871",
pipeline_tag="text-generation",
license="other",
license_name="fair-noncommercial-research-license",
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
coders={
ByteLatentTransformerArgs: (
lambda x: {"args": x.model_dump()},
lambda data: ByteLatentTransformerArgs(**data),
)
},
):
"""
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
@ -861,6 +873,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM
)
)
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def get_output_seq_len(self):
return self.max_seqlen

View file

@ -82,7 +82,7 @@ def main(
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model, _ = load_entropy_model(
entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,

View file

@ -34,7 +34,7 @@ def test_entropy_model():
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
entropy_model = load_entropy_model(
entropy_model, _ = load_entropy_model(
BLT_DATA / "checkpoint_0100000_consolidated",
os.path.join(
BLT_DATA,

View file

@ -4,6 +4,7 @@ import logging
from typing import Optional, Tuple, Union
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
@ -60,7 +61,22 @@ class LMTransformerArgs(BaseTransformerArgs):
sliding_window: int | None = None
class LMTransformer(BaseTransformer):
class LMTransformer(
BaseTransformer,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
paper_url="https://arxiv.org/abs/2412.09871",
pipeline_tag="text-generation",
license="other",
license_name="fair-noncommercial-research-license",
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
coders={
LMTransformerArgs: (
lambda x: {"args": x.model_dump()},
lambda data: LMTransformerArgs(**data),
)
},
):
def __init__(self, args: LMTransformerArgs):
super().__init__(args)
self.weight_tying = args.weight_tying
@ -81,6 +97,11 @@ class LMTransformer(BaseTransformer):
if args.weight_tying:
self.output.weight = self.embeddings.tok_embeddings.weight
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def forward(
self,
token_values: torch.Tensor,