blt/bytelatent/data/iterators/abstract_iterator.py
Pedro Rodriguez 0ffe2ab685 Update iterator inheritance, pass file format args, limit iterator
- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
2025-02-20 00:57:17 +00:00

40 lines
994 B
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
from typing import Any, Generator, Generic, TypeVar
import pydantic
T = TypeVar("T")
C = TypeVar("C")
class StatefulIterator(Generic[T, C], abc.ABC):
@abc.abstractmethod
def get_state(self) -> C:
pass
@abc.abstractmethod
def create_iter(self) -> Generator[T, Any, None]:
pass
class IteratorState(Generic[C]):
@abc.abstractmethod
def build(self) -> StatefulIterator[T, C]:
pass
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
model_config = pydantic.ConfigDict(extra="forbid")
def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs
# to be restarted.
state = iterator.get_state()
data_loader = state.build()
py_iterator = data_loader.create_iter()
return state, data_loader, py_iterator