mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
- Create a common class to use in all inheritance for states - Add a limit iterator that we can use in evals - Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json' - Make EvalArgs valid - Move testing iterators to a common directory to allow usage in multiple test files - Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly) Test Plan: - `pytest bytelatent` - `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
40 lines
994 B
Python
40 lines
994 B
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import abc
|
|
from typing import Any, Generator, Generic, TypeVar
|
|
|
|
import pydantic
|
|
|
|
T = TypeVar("T")
|
|
C = TypeVar("C")
|
|
|
|
|
|
class StatefulIterator(Generic[T, C], abc.ABC):
|
|
|
|
@abc.abstractmethod
|
|
def get_state(self) -> C:
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def create_iter(self) -> Generator[T, Any, None]:
|
|
pass
|
|
|
|
|
|
class IteratorState(Generic[C]):
|
|
@abc.abstractmethod
|
|
def build(self) -> StatefulIterator[T, C]:
|
|
pass
|
|
|
|
|
|
class PydanticIteratorState(pydantic.BaseModel, IteratorState):
|
|
model_config = pydantic.ConfigDict(extra="forbid")
|
|
|
|
|
|
def get_state_and_refresh(iterator: StatefulIterator):
|
|
# Re-init dataloader and iterator is necessary since get_state()
|
|
# on mp iterator shuts down MP to correctly persist state and it needs
|
|
# to be restarted.
|
|
state = iterator.get_state()
|
|
data_loader = state.build()
|
|
py_iterator = data_loader.create_iter()
|
|
return state, data_loader, py_iterator
|