blt/bytelatent/data/iterators/limit_iterator.py
Pedro Rodriguez 0ffe2ab685 Update iterator inheritance, pass file format args, limit iterator
- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
2025-02-20 00:57:17 +00:00

48 lines
1.4 KiB
Python

from pydantic import ConfigDict
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState
class LimitIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: (
BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
)
n_yielded: int
limit: int
def build(self) -> "LimitIterator":
return LimitIterator(
base_iterator=self.base_iterator_state.build(),
n_yielded=self.n_yielded,
limit=self.limit,
)
class LimitIterator(StatefulIterator):
def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
self.base_iterator = base_iterator
self.n_yielded = n_yielded
self.limit = limit
def get_state(self):
return LimitIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_yielded=self.n_yielded,
limit=self.limit,
)
def create_iter(self):
iterator = self.base_iterator.create_iter()
try:
while self.n_yielded < self.limit or self.limit < 0:
yield next(iterator)
self.n_yielded += 1
except StopIteration:
pass