mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 10:09:06 +00:00
Open source weights! (#97)
Summary: Add code to download weights and demo code for running model. Weights at: - https://huggingface.co/collections/facebook/blt-6801263d4ac1704702a192a6 - https://huggingface.co/facebook/blt - https://huggingface.co/facebook/blt-1b - https://huggingface.co/facebook/blt-7b Test Plan:
This commit is contained in:
parent
e299427ae4
commit
96d51b59d2
7 changed files with 98 additions and 12 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -169,4 +169,4 @@ internal/
|
|||
jobs_parallel-copy/
|
||||
wandb/
|
||||
*.ipynb
|
||||
|
||||
hf-weights/
|
||||
|
|
20
README.md
20
README.md
|
@ -46,7 +46,25 @@ Once that is done you can activate the environment
|
|||
conda activate blt_<date>
|
||||
```
|
||||
|
||||
use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
|
||||
## Downloading HF Model Weights and Generating Text
|
||||
|
||||
We have released weights on HF for the [BLT 1B Model](https://huggingface.co/facebook/blt-1b) and [BLT 7B Model](https://huggingface.co/facebook/blt-7b).
|
||||
We are actively working with HF to make BLT available in [Transformers](https://huggingface.co/docs/transformers/en/index) and will update this when it is.
|
||||
In the meantime, you can follow these instructions to load model weights, initialize a model, and generate text.
|
||||
These instructions have been tested on H100 GPUs, but we can only offer suggestions on running on other hardware.
|
||||
|
||||
1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval.
|
||||
2. On the huggingface cli, login: `huggingface-cli login`
|
||||
3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
|
||||
4. Run the generate demo: `python demo.py "A BLT has"`.
|
||||
|
||||
The demo generates text, but is also a good starting point for loading BLT in your own code.
|
||||
|
||||
## Downloading Training Data
|
||||
|
||||
Note: The following instructions are not well tested in the BLT code as it is based on the lingua code, which we have diverged from.
|
||||
|
||||
Use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
|
||||
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
|
||||
|
||||
```bash
|
||||
|
|
|
@ -476,12 +476,19 @@ class Patcher:
|
|||
assert (
|
||||
patcher_args.entropy_model_checkpoint_dir is not None
|
||||
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||
maybe_consolidated = os.path.join(
|
||||
patcher_args.entropy_model_checkpoint_dir,
|
||||
"consolidated/consolidated.pth",
|
||||
)
|
||||
if os.path.exists(maybe_consolidated):
|
||||
state_path = maybe_consolidated
|
||||
else:
|
||||
state_path = os.path.join(
|
||||
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
||||
)
|
||||
entropy_model = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir,
|
||||
os.path.join(
|
||||
patcher_args.entropy_model_checkpoint_dir,
|
||||
"consolidated/consolidated.pth",
|
||||
),
|
||||
state_path,
|
||||
)
|
||||
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
||||
self.entropy_model = entropy_model
|
||||
|
|
|
@ -206,7 +206,9 @@ def eval_ppl_on_path(
|
|||
pred = model(x, patch_lengths=patch_lengths)
|
||||
else:
|
||||
pred = model(x)
|
||||
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
|
||||
loss = F.cross_entropy(
|
||||
pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0
|
||||
)
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -25,7 +25,11 @@ from bytelatent.checkpoint import (
|
|||
)
|
||||
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
|
||||
from bytelatent.distributed import (
|
||||
DistributedArgs,
|
||||
get_global_rank,
|
||||
setup_torch_distributed,
|
||||
)
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||
from bytelatent.transformer import LMTransformer
|
||||
|
@ -388,10 +392,7 @@ class PackedCausalTransformerGenerator:
|
|||
return generation, loglikelihood, greedy
|
||||
|
||||
|
||||
def load_consolidated_model_and_tokenizer(
|
||||
consolidated_path,
|
||||
init_distributed=False
|
||||
):
|
||||
def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=False):
|
||||
if init_distributed:
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
|
|
43
demo.py
Normal file
43
demo.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import typer
|
||||
|
||||
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
||||
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
||||
from bytelatent.generate_blt import generate_nocache
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||
|
||||
|
||||
def main(prompt: str, model_name: str = "blt-1b"):
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
if not torch.distributed.is_initialized():
|
||||
setup_torch_distributed(distributed_args)
|
||||
checkpoint_path = os.path.join("hf-weights", model_name)
|
||||
print(f"Loading BLT model: {model_name}")
|
||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||
checkpoint_path,
|
||||
)
|
||||
assert isinstance(model, ByteLatentTransformer)
|
||||
assert isinstance(tokenizer, BltTokenizer)
|
||||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
patcher_args.realtime_patching = True
|
||||
print("Loading entropy model and patcher")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
checkpoint_path, "entropy_model"
|
||||
)
|
||||
patcher = patcher_args.build()
|
||||
prompts = [prompt]
|
||||
outputs = generate_nocache(
|
||||
prompts, model=model, tokenizer=tokenizer, patcher=patcher
|
||||
)
|
||||
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||
for p, t in zip(prompts, text_outputs):
|
||||
print(f'Prompt: "{p}" Completion: "{t}"')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
15
download_blt_weights.py
Normal file
15
download_blt_weights.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
|
||||
import typer
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def main(models: list[str] = ["blt-1b", "blt-7b"]):
|
||||
if not os.path.exists("hf-weights"):
|
||||
os.makedirs("hf-weights")
|
||||
for model in models:
|
||||
snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
Loading…
Add table
Reference in a new issue