mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 18:19:06 +00:00
199 lines
7 KiB
Python
199 lines
7 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Union
|
|
|
|
import torch
|
|
import typer
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.hub_mixin import ModelHubMixin
|
|
|
|
from bytelatent.args import TrainArgs
|
|
from bytelatent.data.patcher import PatcherArgs, to_device
|
|
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
|
from bytelatent.entropy_model import load_entropy_model
|
|
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
|
from bytelatent.generate_blt import generate_nocache
|
|
from bytelatent.model.blt import ByteLatentTransformer
|
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
|
from bytelatent.transformer import LMTransformer
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
class BltTokenizerAndPatcher(ModelHubMixin):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
patcher_args: PatcherArgs,
|
|
tokenizer_args: TokenizerArgs,
|
|
distributed_args: DistributedArgs,
|
|
):
|
|
self.patcher_args = patcher_args
|
|
self.tokenizer_args = tokenizer_args
|
|
self.distributed_args = distributed_args
|
|
|
|
def push_to_hub(self, *args, **kwargs):
|
|
raise ValueError(
|
|
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
|
)
|
|
|
|
def save_pretrained(self, *args, **kwargs):
|
|
raise ValueError(
|
|
"Tokenizer and Patcher are saved by BLT, this class is just for loading"
|
|
)
|
|
|
|
def _save_pretrained(self, *args, **kwargs):
|
|
raise ValueError(
|
|
"Tokenizer and Patcher are saved by BLT, this class is just for loading"
|
|
)
|
|
|
|
@classmethod
|
|
def _from_pretrained(
|
|
cls,
|
|
*,
|
|
model_id: str,
|
|
revision: Optional[str],
|
|
cache_dir: Optional[Union[str, Path]],
|
|
force_download: bool,
|
|
proxies: Optional[Dict],
|
|
resume_download: Optional[bool],
|
|
local_files_only: bool,
|
|
token: Optional[Union[str, bool]],
|
|
**model_kwargs,
|
|
):
|
|
if os.path.isdir(model_id):
|
|
train_args_file = os.path.join(model_id, "train_args.json")
|
|
else:
|
|
train_args_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename="train_args.json",
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
)
|
|
|
|
with open(train_args_file) as f:
|
|
train_args = TrainArgs(**json.load(f))
|
|
return cls(
|
|
patcher_args=train_args.data.patcher_args,
|
|
tokenizer_args=train_args.data.tokenizer_args,
|
|
distributed_args=train_args.distributed,
|
|
)
|
|
|
|
|
|
@app.command()
|
|
def convert_to_transformers(blt_weights_dir: str, output_dir: str):
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir)
|
|
blt_dir = os.path.join(output_dir, "blt")
|
|
entropy_dir = os.path.join(output_dir, "entropy")
|
|
model.save_pretrained(blt_dir, config={"args": train_cfg.model.model_dump()})
|
|
shutil.copyfile(
|
|
os.path.join(blt_weights_dir, "params.json"),
|
|
os.path.join(blt_dir, "train_args.json"),
|
|
)
|
|
blt_readme_file = os.path.join(blt_dir, "README.md")
|
|
if os.path.exists(blt_readme_file):
|
|
os.remove(blt_readme_file)
|
|
|
|
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
|
patcher_args.realtime_patching = False
|
|
print("Loading entropy model and patcher")
|
|
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
|
blt_weights_dir, "entropy_model"
|
|
)
|
|
state_path = os.path.join(
|
|
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
|
)
|
|
entropy_model, entropy_model_args = load_entropy_model(
|
|
patcher_args.entropy_model_checkpoint_dir, state_path
|
|
)
|
|
entropy_model.save_pretrained(
|
|
entropy_dir, config={"args": entropy_model_args.model_dump()}
|
|
)
|
|
entropy_readme_file = os.path.join(entropy_dir, "README.md")
|
|
if os.path.exists(entropy_readme_file):
|
|
os.remove(entropy_readme_file)
|
|
|
|
|
|
@app.command()
|
|
def load_transformers(
|
|
source: str,
|
|
entropy_repo: str = "facebook/blt-entropy",
|
|
blt_repo: str = "facebook/blt-1b",
|
|
entropy_dir: str | None = None,
|
|
blt_dir: str | None = None,
|
|
prompt: str | None = None,
|
|
):
|
|
if source == "local":
|
|
assert entropy_dir is not None
|
|
assert blt_dir is not None
|
|
entropy_model = LMTransformer.from_pretrained(
|
|
entropy_dir, local_files_only=True
|
|
)
|
|
blt_model = ByteLatentTransformer.from_pretrained(
|
|
blt_dir, local_files_only=True
|
|
)
|
|
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(
|
|
blt_dir, local_files_only=True
|
|
)
|
|
tokenizer = tok_and_patcher.tokenizer_args.build()
|
|
patcher = tok_and_patcher.patcher_args.build()
|
|
print("Loaded all local")
|
|
print(entropy_model)
|
|
print(blt_model)
|
|
print(tok_and_patcher)
|
|
elif source == "hub":
|
|
entropy_model = LMTransformer.from_pretrained(entropy_repo)
|
|
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
|
|
tok_and_patcher = BltTokenizerAndPatcher.from_pretrained(blt_repo)
|
|
tokenizer = tok_and_patcher.tokenizer_args.build()
|
|
patcher = tok_and_patcher.patcher_args.build()
|
|
print("Loaded all remote")
|
|
print(entropy_model)
|
|
print(blt_model)
|
|
print(tok_and_patcher)
|
|
else:
|
|
raise ValueError(f"Unknown source: {source}")
|
|
|
|
if prompt is not None:
|
|
assert isinstance(tokenizer, BltTokenizer)
|
|
# Move args to correct GPU
|
|
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
|
tok_and_patcher.distributed_args.model_dtype
|
|
]
|
|
blt_model = blt_model.cuda().eval()
|
|
for param in blt_model.parameters():
|
|
param.data = param.data.to(dtype=param_dtype)
|
|
|
|
# Enable realtime patching
|
|
patcher.realtime_patching = True
|
|
patcher.entropy_model, _ = to_device(
|
|
entropy_model, tok_and_patcher.patcher_args.patching_device
|
|
)
|
|
|
|
# Setup distributed
|
|
distributed_args = DistributedArgs()
|
|
distributed_args.configure_world()
|
|
if not torch.distributed.is_initialized():
|
|
setup_torch_distributed(distributed_args)
|
|
prompts = [prompt]
|
|
outputs = generate_nocache(
|
|
prompts, model=blt_model, tokenizer=tokenizer, patcher=patcher
|
|
)
|
|
text_outputs = [tokenizer.decode(t) for t in outputs]
|
|
for p, t in zip(prompts, text_outputs):
|
|
print(f'Prompt: "{p}"\nCompletion: "{t}"')
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app()
|