diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index d44676d..25fbf71 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import abc import logging import os from enum import Enum @@ -572,7 +573,13 @@ class TransformerBlock(nn.Module): self.ffn_norm.reset_parameters() -class BaseTransformer(nn.Module): +class SequenceModelWithOutput(abc.ABC): + @abc.abstractmethod + def get_output_seq_len(self) -> int: + pass + + +class BaseTransformer(nn.Module, SequenceModelWithOutput): def __init__(self, args: BaseTransformerArgs): super().__init__() self.dim = args.dim @@ -593,6 +600,9 @@ class BaseTransformer(nn.Module): for _ in range(args.n_layers): self.layers.append(TransformerBlock(args)) + def get_output_seq_len(self): + return self.max_seqlen + def forward( self, h, diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index b8586fe..134a990 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -12,6 +12,7 @@ from typing_extensions import Self from bytelatent.base_transformer import ( BaseTransformerArgs, InitStdFactor, + SequenceModelWithOutput, TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs @@ -766,7 +767,7 @@ def compute_hash_embeddings( return local_encoder_embeds -class ByteLatentTransformer(nn.Module): +class ByteLatentTransformer(nn.Module, SequenceModelWithOutput): """ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, @@ -856,6 +857,9 @@ class ByteLatentTransformer(nn.Module): ) ) + def get_output_seq_len(self): + return self.max_seqlen + def forward( self, tokens: torch.Tensor, diff --git a/bytelatent/tokenizers/abstract_tokenizer.py b/bytelatent/tokenizers/abstract_tokenizer.py index fff26c3..f827302 100644 --- a/bytelatent/tokenizers/abstract_tokenizer.py +++ b/bytelatent/tokenizers/abstract_tokenizer.py @@ -17,3 +17,7 @@ class Tokenizer(abc.ABC): ) -> tuple[list[str], list[int]]: """Return the offsets of the tokens in the original text. Only used for evaluation.""" pass + + @abc.abstractmethod + def get_vocab_size(self) -> int: + pass diff --git a/bytelatent/tokenizers/blt_tokenizer.py b/bytelatent/tokenizers/blt_tokenizer.py index a3462e1..1e4ec33 100644 --- a/bytelatent/tokenizers/blt_tokenizer.py +++ b/bytelatent/tokenizers/blt_tokenizer.py @@ -101,6 +101,9 @@ class BltTokenizer(Tokenizer): self.vocab_size_unit_1 = vocab_size_unit_1 self.n_words = vocab_size_unit_1 + self.offsetting_special_char + def get_vocab_size(self) -> int: + return self.n_words + def encode( self, text: str, add_bos: bool | None = None, add_eos: bool | None = None ): diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py index f85f4f7..39ad75c 100644 --- a/bytelatent/tokenizers/byte_tokenizer.py +++ b/bytelatent/tokenizers/byte_tokenizer.py @@ -8,6 +8,9 @@ class ByteTokenizer(Tokenizer): self.eos_id = 257 self.n_words = 258 + def get_vocab_size(self) -> int: + return self.n_words + def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos return tokens diff --git a/bytelatent/tokenizers/sentence_piece_tokenizer.py b/bytelatent/tokenizers/sentence_piece_tokenizer.py index faeb997..1f41420 100644 --- a/bytelatent/tokenizers/sentence_piece_tokenizer.py +++ b/bytelatent/tokenizers/sentence_piece_tokenizer.py @@ -35,6 +35,9 @@ class SentencePieceTokenizer(Tokenizer): ) assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + def get_vocab_size(self) -> int: + return self.n_words + def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): if add_bos is None: add_bos = self.add_bos diff --git a/bytelatent/tokenizers/tiktoken_tokenizer.py b/bytelatent/tokenizers/tiktoken_tokenizer.py index f498bf2..90d595d 100644 --- a/bytelatent/tokenizers/tiktoken_tokenizer.py +++ b/bytelatent/tokenizers/tiktoken_tokenizer.py @@ -53,6 +53,9 @@ class TikTokenTokenizer(Tokenizer): f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" ) + def get_vocab_size(self) -> int: + return self.n_words + def encode(self, s: str, add_bos: bool, add_eos: bool): assert isinstance(s, str)