Improve HF integration (#98)
Some checks failed
Lint with Black / lint (push) Failing after 3s
Lint with isort / lint (push) Failing after 3s

* Add mixin

* Update license
This commit is contained in:
NielsRogge 2025-04-18 23:27:53 +02:00 committed by GitHub
parent 96d51b59d2
commit 1b67cbe022
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,7 @@ from enum import Enum, auto
from typing import Any, Optional
import torch
from pydantic import ConfigDict, model_validator
from pydantic import model_validator
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from typing_extensions import Self
@ -13,7 +13,6 @@ from bytelatent.base_transformer import (
BaseTransformerArgs,
InitStdFactor,
SequenceModelWithOutput,
TransformerBlock,
)
from bytelatent.data.patcher import Patcher, PatcherArgs
from bytelatent.model.latent_transformer import GlobalTransformer
@ -21,6 +20,8 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
from bytelatent.model.utils import downsample
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
from huggingface_hub import PyTorchModelHubMixin
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
@ -767,7 +768,10 @@ def compute_hash_embeddings(
return local_encoder_embeds
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
repo_url="https://github.com/facebookresearch/blt",
pipeline_tag="text-generation",
license="other"):
"""
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,