mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-04 19:49:06 +00:00
Improve HF compatibility
Summary: Test Plan:
This commit is contained in:
parent
1b67cbe022
commit
cbe26bdf6b
3 changed files with 249 additions and 5 deletions
217
bytelatent/hf.py
Normal file
217
bytelatent/hf.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
import os
|
||||
from typing import Union, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import json
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
from huggingface_hub.hub_mixin import ModelHubMixin, DataclassInstance
|
||||
from huggingface_hub import snapshot_download, constants
|
||||
import typer
|
||||
import torch
|
||||
|
||||
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
||||
from bytelatent.generate_blt import generate_nocache
|
||||
from bytelatent.entropy_model import load_entropy_model
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.transformer import LMTransformer
|
||||
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||
|
||||
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class BltModelWrapper(ModelHubMixin):
|
||||
def __init__(self, checkpoint_dir: str):
|
||||
self.model, self.tokenizer, self.train_cfg = (
|
||||
load_consolidated_model_and_tokenizer(checkpoint_dir)
|
||||
)
|
||||
assert isinstance(self.model, ByteLatentTransformer)
|
||||
assert isinstance(self.tokenizer, BltTokenizer)
|
||||
self.patcher_args = self.train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
self.patcher_args.realtime_patching = True
|
||||
self.patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
checkpoint_dir, "entropy_model"
|
||||
)
|
||||
self.patcher = self.patcher_args.build()
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: str | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
force_download: bool = False,
|
||||
proxies: dict | None = None,
|
||||
resume_download: bool | None = None,
|
||||
local_files_only: bool = False,
|
||||
token: str | bool | None = None,
|
||||
):
|
||||
if os.path.isdir(model_id):
|
||||
model = cls(model_id)
|
||||
else:
|
||||
checkpoint_dir = snapshot_download(
|
||||
model_id,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
model = cls(checkpoint_dir)
|
||||
return model
|
||||
|
||||
# Copied from https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hub_mixin.py#L76
|
||||
# So that we can remove behavior we don't want, specifically:
|
||||
# - Overwriting the model card should not be allowed, any changes to the facebook/blt and related model cards should be done by hand and verified.
|
||||
# - Push to hub should be disabled, this also should be done by hand and verified.
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, Path],
|
||||
*,
|
||||
config: Optional[Union[dict, DataclassInstance]] = None,
|
||||
repo_id: Optional[str] = None,
|
||||
push_to_hub: bool = False,
|
||||
model_card_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**push_to_hub_kwargs,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Save weights in local directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `Path`):
|
||||
Path to directory in which the model weights and configuration will be saved.
|
||||
config (`dict` or `DataclassInstance`, *optional*):
|
||||
Model configuration specified as a key/value dictionary or a dataclass instance.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Huggingface Hub after saving it.
|
||||
repo_id (`str`, *optional*):
|
||||
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
||||
not provided.
|
||||
model_card_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional arguments passed to the model card template to customize the model card.
|
||||
push_to_hub_kwargs:
|
||||
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
||||
Returns:
|
||||
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
||||
"""
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
|
||||
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
|
||||
# an existing config.json if it was not saved by `_save_pretrained`.
|
||||
config_path = save_directory / constants.CONFIG_NAME
|
||||
config_path.unlink(missing_ok=True)
|
||||
|
||||
# save model weights/files (framework-specific)
|
||||
self._save_pretrained(save_directory)
|
||||
|
||||
# save config (if provided and if not serialized yet in `_save_pretrained`)
|
||||
if config is None:
|
||||
config = self._hub_mixin_config
|
||||
if config is not None:
|
||||
if is_dataclass(config):
|
||||
config = asdict(config) # type: ignore[arg-type]
|
||||
if not config_path.exists():
|
||||
config_str = json.dumps(config, sort_keys=True, indent=2)
|
||||
config_path.write_text(config_str)
|
||||
|
||||
return None
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
raise ValueError(
|
||||
"Not needed for loading pre-trained weights, but nice to have later"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def convert_to_transformers(blt_weights_dir: str, output_dir: str):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir)
|
||||
blt_dir = os.path.join(output_dir, "blt")
|
||||
entropy_dir = os.path.join(output_dir, "entropy")
|
||||
model.save_pretrained(blt_dir)
|
||||
blt_readme_file = os.path.join(blt_dir, "README.md")
|
||||
if os.path.exists(blt_readme_file):
|
||||
os.remove(blt_readme_file)
|
||||
|
||||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
patcher_args.realtime_patching = False
|
||||
print("Loading entropy model and patcher")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
blt_weights_dir, "entropy_model"
|
||||
)
|
||||
patcher = patcher_args.build()
|
||||
state_path = os.path.join(
|
||||
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
||||
)
|
||||
entropy_model = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir, state_path
|
||||
)
|
||||
entropy_model.save_pretrained(entropy_dir)
|
||||
entropy_readme_file = os.path.join(entropy_dir, "README.md")
|
||||
if os.path.exists(entropy_readme_file):
|
||||
os.remove(entropy_readme_file)
|
||||
|
||||
# TODO: Persist tokenizer in HF compatible way
|
||||
|
||||
|
||||
@app.command()
|
||||
def load_custom(
|
||||
blt_repo: str = "facebook/blt-1b",
|
||||
):
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
if not torch.distributed.is_initialized():
|
||||
setup_torch_distributed(distributed_args)
|
||||
blt = BltModelWrapper.from_pretrained(blt_repo)
|
||||
prompts = ["The answer is"]
|
||||
outputs = generate_nocache(
|
||||
prompts, model=blt.model, tokenizer=blt.tokenizer, patcher=blt.patcher
|
||||
)
|
||||
text_outputs = [blt.tokenizer.decode(t) for t in outputs]
|
||||
for p, t in zip(prompts, text_outputs):
|
||||
print(f'Prompt: "{p}" Completion: "{t}"')
|
||||
print()
|
||||
|
||||
|
||||
@app.command()
|
||||
def load_transformers(
|
||||
source: str,
|
||||
entropy_repo: str = "facebook/blt-entropy",
|
||||
blt_repo: str = "facebook/blt-1b",
|
||||
entropy_dir: str | None = None,
|
||||
blt_dir: str | None = None,
|
||||
):
|
||||
if source == "local":
|
||||
assert entropy_dir is not None
|
||||
assert blt_dir is not None
|
||||
entropy_model = LMTransformer.from_pretrained(
|
||||
entropy_dir, local_files_only=True
|
||||
)
|
||||
blt_model = ByteLatentTransformer.from_pretrained(
|
||||
blt_dir, local_files_only=True
|
||||
)
|
||||
elif source == "hub":
|
||||
entropy_model = LMTransformer.from_pretrained(entropy_repo)
|
||||
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
|
||||
else:
|
||||
raise ValueError(f"Unknown source: {source}")
|
||||
|
||||
# TODO: Need a way to get tokenizer
|
||||
# TODO: Need a way to get patching settings
|
||||
# TODO: Insert test inference call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
|
@ -768,10 +768,17 @@ def compute_hash_embeddings(
|
|||
return local_encoder_embeds
|
||||
|
||||
|
||||
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/facebookresearch/blt",
|
||||
pipeline_tag="text-generation",
|
||||
license="other"):
|
||||
class ByteLatentTransformer(
|
||||
nn.Module,
|
||||
SequenceModelWithOutput,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/facebookresearch/blt",
|
||||
paper_url="https://arxiv.org/abs/2412.09871",
|
||||
pipeline_tag="text-generation",
|
||||
license="other",
|
||||
license_name="fair-noncommercial-research-license",
|
||||
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
|
||||
):
|
||||
"""
|
||||
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
||||
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
||||
|
@ -861,6 +868,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM
|
|||
)
|
||||
)
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
)
|
||||
|
||||
def get_output_seq_len(self):
|
||||
return self.max_seqlen
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch.distributed.tensor.parallel import (
|
|||
)
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||
from xformers.ops import AttentionBias
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from bytelatent.base_transformer import (
|
||||
BaseTransformer,
|
||||
|
@ -60,7 +61,16 @@ class LMTransformerArgs(BaseTransformerArgs):
|
|||
sliding_window: int | None = None
|
||||
|
||||
|
||||
class LMTransformer(BaseTransformer):
|
||||
class LMTransformer(
|
||||
BaseTransformer,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/facebookresearch/blt",
|
||||
paper_url="https://arxiv.org/abs/2412.09871",
|
||||
pipeline_tag="text-generation",
|
||||
license="other",
|
||||
license_name="fair-noncommercial-research-license",
|
||||
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
|
||||
):
|
||||
def __init__(self, args: LMTransformerArgs):
|
||||
super().__init__(args)
|
||||
self.weight_tying = args.weight_tying
|
||||
|
@ -81,6 +91,11 @@ class LMTransformer(BaseTransformer):
|
|||
if args.weight_tying:
|
||||
self.output.weight = self.embeddings.tok_embeddings.weight
|
||||
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
raise ValueError(
|
||||
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_values: torch.Tensor,
|
||||
|
|
Loading…
Add table
Reference in a new issue