From aeb95f12a1c31de0b62b2925d5c479b2a2dbf4d7 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <par@meta.com> Date: Tue, 25 Feb 2025 11:10:59 -0800 Subject: [PATCH] Remove byte tokenizer and add config args to switch between byte/patch packing (#68) Summary: Test Plan: ``` python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null checkpoint.dump.every=1000 checkpoint.eval.every=100000 eval=null pytest bytelatent/ ``` --- bytelatent/args.py | 14 ++++++-- bytelatent/configs/entropy_model.yaml | 2 +- bytelatent/data/iterators/packing_iterator.py | 14 ++++++-- bytelatent/tokenizers/build_tokenizer.py | 3 -- bytelatent/tokenizers/byte_tokenizer.py | 35 ------------------- 5 files changed, 23 insertions(+), 45 deletions(-) delete mode 100644 bytelatent/tokenizers/byte_tokenizer.py diff --git a/bytelatent/args.py b/bytelatent/args.py index 8ffa717..11c6548 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -14,14 +14,18 @@ from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator -from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator +from bytelatent.data.iterators.packing_iterator import ( + PackingArgs, + PackingIterator, + PackingMode, +) from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.iterators.sampling_iterator import SamplingIterator from bytelatent.data.iterators.sequence_iterator import ( SequenceIterator, SequencePackingArgs, ) -from bytelatent.data.patcher import PatcherArgs +from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum from bytelatent.distributed import DistributedArgs, EnvironmentArgs from bytelatent.metrics import LoggingArgs from bytelatent.model.blt import ByteLatentTransformerArgs @@ -202,7 +206,11 @@ class DataloaderArgs(BaseModel): max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, - tokenizer_name=self.tokenizer_args.name, + packing_mode=( + PackingMode.BYTES + if self.patcher_args.patching_mode == PatchingModeEnum.byte + else PackingMode.PATCHING + ), ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index 79cc85b..f0e9ee7 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -55,7 +55,7 @@ data: # so pick the most efficient, so static patching_mode: byte tokenizer_args: - name: bytes + name: blt profiling: run: false diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 5ed280d..dc34120 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +from enum import Enum from typing import Any import numpy as np @@ -12,6 +13,11 @@ from bytelatent.data.iterators.abstract_iterator import ( from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState +class PackingMode(str, Enum): + BYTES = "bytes" + PATCHING = "patching" + + class PackingArgs(BaseModel): model_config = ConfigDict(extra="forbid") batch_size: int @@ -20,7 +26,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool - tokenizer_name: str + packing_mode: PackingMode class PackingIteratorState(PydanticIteratorState): @@ -155,10 +161,12 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): - if self.packing_args.tokenizer_name == "bytes": + if self.packing_args.packing_mode == PackingMode.BYTES: return self._create_iter_from_bytes() - else: + elif self.packing_args.packing_mode == PackingMode.PATCHING: return self._create_iter_from_patch_lengths() + else: + raise ValueError(f"Invalid patching mode: {self.packing_args.packing_mode}") def _create_iter_from_bytes(self): sequence_iter = self.sequence_iterator.create_iter() diff --git a/bytelatent/tokenizers/build_tokenizer.py b/bytelatent/tokenizers/build_tokenizer.py index 8aa434d..f60dfa4 100644 --- a/bytelatent/tokenizers/build_tokenizer.py +++ b/bytelatent/tokenizers/build_tokenizer.py @@ -5,7 +5,6 @@ from typing import Any from pydantic import BaseModel from bytelatent.tokenizers.blt_tokenizer import BltTokenizer -from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer try: @@ -55,8 +54,6 @@ class TokenizerArgs(BaseModel): init_kwargs = self.init_kwargs if self.name == "blt": return BltTokenizer(**init_kwargs) - elif self.name == "bytes": - return ByteTokenizer(**init_kwargs) elif self.name == "mock": return MockTokenizer(**init_kwargs) elif self.name == "sp": diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py deleted file mode 100644 index f85f4f7..0000000 --- a/bytelatent/tokenizers/byte_tokenizer.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -from bytelatent.tokenizers.abstract_tokenizer import Tokenizer - - -class ByteTokenizer(Tokenizer): - def __init__(self): - self.bos_id = 256 - self.eos_id = 257 - self.n_words = 258 - - def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): - tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos - return tokens - - def decode(self, tokens: list[int]): - byte_tokens = bytes([t for t in tokens if t < 256]) - return byte_tokens.decode("utf-8", errors="backslashreplace") - - def get_token_offsets( - self, text: str, tokens: list[int] | None = None - ) -> tuple[list[str], list[int]]: - if tokens is None: - tokens = self.encode(text) - - decoded_chars, offsets = [], [] - byte_pos = 0 - for token in tokens: - if token < 256: - char = bytes([token]).decode("utf-8", errors="ignore") - if char: - decoded_chars.append(char) - offsets.append(byte_pos) - byte_pos += len(char.encode("utf-8")) - - return decoded_chars, offsets