blt/bytelatent/data/iterators/looping_iterator.py
Pedro Rodriguez 0ffe2ab685 Update iterator inheritance, pass file format args, limit iterator
- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
2025-02-20 00:57:17 +00:00

39 lines
1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
class LoopingIteratorState(PydanticIteratorState):
file_iterator_state: ArrowFileIteratorState
epoch: int
def build(self) -> "LoopingIterator":
return LoopingIterator(
file_iterator=self.file_iterator_state.build(),
epoch=self.epoch,
)
class LoopingIterator(StatefulIterator):
def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1):
self.file_iterator = file_iterator
self.epoch = epoch
def get_state(self):
return LoopingIteratorState(
file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch
)
def create_iter(self):
while True:
self.epoch += 1
iterator = self.file_iterator.create_iter()
yield from iterator