diff --git a/bytelatent/args.py b/bytelatent/args.py index 8ffa717..c1c60dc 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -14,7 +14,11 @@ from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator -from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator +from bytelatent.data.iterators.packing_iterator import ( + PackingArgs, + PackingIterator, + PackingMode, +) from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.iterators.sampling_iterator import SamplingIterator from bytelatent.data.iterators.sequence_iterator import ( @@ -134,6 +138,7 @@ class DataloaderArgs(BaseModel): buffer_size: int = 64 file_format: str = "arrow" + packing_mode: PackingMode = PackingMode.PATCHING pad_to_max_length: bool = True max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False @@ -202,7 +207,7 @@ class DataloaderArgs(BaseModel): max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, - tokenizer_name=self.tokenizer_args.name, + packing_mode=self.packing_mode, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 1369364..563b470 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -71,6 +71,7 @@ data: root_dir: ??? sources: dclm_baseline_1.0: 1.0 + packing_mode: patching batch_size: 2 prefetch_size: 64 seq_len: 4096 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index 79cc85b..adca091 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -39,6 +39,7 @@ data: root_dir: ??? sources: dclm_baseline_1.0: 1.0 + packing_mode: bytes batch_size: 2 prefetch_size: 64 # seqlen is in terms of patches and @@ -55,7 +56,7 @@ data: # so pick the most efficient, so static patching_mode: byte tokenizer_args: - name: bytes + name: blt profiling: run: false diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 5ed280d..dc34120 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +from enum import Enum from typing import Any import numpy as np @@ -12,6 +13,11 @@ from bytelatent.data.iterators.abstract_iterator import ( from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState +class PackingMode(str, Enum): + BYTES = "bytes" + PATCHING = "patching" + + class PackingArgs(BaseModel): model_config = ConfigDict(extra="forbid") batch_size: int @@ -20,7 +26,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool - tokenizer_name: str + packing_mode: PackingMode class PackingIteratorState(PydanticIteratorState): @@ -155,10 +161,12 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): - if self.packing_args.tokenizer_name == "bytes": + if self.packing_args.packing_mode == PackingMode.BYTES: return self._create_iter_from_bytes() - else: + elif self.packing_args.packing_mode == PackingMode.PATCHING: return self._create_iter_from_patch_lengths() + else: + raise ValueError(f"Invalid patching mode: {self.packing_args.packing_mode}") def _create_iter_from_bytes(self): sequence_iter = self.sequence_iterator.create_iter() diff --git a/bytelatent/tokenizers/build_tokenizer.py b/bytelatent/tokenizers/build_tokenizer.py index 8aa434d..f60dfa4 100644 --- a/bytelatent/tokenizers/build_tokenizer.py +++ b/bytelatent/tokenizers/build_tokenizer.py @@ -5,7 +5,6 @@ from typing import Any from pydantic import BaseModel from bytelatent.tokenizers.blt_tokenizer import BltTokenizer -from bytelatent.tokenizers.byte_tokenizer import ByteTokenizer from bytelatent.tokenizers.tiktoken_tokenizer import TikTokenTokenizer try: @@ -55,8 +54,6 @@ class TokenizerArgs(BaseModel): init_kwargs = self.init_kwargs if self.name == "blt": return BltTokenizer(**init_kwargs) - elif self.name == "bytes": - return ByteTokenizer(**init_kwargs) elif self.name == "mock": return MockTokenizer(**init_kwargs) elif self.name == "sp": diff --git a/bytelatent/tokenizers/byte_tokenizer.py b/bytelatent/tokenizers/byte_tokenizer.py deleted file mode 100644 index f85f4f7..0000000 --- a/bytelatent/tokenizers/byte_tokenizer.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -from bytelatent.tokenizers.abstract_tokenizer import Tokenizer - - -class ByteTokenizer(Tokenizer): - def __init__(self): - self.bos_id = 256 - self.eos_id = 257 - self.n_words = 258 - - def encode(self, s: str, add_bos: bool = False, add_eos: bool = False): - tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos - return tokens - - def decode(self, tokens: list[int]): - byte_tokens = bytes([t for t in tokens if t < 256]) - return byte_tokens.decode("utf-8", errors="backslashreplace") - - def get_token_offsets( - self, text: str, tokens: list[int] | None = None - ) -> tuple[list[str], list[int]]: - if tokens is None: - tokens = self.encode(text) - - decoded_chars, offsets = [], [] - byte_pos = 0 - for token in tokens: - if token < 256: - char = bytes([token]).decode("utf-8", errors="ignore") - if char: - decoded_chars.append(char) - offsets.append(byte_pos) - byte_pos += len(char.encode("utf-8")) - - return decoded_chars, offsets