Update iterators

This commit is contained in:
Pedro Rodriguez 2025-02-20 00:35:04 +00:00
parent b0956bde99
commit 2a717d6b40
14 changed files with 285 additions and 132 deletions

2
.gitignore vendored
View file

@ -168,3 +168,5 @@ figures/
internal/
jobs_parallel-copy/
wandb/
*.ipynb

View file

@ -72,6 +72,7 @@ def distribute_data_to_rank(
arrow_batch_size: int,
rank: int,
world_size: int,
file_format: str,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
@ -85,6 +86,7 @@ def distribute_data_to_rank(
rank_to_arrow_iterator_params.append(
ArrowFileIterator(
file_path=chunk_path,
file_format=file_format,
worker_id=worker_id,
num_workers=n_workers_per_chunk,
preprocess_dir=preprocess_dir,
@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel):
entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100
buffer_size: int = 64
file_format: str = "arrow"
pad_to_max_length: bool = True
max_encoder_seq_length: int = 12288
@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel):
for dataset_path in self.sources:
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
arrow_iterator = distribute_data_to_rank(
file_format=self.file_format,
dataset_path=os.path.join(self.root_dir, dataset_path),
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel):
class ValidationArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
max_steps: int | None = (
max_n_docs: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources
@ -248,8 +252,8 @@ class ValidationArgs(BaseModel):
class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump_dir: str
ckpt_dir: str
dump_dir: str | None = None
ckpt_dir: str | None = None
metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs()

View file

@ -2,6 +2,8 @@
import abc
from typing import Any, Generator, Generic, TypeVar
import pydantic
T = TypeVar("T")
C = TypeVar("C")
@ -23,6 +25,10 @@ class IteratorState(Generic[C]):
pass
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
model_config = pydantic.ConfigDict(extra="forbid")
def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs

View file

@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict
from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__)
class ArrowFileIteratorState(BaseModel, IteratorState):
class ArrowFileIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
file_path: str | None
row_num: int
@ -110,39 +113,42 @@ class ArrowFileIterator(StatefulIterator):
logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None:
# Prepare arrow shards
jsonl_file = file_path
parts = re.match(
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
)
assert parts is not None
dataset = parts.group(1)
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
data_dir_with_glob = os.path.join(
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
)
if self.fs is None:
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
if isinstance(self.fs, s3fs.S3FileSystem):
self.filesystem_type = "s3"
else:
self.filesystem_type = "file"
shard_files = self.fs.glob(data_dir_with_glob)
for s in shard_files:
complete_file = os.path.join(
data_dir, f"{os.path.basename(s)}.complete"
if file_format == "json":
self.dataset_files = [file_path]
else:
# Prepare arrow shards
jsonl_file = file_path
parts = re.match(
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
)
if not self.fs.exists(complete_file):
raise ValueError(f"Missing .complete for input file: {s}")
shard_files = sorted(shard_files, key=shard_sort_key)
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
assert parts is not None
dataset = parts.group(1)
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
data_dir_with_glob = os.path.join(
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
)
self.dataset_files = [f for f in shard_files]
if self.fs is None:
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
if isinstance(self.fs, s3fs.S3FileSystem):
self.filesystem_type = "s3"
else:
self.filesystem_type = "file"
shard_files = self.fs.glob(data_dir_with_glob)
for s in shard_files:
complete_file = os.path.join(
data_dir, f"{os.path.basename(s)}.complete"
)
if not self.fs.exists(complete_file):
raise ValueError(f"Missing .complete for input file: {s}")
shard_files = sorted(shard_files, key=shard_sort_key)
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
)
self.dataset_files = [f for f in shard_files]
else:
self.preprocess_dir = None
self.dataset_files = dataset_files

View file

@ -0,0 +1,78 @@
import pandas as pd
from pydantic import ConfigDict
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
class BltTestIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
position: int
total: int
def build(self):
blt_iter = BltTestIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=f"This is some test {i} text.",
tokens=None,
mask=None,
entropies=None,
patch_lengths=None,
)
class BltTestWithEntropiesIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
position: int
total: int
def build(self):
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestWithEntropiesIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
df = pd.read_json("fixtures/tokens_with_entropies.json")
tokens = df["token_ids"].tolist()
entropies = df["entropies"].tolist()
# BOS and EOS
assert len(tokens) == len(text) + 2
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=text,
tokens=tokens,
mask=[True] * len(tokens),
entropies=entropies,
patch_lengths=None,
)

View file

@ -0,0 +1,47 @@
from pydantic import ConfigDict
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
class LimitIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: (
BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
)
n_yielded: int
limit: int
def build(self) -> "LimitIterator":
return LimitIterator(
base_iterator=self.base_iterator_state.build(),
n_yielded=self.n_yielded,
limit=self.limit,
)
class LimitIterator(StatefulIterator):
def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
self.base_iterator = base_iterator
self.n_yielded = n_yielded
self.limit = limit
def get_state(self):
return LimitIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_yielded=self.n_yielded,
limit=self.limit,
)
def create_iter(self):
iterator = self.base_iterator.create_iter()
try:
while self.n_yielded < self.limit or self.limit < 0:
yield next(iterator)
self.n_yielded += 1
except StopIteration:
pass

View file

@ -1,14 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from pydantic import BaseModel
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
class LoopingIteratorState(BaseModel, IteratorState):
class LoopingIteratorState(PydanticIteratorState):
file_iterator_state: ArrowFileIteratorState
epoch: int

View file

@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Full
import numpy as np
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from bytelatent.data.data_types import Batch
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
IteratorState,
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
logger = logging.getLogger()
class MultiprocessIteratorState(BaseModel, IteratorState):
class MultiprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: PackingIteratorState
n_batches_to_prefetch: int

View file

@ -5,7 +5,10 @@ import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import Batch, BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
@ -20,7 +23,7 @@ class PackingArgs(BaseModel):
tokenizer_name: str
class PackingIteratorState(BaseModel, IteratorState):
class PackingIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
sequence_iterator_state: SamplingIteratorState
packing_args: PackingArgs

View file

@ -5,20 +5,29 @@ import torch
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
from bytelatent.data.iterators.looping_iterator import (
LoopingIterator,
LoopingIteratorState,
)
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class PreprocessIteratorState(BaseModel, IteratorState):
class PreprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
arrow_file_iterator_state: (
ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
)
add_tokens: bool
add_patches: bool
tokenizer_args: TokenizerArgs
@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator):
def __init__(
self,
arrow_iterator: ArrowFileIterator,
arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
*,
patcher_args: PatcherArgs,
tokenizer_args: TokenizerArgs,

View file

@ -2,13 +2,16 @@
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
class SamplingIteratorState(BaseModel):
class SamplingIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
rng_state: dict[str, Any]
source_to_weight: dict[str, float]

View file

@ -6,7 +6,10 @@ import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.preprocess_iterator import (
PreprocessIterator,
PreprocessIteratorState,
@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel):
buffer_size: int
class SequenceIteratorState(BaseModel, IteratorState):
class SequenceIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
sequence_packing_args: SequencePackingArgs
preprocess_iterator_state: PreprocessIteratorState
rng_state: dict[str, Any]
# If None, rng is disabled.
rng_state: dict[str, Any] | None
def build(self):
preprocess_iterator = self.preprocess_iterator_state.build()
@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator):
self,
preprocess_iterator: PreprocessIterator,
*,
rng_state: dict[str, Any],
rng_state: dict[str, Any] | None,
sequence_packing_args: SequencePackingArgs,
):
self.preprocess_iterator = preprocess_iterator
self.sequence_packing_args = sequence_packing_args
self.output_seq_len = sequence_packing_args.output_seq_len
self.buffer_size = sequence_packing_args.buffer_size
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
if rng_state is None:
self.rng = None
else:
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
def get_state(self):
# TODO: need to also perist the current shuffle buffer
return SequenceIteratorState(
sequence_packing_args=self.sequence_packing_args,
preprocess_iterator_state=self.preprocess_iterator.get_state(),
rng_state=self.rng.bit_generator.state,
rng_state=None if self.rng is None else self.rng.bit_generator.state,
)
def create_iter(self):
@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator):
seq_patch_lengths: list[list[int]] = x_patches.tolist()
assert len(seq_patch_lengths) == self.buffer_size
for idx in self.rng.permutation(len(seq_patch_lengths)):
if self.rng is None:
permutations = list(range(len(seq_patch_lengths)))
else:
permutations = self.rng.permutation(len(seq_patch_lengths))
for idx in permutations:
assert len(seq_patch_lengths[idx]) == self.output_seq_len
assert (
sum(seq_patch_lengths[idx])

View file

@ -1,83 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pandas as pd
from pydantic import BaseModel
from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.dev_iterators import (
BltTestIterator,
BltTestWithEntropiesIterator,
)
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class BltTestIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=f"This is some test {i} text.",
tokens=None,
mask=None,
entropies=None,
patch_lengths=None,
)
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestWithEntropiesIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
df = pd.read_json("fixtures/tokens_with_entropies.json")
tokens = df["token_ids"].tolist()
entropies = df["entropies"].tolist()
# BOS and EOS
assert len(tokens) == len(text) + 2
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=text,
tokens=tokens,
mask=[True] * len(tokens),
entropies=entropies,
patch_lengths=None,
)
def test_preprocess_iter():
total = 3
tokenizer_args = TokenizerArgs(

View file

@ -0,0 +1,45 @@
from bytelatent.data.iterators.dev_iterators import BltTestIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
def test_limit_iterator():
total = 10
limit = 5
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == limit
limit = 10
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == limit == total
limit = 20
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == total
limit = -1
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == total