mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 06:14:35 +00:00
Open source weights! (#97)
Summary: Add code to download weights and demo code for running model. Weights at: - https://huggingface.co/collections/facebook/blt-6801263d4ac1704702a192a6 - https://huggingface.co/facebook/blt - https://huggingface.co/facebook/blt-1b - https://huggingface.co/facebook/blt-7b Test Plan:
This commit is contained in:
parent
e299427ae4
commit
96d51b59d2
7 changed files with 98 additions and 12 deletions
43
demo.py
Normal file
43
demo.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import typer
|
||||
|
||||
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
||||
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
||||
from bytelatent.generate_blt import generate_nocache
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||
|
||||
|
||||
def main(prompt: str, model_name: str = "blt-1b"):
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
if not torch.distributed.is_initialized():
|
||||
setup_torch_distributed(distributed_args)
|
||||
checkpoint_path = os.path.join("hf-weights", model_name)
|
||||
print(f"Loading BLT model: {model_name}")
|
||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||
checkpoint_path,
|
||||
)
|
||||
assert isinstance(model, ByteLatentTransformer)
|
||||
assert isinstance(tokenizer, BltTokenizer)
|
||||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
patcher_args.realtime_patching = True
|
||||
print("Loading entropy model and patcher")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
checkpoint_path, "entropy_model"
|
||||
)
|
||||
patcher = patcher_args.build()
|
||||
prompts = [prompt]
|
||||
outputs = generate_nocache(
|
||||
prompts, model=model, tokenizer=tokenizer, patcher=patcher
|
||||
)
|
||||
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||
for p, t in zip(prompts, text_outputs):
|
||||
print(f'Prompt: "{p}" Completion: "{t}"')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
Loading…
Add table
Add a link
Reference in a new issue