mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Update iterator inheritance, pass file format args, limit iterator
- Create a common class to use in all inheritance for states - Add a limit iterator that we can use in evals - Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json' - Make EvalArgs valid - Move testing iterators to a common directory to allow usage in multiple test files - Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly) Test Plan: - `pytest bytelatent` - `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
This commit is contained in:
parent
b0956bde99
commit
0ffe2ab685
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -168,3 +168,5 @@ figures/
|
|||
internal/
|
||||
jobs_parallel-copy/
|
||||
wandb/
|
||||
*.ipynb
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ def distribute_data_to_rank(
|
|||
arrow_batch_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
file_format: str,
|
||||
s3_profile: str | None = None,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
) -> ArrowFileIterator:
|
||||
|
@ -85,6 +86,7 @@ def distribute_data_to_rank(
|
|||
rank_to_arrow_iterator_params.append(
|
||||
ArrowFileIterator(
|
||||
file_path=chunk_path,
|
||||
file_format=file_format,
|
||||
worker_id=worker_id,
|
||||
num_workers=n_workers_per_chunk,
|
||||
preprocess_dir=preprocess_dir,
|
||||
|
@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel):
|
|||
entropy_model_name: str | None = "transformer_100m"
|
||||
arrow_batch_size: int = 100
|
||||
buffer_size: int = 64
|
||||
file_format: str = "arrow"
|
||||
|
||||
pad_to_max_length: bool = True
|
||||
max_encoder_seq_length: int = 12288
|
||||
|
@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel):
|
|||
for dataset_path in self.sources:
|
||||
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
||||
arrow_iterator = distribute_data_to_rank(
|
||||
file_format=self.file_format,
|
||||
dataset_path=os.path.join(self.root_dir, dataset_path),
|
||||
preprocess_dir=self.preprocess_dir,
|
||||
entropy_model_name=self.entropy_model_name,
|
||||
|
@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel):
|
|||
|
||||
class ValidationArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
max_steps: int | None = (
|
||||
max_n_docs: int | None = (
|
||||
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||
)
|
||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||
|
@ -248,8 +252,8 @@ class ValidationArgs(BaseModel):
|
|||
|
||||
class EvalArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
dump_dir: str
|
||||
ckpt_dir: str
|
||||
dump_dir: str | None = None
|
||||
ckpt_dir: str | None = None
|
||||
metric_log_dir: str | None = None
|
||||
generator: PackedCausalTransformerGeneratorArgs = (
|
||||
PackedCausalTransformerGeneratorArgs()
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
import abc
|
||||
from typing import Any, Generator, Generic, TypeVar
|
||||
|
||||
import pydantic
|
||||
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C")
|
||||
|
||||
|
@ -23,6 +25,10 @@ class IteratorState(Generic[C]):
|
|||
pass
|
||||
|
||||
|
||||
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
|
||||
model_config = pydantic.ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def get_state_and_refresh(iterator: StatefulIterator):
|
||||
# Re-init dataloader and iterator is necessary since get_state()
|
||||
# on mp iterator shuts down MP to correctly persist state and it needs
|
||||
|
|
|
@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict
|
|||
from bytelatent import ByteLatentError
|
||||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||
class ArrowFileIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
file_path: str | None
|
||||
row_num: int
|
||||
|
@ -110,39 +113,51 @@ class ArrowFileIterator(StatefulIterator):
|
|||
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||
|
||||
if dataset_files is None:
|
||||
# Prepare arrow shards
|
||||
jsonl_file = file_path
|
||||
parts = re.match(
|
||||
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
|
||||
)
|
||||
assert parts is not None
|
||||
dataset = parts.group(1)
|
||||
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
|
||||
data_dir_with_glob = os.path.join(
|
||||
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
|
||||
)
|
||||
if self.fs is None:
|
||||
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
|
||||
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||
self.filesystem_type = "s3"
|
||||
else:
|
||||
self.filesystem_type = "file"
|
||||
shard_files = self.fs.glob(data_dir_with_glob)
|
||||
|
||||
for s in shard_files:
|
||||
complete_file = os.path.join(
|
||||
data_dir, f"{os.path.basename(s)}.complete"
|
||||
assert (
|
||||
file_path is not None
|
||||
), "Must specify file_Path if dataset_files is None"
|
||||
if file_format == "json":
|
||||
if self.fs is None:
|
||||
self.fs = get_fs(file_path, s3_profile=s3_profile)
|
||||
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||
self.filesystem_type = "s3"
|
||||
else:
|
||||
self.filesystem_type = "file"
|
||||
self.dataset_files = [file_path]
|
||||
else:
|
||||
# Prepare arrow shards
|
||||
jsonl_file = file_path
|
||||
parts = re.match(
|
||||
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
|
||||
)
|
||||
|
||||
if not self.fs.exists(complete_file):
|
||||
raise ValueError(f"Missing .complete for input file: {s}")
|
||||
|
||||
shard_files = sorted(shard_files, key=shard_sort_key)
|
||||
if len(shard_files) == 0:
|
||||
raise ByteLatentError(
|
||||
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
||||
assert parts is not None
|
||||
dataset = parts.group(1)
|
||||
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
|
||||
data_dir_with_glob = os.path.join(
|
||||
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
|
||||
)
|
||||
self.dataset_files = [f for f in shard_files]
|
||||
if self.fs is None:
|
||||
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
|
||||
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||
self.filesystem_type = "s3"
|
||||
else:
|
||||
self.filesystem_type = "file"
|
||||
shard_files = self.fs.glob(data_dir_with_glob)
|
||||
|
||||
for s in shard_files:
|
||||
complete_file = os.path.join(
|
||||
data_dir, f"{os.path.basename(s)}.complete"
|
||||
)
|
||||
|
||||
if not self.fs.exists(complete_file):
|
||||
raise ValueError(f"Missing .complete for input file: {s}")
|
||||
|
||||
shard_files = sorted(shard_files, key=shard_sort_key)
|
||||
if len(shard_files) == 0:
|
||||
raise ByteLatentError(
|
||||
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
||||
)
|
||||
self.dataset_files = [f for f in shard_files]
|
||||
else:
|
||||
self.preprocess_dir = None
|
||||
self.dataset_files = dataset_files
|
||||
|
|
78
bytelatent/data/iterators/dev_iterators.py
Normal file
78
bytelatent/data/iterators/dev_iterators.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
import pandas as pd
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
|
||||
|
||||
class BltTestIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
position: int
|
||||
total: int
|
||||
|
||||
def build(self):
|
||||
blt_iter = BltTestIteratorState(total=self.total)
|
||||
blt_iter.position = self.position
|
||||
return blt_iter
|
||||
|
||||
|
||||
class BltTestIterator(StatefulIterator):
|
||||
def __init__(self, total: int):
|
||||
self.position = 0
|
||||
self.total = total
|
||||
|
||||
def get_state(self):
|
||||
return BltTestIteratorState(position=self.position, total=self.total)
|
||||
|
||||
def create_iter(self):
|
||||
for i in range(self.total):
|
||||
self.position += 1
|
||||
yield BltExample(
|
||||
sample_id=f"test_{i}",
|
||||
text=f"This is some test {i} text.",
|
||||
tokens=None,
|
||||
mask=None,
|
||||
entropies=None,
|
||||
patch_lengths=None,
|
||||
)
|
||||
|
||||
|
||||
class BltTestWithEntropiesIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
position: int
|
||||
total: int
|
||||
|
||||
def build(self):
|
||||
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
||||
blt_iter.position = self.position
|
||||
return blt_iter
|
||||
|
||||
|
||||
class BltTestWithEntropiesIterator(StatefulIterator):
|
||||
def __init__(self, total: int):
|
||||
self.position = 0
|
||||
self.total = total
|
||||
|
||||
def get_state(self):
|
||||
return BltTestIteratorState(position=self.position, total=self.total)
|
||||
|
||||
def create_iter(self):
|
||||
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
||||
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
||||
tokens = df["token_ids"].tolist()
|
||||
entropies = df["entropies"].tolist()
|
||||
# BOS and EOS
|
||||
assert len(tokens) == len(text) + 2
|
||||
for i in range(self.total):
|
||||
self.position += 1
|
||||
yield BltExample(
|
||||
sample_id=f"test_{i}",
|
||||
text=text,
|
||||
tokens=tokens,
|
||||
mask=[True] * len(tokens),
|
||||
entropies=entropies,
|
||||
patch_lengths=None,
|
||||
)
|
47
bytelatent/data/iterators/limit_iterator.py
Normal file
47
bytelatent/data/iterators/limit_iterator.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from pydantic import ConfigDict
|
||||
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
||||
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
|
||||
|
||||
|
||||
class LimitIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
base_iterator_state: (
|
||||
BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
|
||||
)
|
||||
n_yielded: int
|
||||
limit: int
|
||||
|
||||
def build(self) -> "LimitIterator":
|
||||
return LimitIterator(
|
||||
base_iterator=self.base_iterator_state.build(),
|
||||
n_yielded=self.n_yielded,
|
||||
limit=self.limit,
|
||||
)
|
||||
|
||||
|
||||
class LimitIterator(StatefulIterator):
|
||||
def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
|
||||
self.base_iterator = base_iterator
|
||||
self.n_yielded = n_yielded
|
||||
self.limit = limit
|
||||
|
||||
def get_state(self):
|
||||
return LimitIteratorState(
|
||||
base_iterator_state=self.base_iterator.get_state(),
|
||||
n_yielded=self.n_yielded,
|
||||
limit=self.limit,
|
||||
)
|
||||
|
||||
def create_iter(self):
|
||||
iterator = self.base_iterator.create_iter()
|
||||
try:
|
||||
while self.n_yielded < self.limit or self.limit < 0:
|
||||
yield next(iterator)
|
||||
self.n_yielded += 1
|
||||
except StopIteration:
|
||||
pass
|
|
@ -1,14 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
ArrowFileIteratorState,
|
||||
)
|
||||
|
||||
|
||||
class LoopingIteratorState(BaseModel, IteratorState):
|
||||
class LoopingIteratorState(PydanticIteratorState):
|
||||
file_iterator_state: ArrowFileIteratorState
|
||||
epoch: int
|
||||
|
||||
|
|
|
@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
|
|||
from queue import Empty, Full
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from bytelatent.data.data_types import Batch
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
IteratorState,
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class MultiprocessIteratorState(BaseModel, IteratorState):
|
||||
class MultiprocessIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
base_iterator_state: PackingIteratorState
|
||||
n_batches_to_prefetch: int
|
||||
|
|
|
@ -5,7 +5,10 @@ import numpy as np
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.data.data_types import Batch, BltSequence
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
||||
|
||||
|
||||
|
@ -20,7 +23,7 @@ class PackingArgs(BaseModel):
|
|||
tokenizer_name: str
|
||||
|
||||
|
||||
class PackingIteratorState(BaseModel, IteratorState):
|
||||
class PackingIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
sequence_iterator_state: SamplingIteratorState
|
||||
packing_args: PackingArgs
|
||||
|
|
|
@ -5,20 +5,29 @@ import torch
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
ArrowFileIteratorState,
|
||||
)
|
||||
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
|
||||
from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
|
||||
from bytelatent.data.iterators.looping_iterator import (
|
||||
LoopingIterator,
|
||||
LoopingIteratorState,
|
||||
)
|
||||
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
||||
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
class PreprocessIteratorState(BaseModel, IteratorState):
|
||||
class PreprocessIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
|
||||
arrow_file_iterator_state: (
|
||||
ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
|
||||
)
|
||||
add_tokens: bool
|
||||
add_patches: bool
|
||||
tokenizer_args: TokenizerArgs
|
||||
|
@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
arrow_iterator: ArrowFileIterator,
|
||||
arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
|
||||
*,
|
||||
patcher_args: PatcherArgs,
|
||||
tokenizer_args: TokenizerArgs,
|
||||
|
|
|
@ -2,13 +2,16 @@
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
||||
|
||||
|
||||
class SamplingIteratorState(BaseModel):
|
||||
class SamplingIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
rng_state: dict[str, Any]
|
||||
source_to_weight: dict[str, float]
|
||||
|
|
|
@ -6,7 +6,10 @@ import numpy as np
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.data.data_types import BltSequence
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.abstract_iterator import (
|
||||
PydanticIteratorState,
|
||||
StatefulIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.preprocess_iterator import (
|
||||
PreprocessIterator,
|
||||
PreprocessIteratorState,
|
||||
|
@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel):
|
|||
buffer_size: int
|
||||
|
||||
|
||||
class SequenceIteratorState(BaseModel, IteratorState):
|
||||
class SequenceIteratorState(PydanticIteratorState):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
sequence_packing_args: SequencePackingArgs
|
||||
preprocess_iterator_state: PreprocessIteratorState
|
||||
rng_state: dict[str, Any]
|
||||
# If None, rng is disabled.
|
||||
rng_state: dict[str, Any] | None
|
||||
|
||||
def build(self):
|
||||
preprocess_iterator = self.preprocess_iterator_state.build()
|
||||
|
@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator):
|
|||
self,
|
||||
preprocess_iterator: PreprocessIterator,
|
||||
*,
|
||||
rng_state: dict[str, Any],
|
||||
rng_state: dict[str, Any] | None,
|
||||
sequence_packing_args: SequencePackingArgs,
|
||||
):
|
||||
self.preprocess_iterator = preprocess_iterator
|
||||
self.sequence_packing_args = sequence_packing_args
|
||||
self.output_seq_len = sequence_packing_args.output_seq_len
|
||||
self.buffer_size = sequence_packing_args.buffer_size
|
||||
self.rng = np.random.default_rng()
|
||||
self.rng.bit_generator.state = rng_state
|
||||
if rng_state is None:
|
||||
self.rng = None
|
||||
else:
|
||||
self.rng = np.random.default_rng()
|
||||
self.rng.bit_generator.state = rng_state
|
||||
|
||||
def get_state(self):
|
||||
# TODO: need to also perist the current shuffle buffer
|
||||
return SequenceIteratorState(
|
||||
sequence_packing_args=self.sequence_packing_args,
|
||||
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
||||
rng_state=self.rng.bit_generator.state,
|
||||
rng_state=None if self.rng is None else self.rng.bit_generator.state,
|
||||
)
|
||||
|
||||
def create_iter(self):
|
||||
|
@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator):
|
|||
|
||||
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
||||
assert len(seq_patch_lengths) == self.buffer_size
|
||||
for idx in self.rng.permutation(len(seq_patch_lengths)):
|
||||
if self.rng is None:
|
||||
permutations = list(range(len(seq_patch_lengths)))
|
||||
else:
|
||||
permutations = self.rng.permutation(len(seq_patch_lengths))
|
||||
|
||||
for idx in permutations:
|
||||
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
||||
assert (
|
||||
sum(seq_patch_lengths[idx])
|
||||
|
|
|
@ -6,7 +6,10 @@ import pyarrow as pa
|
|||
import pyarrow.dataset # pyright: ignore
|
||||
|
||||
from bytelatent.constants import BLT_DATA
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
ArrowFileIteratorState,
|
||||
)
|
||||
|
||||
ENTROPY_MODEL = "transformer_100m"
|
||||
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
||||
|
@ -93,3 +96,19 @@ def test_basic_arrow_file():
|
|||
i += 1
|
||||
if i >= len(expected_ids):
|
||||
break
|
||||
|
||||
|
||||
def test_read_jsonl_from_arrow():
|
||||
arrow_iterator = ArrowFileIterator(
|
||||
file_path="fixtures/test_docs.jsonl",
|
||||
num_workers=1,
|
||||
worker_id=0,
|
||||
preprocess_dir=None,
|
||||
entropy_model_name=None,
|
||||
file_format="json",
|
||||
arrow_batch_size=100,
|
||||
)
|
||||
iterator = arrow_iterator.create_iter()
|
||||
for i, example in enumerate(iterator):
|
||||
assert example.sample_id == str(i)
|
||||
assert example.text == f"test_{i}"
|
||||
|
|
|
@ -1,83 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bytelatent.constants import BLT_DATA
|
||||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.data.iterators.dev_iterators import (
|
||||
BltTestIterator,
|
||||
BltTestWithEntropiesIterator,
|
||||
)
|
||||
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
class BltTestIteratorState(BaseModel, IteratorState):
|
||||
position: int
|
||||
total: int
|
||||
|
||||
def build(self):
|
||||
blt_iter = BltTestIteratorState(total=self.total)
|
||||
blt_iter.position = self.position
|
||||
return blt_iter
|
||||
|
||||
|
||||
class BltTestIterator(StatefulIterator):
|
||||
def __init__(self, total: int):
|
||||
self.position = 0
|
||||
self.total = total
|
||||
|
||||
def get_state(self):
|
||||
return BltTestIteratorState(position=self.position, total=self.total)
|
||||
|
||||
def create_iter(self):
|
||||
for i in range(self.total):
|
||||
self.position += 1
|
||||
yield BltExample(
|
||||
sample_id=f"test_{i}",
|
||||
text=f"This is some test {i} text.",
|
||||
tokens=None,
|
||||
mask=None,
|
||||
entropies=None,
|
||||
patch_lengths=None,
|
||||
)
|
||||
|
||||
|
||||
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
|
||||
position: int
|
||||
total: int
|
||||
|
||||
def build(self):
|
||||
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
||||
blt_iter.position = self.position
|
||||
return blt_iter
|
||||
|
||||
|
||||
class BltTestWithEntropiesIterator(StatefulIterator):
|
||||
def __init__(self, total: int):
|
||||
self.position = 0
|
||||
self.total = total
|
||||
|
||||
def get_state(self):
|
||||
return BltTestIteratorState(position=self.position, total=self.total)
|
||||
|
||||
def create_iter(self):
|
||||
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
||||
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
||||
tokens = df["token_ids"].tolist()
|
||||
entropies = df["entropies"].tolist()
|
||||
# BOS and EOS
|
||||
assert len(tokens) == len(text) + 2
|
||||
for i in range(self.total):
|
||||
self.position += 1
|
||||
yield BltExample(
|
||||
sample_id=f"test_{i}",
|
||||
text=text,
|
||||
tokens=tokens,
|
||||
mask=[True] * len(tokens),
|
||||
entropies=entropies,
|
||||
patch_lengths=None,
|
||||
)
|
||||
|
||||
|
||||
def test_preprocess_iter():
|
||||
total = 3
|
||||
tokenizer_args = TokenizerArgs(
|
||||
|
|
45
bytelatent/data/iterators/test_limit_iterator.py
Normal file
45
bytelatent/data/iterators/test_limit_iterator.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from bytelatent.data.iterators.dev_iterators import BltTestIterator
|
||||
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
||||
|
||||
|
||||
def test_limit_iterator():
|
||||
total = 10
|
||||
limit = 5
|
||||
base_iterator = BltTestIterator(total=total)
|
||||
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||
iterator = limit_iterator.create_iter()
|
||||
n = 0
|
||||
for example in iterator:
|
||||
assert example.sample_id == f"test_{n}"
|
||||
n += 1
|
||||
assert n == limit
|
||||
|
||||
limit = 10
|
||||
base_iterator = BltTestIterator(total=total)
|
||||
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||
iterator = limit_iterator.create_iter()
|
||||
n = 0
|
||||
for example in iterator:
|
||||
assert example.sample_id == f"test_{n}"
|
||||
n += 1
|
||||
assert n == limit == total
|
||||
|
||||
limit = 20
|
||||
base_iterator = BltTestIterator(total=total)
|
||||
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||
iterator = limit_iterator.create_iter()
|
||||
n = 0
|
||||
for example in iterator:
|
||||
assert example.sample_id == f"test_{n}"
|
||||
n += 1
|
||||
assert n == total
|
||||
|
||||
limit = -1
|
||||
base_iterator = BltTestIterator(total=total)
|
||||
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||
iterator = limit_iterator.create_iter()
|
||||
n = 0
|
||||
for example in iterator:
|
||||
assert example.sample_id == f"test_{n}"
|
||||
n += 1
|
||||
assert n == total
|
3
fixtures/test_docs.jsonl
Normal file
3
fixtures/test_docs.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
|||
{"sample_id": "0", "text": "test_0"}
|
||||
{"sample_id": "1", "text": "test_1"}
|
||||
{"sample_id": "2", "text": "test_2"}
|
Loading…
Reference in a new issue