blt/bytelatent/data/iterators/preprocess_iterator.py
Pedro Rodriguez 0ffe2ab685 Update iterator inheritance, pass file format args, limit iterator
- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
2025-02-20 00:57:17 +00:00

121 lines
4.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any, Generator
import torch
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState
from bytelatent.data.iterators.looping_iterator import (
LoopingIterator,
LoopingIteratorState,
)
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class PreprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
arrow_file_iterator_state: (
ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState
)
add_tokens: bool
add_patches: bool
tokenizer_args: TokenizerArgs
patcher_args: PatcherArgs
def build(self):
arrow_iterator = self.arrow_file_iterator_state.build()
return PreprocessIterator(
arrow_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
add_tokens=self.add_tokens,
add_patches=self.add_patches,
)
class PreprocessIterator(StatefulIterator):
"""
Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
preprocessing like tokenization and patching
"""
def __init__(
self,
arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator,
*,
patcher_args: PatcherArgs,
tokenizer_args: TokenizerArgs,
add_tokens: bool = True,
add_patches: bool = True,
):
self.arrow_iterator = arrow_iterator
self.tokenizer_args = tokenizer_args
self.patcher_args = patcher_args
self.add_tokens = add_tokens
self.add_patches = add_patches
self.tokenizer: BltTokenizer | None = None
self.patcher: Patcher | None = None
def get_state(self) -> PreprocessIteratorState:
"""
The only state to maintain here is from arrow, there
isn't any internal state on this iterator.
"""
return PreprocessIteratorState(
arrow_file_iterator_state=self.arrow_iterator.get_state(),
tokenizer_args=self.tokenizer_args,
patcher_args=self.patcher_args,
add_tokens=self.add_tokens,
add_patches=self.add_patches,
)
def create_iter(self) -> Generator[BltExample, Any, None]:
if self.tokenizer is None and self.add_tokens:
self.tokenizer = self.tokenizer_args.build()
if self.patcher is None and self.add_patches:
self.patcher = self.patcher_args.build()
example_iter = self.arrow_iterator.create_iter()
for example in example_iter:
if self.add_tokens:
tokens = self.tokenizer.encode(example.text)
else:
tokens = example.tokens
if (
self.patcher is not None
and self.patcher.patching_mode == PatchingModeEnum.entropy
):
assert (
example.entropies is not None
), "For patching, entropies cannot be None"
entropies = torch.tensor(example.entropies).unsqueeze(0)
else:
entropies = None
if self.patcher is None:
patch_lengths = None
else:
patch_lengths = self.patcher.patch(
torch.tensor(tokens).unsqueeze(0),
include_next_token=False,
entropies=entropies,
)[0][0].tolist()
yield BltExample(
sample_id=example.sample_id,
text=example.text,
tokens=tokens,
mask=[True] * len(tokens),
patch_lengths=patch_lengths,
entropies=example.entropies,
)