From cbe26bdf6bbc461b1de3238874230e0a04103625 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 18 Apr 2025 22:58:06 +0000 Subject: [PATCH] Improve HF compatibility Summary: Test Plan: --- bytelatent/hf.py | 217 ++++++++++++++++++++++++++++++++++++++ bytelatent/model/blt.py | 20 +++- bytelatent/transformer.py | 17 ++- 3 files changed, 249 insertions(+), 5 deletions(-) create mode 100644 bytelatent/hf.py diff --git a/bytelatent/hf.py b/bytelatent/hf.py new file mode 100644 index 0000000..82a2972 --- /dev/null +++ b/bytelatent/hf.py @@ -0,0 +1,217 @@ +import os +from typing import Union, Optional, Dict, Any +from pathlib import Path +import json +from dataclasses import asdict, is_dataclass + +from huggingface_hub.hub_mixin import ModelHubMixin, DataclassInstance +from huggingface_hub import snapshot_download, constants +import typer +import torch + +from bytelatent.distributed import DistributedArgs, setup_torch_distributed +from bytelatent.generate_blt import generate_nocache +from bytelatent.entropy_model import load_entropy_model +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.transformer import LMTransformer +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer + +from bytelatent.generate import load_consolidated_model_and_tokenizer + + +app = typer.Typer() + + +class BltModelWrapper(ModelHubMixin): + def __init__(self, checkpoint_dir: str): + self.model, self.tokenizer, self.train_cfg = ( + load_consolidated_model_and_tokenizer(checkpoint_dir) + ) + assert isinstance(self.model, ByteLatentTransformer) + assert isinstance(self.tokenizer, BltTokenizer) + self.patcher_args = self.train_cfg.data.patcher_args.model_copy(deep=True) + self.patcher_args.realtime_patching = True + self.patcher_args.entropy_model_checkpoint_dir = os.path.join( + checkpoint_dir, "entropy_model" + ) + self.patcher = self.patcher_args.build() + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: str | None = None, + cache_dir: str | Path | None = None, + force_download: bool = False, + proxies: dict | None = None, + resume_download: bool | None = None, + local_files_only: bool = False, + token: str | bool | None = None, + ): + if os.path.isdir(model_id): + model = cls(model_id) + else: + checkpoint_dir = snapshot_download( + model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + ) + model = cls(checkpoint_dir) + return model + + # Copied from https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hub_mixin.py#L76 + # So that we can remove behavior we don't want, specifically: + # - Overwriting the model card should not be allowed, any changes to the facebook/blt and related model cards should be done by hand and verified. + # - Push to hub should be disabled, this also should be done by hand and verified. + def save_pretrained( + self, + save_directory: Union[str, Path], + *, + config: Optional[Union[dict, DataclassInstance]] = None, + repo_id: Optional[str] = None, + push_to_hub: bool = False, + model_card_kwargs: Optional[Dict[str, Any]] = None, + **push_to_hub_kwargs, + ) -> Optional[str]: + """ + Save weights in local directory. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if + not provided. + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + push_to_hub_kwargs: + Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. + Returns: + `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json + # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite + # an existing config.json if it was not saved by `_save_pretrained`. + config_path = save_directory / constants.CONFIG_NAME + config_path.unlink(missing_ok=True) + + # save model weights/files (framework-specific) + self._save_pretrained(save_directory) + + # save config (if provided and if not serialized yet in `_save_pretrained`) + if config is None: + config = self._hub_mixin_config + if config is not None: + if is_dataclass(config): + config = asdict(config) # type: ignore[arg-type] + if not config_path.exists(): + config_str = json.dumps(config, sort_keys=True, indent=2) + config_path.write_text(config_str) + + return None + + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + + def _save_pretrained(self, save_directory: Path) -> None: + raise ValueError( + "Not needed for loading pre-trained weights, but nice to have later" + ) + + +@app.command() +def convert_to_transformers(blt_weights_dir: str, output_dir: str): + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir) + blt_dir = os.path.join(output_dir, "blt") + entropy_dir = os.path.join(output_dir, "entropy") + model.save_pretrained(blt_dir) + blt_readme_file = os.path.join(blt_dir, "README.md") + if os.path.exists(blt_readme_file): + os.remove(blt_readme_file) + + patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) + patcher_args.realtime_patching = False + print("Loading entropy model and patcher") + patcher_args.entropy_model_checkpoint_dir = os.path.join( + blt_weights_dir, "entropy_model" + ) + patcher = patcher_args.build() + state_path = os.path.join( + patcher_args.entropy_model_checkpoint_dir, "consolidated.pth" + ) + entropy_model = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir, state_path + ) + entropy_model.save_pretrained(entropy_dir) + entropy_readme_file = os.path.join(entropy_dir, "README.md") + if os.path.exists(entropy_readme_file): + os.remove(entropy_readme_file) + + # TODO: Persist tokenizer in HF compatible way + + +@app.command() +def load_custom( + blt_repo: str = "facebook/blt-1b", +): + distributed_args = DistributedArgs() + distributed_args.configure_world() + if not torch.distributed.is_initialized(): + setup_torch_distributed(distributed_args) + blt = BltModelWrapper.from_pretrained(blt_repo) + prompts = ["The answer is"] + outputs = generate_nocache( + prompts, model=blt.model, tokenizer=blt.tokenizer, patcher=blt.patcher + ) + text_outputs = [blt.tokenizer.decode(t) for t in outputs] + for p, t in zip(prompts, text_outputs): + print(f'Prompt: "{p}" Completion: "{t}"') + print() + + +@app.command() +def load_transformers( + source: str, + entropy_repo: str = "facebook/blt-entropy", + blt_repo: str = "facebook/blt-1b", + entropy_dir: str | None = None, + blt_dir: str | None = None, +): + if source == "local": + assert entropy_dir is not None + assert blt_dir is not None + entropy_model = LMTransformer.from_pretrained( + entropy_dir, local_files_only=True + ) + blt_model = ByteLatentTransformer.from_pretrained( + blt_dir, local_files_only=True + ) + elif source == "hub": + entropy_model = LMTransformer.from_pretrained(entropy_repo) + blt_model = ByteLatentTransformer.from_pretrained(blt_repo) + else: + raise ValueError(f"Unknown source: {source}") + + # TODO: Need a way to get tokenizer + # TODO: Need a way to get patching settings + # TODO: Insert test inference call + + +if __name__ == "__main__": + app() diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index ccfe395..437fa5d 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -768,10 +768,17 @@ def compute_hash_embeddings( return local_encoder_embeds -class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin, - repo_url="https://github.com/facebookresearch/blt", - pipeline_tag="text-generation", - license="other"): +class ByteLatentTransformer( + nn.Module, + SequenceModelWithOutput, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", +): """ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, @@ -861,6 +868,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM ) ) + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + def get_output_seq_len(self): return self.max_seqlen diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 906a54b..af55f40 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -15,6 +15,7 @@ from torch.distributed.tensor.parallel import ( ) from torch.nn.attention.flex_attention import BlockMask, create_block_mask from xformers.ops import AttentionBias +from huggingface_hub import PyTorchModelHubMixin from bytelatent.base_transformer import ( BaseTransformer, @@ -60,7 +61,16 @@ class LMTransformerArgs(BaseTransformerArgs): sliding_window: int | None = None -class LMTransformer(BaseTransformer): +class LMTransformer( + BaseTransformer, + PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + paper_url="https://arxiv.org/abs/2412.09871", + pipeline_tag="text-generation", + license="other", + license_name="fair-noncommercial-research-license", + license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE", +): def __init__(self, args: LMTransformerArgs): super().__init__(args) self.weight_tying = args.weight_tying @@ -81,6 +91,11 @@ class LMTransformer(BaseTransformer): if args.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight + def push_to_hub(self, *args, **kwargs): + raise ValueError( + "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." + ) + def forward( self, token_values: torch.Tensor,