Improve HF compatibility

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-04-18 22:58:06 +00:00
parent 1b67cbe022
commit cfb99b07d4
3 changed files with 107 additions and 5 deletions

75
bytelatent/hf.py Normal file
View file

@ -0,0 +1,75 @@
import os
from bytelatent.entropy_model import load_entropy_model
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.transformer import LMTransformer
import typer
from bytelatent.generate import load_consolidated_model_and_tokenizer
app = typer.Typer()
@app.command()
def convert_to_transformers(blt_weights_dir: str, output_dir: str):
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir)
blt_dir = os.path.join(output_dir, "blt")
entropy_dir = os.path.join(output_dir, "entropy")
model.save_pretrained(blt_dir)
blt_readme_file = os.path.join(blt_dir, "README.md")
if os.path.exists(blt_readme_file):
os.remove(blt_readme_file)
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = False
print("Loading entropy model and patcher")
patcher_args.entropy_model_checkpoint_dir = os.path.join(
blt_weights_dir, "entropy_model"
)
patcher = patcher_args.build()
state_path = os.path.join(
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
)
entropy_model = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir, state_path
)
entropy_model.save_pretrained(entropy_dir)
entropy_readme_file = os.path.join(entropy_dir, "README.md")
if os.path.exists(entropy_readme_file):
os.remove(entropy_readme_file)
# TODO: Persist tokenizer in HF compatible way
@app.command()
def load_transformers(
source: str,
entropy_repo: str = "facebook/blt-entropy",
blt_repo: str = "facebook/blt-1b",
entropy_dir: str | None = None,
blt_dir: str | None = None,
):
if source == "local":
assert entropy_dir is not None
assert blt_dir is not None
entropy_model = LMTransformer.from_pretrained(
entropy_dir, local_files_only=True
)
blt_model = ByteLatentTransformer.from_pretrained(
blt_dir, local_files_only=True
)
elif source == "hub":
entropy_model = LMTransformer.from_pretrained(entropy_repo)
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
else:
raise ValueError(f"Unknown source: {source}")
# TODO: Need a way to get tokenizer
# TODO: Need a way to get patching settings
# TODO: Insert test inference call
if __name__ == "__main__":
app()

View file

@ -768,10 +768,17 @@ def compute_hash_embeddings(
return local_encoder_embeds
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
pipeline_tag="text-generation",
license="other"):
class ByteLatentTransformer(
nn.Module,
SequenceModelWithOutput,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
paper_url="https://arxiv.org/abs/2412.09871",
pipeline_tag="text-generation",
license="other",
license_name="fair-noncommercial-research-license",
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
):
"""
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
@ -861,6 +868,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM
)
)
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def get_output_seq_len(self):
return self.max_seqlen

View file

@ -15,6 +15,7 @@ from torch.distributed.tensor.parallel import (
)
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from xformers.ops import AttentionBias
from huggingface_hub import PyTorchModelHubMixin
from bytelatent.base_transformer import (
BaseTransformer,
@ -60,7 +61,16 @@ class LMTransformerArgs(BaseTransformerArgs):
sliding_window: int | None = None
class LMTransformer(BaseTransformer):
class LMTransformer(
BaseTransformer,
PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
paper_url="https://arxiv.org/abs/2412.09871",
pipeline_tag="text-generation",
license="other",
license_name="fair-noncommercial-research-license",
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
):
def __init__(self, args: LMTransformerArgs):
super().__init__(args)
self.weight_tying = args.weight_tying
@ -81,6 +91,11 @@ class LMTransformer(BaseTransformer):
if args.weight_tying:
self.output.weight = self.embeddings.tok_embeddings.weight
def push_to_hub(self, *args, **kwargs):
raise ValueError(
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
)
def forward(
self,
token_values: torch.Tensor,