Merge 45456fa6d8 into sapling-pr-archive-EntilZha
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-20 12:16:25 -08:00 committed by GitHub
commit 06a17a0ddc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 32 additions and 2 deletions

View file

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
import logging
import os
from enum import Enum
@ -572,7 +573,13 @@ class TransformerBlock(nn.Module):
self.ffn_norm.reset_parameters()
class BaseTransformer(nn.Module):
class SequenceModelWithOutput(abc.ABC):
@abc.abstractmethod
def get_output_seq_len(self) -> int:
pass
class BaseTransformer(nn.Module, SequenceModelWithOutput):
def __init__(self, args: BaseTransformerArgs):
super().__init__()
self.dim = args.dim
@ -593,6 +600,9 @@ class BaseTransformer(nn.Module):
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
def get_output_seq_len(self):
return self.max_seqlen
def forward(
self,
h,

View file

@ -12,6 +12,7 @@ from typing_extensions import Self
from bytelatent.base_transformer import (
BaseTransformerArgs,
InitStdFactor,
SequenceModelWithOutput,
TransformerBlock,
)
from bytelatent.data.patcher import Patcher, PatcherArgs
@ -766,7 +767,7 @@ def compute_hash_embeddings(
return local_encoder_embeds
class ByteLatentTransformer(nn.Module):
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
"""
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
@ -856,6 +857,9 @@ class ByteLatentTransformer(nn.Module):
)
)
def get_output_seq_len(self):
return self.max_seqlen
def forward(
self,
tokens: torch.Tensor,

View file

@ -17,3 +17,7 @@ class Tokenizer(abc.ABC):
) -> tuple[list[str], list[int]]:
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
pass
@abc.abstractmethod
def get_vocab_size(self) -> int:
pass

View file

@ -101,6 +101,9 @@ class BltTokenizer(Tokenizer):
self.vocab_size_unit_1 = vocab_size_unit_1
self.n_words = vocab_size_unit_1 + self.offsetting_special_char
def get_vocab_size(self) -> int:
return self.n_words
def encode(
self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
):

View file

@ -8,6 +8,9 @@ class ByteTokenizer(Tokenizer):
self.eos_id = 257
self.n_words = 258
def get_vocab_size(self) -> int:
return self.n_words
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
return tokens

View file

@ -35,6 +35,9 @@ class SentencePieceTokenizer(Tokenizer):
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def get_vocab_size(self) -> int:
return self.n_words
def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
if add_bos is None:
add_bos = self.add_bos

View file

@ -53,6 +53,9 @@ class TikTokenTokenizer(Tokenizer):
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def get_vocab_size(self) -> int:
return self.n_words
def encode(self, s: str, add_bos: bool, add_eos: bool):
assert isinstance(s, str)