Remove byte tokenizer and add config args to switch between byte/patch packing ()

Summary:

Test Plan:

```
python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null checkpoint.dump.every=1000 checkpoint.eval.every=100000 eval=null

pytest bytelatent/
```
This commit is contained in:
Pedro Rodriguez 2025-02-25 11:10:59 -08:00 committed by GitHub
parent ff36aa8642
commit aeb95f12a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 23 additions and 45 deletions

View file

@ -14,14 +14,18 @@ from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator from bytelatent.data.iterators.packing_iterator import (
PackingArgs,
PackingIterator,
PackingMode,
)
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.iterators.sampling_iterator import SamplingIterator from bytelatent.data.iterators.sampling_iterator import SamplingIterator
from bytelatent.data.iterators.sequence_iterator import ( from bytelatent.data.iterators.sequence_iterator import (
SequenceIterator, SequenceIterator,
SequencePackingArgs, SequencePackingArgs,
) )
from bytelatent.data.patcher import PatcherArgs from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.distributed import DistributedArgs, EnvironmentArgs from bytelatent.distributed import DistributedArgs, EnvironmentArgs
from bytelatent.metrics import LoggingArgs from bytelatent.metrics import LoggingArgs
from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.model.blt import ByteLatentTransformerArgs
@ -202,7 +206,11 @@ class DataloaderArgs(BaseModel):
max_length=self.max_encoder_seq_length, max_length=self.max_encoder_seq_length,
pad_to_max_length=self.pad_to_max_length, pad_to_max_length=self.pad_to_max_length,
enable_byte_ngrams=self.enable_byte_ngrams, enable_byte_ngrams=self.enable_byte_ngrams,
tokenizer_name=self.tokenizer_args.name, packing_mode=(
PackingMode.BYTES
if self.patcher_args.patching_mode == PatchingModeEnum.byte
else PackingMode.PATCHING
),
) )
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
if self.load_async: if self.load_async:

View file

@ -55,7 +55,7 @@ data:
# so pick the most efficient, so static # so pick the most efficient, so static
patching_mode: byte patching_mode: byte
tokenizer_args: tokenizer_args:
name: bytes name: blt
profiling: profiling:
run: false run: false

View file

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
from enum import Enum
from typing import Any from typing import Any
import numpy as np import numpy as np
@ -12,6 +13,11 @@ from bytelatent.data.iterators.abstract_iterator import (
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
class PackingMode(str, Enum):
BYTES = "bytes"
PATCHING = "patching"
class PackingArgs(BaseModel): class PackingArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
batch_size: int batch_size: int
@ -20,7 +26,7 @@ class PackingArgs(BaseModel):
max_length: int | None max_length: int | None
pad_to_max_length: bool pad_to_max_length: bool
enable_byte_ngrams: bool enable_byte_ngrams: bool
tokenizer_name: str packing_mode: PackingMode
class PackingIteratorState(PydanticIteratorState): class PackingIteratorState(PydanticIteratorState):
@ -155,10 +161,12 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
) )
def create_iter(self): def create_iter(self):
if self.packing_args.tokenizer_name == "bytes": if self.packing_args.packing_mode == PackingMode.BYTES:
return self._create_iter_from_bytes() return self._create_iter_from_bytes()
else: elif self.packing_args.packing_mode == PackingMode.PATCHING:
return self._create_iter_from_patch_lengths() return self._create_iter_from_patch_lengths()
else:
raise ValueError(f"Invalid patching mode: {self.packing_args.packing_mode}")
def _create_iter_from_bytes(self): def _create_iter_from_bytes(self):
sequence_iter = self.sequence_iterator.create_iter() sequence_iter = self.sequence_iterator.create_iter()

View file

@ -5,7 +5,6 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer
from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer
try: try:
@ -55,8 +54,6 @@ class TokenizerArgs(BaseModel):
init_kwargs = self.init_kwargs init_kwargs = self.init_kwargs
if self.name == "blt": if self.name == "blt":
return BltTokenizer(**init_kwargs) return BltTokenizer(**init_kwargs)
elif self.name == "bytes":
return ByteTokenizer(**init_kwargs)
elif self.name == "mock": elif self.name == "mock":
return MockTokenizer(**init_kwargs) return MockTokenizer(**init_kwargs)
elif self.name == "sp": elif self.name == "sp":

View file

@ -1,35 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
class ByteTokenizer(Tokenizer):
def __init__(self):
self.bos_id = 256
self.eos_id = 257
self.n_words = 258
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
return tokens
def decode(self, tokens: list[int]):
byte_tokens = bytes([t for t in tokens if t < 256])
return byte_tokens.decode("utf-8", errors="backslashreplace")
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
if tokens is None:
tokens = self.encode(text)
decoded_chars, offsets = [], []
byte_pos = 0
for token in tokens:
if token < 256:
char = bytes([token]).decode("utf-8", errors="ignore")
if char:
decoded_chars.append(char)
offsets.append(byte_pos)
byte_pos += len(char.encode("utf-8"))
return decoded_chars, offsets