mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-07 04:59:07 +00:00
Improve HF compatibility
Summary: Test Plan:
This commit is contained in:
parent
1b67cbe022
commit
cfb99b07d4
3 changed files with 107 additions and 5 deletions
75
bytelatent/hf.py
Normal file
75
bytelatent/hf.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import os
|
||||||
|
from bytelatent.entropy_model import load_entropy_model
|
||||||
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
|
from bytelatent.transformer import LMTransformer
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def convert_to_transformers(blt_weights_dir: str, output_dir: str):
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(blt_weights_dir)
|
||||||
|
blt_dir = os.path.join(output_dir, "blt")
|
||||||
|
entropy_dir = os.path.join(output_dir, "entropy")
|
||||||
|
model.save_pretrained(blt_dir)
|
||||||
|
blt_readme_file = os.path.join(blt_dir, "README.md")
|
||||||
|
if os.path.exists(blt_readme_file):
|
||||||
|
os.remove(blt_readme_file)
|
||||||
|
|
||||||
|
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||||
|
patcher_args.realtime_patching = False
|
||||||
|
print("Loading entropy model and patcher")
|
||||||
|
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||||
|
blt_weights_dir, "entropy_model"
|
||||||
|
)
|
||||||
|
patcher = patcher_args.build()
|
||||||
|
state_path = os.path.join(
|
||||||
|
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
||||||
|
)
|
||||||
|
entropy_model = load_entropy_model(
|
||||||
|
patcher_args.entropy_model_checkpoint_dir, state_path
|
||||||
|
)
|
||||||
|
entropy_model.save_pretrained(entropy_dir)
|
||||||
|
entropy_readme_file = os.path.join(entropy_dir, "README.md")
|
||||||
|
if os.path.exists(entropy_readme_file):
|
||||||
|
os.remove(entropy_readme_file)
|
||||||
|
|
||||||
|
# TODO: Persist tokenizer in HF compatible way
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def load_transformers(
|
||||||
|
source: str,
|
||||||
|
entropy_repo: str = "facebook/blt-entropy",
|
||||||
|
blt_repo: str = "facebook/blt-1b",
|
||||||
|
entropy_dir: str | None = None,
|
||||||
|
blt_dir: str | None = None,
|
||||||
|
):
|
||||||
|
if source == "local":
|
||||||
|
assert entropy_dir is not None
|
||||||
|
assert blt_dir is not None
|
||||||
|
entropy_model = LMTransformer.from_pretrained(
|
||||||
|
entropy_dir, local_files_only=True
|
||||||
|
)
|
||||||
|
blt_model = ByteLatentTransformer.from_pretrained(
|
||||||
|
blt_dir, local_files_only=True
|
||||||
|
)
|
||||||
|
elif source == "hub":
|
||||||
|
entropy_model = LMTransformer.from_pretrained(entropy_repo)
|
||||||
|
blt_model = ByteLatentTransformer.from_pretrained(blt_repo)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown source: {source}")
|
||||||
|
|
||||||
|
# TODO: Need a way to get tokenizer
|
||||||
|
# TODO: Need a way to get patching settings
|
||||||
|
# TODO: Insert test inference call
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
|
@ -768,10 +768,17 @@ def compute_hash_embeddings(
|
||||||
return local_encoder_embeds
|
return local_encoder_embeds
|
||||||
|
|
||||||
|
|
||||||
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
|
class ByteLatentTransformer(
|
||||||
repo_url="https://github.com/facebookresearch/blt",
|
nn.Module,
|
||||||
pipeline_tag="text-generation",
|
SequenceModelWithOutput,
|
||||||
license="other"):
|
PyTorchModelHubMixin,
|
||||||
|
repo_url="https://github.com/facebookresearch/blt",
|
||||||
|
paper_url="https://arxiv.org/abs/2412.09871",
|
||||||
|
pipeline_tag="text-generation",
|
||||||
|
license="other",
|
||||||
|
license_name="fair-noncommercial-research-license",
|
||||||
|
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
||||||
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
||||||
|
@ -861,6 +868,11 @@ class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubM
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def push_to_hub(self, *args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||||
|
)
|
||||||
|
|
||||||
def get_output_seq_len(self):
|
def get_output_seq_len(self):
|
||||||
return self.max_seqlen
|
return self.max_seqlen
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from torch.distributed.tensor.parallel import (
|
||||||
)
|
)
|
||||||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformer,
|
BaseTransformer,
|
||||||
|
@ -60,7 +61,16 @@ class LMTransformerArgs(BaseTransformerArgs):
|
||||||
sliding_window: int | None = None
|
sliding_window: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class LMTransformer(BaseTransformer):
|
class LMTransformer(
|
||||||
|
BaseTransformer,
|
||||||
|
PyTorchModelHubMixin,
|
||||||
|
repo_url="https://github.com/facebookresearch/blt",
|
||||||
|
paper_url="https://arxiv.org/abs/2412.09871",
|
||||||
|
pipeline_tag="text-generation",
|
||||||
|
license="other",
|
||||||
|
license_name="fair-noncommercial-research-license",
|
||||||
|
license_link="https://huggingface.co/facebook/blt/blob/main/LICENSE",
|
||||||
|
):
|
||||||
def __init__(self, args: LMTransformerArgs):
|
def __init__(self, args: LMTransformerArgs):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.weight_tying = args.weight_tying
|
self.weight_tying = args.weight_tying
|
||||||
|
@ -81,6 +91,11 @@ class LMTransformer(BaseTransformer):
|
||||||
if args.weight_tying:
|
if args.weight_tying:
|
||||||
self.output.weight = self.embeddings.tok_embeddings.weight
|
self.output.weight = self.embeddings.tok_embeddings.weight
|
||||||
|
|
||||||
|
def push_to_hub(self, *args, **kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
"For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
token_values: torch.Tensor,
|
token_values: torch.Tensor,
|
||||||
|
|
Loading…
Add table
Reference in a new issue