From 0ffe2ab685bb966c9ce4e9d1779d127ab9c15e41 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 20 Feb 2025 00:56:52 +0000 Subject: [PATCH] Update iterator inheritance, pass file format args, limit iterator - Create a common class to use in all inheritance for states - Add a limit iterator that we can use in evals - Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json' - Make EvalArgs valid - Move testing iterators to a common directory to allow usage in multiple test files - Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly) Test Plan: - `pytest bytelatent` - `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null` --- .gitignore | 2 + bytelatent/args.py | 10 ++- .../data/iterators/abstract_iterator.py | 6 ++ bytelatent/data/iterators/arrow_iterator.py | 81 +++++++++++-------- bytelatent/data/iterators/dev_iterators.py | 78 ++++++++++++++++++ bytelatent/data/iterators/limit_iterator.py | 47 +++++++++++ bytelatent/data/iterators/looping_iterator.py | 8 +- .../data/iterators/multiprocess_iterator.py | 10 ++- bytelatent/data/iterators/packing_iterator.py | 7 +- .../data/iterators/preprocess_iterator.py | 19 +++-- .../data/iterators/sampling_iterator.py | 9 ++- .../data/iterators/sequence_iterator.py | 28 +++++-- .../data/iterators/test_arrow_iterator.py | 21 ++++- bytelatent/data/iterators/test_iters.py | 76 +---------------- .../data/iterators/test_limit_iterator.py | 45 +++++++++++ fixtures/test_docs.jsonl | 3 + 16 files changed, 317 insertions(+), 133 deletions(-) create mode 100644 bytelatent/data/iterators/dev_iterators.py create mode 100644 bytelatent/data/iterators/limit_iterator.py create mode 100644 bytelatent/data/iterators/test_limit_iterator.py create mode 100644 fixtures/test_docs.jsonl diff --git a/.gitignore b/.gitignore index 2d0f075..cef4d53 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,5 @@ figures/ internal/ jobs_parallel-copy/ wandb/ +*.ipynb + diff --git a/bytelatent/args.py b/bytelatent/args.py index dd1fef5..8ffa717 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -72,6 +72,7 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + file_format: str, s3_profile: str | None = None, file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: @@ -85,6 +86,7 @@ def distribute_data_to_rank( rank_to_arrow_iterator_params.append( ArrowFileIterator( file_path=chunk_path, + file_format=file_format, worker_id=worker_id, num_workers=n_workers_per_chunk, preprocess_dir=preprocess_dir, @@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel): entropy_model_name: str | None = "transformer_100m" arrow_batch_size: int = 100 buffer_size: int = 64 + file_format: str = "arrow" pad_to_max_length: bool = True max_encoder_seq_length: int = 12288 @@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel): for dataset_path in self.sources: shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) arrow_iterator = distribute_data_to_rank( + file_format=self.file_format, dataset_path=os.path.join(self.root_dir, dataset_path), preprocess_dir=self.preprocess_dir, entropy_model_name=self.entropy_model_name, @@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel): class ValidationArgs(BaseModel): model_config = ConfigDict(extra="forbid") - max_steps: int | None = ( + max_n_docs: int | None = ( None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) ) use_val_from_train_src: bool = True # Use the validation set from training sources @@ -248,8 +252,8 @@ class ValidationArgs(BaseModel): class EvalArgs(BaseModel): model_config = ConfigDict(extra="forbid") - dump_dir: str - ckpt_dir: str + dump_dir: str | None = None + ckpt_dir: str | None = None metric_log_dir: str | None = None generator: PackedCausalTransformerGeneratorArgs = ( PackedCausalTransformerGeneratorArgs() diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 8ac7f19..e80edd3 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -2,6 +2,8 @@ import abc from typing import Any, Generator, Generic, TypeVar +import pydantic + T = TypeVar("T") C = TypeVar("C") @@ -23,6 +25,10 @@ class IteratorState(Generic[C]): pass +class PydanticIteratorState(pydantic.BaseModel, IteratorState): + model_config = pydantic.ConfigDict(extra="forbid") + + def get_state_and_refresh(iterator: StatefulIterator): # Re-init dataloader and iterator is necessary since get_state() # on mp iterator shuts down MP to correctly persist state and it needs diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 995cd02..34f58d3 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) -class ArrowFileIteratorState(BaseModel, IteratorState): +class ArrowFileIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") file_path: str | None row_num: int @@ -110,39 +113,51 @@ class ArrowFileIterator(StatefulIterator): logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: - # Prepare arrow shards - jsonl_file = file_path - parts = re.match( - r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) - ) - assert parts is not None - dataset = parts.group(1) - data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) - data_dir_with_glob = os.path.join( - data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" - ) - if self.fs is None: - self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) - if isinstance(self.fs, s3fs.S3FileSystem): - self.filesystem_type = "s3" - else: - self.filesystem_type = "file" - shard_files = self.fs.glob(data_dir_with_glob) - - for s in shard_files: - complete_file = os.path.join( - data_dir, f"{os.path.basename(s)}.complete" + assert ( + file_path is not None + ), "Must specify file_Path if dataset_files is None" + if file_format == "json": + if self.fs is None: + self.fs = get_fs(file_path, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + self.dataset_files = [file_path] + else: + # Prepare arrow shards + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) ) - - if not self.fs.exists(complete_file): - raise ValueError(f"Missing .complete for input file: {s}") - - shard_files = sorted(shard_files, key=shard_sort_key) - if len(shard_files) == 0: - raise ByteLatentError( - f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" + assert parts is not None + dataset = parts.group(1) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" ) - self.dataset_files = [f for f in shard_files] + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + + for s in shard_files: + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): + raise ValueError(f"Missing .complete for input file: {s}") + + shard_files = sorted(shard_files, key=shard_sort_key) + if len(shard_files) == 0: + raise ByteLatentError( + f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" + ) + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files diff --git a/bytelatent/data/iterators/dev_iterators.py b/bytelatent/data/iterators/dev_iterators.py new file mode 100644 index 0000000..1b33e3d --- /dev/null +++ b/bytelatent/data/iterators/dev_iterators.py @@ -0,0 +1,78 @@ +import pandas as pd +from pydantic import ConfigDict + +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) + + +class BltTestIteratorState(PydanticIteratorState): + model_config = ConfigDict(extra="forbid") + position: int + total: int + + def build(self): + blt_iter = BltTestIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=f"This is some test {i} text.", + tokens=None, + mask=None, + entropies=None, + patch_lengths=None, + ) + + +class BltTestWithEntropiesIteratorState(PydanticIteratorState): + model_config = ConfigDict(extra="forbid") + position: int + total: int + + def build(self): + blt_iter = BltTestWithEntropiesIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestWithEntropiesIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." + df = pd.read_json("fixtures/tokens_with_entropies.json") + tokens = df["token_ids"].tolist() + entropies = df["entropies"].tolist() + # BOS and EOS + assert len(tokens) == len(text) + 2 + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=text, + tokens=tokens, + mask=[True] * len(tokens), + entropies=entropies, + patch_lengths=None, + ) diff --git a/bytelatent/data/iterators/limit_iterator.py b/bytelatent/data/iterators/limit_iterator.py new file mode 100644 index 0000000..4ca43a9 --- /dev/null +++ b/bytelatent/data/iterators/limit_iterator.py @@ -0,0 +1,47 @@ +from pydantic import ConfigDict + +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState +from bytelatent.data.iterators.dev_iterators import BltTestIteratorState + + +class LimitIteratorState(PydanticIteratorState): + model_config = ConfigDict(extra="forbid") + base_iterator_state: ( + BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState + ) + n_yielded: int + limit: int + + def build(self) -> "LimitIterator": + return LimitIterator( + base_iterator=self.base_iterator_state.build(), + n_yielded=self.n_yielded, + limit=self.limit, + ) + + +class LimitIterator(StatefulIterator): + def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0): + self.base_iterator = base_iterator + self.n_yielded = n_yielded + self.limit = limit + + def get_state(self): + return LimitIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_yielded=self.n_yielded, + limit=self.limit, + ) + + def create_iter(self): + iterator = self.base_iterator.create_iter() + try: + while self.n_yielded < self.limit or self.limit < 0: + yield next(iterator) + self.n_yielded += 1 + except StopIteration: + pass diff --git a/bytelatent/data/iterators/looping_iterator.py b/bytelatent/data/iterators/looping_iterator.py index 2eff38c..7406f61 100644 --- a/bytelatent/data/iterators/looping_iterator.py +++ b/bytelatent/data/iterators/looping_iterator.py @@ -1,14 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from pydantic import BaseModel -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, ArrowFileIteratorState, ) -class LoopingIteratorState(BaseModel, IteratorState): +class LoopingIteratorState(PydanticIteratorState): file_iterator_state: ArrowFileIteratorState epoch: int diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 33bde94..b4df945 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass from queue import Empty, Full import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from bytelatent.data.data_types import Batch -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + IteratorState, + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.packing_iterator import PackingIteratorState logger = logging.getLogger() -class MultiprocessIteratorState(BaseModel, IteratorState): +class MultiprocessIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") base_iterator_state: PackingIteratorState n_batches_to_prefetch: int diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index fa29149..5ed280d 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -5,7 +5,10 @@ import numpy as np from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import Batch, BltSequence -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState @@ -20,7 +23,7 @@ class PackingArgs(BaseModel): tokenizer_name: str -class PackingIteratorState(BaseModel, IteratorState): +class PackingIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") sequence_iterator_state: SamplingIteratorState packing_args: PackingArgs diff --git a/bytelatent/data/iterators/preprocess_iterator.py b/bytelatent/data/iterators/preprocess_iterator.py index 8eeba41..f72364d 100644 --- a/bytelatent/data/iterators/preprocess_iterator.py +++ b/bytelatent/data/iterators/preprocess_iterator.py @@ -5,20 +5,29 @@ import torch from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import BltExample -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, ArrowFileIteratorState, ) -from bytelatent.data.iterators.looping_iterator import LoopingIteratorState +from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState +from bytelatent.data.iterators.looping_iterator import ( + LoopingIterator, + LoopingIteratorState, +) from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -class PreprocessIteratorState(BaseModel, IteratorState): +class PreprocessIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") - arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState + arrow_file_iterator_state: ( + ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState + ) add_tokens: bool add_patches: bool tokenizer_args: TokenizerArgs @@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator): def __init__( self, - arrow_iterator: ArrowFileIterator, + arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator, *, patcher_args: PatcherArgs, tokenizer_args: TokenizerArgs, diff --git a/bytelatent/data/iterators/sampling_iterator.py b/bytelatent/data/iterators/sampling_iterator.py index 6474bf6..170f215 100644 --- a/bytelatent/data/iterators/sampling_iterator.py +++ b/bytelatent/data/iterators/sampling_iterator.py @@ -2,13 +2,16 @@ from typing import Any import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict -from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState -class SamplingIteratorState(BaseModel): +class SamplingIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") rng_state: dict[str, Any] source_to_weight: dict[str, float] diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index d90ea31..0a492be 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -6,7 +6,10 @@ import numpy as np from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import BltSequence -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.preprocess_iterator import ( PreprocessIterator, PreprocessIteratorState, @@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel): buffer_size: int -class SequenceIteratorState(BaseModel, IteratorState): +class SequenceIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") sequence_packing_args: SequencePackingArgs preprocess_iterator_state: PreprocessIteratorState - rng_state: dict[str, Any] + # If None, rng is disabled. + rng_state: dict[str, Any] | None def build(self): preprocess_iterator = self.preprocess_iterator_state.build() @@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator): self, preprocess_iterator: PreprocessIterator, *, - rng_state: dict[str, Any], + rng_state: dict[str, Any] | None, sequence_packing_args: SequencePackingArgs, ): self.preprocess_iterator = preprocess_iterator self.sequence_packing_args = sequence_packing_args self.output_seq_len = sequence_packing_args.output_seq_len self.buffer_size = sequence_packing_args.buffer_size - self.rng = np.random.default_rng() - self.rng.bit_generator.state = rng_state + if rng_state is None: + self.rng = None + else: + self.rng = np.random.default_rng() + self.rng.bit_generator.state = rng_state def get_state(self): # TODO: need to also perist the current shuffle buffer return SequenceIteratorState( sequence_packing_args=self.sequence_packing_args, preprocess_iterator_state=self.preprocess_iterator.get_state(), - rng_state=self.rng.bit_generator.state, + rng_state=None if self.rng is None else self.rng.bit_generator.state, ) def create_iter(self): @@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator): seq_patch_lengths: list[list[int]] = x_patches.tolist() assert len(seq_patch_lengths) == self.buffer_size - for idx in self.rng.permutation(len(seq_patch_lengths)): + if self.rng is None: + permutations = list(range(len(seq_patch_lengths))) + else: + permutations = self.rng.permutation(len(seq_patch_lengths)) + + for idx in permutations: assert len(seq_patch_lengths[idx]) == self.output_seq_len assert ( sum(seq_patch_lengths[idx]) diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 064217e..caaa102 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -6,7 +6,10 @@ import pyarrow as pa import pyarrow.dataset # pyright: ignore from bytelatent.constants import BLT_DATA -from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState +from bytelatent.data.iterators.arrow_iterator import ( + ArrowFileIterator, + ArrowFileIteratorState, +) ENTROPY_MODEL = "transformer_100m" ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") @@ -93,3 +96,19 @@ def test_basic_arrow_file(): i += 1 if i >= len(expected_ids): break + + +def test_read_jsonl_from_arrow(): + arrow_iterator = ArrowFileIterator( + file_path="fixtures/test_docs.jsonl", + num_workers=1, + worker_id=0, + preprocess_dir=None, + entropy_model_name=None, + file_format="json", + arrow_batch_size=100, + ) + iterator = arrow_iterator.create_iter() + for i, example in enumerate(iterator): + assert example.sample_id == str(i) + assert example.text == f"test_{i}" diff --git a/bytelatent/data/iterators/test_iters.py b/bytelatent/data/iterators/test_iters.py index 9bc9d59..4749c8a 100644 --- a/bytelatent/data/iterators/test_iters.py +++ b/bytelatent/data/iterators/test_iters.py @@ -1,83 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import pandas as pd -from pydantic import BaseModel from bytelatent.constants import BLT_DATA -from bytelatent.data.data_types import BltExample -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.dev_iterators import ( + BltTestIterator, + BltTestWithEntropiesIterator, +) from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -class BltTestIteratorState(BaseModel, IteratorState): - position: int - total: int - - def build(self): - blt_iter = BltTestIteratorState(total=self.total) - blt_iter.position = self.position - return blt_iter - - -class BltTestIterator(StatefulIterator): - def __init__(self, total: int): - self.position = 0 - self.total = total - - def get_state(self): - return BltTestIteratorState(position=self.position, total=self.total) - - def create_iter(self): - for i in range(self.total): - self.position += 1 - yield BltExample( - sample_id=f"test_{i}", - text=f"This is some test {i} text.", - tokens=None, - mask=None, - entropies=None, - patch_lengths=None, - ) - - -class BltTestWithEntropiesIteratorState(BaseModel, IteratorState): - position: int - total: int - - def build(self): - blt_iter = BltTestWithEntropiesIteratorState(total=self.total) - blt_iter.position = self.position - return blt_iter - - -class BltTestWithEntropiesIterator(StatefulIterator): - def __init__(self, total: int): - self.position = 0 - self.total = total - - def get_state(self): - return BltTestIteratorState(position=self.position, total=self.total) - - def create_iter(self): - text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." - df = pd.read_json("fixtures/tokens_with_entropies.json") - tokens = df["token_ids"].tolist() - entropies = df["entropies"].tolist() - # BOS and EOS - assert len(tokens) == len(text) + 2 - for i in range(self.total): - self.position += 1 - yield BltExample( - sample_id=f"test_{i}", - text=text, - tokens=tokens, - mask=[True] * len(tokens), - entropies=entropies, - patch_lengths=None, - ) - - def test_preprocess_iter(): total = 3 tokenizer_args = TokenizerArgs( diff --git a/bytelatent/data/iterators/test_limit_iterator.py b/bytelatent/data/iterators/test_limit_iterator.py new file mode 100644 index 0000000..47d5c27 --- /dev/null +++ b/bytelatent/data/iterators/test_limit_iterator.py @@ -0,0 +1,45 @@ +from bytelatent.data.iterators.dev_iterators import BltTestIterator +from bytelatent.data.iterators.limit_iterator import LimitIterator + + +def test_limit_iterator(): + total = 10 + limit = 5 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == limit + + limit = 10 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == limit == total + + limit = 20 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == total + + limit = -1 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == total diff --git a/fixtures/test_docs.jsonl b/fixtures/test_docs.jsonl new file mode 100644 index 0000000..7a10f5e --- /dev/null +++ b/fixtures/test_docs.jsonl @@ -0,0 +1,3 @@ +{"sample_id": "0", "text": "test_0"} +{"sample_id": "1", "text": "test_1"} +{"sample_id": "2", "text": "test_2"}