From 96d51b59d29fdd98248cd87db3592365aad60725 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 17 Apr 2025 09:38:56 -0700 Subject: [PATCH] Open source weights! (#97) Summary: Add code to download weights and demo code for running model. Weights at: - https://huggingface.co/collections/facebook/blt-6801263d4ac1704702a192a6 - https://huggingface.co/facebook/blt - https://huggingface.co/facebook/blt-1b - https://huggingface.co/facebook/blt-7b Test Plan: --- .gitignore | 2 +- README.md | 20 +++++++++++++++++- bytelatent/data/patcher.py | 15 +++++++++---- bytelatent/eval.py | 4 +++- bytelatent/generate.py | 11 +++++----- demo.py | 43 ++++++++++++++++++++++++++++++++++++++ download_blt_weights.py | 15 +++++++++++++ 7 files changed, 98 insertions(+), 12 deletions(-) create mode 100644 demo.py create mode 100644 download_blt_weights.py diff --git a/.gitignore b/.gitignore index cef4d53..10fc469 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,4 @@ internal/ jobs_parallel-copy/ wandb/ *.ipynb - +hf-weights/ diff --git a/README.md b/README.md index 90c2f64..e80b0a0 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,25 @@ Once that is done you can activate the environment conda activate blt_ ``` -use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`). +## Downloading HF Model Weights and Generating Text + +We have released weights on HF for the [BLT 1B Model](https://huggingface.co/facebook/blt-1b) and [BLT 7B Model](https://huggingface.co/facebook/blt-7b). +We are actively working with HF to make BLT available in [Transformers](https://huggingface.co/docs/transformers/en/index) and will update this when it is. +In the meantime, you can follow these instructions to load model weights, initialize a model, and generate text. +These instructions have been tested on H100 GPUs, but we can only offer suggestions on running on other hardware. + +1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval. +2. On the huggingface cli, login: `huggingface-cli login` +3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights` +4. Run the generate demo: `python demo.py "A BLT has"`. + +The demo generates text, but is also a good starting point for loading BLT in your own code. + +## Downloading Training Data + +Note: The following instructions are not well tested in the BLT code as it is based on the lingua code, which we have diverged from. + +Use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`). This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details. ```bash diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index 3e6fe12..e517b7a 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -476,12 +476,19 @@ class Patcher: assert ( patcher_args.entropy_model_checkpoint_dir is not None ), "Cannot require realtime patching without an entropy model checkpoint" + maybe_consolidated = os.path.join( + patcher_args.entropy_model_checkpoint_dir, + "consolidated/consolidated.pth", + ) + if os.path.exists(maybe_consolidated): + state_path = maybe_consolidated + else: + state_path = os.path.join( + patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" + ) entropy_model = load_entropy_model( patcher_args.entropy_model_checkpoint_dir, - os.path.join( - patcher_args.entropy_model_checkpoint_dir, - "consolidated/consolidated.pth", - ), + state_path, ) entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) self.entropy_model = entropy_model diff --git a/bytelatent/eval.py b/bytelatent/eval.py index a917f48..e69403c 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -206,7 +206,9 @@ def eval_ppl_on_path( pred = model(x, patch_lengths=patch_lengths) else: pred = model(x) - loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0) + loss = F.cross_entropy( + pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0 + ) total_loss += loss.item() else: raise NotImplementedError() diff --git a/bytelatent/generate.py b/bytelatent/generate.py index b0280dd..97434dc 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -25,7 +25,11 @@ from bytelatent.checkpoint import ( ) from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs -from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs +from bytelatent.distributed import ( + DistributedArgs, + get_global_rank, + setup_torch_distributed, +) from bytelatent.model.blt import ByteLatentTransformer from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.transformer import LMTransformer @@ -388,10 +392,7 @@ class PackedCausalTransformerGenerator: return generation, loglikelihood, greedy -def load_consolidated_model_and_tokenizer( - consolidated_path, - init_distributed=False -): +def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=False): if init_distributed: distributed_args = DistributedArgs() distributed_args.configure_world() diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..f79c896 --- /dev/null +++ b/demo.py @@ -0,0 +1,43 @@ +import os + +import torch +import typer + +from bytelatent.distributed import DistributedArgs, setup_torch_distributed +from bytelatent.generate import load_consolidated_model_and_tokenizer +from bytelatent.generate_blt import generate_nocache +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer + + +def main(prompt: str, model_name: str = "blt-1b"): + distributed_args = DistributedArgs() + distributed_args.configure_world() + if not torch.distributed.is_initialized(): + setup_torch_distributed(distributed_args) + checkpoint_path = os.path.join("hf-weights", model_name) + print(f"Loading BLT model: {model_name}") + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( + checkpoint_path, + ) + assert isinstance(model, ByteLatentTransformer) + assert isinstance(tokenizer, BltTokenizer) + patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) + patcher_args.realtime_patching = True + print("Loading entropy model and patcher") + patcher_args.entropy_model_checkpoint_dir = os.path.join( + checkpoint_path, "entropy_model" + ) + patcher = patcher_args.build() + prompts = [prompt] + outputs = generate_nocache( + prompts, model=model, tokenizer=tokenizer, patcher=patcher + ) + text_outputs = [tokenizer.decode(t) for t in outputs] + for p, t in zip(prompts, text_outputs): + print(f'Prompt: "{p}" Completion: "{t}"') + print() + + +if __name__ == "__main__": + typer.run(main) diff --git a/download_blt_weights.py b/download_blt_weights.py new file mode 100644 index 0000000..0035494 --- /dev/null +++ b/download_blt_weights.py @@ -0,0 +1,15 @@ +import os + +import typer +from huggingface_hub import snapshot_download + + +def main(models: list[str] = ["blt-1b", "blt-7b"]): + if not os.path.exists("hf-weights"): + os.makedirs("hf-weights") + for model in models: + snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}") + + +if __name__ == "__main__": + typer.run(main)