Update iterator inheritance, pass file format args, limit iterator (#63)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
This commit is contained in:
Pedro Rodriguez 2025-02-21 16:21:07 -08:00 committed by GitHub
parent b0956bde99
commit fc3399ef40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 317 additions and 133 deletions

View file

@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Full
import numpy as np
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from bytelatent.data.data_types import Batch
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.abstract_iterator import (
IteratorState,
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
logger = logging.getLogger()
class MultiprocessIteratorState(BaseModel, IteratorState):
class MultiprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: PackingIteratorState
n_batches_to_prefetch: int