From cfb99b07d406aa3d315a23f9a6e673673825a125 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 18 Apr 2025 22:58:06 +0000 Subject: [PATCH] Improve HF compatibility Summary: Test Plan: --- bytelatent/hf.py | 75 +++++++++++++++++++++++++++++++++++++++ bytelatent/model/blt.py | 20 ++++++++--- bytelatent/transformer.py | 17 ++++++++- 3 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 bytelatent/hf.py diff --git a/bytelatent/hf.py b/bytelatent/hf.py new file mode 100644 index 0000000..e1b5a51 --- /dev/null +++ b/bytelatent/hf.py @@ -0,0 +1,75 @@ +import os +from bytelatent.entropy_model import load_entropy_model +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.transformer import LMTransformer +import typer + +from bytelatent.generate import load_consolidated_model_and_tokenizer + + +app = typer.Typer() + + +@app.command() +def convert_to_transformers(blt_weights_dir: str, output_dir: str): + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir) + blt_dir = os.path.join(output_dir, "blt") + entropy_dir = os.path.join(output_dir, "entropy") + model.save_pretrained(blt_dir) + blt_readme_file = os.path.join(blt_dir, "README.md") + if os.path.exists(blt_readme_file): + os.remove(blt_readme_file) + + patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) + patcher_args.realtime_patching = False + print("Loading entropy model and patcher") + patcher_args.entropy_model_checkpoint_dir = os.path.join( + blt_weights_dir, "entropy_model" + ) + patcher = patcher_args.build() + state_path = os.path.join( + patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" + ) + entropy_model = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir, state_path + ) + entropy_model.save_pretrained(entropy_dir) + entropy_readme_file = os.path.join(entropy_dir, "README.md") + if os.path.exists(entropy_readme_file): + os.remove(entropy_readme_file) + + # TODO: Persist tokenizer in HF compatible way + + +@app.command() +def load_transformers( + source: str, + entropy_repo: str = "facebook/blt-entropy", + blt_repo: str = "facebook/blt-1b", + entropy_dir: str | None = None, + blt_dir: str | None = None, +): + if source == "local": + assert entropy_dir is not None + assert blt_dir is not None + entropy_model = LMTransformer.from_pretrained( + entropy_dir, local_files_only=True + ) + blt_model = ByteLatentTransformer.from_pretrained( + blt_dir, local_files_only=True + ) + elif source == "hub": + entropy_model = LMTransformer.from_pretrained(entropy_repo) + blt_model = ByteLatentTransformer.from_pretrained(blt_repo) + else: + raise ValueError(f"Unknown source: {source}") + + # TODO: Need a way to get tokenizer + # TODO: Need a way to get patching settings + # TODO: Insert test inference call + + +if __name__ == "__main__": + app() diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index ccfe395..437fa5d 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -768,10 +768,17 @@ def compute_hash_embeddings( return local_encoder_embeds -class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin, - repo_url="https://github.com/facebookresearch/blt", - pipeline_tag="text-generation", - license="other"): +class ByteLatentTransformer( + nn.Module, + SequenceModelWithOutput, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", +): """ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, @@ -861,6 +868,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM ) ) + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + def get_output_seq_len(self): return self.max_seqlen diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 906a54b..af55f40 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -15,6 +15,7 @@ from torch.distributed.tensor.parallel import ( ) from torch.nn.attention.flex_attention import BlockMask, create_block_mask from xformers.ops import AttentionBias +from huggingface_hub import PyTorchModelHubMixin from bytelatent.base_transformer import ( BaseTransformer, @@ -60,7 +61,16 @@ class LMTransformerArgs(BaseTransformerArgs): sliding_window: int | None = None -class LMTransformer(BaseTransformer): +class LMTransformer( + BaseTransformer, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", +): def __init__(self, args: LMTransformerArgs): super().__init__(args) self.weight_tying = args.weight_tying @@ -81,6 +91,11 @@ class LMTransformer(BaseTransformer): if args.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + def forward( self, token_values: torch.Tensor,