Update iterator inheritance, pass file format args, limit iterator

- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
This commit is contained in:
Pedro Rodriguez 2025-02-20 00:56:52 +00:00
parent b0956bde99
commit 0ffe2ab685
16 changed files with 317 additions and 133 deletions

2
.gitignore vendored
View file

@ -168,3 +168,5 @@ figures/
internal/ internal/
jobs_parallel-copy/ jobs_parallel-copy/
wandb/ wandb/
*.ipynb

View file

@ -72,6 +72,7 @@ def distribute_data_to_rank(
arrow_batch_size: int, arrow_batch_size: int,
rank: int, rank: int,
world_size: int, world_size: int,
file_format: str,
s3_profile: str | None = None, s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN, file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator: ) -> ArrowFileIterator:
@ -85,6 +86,7 @@ def distribute_data_to_rank(
rank_to_arrow_iterator_params.append( rank_to_arrow_iterator_params.append(
ArrowFileIterator( ArrowFileIterator(
file_path=chunk_path, file_path=chunk_path,
file_format=file_format,
worker_id=worker_id, worker_id=worker_id,
num_workers=n_workers_per_chunk, num_workers=n_workers_per_chunk,
preprocess_dir=preprocess_dir, preprocess_dir=preprocess_dir,
@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel):
entropy_model_name: str | None = "transformer_100m" entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100 arrow_batch_size: int = 100
buffer_size: int = 64 buffer_size: int = 64
file_format: str = "arrow"
pad_to_max_length: bool = True pad_to_max_length: bool = True
max_encoder_seq_length: int = 12288 max_encoder_seq_length: int = 12288
@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel):
for dataset_path in self.sources: for dataset_path in self.sources:
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
arrow_iterator = distribute_data_to_rank( arrow_iterator = distribute_data_to_rank(
file_format=self.file_format,
dataset_path=os.path.join(self.root_dir, dataset_path), dataset_path=os.path.join(self.root_dir, dataset_path),
preprocess_dir=self.preprocess_dir, preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name, entropy_model_name=self.entropy_model_name,
@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel):
class ValidationArgs(BaseModel): class ValidationArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
max_steps: int | None = ( max_n_docs: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
) )
use_val_from_train_src: bool = True # Use the validation set from training sources use_val_from_train_src: bool = True # Use the validation set from training sources
@ -248,8 +252,8 @@ class ValidationArgs(BaseModel):
class EvalArgs(BaseModel): class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
dump_dir: str dump_dir: str | None = None
ckpt_dir: str ckpt_dir: str | None = None
metric_log_dir: str | None = None metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = ( generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs() PackedCausalTransformerGeneratorArgs()

View file

@ -2,6 +2,8 @@
import abc import abc
from typing import Any, Generator, Generic, TypeVar from typing import Any, Generator, Generic, TypeVar
import pydantic
T = TypeVar("T") T = TypeVar("T")
C = TypeVar("C") C = TypeVar("C")
@ -23,6 +25,10 @@ class IteratorState(Generic[C]):
pass pass
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
model_config = pydantic.ConfigDict(extra="forbid")
def get_state_and_refresh(iterator: StatefulIterator): def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state() # Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs # on mp iterator shuts down MP to correctly persist state and it needs

View file

@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict
from bytelatent import ByteLatentError from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__) logger = getLogger(__name__)
class ArrowFileIteratorState(BaseModel, IteratorState): class ArrowFileIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
file_path: str | None file_path: str | None
row_num: int row_num: int
@ -110,39 +113,51 @@ class ArrowFileIterator(StatefulIterator):
logger.info("Arrow iterator using fs=%s", self.fs) logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None: if dataset_files is None:
# Prepare arrow shards assert (
jsonl_file = file_path file_path is not None
parts = re.match( ), "Must specify file_Path if dataset_files is None"
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) if file_format == "json":
) if self.fs is None:
assert parts is not None self.fs = get_fs(file_path, s3_profile=s3_profile)
dataset = parts.group(1) if isinstance(self.fs, s3fs.S3FileSystem):
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) self.filesystem_type = "s3"
data_dir_with_glob = os.path.join( else:
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" self.filesystem_type = "file"
) self.dataset_files = [file_path]
if self.fs is None: else:
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) # Prepare arrow shards
if isinstance(self.fs, s3fs.S3FileSystem): jsonl_file = file_path
self.filesystem_type = "s3" parts = re.match(
else: r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
self.filesystem_type = "file"
shard_files = self.fs.glob(data_dir_with_glob)
for s in shard_files:
complete_file = os.path.join(
data_dir, f"{os.path.basename(s)}.complete"
) )
assert parts is not None
if not self.fs.exists(complete_file): dataset = parts.group(1)
raise ValueError(f"Missing .complete for input file: {s}") data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
data_dir_with_glob = os.path.join(
shard_files = sorted(shard_files, key=shard_sort_key) data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
) )
self.dataset_files = [f for f in shard_files] if self.fs is None:
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
if isinstance(self.fs, s3fs.S3FileSystem):
self.filesystem_type = "s3"
else:
self.filesystem_type = "file"
shard_files = self.fs.glob(data_dir_with_glob)
for s in shard_files:
complete_file = os.path.join(
data_dir, f"{os.path.basename(s)}.complete"
)
if not self.fs.exists(complete_file):
raise ValueError(f"Missing .complete for input file: {s}")
shard_files = sorted(shard_files, key=shard_sort_key)
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
)
self.dataset_files = [f for f in shard_files]
else: else:
self.preprocess_dir = None self.preprocess_dir = None
self.dataset_files = dataset_files self.dataset_files = dataset_files

View file

@ -0,0 +1,78 @@
import pandas as pd
from pydantic import ConfigDict
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
class BltTestIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
position: int
total: int
def build(self):
blt_iter = BltTestIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=f"This is some test {i} text.",
tokens=None,
mask=None,
entropies=None,
patch_lengths=None,
)
class BltTestWithEntropiesIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
position: int
total: int
def build(self):
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestWithEntropiesIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
df = pd.read_json("fixtures/tokens_with_entropies.json")
tokens = df["token_ids"].tolist()
entropies = df["entropies"].tolist()
# BOS and EOS
assert len(tokens) == len(text) + 2
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=text,
tokens=tokens,
mask=[True] * len(tokens),
entropies=entropies,
patch_lengths=None,
)

View file

@ -0,0 +1,47 @@
from pydantic import ConfigDict
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
class LimitIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: (
BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
)
n_yielded: int
limit: int
def build(self) -> "LimitIterator":
return LimitIterator(
base_iterator=self.base_iterator_state.build(),
n_yielded=self.n_yielded,
limit=self.limit,
)
class LimitIterator(StatefulIterator):
def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
self.base_iterator = base_iterator
self.n_yielded = n_yielded
self.limit = limit
def get_state(self):
return LimitIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_yielded=self.n_yielded,
limit=self.limit,
)
def create_iter(self):
iterator = self.base_iterator.create_iter()
try:
while self.n_yielded < self.limit or self.limit < 0:
yield next(iterator)
self.n_yielded += 1
except StopIteration:
pass

View file

@ -1,14 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
from pydantic import BaseModel
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ( from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator, ArrowFileIterator,
ArrowFileIteratorState, ArrowFileIteratorState,
) )
class LoopingIteratorState(BaseModel, IteratorState): class LoopingIteratorState(PydanticIteratorState):
file_iterator_state: ArrowFileIteratorState file_iterator_state: ArrowFileIteratorState
epoch: int epoch: int

View file

@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Full from queue import Empty, Full
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict from pydantic import ConfigDict
from bytelatent.data.data_types import Batch from bytelatent.data.data_types import Batch
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
IteratorState,
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.data.iterators.packing_iterator import PackingIteratorState
logger = logging.getLogger() logger = logging.getLogger()
class MultiprocessIteratorState(BaseModel, IteratorState): class MultiprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
base_iterator_state: PackingIteratorState base_iterator_state: PackingIteratorState
n_batches_to_prefetch: int n_batches_to_prefetch: int

View file

@ -5,7 +5,10 @@ import numpy as np
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import Batch, BltSequence from bytelatent.data.data_types import Batch, BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
@ -20,7 +23,7 @@ class PackingArgs(BaseModel):
tokenizer_name: str tokenizer_name: str
class PackingIteratorState(BaseModel, IteratorState): class PackingIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
sequence_iterator_state: SamplingIteratorState sequence_iterator_state: SamplingIteratorState
packing_args: PackingArgs packing_args: PackingArgs

View file

@ -5,20 +5,29 @@ import torch
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltExample from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ( from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator, ArrowFileIterator,
ArrowFileIteratorState, ArrowFileIteratorState,
) )
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
from bytelatent.data.iterators.looping_iterator import (
LoopingIterator,
LoopingIteratorState,
)
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class PreprocessIteratorState(BaseModel, IteratorState): class PreprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState arrow_file_iterator_state: (
ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
)
add_tokens: bool add_tokens: bool
add_patches: bool add_patches: bool
tokenizer_args: TokenizerArgs tokenizer_args: TokenizerArgs
@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator):
def __init__( def __init__(
self, self,
arrow_iterator: ArrowFileIterator, arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
*, *,
patcher_args: PatcherArgs, patcher_args: PatcherArgs,
tokenizer_args: TokenizerArgs, tokenizer_args: TokenizerArgs,

View file

@ -2,13 +2,16 @@
from typing import Any from typing import Any
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict from pydantic import ConfigDict
from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
class SamplingIteratorState(BaseModel): class SamplingIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
rng_state: dict[str, Any] rng_state: dict[str, Any]
source_to_weight: dict[str, float] source_to_weight: dict[str, float]

View file

@ -6,7 +6,10 @@ import numpy as np
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltSequence from bytelatent.data.data_types import BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.preprocess_iterator import ( from bytelatent.data.iterators.preprocess_iterator import (
PreprocessIterator, PreprocessIterator,
PreprocessIteratorState, PreprocessIteratorState,
@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel):
buffer_size: int buffer_size: int
class SequenceIteratorState(BaseModel, IteratorState): class SequenceIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
sequence_packing_args: SequencePackingArgs sequence_packing_args: SequencePackingArgs
preprocess_iterator_state: PreprocessIteratorState preprocess_iterator_state: PreprocessIteratorState
rng_state: dict[str, Any] # If None, rng is disabled.
rng_state: dict[str, Any] | None
def build(self): def build(self):
preprocess_iterator = self.preprocess_iterator_state.build() preprocess_iterator = self.preprocess_iterator_state.build()
@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator):
self, self,
preprocess_iterator: PreprocessIterator, preprocess_iterator: PreprocessIterator,
*, *,
rng_state: dict[str, Any], rng_state: dict[str, Any] | None,
sequence_packing_args: SequencePackingArgs, sequence_packing_args: SequencePackingArgs,
): ):
self.preprocess_iterator = preprocess_iterator self.preprocess_iterator = preprocess_iterator
self.sequence_packing_args = sequence_packing_args self.sequence_packing_args = sequence_packing_args
self.output_seq_len = sequence_packing_args.output_seq_len self.output_seq_len = sequence_packing_args.output_seq_len
self.buffer_size = sequence_packing_args.buffer_size self.buffer_size = sequence_packing_args.buffer_size
self.rng = np.random.default_rng() if rng_state is None:
self.rng.bit_generator.state = rng_state self.rng = None
else:
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
def get_state(self): def get_state(self):
# TODO: need to also perist the current shuffle buffer # TODO: need to also perist the current shuffle buffer
return SequenceIteratorState( return SequenceIteratorState(
sequence_packing_args=self.sequence_packing_args, sequence_packing_args=self.sequence_packing_args,
preprocess_iterator_state=self.preprocess_iterator.get_state(), preprocess_iterator_state=self.preprocess_iterator.get_state(),
rng_state=self.rng.bit_generator.state, rng_state=None if self.rng is None else self.rng.bit_generator.state,
) )
def create_iter(self): def create_iter(self):
@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator):
seq_patch_lengths: list[list[int]] = x_patches.tolist() seq_patch_lengths: list[list[int]] = x_patches.tolist()
assert len(seq_patch_lengths) == self.buffer_size assert len(seq_patch_lengths) == self.buffer_size
for idx in self.rng.permutation(len(seq_patch_lengths)): if self.rng is None:
permutations = list(range(len(seq_patch_lengths)))
else:
permutations = self.rng.permutation(len(seq_patch_lengths))
for idx in permutations:
assert len(seq_patch_lengths[idx]) == self.output_seq_len assert len(seq_patch_lengths[idx]) == self.output_seq_len
assert ( assert (
sum(seq_patch_lengths[idx]) sum(seq_patch_lengths[idx])

View file

@ -6,7 +6,10 @@ import pyarrow as pa
import pyarrow.dataset # pyright: ignore import pyarrow.dataset # pyright: ignore
from bytelatent.constants import BLT_DATA from bytelatent.constants import BLT_DATA
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
ENTROPY_MODEL = "transformer_100m" ENTROPY_MODEL = "transformer_100m"
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
@ -93,3 +96,19 @@ def test_basic_arrow_file():
i += 1 i += 1
if i >= len(expected_ids): if i >= len(expected_ids):
break break
def test_read_jsonl_from_arrow():
arrow_iterator = ArrowFileIterator(
file_path="fixtures/test_docs.jsonl",
num_workers=1,
worker_id=0,
preprocess_dir=None,
entropy_model_name=None,
file_format="json",
arrow_batch_size=100,
)
iterator = arrow_iterator.create_iter()
for i, example in enumerate(iterator):
assert example.sample_id == str(i)
assert example.text == f"test_{i}"

View file

@ -1,83 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import pandas as pd
from pydantic import BaseModel
from bytelatent.constants import BLT_DATA from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import BltExample from bytelatent.data.iterators.dev_iterators import (
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator BltTestIterator,
BltTestWithEntropiesIterator,
)
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class BltTestIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=f"This is some test {i} text.",
tokens=None,
mask=None,
entropies=None,
patch_lengths=None,
)
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestWithEntropiesIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
df = pd.read_json("fixtures/tokens_with_entropies.json")
tokens = df["token_ids"].tolist()
entropies = df["entropies"].tolist()
# BOS and EOS
assert len(tokens) == len(text) + 2
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=text,
tokens=tokens,
mask=[True] * len(tokens),
entropies=entropies,
patch_lengths=None,
)
def test_preprocess_iter(): def test_preprocess_iter():
total = 3 total = 3
tokenizer_args = TokenizerArgs( tokenizer_args = TokenizerArgs(

View file

@ -0,0 +1,45 @@
from bytelatent.data.iterators.dev_iterators import BltTestIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
def test_limit_iterator():
total = 10
limit = 5
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == limit
limit = 10
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == limit == total
limit = 20
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == total
limit = -1
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == total

3
fixtures/test_docs.jsonl Normal file
View file

@ -0,0 +1,3 @@
{"sample_id": "0", "text": "test_0"}
{"sample_id": "1", "text": "test_1"}
{"sample_id": "2", "text": "test_2"}