Open source weights! (#97)
Some checks failed
Lint with Black / lint (push) Failing after 3s
Lint with isort / lint (push) Failing after 2s

Summary:

Add code to download weights and demo code for running model.

Weights at:
- https://huggingface.co/collections/facebook/blt-6801263d4ac1704702a192a6
- https://huggingface.co/facebook/blt
- https://huggingface.co/facebook/blt-1b
- https://huggingface.co/facebook/blt-7b

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-04-17 09:38:56 -07:00 committed by GitHub
parent e299427ae4
commit 96d51b59d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 98 additions and 12 deletions

2
.gitignore vendored
View file

@ -169,4 +169,4 @@ internal/
jobs_parallel-copy/
wandb/
*.ipynb
hf-weights/

View file

@ -46,7 +46,25 @@ Once that is done you can activate the environment
conda activate blt_<date>
```
use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
## Downloading HF Model Weights and Generating Text
We have released weights on HF for the [BLT 1B Model](https://huggingface.co/facebook/blt-1b) and [BLT 7B Model](https://huggingface.co/facebook/blt-7b).
We are actively working with HF to make BLT available in [Transformers](https://huggingface.co/docs/transformers/en/index) and will update this when it is.
In the meantime, you can follow these instructions to load model weights, initialize a model, and generate text.
These instructions have been tested on H100 GPUs, but we can only offer suggestions on running on other hardware.
1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval.
2. On the huggingface cli, login: `huggingface-cli login`
3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
4. Run the generate demo: `python demo.py "A BLT has"`.
The demo generates text, but is also a good starting point for loading BLT in your own code.
## Downloading Training Data
Note: The following instructions are not well tested in the BLT code as it is based on the lingua code, which we have diverged from.
Use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
```bash

View file

@ -476,12 +476,19 @@ class Patcher:
assert (
patcher_args.entropy_model_checkpoint_dir is not None
), "Cannot require realtime patching without an entropy model checkpoint"
maybe_consolidated = os.path.join(
patcher_args.entropy_model_checkpoint_dir,
"consolidated/consolidated.pth",
)
if os.path.exists(maybe_consolidated):
state_path = maybe_consolidated
else:
state_path = os.path.join(
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
)
entropy_model = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir,
os.path.join(
patcher_args.entropy_model_checkpoint_dir,
"consolidated/consolidated.pth",
),
state_path,
)
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
self.entropy_model = entropy_model

View file

@ -206,7 +206,9 @@ def eval_ppl_on_path(
pred = model(x, patch_lengths=patch_lengths)
else:
pred = model(x)
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
loss = F.cross_entropy(
pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0
)
total_loss += loss.item()
else:
raise NotImplementedError()

View file

@ -25,7 +25,11 @@ from bytelatent.checkpoint import (
)
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
from bytelatent.distributed import (
DistributedArgs,
get_global_rank,
setup_torch_distributed,
)
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer
@ -388,10 +392,7 @@ class PackedCausalTransformerGenerator:
return generation, loglikelihood, greedy
def load_consolidated_model_and_tokenizer(
consolidated_path,
init_distributed=False
):
def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=False):
if init_distributed:
distributed_args = DistributedArgs()
distributed_args.configure_world()

43
demo.py Normal file
View file

@ -0,0 +1,43 @@
import os
import torch
import typer
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
from bytelatent.generate import load_consolidated_model_and_tokenizer
from bytelatent.generate_blt import generate_nocache
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
def main(prompt: str, model_name: str = "blt-1b"):
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized():
setup_torch_distributed(distributed_args)
checkpoint_path = os.path.join("hf-weights", model_name)
print(f"Loading BLT model: {model_name}")
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
checkpoint_path,
)
assert isinstance(model, ByteLatentTransformer)
assert isinstance(tokenizer, BltTokenizer)
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
print("Loading entropy model and patcher")
patcher_args.entropy_model_checkpoint_dir = os.path.join(
checkpoint_path, "entropy_model"
)
patcher = patcher_args.build()
prompts = [prompt]
outputs = generate_nocache(
prompts, model=model, tokenizer=tokenizer, patcher=patcher
)
text_outputs = [tokenizer.decode(t) for t in outputs]
for p, t in zip(prompts, text_outputs):
print(f'Prompt: "{p}" Completion: "{t}"')
print()
if __name__ == "__main__":
typer.run(main)

15
download_blt_weights.py Normal file
View file

@ -0,0 +1,15 @@
import os
import typer
from huggingface_hub import snapshot_download
def main(models: list[str] = ["blt-1b", "blt-7b"]):
if not os.path.exists("hf-weights"):
os.makedirs("hf-weights")
for model in models:
snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}")
if __name__ == "__main__":
typer.run(main)