diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py
index d44676d..25fbf71 100644
--- a/bytelatent/base_transformer.py
+++ b/bytelatent/base_transformer.py
@@ -1,4 +1,5 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
+import abc
 import logging
 import os
 from enum import Enum
@@ -572,7 +573,13 @@ class TransformerBlock(nn.Module):
         self.ffn_norm.reset_parameters()
 
 
-class BaseTransformer(nn.Module):
+class SequenceModelWithOutput(abc.ABC):
+    @abc.abstractmethod
+    def get_output_seq_len(self) -> int:
+        pass
+
+
+class BaseTransformer(nn.Module, SequenceModelWithOutput):
     def __init__(self, args: BaseTransformerArgs):
         super().__init__()
         self.dim = args.dim
@@ -593,6 +600,9 @@ class BaseTransformer(nn.Module):
         for _ in range(args.n_layers):
             self.layers.append(TransformerBlock(args))
 
+    def get_output_seq_len(self):
+        return self.max_seqlen
+
     def forward(
         self,
         h,
diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py
index b8586fe..134a990 100644
--- a/bytelatent/model/blt.py
+++ b/bytelatent/model/blt.py
@@ -12,6 +12,7 @@ from typing_extensions import Self
 from bytelatent.base_transformer import (
     BaseTransformerArgs,
     InitStdFactor,
+    SequenceModelWithOutput,
     TransformerBlock,
 )
 from bytelatent.data.patcher import Patcher, PatcherArgs
@@ -766,7 +767,7 @@ def compute_hash_embeddings(
     return local_encoder_embeds
 
 
-class ByteLatentTransformer(nn.Module):
+class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
     """
     The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
     by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
@@ -856,6 +857,9 @@ class ByteLatentTransformer(nn.Module):
                 )
             )
 
+    def get_output_seq_len(self):
+        return self.max_seqlen
+
     def forward(
         self,
         tokens: torch.Tensor,
diff --git a/bytelatent/tokenizers/abstract_tokenizer.py b/bytelatent/tokenizers/abstract_tokenizer.py
index fff26c3..f827302 100644
--- a/bytelatent/tokenizers/abstract_tokenizer.py
+++ b/bytelatent/tokenizers/abstract_tokenizer.py
@@ -17,3 +17,7 @@ class Tokenizer(abc.ABC):
     ) -> tuple[list[str], list[int]]:
         """Return the offsets of the tokens in the original text. Only used for evaluation."""
         pass
+
+    @abc.abstractmethod
+    def get_vocab_size(self) -> int:
+        pass
diff --git a/bytelatent/tokenizers/blt_tokenizer.py b/bytelatent/tokenizers/blt_tokenizer.py
index a3462e1..1e4ec33 100644
--- a/bytelatent/tokenizers/blt_tokenizer.py
+++ b/bytelatent/tokenizers/blt_tokenizer.py
@@ -101,6 +101,9 @@ class BltTokenizer(Tokenizer):
         self.vocab_size_unit_1 = vocab_size_unit_1
         self.n_words = vocab_size_unit_1 + self.offsetting_special_char
 
+    def get_vocab_size(self) -> int:
+        return self.n_words
+
     def encode(
         self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
     ):
diff --git a/bytelatent/tokenizers/sentence_piece_tokenizer.py b/bytelatent/tokenizers/sentence_piece_tokenizer.py
index faeb997..1f41420 100644
--- a/bytelatent/tokenizers/sentence_piece_tokenizer.py
+++ b/bytelatent/tokenizers/sentence_piece_tokenizer.py
@@ -35,6 +35,9 @@ class SentencePieceTokenizer(Tokenizer):
         )
         assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
 
+    def get_vocab_size(self) -> int:
+        return self.n_words
+
     def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
         if add_bos is None:
             add_bos = self.add_bos
diff --git a/bytelatent/tokenizers/tiktoken_tokenizer.py b/bytelatent/tokenizers/tiktoken_tokenizer.py
index f498bf2..90d595d 100644
--- a/bytelatent/tokenizers/tiktoken_tokenizer.py
+++ b/bytelatent/tokenizers/tiktoken_tokenizer.py
@@ -53,6 +53,9 @@ class TikTokenTokenizer(Tokenizer):
             f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
         )
 
+    def get_vocab_size(self) -> int:
+        return self.n_words
+
     def encode(self, s: str, add_bos: bool, add_eos: bool):
         assert isinstance(s, str)