mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 10:09:06 +00:00
parent
1b67cbe022
commit
bbc205c2b7
9 changed files with 291 additions and 26 deletions
33
README.md
33
README.md
|
@ -55,8 +55,37 @@ These instructions have been tested on H100 GPUs, but we can only offer suggesti
|
|||
|
||||
1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval.
|
||||
2. On the huggingface cli, login: `huggingface-cli login`
|
||||
3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
|
||||
4. Run the generate demo: `python demo.py "A BLT has"`.
|
||||
|
||||
From here there are two options: (1) load weights in our train script and (2) loading weights via HF hub to use for anything else.
|
||||
|
||||
## Load Weights via HF Hub
|
||||
|
||||
In your terminal:
|
||||
|
||||
```bash
|
||||
python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b hub --prompt "My test prompt"
|
||||
```
|
||||
|
||||
In your own code:
|
||||
|
||||
```python
|
||||
from bytelatent.transformer import LMTransformer
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.hf import BltTokenizerAndPatcher
|
||||
|
||||
entropy_repo = "facebook/blt-entropy"
|
||||
blt_repo = "facebook/blt-1b"
|
||||
entropy_model = LMTransformer.from_pretrained(entropy_repo)
|
||||
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
|
||||
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo)
|
||||
tokenizer = tok_and_patcher.tokenizer_args.build()
|
||||
patcher = tok_and_patcher.patcher_args.build()
|
||||
```
|
||||
|
||||
## Load Weights for Running BLT Train Script
|
||||
|
||||
1. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
|
||||
2. Run the generate demo: `python demo.py "A BLT has"`.
|
||||
|
||||
The demo generates text, but is also a good starting point for loading BLT in your own code.
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ else:
|
|||
flex_attention_comp = None
|
||||
|
||||
|
||||
class InitStdFactor(Enum):
|
||||
class InitStdFactor(str, Enum):
|
||||
DISABLED = "disabled" # Init std is divided by 1.0
|
||||
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
|
||||
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
||||
|
|
|
@ -486,7 +486,7 @@ class Patcher:
|
|||
state_path = os.path.join(
|
||||
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
||||
)
|
||||
entropy_model = load_entropy_model(
|
||||
entropy_model, _ = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir,
|
||||
state_path,
|
||||
)
|
||||
|
|
|
@ -19,19 +19,18 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
|||
logger.warning(
|
||||
"Update checkpoint to load attn and sliding window args from checkpoint"
|
||||
)
|
||||
entropy_model = LMTransformer(
|
||||
LMTransformerArgs(
|
||||
dim=model_params["dim"],
|
||||
n_layers=model_params["n_layers"],
|
||||
n_heads=model_params["n_heads"],
|
||||
max_seqlen=model_params["max_seqlen"],
|
||||
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
||||
vocab_size=model_params["vocab_size"],
|
||||
attn_bias_type="local_block_causal",
|
||||
attn_impl="xformers",
|
||||
sliding_window=512,
|
||||
)
|
||||
entropy_model_args = LMTransformerArgs(
|
||||
dim=model_params["dim"],
|
||||
n_layers=model_params["n_layers"],
|
||||
n_heads=model_params["n_heads"],
|
||||
max_seqlen=model_params["max_seqlen"],
|
||||
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
||||
vocab_size=model_params["vocab_size"],
|
||||
attn_bias_type="local_block_causal",
|
||||
attn_impl="xformers",
|
||||
sliding_window=512,
|
||||
)
|
||||
entropy_model = LMTransformer(entropy_model_args)
|
||||
|
||||
entropy_model.load_state_dict(
|
||||
torch.load(state_dict_path, map_location=device)["model"], strict=False
|
||||
|
@ -41,4 +40,4 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
|||
# no grads for the model:
|
||||
for param in entropy_model.parameters():
|
||||
param.requires_grad = False
|
||||
return entropy_model
|
||||
return entropy_model, entropy_model_args
|
||||
|
|
199
bytelatent/hf.py
Normal file
199
bytelatent/hf.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import typer
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.hub_mixin import ModelHubMixin
|
||||
|
||||
from bytelatent.args import TrainArgs
|
||||
from bytelatent.data.patcher import PatcherArgs, to_device
|
||||
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
||||
from bytelatent.entropy_model import load_entropy_model
|
||||
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
||||
from bytelatent.generate_blt import generate_nocache
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
from bytelatent.transformer import LMTransformer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class BltTokenizerAndPatcher(ModelHubMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
patcher_args: PatcherArgs,
|
||||
tokenizer_args: TokenizerArgs,
|
||||
distributed_args: DistributedArgs,
|
||||
):
|
||||
self.patcher_args = patcher_args
|
||||
self.tokenizer_args = tokenizer_args
|
||||
self.distributed_args = distributed_args
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
)
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"Tokenizer and Patcher are saved by BLT, this class is just for loading"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"Tokenizer and Patcher are saved by BLT, this class is just for loading"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Union[str, Path]],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: Optional[bool],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
**model_kwargs,
|
||||
):
|
||||
if os.path.isdir(model_id):
|
||||
train_args_file = os.path.join(model_id, "train_args.json")
|
||||
else:
|
||||
train_args_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="train_args.json",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
|
||||
with open(train_args_file) as f:
|
||||
train_args = TrainArgs(**json.load(f))
|
||||
return cls(
|
||||
patcher_args=train_args.data.patcher_args,
|
||||
tokenizer_args=train_args.data.tokenizer_args,
|
||||
distributed_args=train_args.distributed,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def convert_to_transformers(blt_weights_dir: str, output_dir: str):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir)
|
||||
blt_dir = os.path.join(output_dir, "blt")
|
||||
entropy_dir = os.path.join(output_dir, "entropy")
|
||||
model.save_pretrained(blt_dir, config={"args": train_cfg.model.model_dump()})
|
||||
shutil.copyfile(
|
||||
os.path.join(blt_weights_dir, "params.json"),
|
||||
os.path.join(blt_dir, "train_args.json"),
|
||||
)
|
||||
blt_readme_file = os.path.join(blt_dir, "README.md")
|
||||
if os.path.exists(blt_readme_file):
|
||||
os.remove(blt_readme_file)
|
||||
|
||||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
patcher_args.realtime_patching = False
|
||||
print("Loading entropy model and patcher")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
blt_weights_dir, "entropy_model"
|
||||
)
|
||||
state_path = os.path.join(
|
||||
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
||||
)
|
||||
entropy_model, entropy_model_args = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir, state_path
|
||||
)
|
||||
entropy_model.save_pretrained(
|
||||
entropy_dir, config={"args": entropy_model_args.model_dump()}
|
||||
)
|
||||
entropy_readme_file = os.path.join(entropy_dir, "README.md")
|
||||
if os.path.exists(entropy_readme_file):
|
||||
os.remove(entropy_readme_file)
|
||||
|
||||
|
||||
@app.command()
|
||||
def load_transformers(
|
||||
source: str,
|
||||
entropy_repo: str = "facebook/blt-entropy",
|
||||
blt_repo: str = "facebook/blt-1b",
|
||||
entropy_dir: str | None = None,
|
||||
blt_dir: str | None = None,
|
||||
prompt: str | None = None,
|
||||
):
|
||||
if source == "local":
|
||||
assert entropy_dir is not None
|
||||
assert blt_dir is not None
|
||||
entropy_model = LMTransformer.from_pretrained(
|
||||
entropy_dir, local_files_only=True
|
||||
)
|
||||
blt_model = ByteLatentTransformer.from_pretrained(
|
||||
blt_dir, local_files_only=True
|
||||
)
|
||||
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(
|
||||
blt_dir, local_files_only=True
|
||||
)
|
||||
tokenizer = tok_and_patcher.tokenizer_args.build()
|
||||
patcher = tok_and_patcher.patcher_args.build()
|
||||
print("Loaded all local")
|
||||
print(entropy_model)
|
||||
print(blt_model)
|
||||
print(tok_and_patcher)
|
||||
elif source == "hub":
|
||||
entropy_model = LMTransformer.from_pretrained(entropy_repo)
|
||||
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
|
||||
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo)
|
||||
tokenizer = tok_and_patcher.tokenizer_args.build()
|
||||
patcher = tok_and_patcher.patcher_args.build()
|
||||
print("Loaded all remote")
|
||||
print(entropy_model)
|
||||
print(blt_model)
|
||||
print(tok_and_patcher)
|
||||
else:
|
||||
raise ValueError(f"Unknown source: {source}")
|
||||
|
||||
if prompt is not None:
|
||||
assert isinstance(tokenizer, BltTokenizer)
|
||||
# Move args to correct GPU
|
||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||
tok_and_patcher.distributed_args.model_dtype
|
||||
]
|
||||
blt_model = blt_model.cuda().eval()
|
||||
for param in blt_model.parameters():
|
||||
param.data = param.data.to(dtype=param_dtype)
|
||||
|
||||
# Enable realtime patching
|
||||
patcher.realtime_patching = True
|
||||
patcher.entropy_model, _ = to_device(
|
||||
entropy_model, tok_and_patcher.patcher_args.patching_device
|
||||
)
|
||||
|
||||
# Setup distributed
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
if not torch.distributed.is_initialized():
|
||||
setup_torch_distributed(distributed_args)
|
||||
prompts = [prompt]
|
||||
outputs = generate_nocache(
|
||||
prompts, model=blt_model, tokenizer=tokenizer, patcher=patcher
|
||||
)
|
||||
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||
for p, t in zip(prompts, text_outputs):
|
||||
print(f'Prompt: "{p}"\nCompletion: "{t}"')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
|
@ -4,6 +4,7 @@ from enum import Enum, auto
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from pydantic import model_validator
|
||||
from torch import nn
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
@ -20,8 +21,6 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
|
|||
from bytelatent.model.utils import downsample
|
||||
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
||||
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
||||
|
@ -768,10 +767,23 @@ def compute_hash_embeddings(
|
|||
return local_encoder_embeds
|
||||
|
||||
|
||||
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/facebookresearch/blt",
|
||||
pipeline_tag="text-generation",
|
||||
license="other"):
|
||||
class ByteLatentTransformer(
|
||||
nn.Module,
|
||||
SequenceModelWithOutput,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/facebookresearch/blt",
|
||||
paper_url="https://arxiv.org/abs/2412.09871",
|
||||
pipeline_tag="text-generation",
|
||||
license="other",
|
||||
license_name="fair-noncommercial-research-license",
|
||||
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
|
||||
coders={
|
||||
ByteLatentTransformerArgs: (
|
||||
lambda x: {"args": x.model_dump()},
|
||||
lambda data: ByteLatentTransformerArgs(**data),
|
||||
)
|
||||
},
|
||||
):
|
||||
"""
|
||||
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
||||
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
||||
|
@ -861,6 +873,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM
|
|||
)
|
||||
)
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
)
|
||||
|
||||
def get_output_seq_len(self):
|
||||
return self.max_seqlen
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ def main(
|
|||
|
||||
if dry_run:
|
||||
return
|
||||
entropy_model = load_entropy_model(
|
||||
entropy_model, _ = load_entropy_model(
|
||||
entropy_model_checkpoint_dir,
|
||||
entropy_model_state_dict_path,
|
||||
device=patching_device,
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_entropy_model():
|
|||
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
||||
},
|
||||
)
|
||||
entropy_model = load_entropy_model(
|
||||
entropy_model, _ = load_entropy_model(
|
||||
BLT_DATA / "checkpoint_0100000_consolidated",
|
||||
os.path.join(
|
||||
BLT_DATA,
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import nn
|
||||
from torch.distributed._tensor import Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import (
|
||||
|
@ -60,7 +61,22 @@ class LMTransformerArgs(BaseTransformerArgs):
|
|||
sliding_window: int | None = None
|
||||
|
||||
|
||||
class LMTransformer(BaseTransformer):
|
||||
class LMTransformer(
|
||||
BaseTransformer,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/facebookresearch/blt",
|
||||
paper_url="https://arxiv.org/abs/2412.09871",
|
||||
pipeline_tag="text-generation",
|
||||
license="other",
|
||||
license_name="fair-noncommercial-research-license",
|
||||
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
|
||||
coders={
|
||||
LMTransformerArgs: (
|
||||
lambda x: {"args": x.model_dump()},
|
||||
lambda data: LMTransformerArgs(**data),
|
||||
)
|
||||
},
|
||||
):
|
||||
def __init__(self, args: LMTransformerArgs):
|
||||
super().__init__(args)
|
||||
self.weight_tying = args.weight_tying
|
||||
|
@ -81,6 +97,11 @@ class LMTransformer(BaseTransformer):
|
|||
if args.weight_tying:
|
||||
self.output.weight = self.embeddings.tok_embeddings.weight
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_values: torch.Tensor,
|
||||
|
|
Loading…
Add table
Reference in a new issue