From 1b67cbe02202d312eafa4153bf7d1442dac5ce49 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:27:53 +0200 Subject: [PATCH] Improve HF integration (#98) * Add mixin * Update license --- bytelatent/model/blt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 134a990..ccfe395 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -4,7 +4,7 @@ from enum import Enum, auto from typing import Any, Optional import torch -from pydantic import ConfigDict, model_validator +from pydantic import model_validator from torch import nn from torch.nn.attention.flex_attention import create_block_mask from typing_extensions import Self @@ -13,7 +13,6 @@ from bytelatent.base_transformer import ( BaseTransformerArgs, InitStdFactor, SequenceModelWithOutput, - TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.model.latent_transformer import GlobalTransformer @@ -21,6 +20,8 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID +from huggingface_hub import PyTorchModelHubMixin + def attention_flops_per_token(n_layers, seq_len, dim, causal): # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 @@ -767,7 +768,10 @@ def compute_hash_embeddings( return local_encoder_embeds -class ByteLatentTransformer(nn.Module, SequenceModelWithOutput): +class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin, + repo_url="https://github.com/facebookresearch/blt", + pipeline_tag="text-generation", + license="other"): """ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,