diff --git a/README.md b/README.md index e80b0a0..fc34531 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,37 @@ These instructions have been tested on H100 GPUs, but we can only offer suggesti 1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval. 2. On the huggingface cli, login: `huggingface-cli login` -3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights` -4. Run the generate demo: `python demo.py "A BLT has"`. + +From here there are two options: (1) load weights in our train script and (2) loading weights via HF hub to use for anything else. + +## Load Weights via HF Hub + +In your terminal: + +```bash +python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b hub --prompt "My test prompt" +``` + +In your own code: + +```python +from bytelatent.transformer import LMTransformer +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.hf import BltTokenizerAndPatcher + +entropy_repo = "facebook/blt-entropy" +blt_repo = "facebook/blt-1b" +entropy_model = LMTransformer.from_pretrained(entropy_repo) +blt_model = ByteLatentTransformer.from_pretrained(blt_repo) +tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo) +tokenizer = tok_and_patcher.tokenizer_args.build() +patcher = tok_and_patcher.patcher_args.build() +``` + +## Load Weights for Running BLT Train Script + +1. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights` +2. Run the generate demo: `python demo.py "A BLT has"`. The demo generates text, but is also a good starting point for loading BLT in your own code. diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 19b1b33..d947040 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -34,7 +34,7 @@ else: flex_attention_comp = None -class InitStdFactor(Enum): +class InitStdFactor(str, Enum): DISABLED = "disabled" # Init std is divided by 1.0 GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index e517b7a..93be370 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -486,7 +486,7 @@ class Patcher: state_path = os.path.join( patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" ) - entropy_model = load_entropy_model( + entropy_model, _ = load_entropy_model( patcher_args.entropy_model_checkpoint_dir, state_path, ) diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 0e11a60..51973e2 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -19,19 +19,18 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" ) - entropy_model = LMTransformer( - LMTransformerArgs( - dim=model_params["dim"], - n_layers=model_params["n_layers"], - n_heads=model_params["n_heads"], - max_seqlen=model_params["max_seqlen"], - ffn_dim_multiplier=model_params["ffn_dim_multiplier"], - vocab_size=model_params["vocab_size"], - attn_bias_type="local_block_causal", - attn_impl="xformers", - sliding_window=512, - ) + entropy_model_args = LMTransformerArgs( + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + max_seqlen=model_params["max_seqlen"], + ffn_dim_multiplier=model_params["ffn_dim_multiplier"], + vocab_size=model_params["vocab_size"], + attn_bias_type="local_block_causal", + attn_impl="xformers", + sliding_window=512, ) + entropy_model = LMTransformer(entropy_model_args) entropy_model.load_state_dict( torch.load(state_dict_path, map_location=device)["model"], strict=False @@ -41,4 +40,4 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp # no grads for the model: for param in entropy_model.parameters(): param.requires_grad = False - return entropy_model + return entropy_model, entropy_model_args diff --git a/bytelatent/hf.py b/bytelatent/hf.py new file mode 100644 index 0000000..f2fd038 --- /dev/null +++ b/bytelatent/hf.py @@ -0,0 +1,199 @@ +import json +import os +import shutil +from pathlib import Path +from typing import Dict, Optional, Union + +import torch +import typer +from huggingface_hub import hf_hub_download +from huggingface_hub.hub_mixin import ModelHubMixin + +from bytelatent.args import TrainArgs +from bytelatent.data.patcher import PatcherArgs, to_device +from bytelatent.distributed import DistributedArgs, setup_torch_distributed +from bytelatent.entropy_model import load_entropy_model +from bytelatent.generate import load_consolidated_model_and_tokenizer +from bytelatent.generate_blt import generate_nocache +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformer + +app = typer.Typer() + + +class BltTokenizerAndPatcher(ModelHubMixin): + def __init__( + self, + *, + patcher_args: PatcherArgs, + tokenizer_args: TokenizerArgs, + distributed_args: DistributedArgs, + ): + self.patcher_args = patcher_args + self.tokenizer_args = tokenizer_args + self.distributed_args = distributed_args + + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + + def save_pretrained(self, *args, **kwargs): + raise ValueError( + "Tokenizer and Patcher are saved by BLT, this class is just for loading" + ) + + def _save_pretrained(self, *args, **kwargs): + raise ValueError( + "Tokenizer and Patcher are saved by BLT, this class is just for loading" + ) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Optional[Union[str, bool]], + **model_kwargs, + ): + if os.path.isdir(model_id): + train_args_file = os.path.join(model_id, "train_args.json") + else: + train_args_file = hf_hub_download( + repo_id=model_id, + filename="train_args.json", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + ) + + with open(train_args_file) as f: + train_args = TrainArgs(**json.load(f)) + return cls( + patcher_args=train_args.data.patcher_args, + tokenizer_args=train_args.data.tokenizer_args, + distributed_args=train_args.distributed, + ) + + +@app.command() +def convert_to_transformers(blt_weights_dir: str, output_dir: str): + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir) + blt_dir = os.path.join(output_dir, "blt") + entropy_dir = os.path.join(output_dir, "entropy") + model.save_pretrained(blt_dir, config={"args": train_cfg.model.model_dump()}) + shutil.copyfile( + os.path.join(blt_weights_dir, "params.json"), + os.path.join(blt_dir, "train_args.json"), + ) + blt_readme_file = os.path.join(blt_dir, "README.md") + if os.path.exists(blt_readme_file): + os.remove(blt_readme_file) + + patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) + patcher_args.realtime_patching = False + print("Loading entropy model and patcher") + patcher_args.entropy_model_checkpoint_dir = os.path.join( + blt_weights_dir, "entropy_model" + ) + state_path = os.path.join( + patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" + ) + entropy_model, entropy_model_args = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir, state_path + ) + entropy_model.save_pretrained( + entropy_dir, config={"args": entropy_model_args.model_dump()} + ) + entropy_readme_file = os.path.join(entropy_dir, "README.md") + if os.path.exists(entropy_readme_file): + os.remove(entropy_readme_file) + + +@app.command() +def load_transformers( + source: str, + entropy_repo: str = "facebook/blt-entropy", + blt_repo: str = "facebook/blt-1b", + entropy_dir: str | None = None, + blt_dir: str | None = None, + prompt: str | None = None, +): + if source == "local": + assert entropy_dir is not None + assert blt_dir is not None + entropy_model = LMTransformer.from_pretrained( + entropy_dir, local_files_only=True + ) + blt_model = ByteLatentTransformer.from_pretrained( + blt_dir, local_files_only=True + ) + tok_and_patcher = BltTokenizerAndPatcher.from_pretrained( + blt_dir, local_files_only=True + ) + tokenizer = tok_and_patcher.tokenizer_args.build() + patcher = tok_and_patcher.patcher_args.build() + print("Loaded all local") + print(entropy_model) + print(blt_model) + print(tok_and_patcher) + elif source == "hub": + entropy_model = LMTransformer.from_pretrained(entropy_repo) + blt_model = ByteLatentTransformer.from_pretrained(blt_repo) + tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo) + tokenizer = tok_and_patcher.tokenizer_args.build() + patcher = tok_and_patcher.patcher_args.build() + print("Loaded all remote") + print(entropy_model) + print(blt_model) + print(tok_and_patcher) + else: + raise ValueError(f"Unknown source: {source}") + + if prompt is not None: + assert isinstance(tokenizer, BltTokenizer) + # Move args to correct GPU + param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ + tok_and_patcher.distributed_args.model_dtype + ] + blt_model = blt_model.cuda().eval() + for param in blt_model.parameters(): + param.data = param.data.to(dtype=param_dtype) + + # Enable realtime patching + patcher.realtime_patching = True + patcher.entropy_model, _ = to_device( + entropy_model, tok_and_patcher.patcher_args.patching_device + ) + + # Setup distributed + distributed_args = DistributedArgs() + distributed_args.configure_world() + if not torch.distributed.is_initialized(): + setup_torch_distributed(distributed_args) + prompts = [prompt] + outputs = generate_nocache( + prompts, model=blt_model, tokenizer=tokenizer, patcher=patcher + ) + text_outputs = [tokenizer.decode(t) for t in outputs] + for p, t in zip(prompts, text_outputs): + print(f'Prompt: "{p}"\nCompletion: "{t}"') + print() + + +if __name__ == "__main__": + app() diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index ccfe395..26934bb 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -4,6 +4,7 @@ from enum import Enum, auto from typing import Any, Optional import torch +from huggingface_hub import PyTorchModelHubMixin from pydantic import model_validator from torch import nn from torch.nn.attention.flex_attention import create_block_mask @@ -20,8 +21,6 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID -from huggingface_hub import PyTorchModelHubMixin - def attention_flops_per_token(n_layers, seq_len, dim, causal): # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 @@ -768,10 +767,23 @@ def compute_hash_embeddings( return local_encoder_embeds -class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin, - repo_url="https://github.com/facebookresearch/blt", - pipeline_tag="text-generation", - license="other"): +class ByteLatentTransformer( + nn.Module, + SequenceModelWithOutput, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + ByteLatentTransformerArgs: ( + lambda x: {"args": x.model_dump()}, + lambda data: ByteLatentTransformerArgs(**data), + ) + }, +): """ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, @@ -861,6 +873,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM ) ) + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + def get_output_seq_len(self): return self.max_seqlen diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 31a4802..8001103 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -82,7 +82,7 @@ def main( if dry_run: return - entropy_model = load_entropy_model( + entropy_model, _ = load_entropy_model( entropy_model_checkpoint_dir, entropy_model_state_dict_path, device=patching_device, diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 8623eb1..c7a26f6 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -34,7 +34,7 @@ def test_entropy_model(): "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model" }, ) - entropy_model = load_entropy_model( + entropy_model, _ = load_entropy_model( BLT_DATA / "checkpoint_0100000_consolidated", os.path.join( BLT_DATA, diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 906a54b..32d63be 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -4,6 +4,7 @@ import logging from typing import Optional, Tuple, Union import torch +from huggingface_hub import PyTorchModelHubMixin from torch import nn from torch.distributed._tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( @@ -60,7 +61,22 @@ class LMTransformerArgs(BaseTransformerArgs): sliding_window: int | None = None -class LMTransformer(BaseTransformer): +class LMTransformer( + BaseTransformer, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", + coders={ + LMTransformerArgs: ( + lambda x: {"args": x.model_dump()}, + lambda data: LMTransformerArgs(**data), + ) + }, +): def __init__(self, args: LMTransformerArgs): super().__init__(args) self.weight_tying = args.weight_tying @@ -81,6 +97,11 @@ class LMTransformer(BaseTransformer): if args.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + def forward( self, token_values: torch.Tensor,