From aeb95f12a1c31de0b62b2925d5c479b2a2dbf4d7 Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <par@meta.com>
Date: Tue, 25 Feb 2025 11:10:59 -0800
Subject: [PATCH] Remove byte tokenizer and add config args to switch between
 byte/patch packing (#68)

Summary:

Test Plan:

```
python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null checkpoint.dump.every=1000 checkpoint.eval.every=100000 eval=null

pytest bytelatent/
```
---
 bytelatent/args.py                            | 14 ++++++--
 bytelatent/configs/entropy_model.yaml         |  2 +-
 bytelatent/data/iterators/packing_iterator.py | 14 ++++++--
 bytelatent/tokenizers/build_tokenizer.py      |  3 --
 bytelatent/tokenizers/byte_tokenizer.py       | 35 -------------------
 5 files changed, 23 insertions(+), 45 deletions(-)
 delete mode 100644 bytelatent/tokenizers/byte_tokenizer.py

diff --git a/bytelatent/args.py b/bytelatent/args.py
index 8ffa717..11c6548 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -14,14 +14,18 @@ from bytelatent.data.iterators.abstract_iterator import StatefulIterator
 from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
 from bytelatent.data.iterators.looping_iterator import LoopingIterator
 from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
-from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
+from bytelatent.data.iterators.packing_iterator import (
+    PackingArgs,
+    PackingIterator,
+    PackingMode,
+)
 from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
 from bytelatent.data.iterators.sampling_iterator import SamplingIterator
 from bytelatent.data.iterators.sequence_iterator import (
     SequenceIterator,
     SequencePackingArgs,
 )
-from bytelatent.data.patcher import PatcherArgs
+from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
 from bytelatent.distributed import DistributedArgs, EnvironmentArgs
 from bytelatent.metrics import LoggingArgs
 from bytelatent.model.blt import ByteLatentTransformerArgs
@@ -202,7 +206,11 @@ class DataloaderArgs(BaseModel):
             max_length=self.max_encoder_seq_length,
             pad_to_max_length=self.pad_to_max_length,
             enable_byte_ngrams=self.enable_byte_ngrams,
-            tokenizer_name=self.tokenizer_args.name,
+            packing_mode=(
+                PackingMode.BYTES
+                if self.patcher_args.patching_mode == PatchingModeEnum.byte
+                else PackingMode.PATCHING
+            ),
         )
         packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
         if self.load_async:
diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml
index 79cc85b..f0e9ee7 100644
--- a/bytelatent/configs/entropy_model.yaml
+++ b/bytelatent/configs/entropy_model.yaml
@@ -55,7 +55,7 @@ data:
     # so pick the most efficient, so static
     patching_mode: byte
   tokenizer_args:
-    name: bytes
+    name: blt
 
 profiling:
   run: false
diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py
index 5ed280d..dc34120 100644
--- a/bytelatent/data/iterators/packing_iterator.py
+++ b/bytelatent/data/iterators/packing_iterator.py
@@ -1,4 +1,5 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
+from enum import Enum
 from typing import Any
 
 import numpy as np
@@ -12,6 +13,11 @@ from bytelatent.data.iterators.abstract_iterator import (
 from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
 
 
+class PackingMode(str, Enum):
+    BYTES = "bytes"
+    PATCHING = "patching"
+
+
 class PackingArgs(BaseModel):
     model_config = ConfigDict(extra="forbid")
     batch_size: int
@@ -20,7 +26,7 @@ class PackingArgs(BaseModel):
     max_length: int | None
     pad_to_max_length: bool
     enable_byte_ngrams: bool
-    tokenizer_name: str
+    packing_mode: PackingMode
 
 
 class PackingIteratorState(PydanticIteratorState):
@@ -155,10 +161,12 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
         )
 
     def create_iter(self):
-        if self.packing_args.tokenizer_name == "bytes":
+        if self.packing_args.packing_mode == PackingMode.BYTES:
             return self._create_iter_from_bytes()
-        else:
+        elif self.packing_args.packing_mode == PackingMode.PATCHING:
             return self._create_iter_from_patch_lengths()
+        else:
+            raise ValueError(f"Invalid patching mode: {self.packing_args.packing_mode}")
 
     def _create_iter_from_bytes(self):
         sequence_iter = self.sequence_iterator.create_iter()
diff --git a/bytelatent/tokenizers/build_tokenizer.py b/bytelatent/tokenizers/build_tokenizer.py
index 8aa434d..f60dfa4 100644
--- a/bytelatent/tokenizers/build_tokenizer.py
+++ b/bytelatent/tokenizers/build_tokenizer.py
@@ -5,7 +5,6 @@ from typing import Any
 from pydantic import BaseModel
 
 from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
-from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer
 from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer
 
 try:
@@ -55,8 +54,6 @@ class TokenizerArgs(BaseModel):
             init_kwargs = self.init_kwargs
         if self.name == "blt":
             return BltTokenizer(**init_kwargs)
-        elif self.name == "bytes":
-            return ByteTokenizer(**init_kwargs)
         elif self.name == "mock":
             return MockTokenizer(**init_kwargs)
         elif self.name == "sp":
diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py
deleted file mode 100644
index f85f4f7..0000000
--- a/bytelatent/tokenizers/byte_tokenizer.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
-
-
-class ByteTokenizer(Tokenizer):
-    def __init__(self):
-        self.bos_id = 256
-        self.eos_id = 257
-        self.n_words = 258
-
-    def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
-        tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
-        return tokens
-
-    def decode(self, tokens: list[int]):
-        byte_tokens = bytes([t for t in tokens if t < 256])
-        return byte_tokens.decode("utf-8", errors="backslashreplace")
-
-    def get_token_offsets(
-        self, text: str, tokens: list[int] | None = None
-    ) -> tuple[list[str], list[int]]:
-        if tokens is None:
-            tokens = self.encode(text)
-
-        decoded_chars, offsets = [], []
-        byte_pos = 0
-        for token in tokens:
-            if token < 256:
-                char = bytes([token]).decode("utf-8", errors="ignore")
-                if char:
-                    decoded_chars.append(char)
-                    offsets.append(byte_pos)
-                byte_pos += len(char.encode("utf-8"))
-
-        return decoded_chars, offsets