mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Update iterator inheritance, pass file format args, limit iterator
- Create a common class to use in all inheritance for states - Add a limit iterator that we can use in evals - Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json' - Make EvalArgs valid - Move testing iterators to a common directory to allow usage in multiple test files - Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly) Test Plan: - `pytest bytelatent` - `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
This commit is contained in:
parent
b0956bde99
commit
0ffe2ab685
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -168,3 +168,5 @@ figures/
|
||||||
internal/
|
internal/
|
||||||
jobs_parallel-copy/
|
jobs_parallel-copy/
|
||||||
wandb/
|
wandb/
|
||||||
|
*.ipynb
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,7 @@ def distribute_data_to_rank(
|
||||||
arrow_batch_size: int,
|
arrow_batch_size: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
file_format: str,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||||
) -> ArrowFileIterator:
|
) -> ArrowFileIterator:
|
||||||
|
@ -85,6 +86,7 @@ def distribute_data_to_rank(
|
||||||
rank_to_arrow_iterator_params.append(
|
rank_to_arrow_iterator_params.append(
|
||||||
ArrowFileIterator(
|
ArrowFileIterator(
|
||||||
file_path=chunk_path,
|
file_path=chunk_path,
|
||||||
|
file_format=file_format,
|
||||||
worker_id=worker_id,
|
worker_id=worker_id,
|
||||||
num_workers=n_workers_per_chunk,
|
num_workers=n_workers_per_chunk,
|
||||||
preprocess_dir=preprocess_dir,
|
preprocess_dir=preprocess_dir,
|
||||||
|
@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel):
|
||||||
entropy_model_name: str | None = "transformer_100m"
|
entropy_model_name: str | None = "transformer_100m"
|
||||||
arrow_batch_size: int = 100
|
arrow_batch_size: int = 100
|
||||||
buffer_size: int = 64
|
buffer_size: int = 64
|
||||||
|
file_format: str = "arrow"
|
||||||
|
|
||||||
pad_to_max_length: bool = True
|
pad_to_max_length: bool = True
|
||||||
max_encoder_seq_length: int = 12288
|
max_encoder_seq_length: int = 12288
|
||||||
|
@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel):
|
||||||
for dataset_path in self.sources:
|
for dataset_path in self.sources:
|
||||||
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
||||||
arrow_iterator = distribute_data_to_rank(
|
arrow_iterator = distribute_data_to_rank(
|
||||||
|
file_format=self.file_format,
|
||||||
dataset_path=os.path.join(self.root_dir, dataset_path),
|
dataset_path=os.path.join(self.root_dir, dataset_path),
|
||||||
preprocess_dir=self.preprocess_dir,
|
preprocess_dir=self.preprocess_dir,
|
||||||
entropy_model_name=self.entropy_model_name,
|
entropy_model_name=self.entropy_model_name,
|
||||||
|
@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel):
|
||||||
|
|
||||||
class ValidationArgs(BaseModel):
|
class ValidationArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
max_steps: int | None = (
|
max_n_docs: int | None = (
|
||||||
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||||
)
|
)
|
||||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||||
|
@ -248,8 +252,8 @@ class ValidationArgs(BaseModel):
|
||||||
|
|
||||||
class EvalArgs(BaseModel):
|
class EvalArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
dump_dir: str
|
dump_dir: str | None = None
|
||||||
ckpt_dir: str
|
ckpt_dir: str | None = None
|
||||||
metric_log_dir: str | None = None
|
metric_log_dir: str | None = None
|
||||||
generator: PackedCausalTransformerGeneratorArgs = (
|
generator: PackedCausalTransformerGeneratorArgs = (
|
||||||
PackedCausalTransformerGeneratorArgs()
|
PackedCausalTransformerGeneratorArgs()
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
import abc
|
import abc
|
||||||
from typing import Any, Generator, Generic, TypeVar
|
from typing import Any, Generator, Generic, TypeVar
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
C = TypeVar("C")
|
C = TypeVar("C")
|
||||||
|
|
||||||
|
@ -23,6 +25,10 @@ class IteratorState(Generic[C]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
|
||||||
|
model_config = pydantic.ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
def get_state_and_refresh(iterator: StatefulIterator):
|
def get_state_and_refresh(iterator: StatefulIterator):
|
||||||
# Re-init dataloader and iterator is necessary since get_state()
|
# Re-init dataloader and iterator is necessary since get_state()
|
||||||
# on mp iterator shuts down MP to correctly persist state and it needs
|
# on mp iterator shuts down MP to correctly persist state and it needs
|
||||||
|
|
|
@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict
|
||||||
from bytelatent import ByteLatentError
|
from bytelatent import ByteLatentError
|
||||||
from bytelatent.data.data_types import BltExample
|
from bytelatent.data.data_types import BltExample
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ArrowFileIteratorState(BaseModel, IteratorState):
|
class ArrowFileIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
file_path: str | None
|
file_path: str | None
|
||||||
row_num: int
|
row_num: int
|
||||||
|
@ -110,39 +113,51 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
logger.info("Arrow iterator using fs=%s", self.fs)
|
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||||
|
|
||||||
if dataset_files is None:
|
if dataset_files is None:
|
||||||
# Prepare arrow shards
|
assert (
|
||||||
jsonl_file = file_path
|
file_path is not None
|
||||||
parts = re.match(
|
), "Must specify file_Path if dataset_files is None"
|
||||||
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
|
if file_format == "json":
|
||||||
)
|
if self.fs is None:
|
||||||
assert parts is not None
|
self.fs = get_fs(file_path, s3_profile=s3_profile)
|
||||||
dataset = parts.group(1)
|
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||||
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
|
self.filesystem_type = "s3"
|
||||||
data_dir_with_glob = os.path.join(
|
else:
|
||||||
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
|
self.filesystem_type = "file"
|
||||||
)
|
self.dataset_files = [file_path]
|
||||||
if self.fs is None:
|
else:
|
||||||
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
|
# Prepare arrow shards
|
||||||
if isinstance(self.fs, s3fs.S3FileSystem):
|
jsonl_file = file_path
|
||||||
self.filesystem_type = "s3"
|
parts = re.match(
|
||||||
else:
|
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
|
||||||
self.filesystem_type = "file"
|
|
||||||
shard_files = self.fs.glob(data_dir_with_glob)
|
|
||||||
|
|
||||||
for s in shard_files:
|
|
||||||
complete_file = os.path.join(
|
|
||||||
data_dir, f"{os.path.basename(s)}.complete"
|
|
||||||
)
|
)
|
||||||
|
assert parts is not None
|
||||||
if not self.fs.exists(complete_file):
|
dataset = parts.group(1)
|
||||||
raise ValueError(f"Missing .complete for input file: {s}")
|
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
|
||||||
|
data_dir_with_glob = os.path.join(
|
||||||
shard_files = sorted(shard_files, key=shard_sort_key)
|
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
|
||||||
if len(shard_files) == 0:
|
|
||||||
raise ByteLatentError(
|
|
||||||
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
|
||||||
)
|
)
|
||||||
self.dataset_files = [f for f in shard_files]
|
if self.fs is None:
|
||||||
|
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
|
||||||
|
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||||
|
self.filesystem_type = "s3"
|
||||||
|
else:
|
||||||
|
self.filesystem_type = "file"
|
||||||
|
shard_files = self.fs.glob(data_dir_with_glob)
|
||||||
|
|
||||||
|
for s in shard_files:
|
||||||
|
complete_file = os.path.join(
|
||||||
|
data_dir, f"{os.path.basename(s)}.complete"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.fs.exists(complete_file):
|
||||||
|
raise ValueError(f"Missing .complete for input file: {s}")
|
||||||
|
|
||||||
|
shard_files = sorted(shard_files, key=shard_sort_key)
|
||||||
|
if len(shard_files) == 0:
|
||||||
|
raise ByteLatentError(
|
||||||
|
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
||||||
|
)
|
||||||
|
self.dataset_files = [f for f in shard_files]
|
||||||
else:
|
else:
|
||||||
self.preprocess_dir = None
|
self.preprocess_dir = None
|
||||||
self.dataset_files = dataset_files
|
self.dataset_files = dataset_files
|
||||||
|
|
78
bytelatent/data/iterators/dev_iterators.py
Normal file
78
bytelatent/data/iterators/dev_iterators.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import pandas as pd
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.data_types import BltExample
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestIteratorState(PydanticIteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
position: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
blt_iter = BltTestIteratorState(total=self.total)
|
||||||
|
blt_iter.position = self.position
|
||||||
|
return blt_iter
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestIterator(StatefulIterator):
|
||||||
|
def __init__(self, total: int):
|
||||||
|
self.position = 0
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return BltTestIteratorState(position=self.position, total=self.total)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
for i in range(self.total):
|
||||||
|
self.position += 1
|
||||||
|
yield BltExample(
|
||||||
|
sample_id=f"test_{i}",
|
||||||
|
text=f"This is some test {i} text.",
|
||||||
|
tokens=None,
|
||||||
|
mask=None,
|
||||||
|
entropies=None,
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestWithEntropiesIteratorState(PydanticIteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
position: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
||||||
|
blt_iter.position = self.position
|
||||||
|
return blt_iter
|
||||||
|
|
||||||
|
|
||||||
|
class BltTestWithEntropiesIterator(StatefulIterator):
|
||||||
|
def __init__(self, total: int):
|
||||||
|
self.position = 0
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return BltTestIteratorState(position=self.position, total=self.total)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
||||||
|
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
||||||
|
tokens = df["token_ids"].tolist()
|
||||||
|
entropies = df["entropies"].tolist()
|
||||||
|
# BOS and EOS
|
||||||
|
assert len(tokens) == len(text) + 2
|
||||||
|
for i in range(self.total):
|
||||||
|
self.position += 1
|
||||||
|
yield BltExample(
|
||||||
|
sample_id=f"test_{i}",
|
||||||
|
text=text,
|
||||||
|
tokens=tokens,
|
||||||
|
mask=[True] * len(tokens),
|
||||||
|
entropies=entropies,
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
47
bytelatent/data/iterators/limit_iterator.py
Normal file
47
bytelatent/data/iterators/limit_iterator.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
||||||
|
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
|
||||||
|
|
||||||
|
|
||||||
|
class LimitIteratorState(PydanticIteratorState):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
base_iterator_state: (
|
||||||
|
BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
|
||||||
|
)
|
||||||
|
n_yielded: int
|
||||||
|
limit: int
|
||||||
|
|
||||||
|
def build(self) -> "LimitIterator":
|
||||||
|
return LimitIterator(
|
||||||
|
base_iterator=self.base_iterator_state.build(),
|
||||||
|
n_yielded=self.n_yielded,
|
||||||
|
limit=self.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LimitIterator(StatefulIterator):
|
||||||
|
def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
|
||||||
|
self.base_iterator = base_iterator
|
||||||
|
self.n_yielded = n_yielded
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return LimitIteratorState(
|
||||||
|
base_iterator_state=self.base_iterator.get_state(),
|
||||||
|
n_yielded=self.n_yielded,
|
||||||
|
limit=self.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_iter(self):
|
||||||
|
iterator = self.base_iterator.create_iter()
|
||||||
|
try:
|
||||||
|
while self.n_yielded < self.limit or self.limit < 0:
|
||||||
|
yield next(iterator)
|
||||||
|
self.n_yielded += 1
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
|
@ -1,14 +1,16 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.arrow_iterator import (
|
from bytelatent.data.iterators.arrow_iterator import (
|
||||||
ArrowFileIterator,
|
ArrowFileIterator,
|
||||||
ArrowFileIteratorState,
|
ArrowFileIteratorState,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoopingIteratorState(BaseModel, IteratorState):
|
class LoopingIteratorState(PydanticIteratorState):
|
||||||
file_iterator_state: ArrowFileIteratorState
|
file_iterator_state: ArrowFileIteratorState
|
||||||
epoch: int
|
epoch: int
|
||||||
|
|
||||||
|
|
|
@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
|
||||||
from queue import Empty, Full
|
from queue import Empty, Full
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from bytelatent.data.data_types import Batch
|
from bytelatent.data.data_types import Batch
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
IteratorState,
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class MultiprocessIteratorState(BaseModel, IteratorState):
|
class MultiprocessIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
base_iterator_state: PackingIteratorState
|
base_iterator_state: PackingIteratorState
|
||||||
n_batches_to_prefetch: int
|
n_batches_to_prefetch: int
|
||||||
|
|
|
@ -5,7 +5,10 @@ import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.data.data_types import Batch, BltSequence
|
from bytelatent.data.data_types import Batch, BltSequence
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +23,7 @@ class PackingArgs(BaseModel):
|
||||||
tokenizer_name: str
|
tokenizer_name: str
|
||||||
|
|
||||||
|
|
||||||
class PackingIteratorState(BaseModel, IteratorState):
|
class PackingIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
sequence_iterator_state: SamplingIteratorState
|
sequence_iterator_state: SamplingIteratorState
|
||||||
packing_args: PackingArgs
|
packing_args: PackingArgs
|
||||||
|
|
|
@ -5,20 +5,29 @@ import torch
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.data.data_types import BltExample
|
from bytelatent.data.data_types import BltExample
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.arrow_iterator import (
|
from bytelatent.data.iterators.arrow_iterator import (
|
||||||
ArrowFileIterator,
|
ArrowFileIterator,
|
||||||
ArrowFileIteratorState,
|
ArrowFileIteratorState,
|
||||||
)
|
)
|
||||||
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
|
from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
|
||||||
|
from bytelatent.data.iterators.looping_iterator import (
|
||||||
|
LoopingIterator,
|
||||||
|
LoopingIteratorState,
|
||||||
|
)
|
||||||
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
||||||
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
class PreprocessIteratorState(BaseModel, IteratorState):
|
class PreprocessIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
|
arrow_file_iterator_state: (
|
||||||
|
ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
|
||||||
|
)
|
||||||
add_tokens: bool
|
add_tokens: bool
|
||||||
add_patches: bool
|
add_patches: bool
|
||||||
tokenizer_args: TokenizerArgs
|
tokenizer_args: TokenizerArgs
|
||||||
|
@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
arrow_iterator: ArrowFileIterator,
|
arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
|
||||||
*,
|
*,
|
||||||
patcher_args: PatcherArgs,
|
patcher_args: PatcherArgs,
|
||||||
tokenizer_args: TokenizerArgs,
|
tokenizer_args: TokenizerArgs,
|
||||||
|
|
|
@ -2,13 +2,16 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
||||||
|
|
||||||
|
|
||||||
class SamplingIteratorState(BaseModel):
|
class SamplingIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
rng_state: dict[str, Any]
|
rng_state: dict[str, Any]
|
||||||
source_to_weight: dict[str, float]
|
source_to_weight: dict[str, float]
|
||||||
|
|
|
@ -6,7 +6,10 @@ import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.data.data_types import BltSequence
|
from bytelatent.data.data_types import BltSequence
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import (
|
||||||
|
PydanticIteratorState,
|
||||||
|
StatefulIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.preprocess_iterator import (
|
from bytelatent.data.iterators.preprocess_iterator import (
|
||||||
PreprocessIterator,
|
PreprocessIterator,
|
||||||
PreprocessIteratorState,
|
PreprocessIteratorState,
|
||||||
|
@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel):
|
||||||
buffer_size: int
|
buffer_size: int
|
||||||
|
|
||||||
|
|
||||||
class SequenceIteratorState(BaseModel, IteratorState):
|
class SequenceIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
sequence_packing_args: SequencePackingArgs
|
sequence_packing_args: SequencePackingArgs
|
||||||
preprocess_iterator_state: PreprocessIteratorState
|
preprocess_iterator_state: PreprocessIteratorState
|
||||||
rng_state: dict[str, Any]
|
# If None, rng is disabled.
|
||||||
|
rng_state: dict[str, Any] | None
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
preprocess_iterator = self.preprocess_iterator_state.build()
|
preprocess_iterator = self.preprocess_iterator_state.build()
|
||||||
|
@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator):
|
||||||
self,
|
self,
|
||||||
preprocess_iterator: PreprocessIterator,
|
preprocess_iterator: PreprocessIterator,
|
||||||
*,
|
*,
|
||||||
rng_state: dict[str, Any],
|
rng_state: dict[str, Any] | None,
|
||||||
sequence_packing_args: SequencePackingArgs,
|
sequence_packing_args: SequencePackingArgs,
|
||||||
):
|
):
|
||||||
self.preprocess_iterator = preprocess_iterator
|
self.preprocess_iterator = preprocess_iterator
|
||||||
self.sequence_packing_args = sequence_packing_args
|
self.sequence_packing_args = sequence_packing_args
|
||||||
self.output_seq_len = sequence_packing_args.output_seq_len
|
self.output_seq_len = sequence_packing_args.output_seq_len
|
||||||
self.buffer_size = sequence_packing_args.buffer_size
|
self.buffer_size = sequence_packing_args.buffer_size
|
||||||
self.rng = np.random.default_rng()
|
if rng_state is None:
|
||||||
self.rng.bit_generator.state = rng_state
|
self.rng = None
|
||||||
|
else:
|
||||||
|
self.rng = np.random.default_rng()
|
||||||
|
self.rng.bit_generator.state = rng_state
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
# TODO: need to also perist the current shuffle buffer
|
# TODO: need to also perist the current shuffle buffer
|
||||||
return SequenceIteratorState(
|
return SequenceIteratorState(
|
||||||
sequence_packing_args=self.sequence_packing_args,
|
sequence_packing_args=self.sequence_packing_args,
|
||||||
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
||||||
rng_state=self.rng.bit_generator.state,
|
rng_state=None if self.rng is None else self.rng.bit_generator.state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_iter(self):
|
def create_iter(self):
|
||||||
|
@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator):
|
||||||
|
|
||||||
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
||||||
assert len(seq_patch_lengths) == self.buffer_size
|
assert len(seq_patch_lengths) == self.buffer_size
|
||||||
for idx in self.rng.permutation(len(seq_patch_lengths)):
|
if self.rng is None:
|
||||||
|
permutations = list(range(len(seq_patch_lengths)))
|
||||||
|
else:
|
||||||
|
permutations = self.rng.permutation(len(seq_patch_lengths))
|
||||||
|
|
||||||
|
for idx in permutations:
|
||||||
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
||||||
assert (
|
assert (
|
||||||
sum(seq_patch_lengths[idx])
|
sum(seq_patch_lengths[idx])
|
||||||
|
|
|
@ -6,7 +6,10 @@ import pyarrow as pa
|
||||||
import pyarrow.dataset # pyright: ignore
|
import pyarrow.dataset # pyright: ignore
|
||||||
|
|
||||||
from bytelatent.constants import BLT_DATA
|
from bytelatent.constants import BLT_DATA
|
||||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
from bytelatent.data.iterators.arrow_iterator import (
|
||||||
|
ArrowFileIterator,
|
||||||
|
ArrowFileIteratorState,
|
||||||
|
)
|
||||||
|
|
||||||
ENTROPY_MODEL = "transformer_100m"
|
ENTROPY_MODEL = "transformer_100m"
|
||||||
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
||||||
|
@ -93,3 +96,19 @@ def test_basic_arrow_file():
|
||||||
i += 1
|
i += 1
|
||||||
if i >= len(expected_ids):
|
if i >= len(expected_ids):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_jsonl_from_arrow():
|
||||||
|
arrow_iterator = ArrowFileIterator(
|
||||||
|
file_path="fixtures/test_docs.jsonl",
|
||||||
|
num_workers=1,
|
||||||
|
worker_id=0,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=None,
|
||||||
|
file_format="json",
|
||||||
|
arrow_batch_size=100,
|
||||||
|
)
|
||||||
|
iterator = arrow_iterator.create_iter()
|
||||||
|
for i, example in enumerate(iterator):
|
||||||
|
assert example.sample_id == str(i)
|
||||||
|
assert example.text == f"test_{i}"
|
||||||
|
|
|
@ -1,83 +1,15 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
import pandas as pd
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from bytelatent.constants import BLT_DATA
|
from bytelatent.constants import BLT_DATA
|
||||||
from bytelatent.data.data_types import BltExample
|
from bytelatent.data.iterators.dev_iterators import (
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
BltTestIterator,
|
||||||
|
BltTestWithEntropiesIterator,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
||||||
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
||||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
class BltTestIteratorState(BaseModel, IteratorState):
|
|
||||||
position: int
|
|
||||||
total: int
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
blt_iter = BltTestIteratorState(total=self.total)
|
|
||||||
blt_iter.position = self.position
|
|
||||||
return blt_iter
|
|
||||||
|
|
||||||
|
|
||||||
class BltTestIterator(StatefulIterator):
|
|
||||||
def __init__(self, total: int):
|
|
||||||
self.position = 0
|
|
||||||
self.total = total
|
|
||||||
|
|
||||||
def get_state(self):
|
|
||||||
return BltTestIteratorState(position=self.position, total=self.total)
|
|
||||||
|
|
||||||
def create_iter(self):
|
|
||||||
for i in range(self.total):
|
|
||||||
self.position += 1
|
|
||||||
yield BltExample(
|
|
||||||
sample_id=f"test_{i}",
|
|
||||||
text=f"This is some test {i} text.",
|
|
||||||
tokens=None,
|
|
||||||
mask=None,
|
|
||||||
entropies=None,
|
|
||||||
patch_lengths=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
|
|
||||||
position: int
|
|
||||||
total: int
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
|
||||||
blt_iter.position = self.position
|
|
||||||
return blt_iter
|
|
||||||
|
|
||||||
|
|
||||||
class BltTestWithEntropiesIterator(StatefulIterator):
|
|
||||||
def __init__(self, total: int):
|
|
||||||
self.position = 0
|
|
||||||
self.total = total
|
|
||||||
|
|
||||||
def get_state(self):
|
|
||||||
return BltTestIteratorState(position=self.position, total=self.total)
|
|
||||||
|
|
||||||
def create_iter(self):
|
|
||||||
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
|
||||||
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
|
||||||
tokens = df["token_ids"].tolist()
|
|
||||||
entropies = df["entropies"].tolist()
|
|
||||||
# BOS and EOS
|
|
||||||
assert len(tokens) == len(text) + 2
|
|
||||||
for i in range(self.total):
|
|
||||||
self.position += 1
|
|
||||||
yield BltExample(
|
|
||||||
sample_id=f"test_{i}",
|
|
||||||
text=text,
|
|
||||||
tokens=tokens,
|
|
||||||
mask=[True] * len(tokens),
|
|
||||||
entropies=entropies,
|
|
||||||
patch_lengths=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_preprocess_iter():
|
def test_preprocess_iter():
|
||||||
total = 3
|
total = 3
|
||||||
tokenizer_args = TokenizerArgs(
|
tokenizer_args = TokenizerArgs(
|
||||||
|
|
45
bytelatent/data/iterators/test_limit_iterator.py
Normal file
45
bytelatent/data/iterators/test_limit_iterator.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from bytelatent.data.iterators.dev_iterators import BltTestIterator
|
||||||
|
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
||||||
|
|
||||||
|
|
||||||
|
def test_limit_iterator():
|
||||||
|
total = 10
|
||||||
|
limit = 5
|
||||||
|
base_iterator = BltTestIterator(total=total)
|
||||||
|
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||||
|
iterator = limit_iterator.create_iter()
|
||||||
|
n = 0
|
||||||
|
for example in iterator:
|
||||||
|
assert example.sample_id == f"test_{n}"
|
||||||
|
n += 1
|
||||||
|
assert n == limit
|
||||||
|
|
||||||
|
limit = 10
|
||||||
|
base_iterator = BltTestIterator(total=total)
|
||||||
|
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||||
|
iterator = limit_iterator.create_iter()
|
||||||
|
n = 0
|
||||||
|
for example in iterator:
|
||||||
|
assert example.sample_id == f"test_{n}"
|
||||||
|
n += 1
|
||||||
|
assert n == limit == total
|
||||||
|
|
||||||
|
limit = 20
|
||||||
|
base_iterator = BltTestIterator(total=total)
|
||||||
|
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||||
|
iterator = limit_iterator.create_iter()
|
||||||
|
n = 0
|
||||||
|
for example in iterator:
|
||||||
|
assert example.sample_id == f"test_{n}"
|
||||||
|
n += 1
|
||||||
|
assert n == total
|
||||||
|
|
||||||
|
limit = -1
|
||||||
|
base_iterator = BltTestIterator(total=total)
|
||||||
|
limit_iterator = LimitIterator(base_iterator, limit=limit)
|
||||||
|
iterator = limit_iterator.create_iter()
|
||||||
|
n = 0
|
||||||
|
for example in iterator:
|
||||||
|
assert example.sample_id == f"test_{n}"
|
||||||
|
n += 1
|
||||||
|
assert n == total
|
3
fixtures/test_docs.jsonl
Normal file
3
fixtures/test_docs.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{"sample_id": "0", "text": "test_0"}
|
||||||
|
{"sample_id": "1", "text": "test_1"}
|
||||||
|
{"sample_id": "2", "text": "test_2"}
|
Loading…
Reference in a new issue