blt/bytelatent/data/iterators/sampling_iterator.py
Pedro Rodriguez fc3399ef40
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled
Update iterator inheritance, pass file format args, limit iterator (#63)
- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
2025-02-21 16:21:07 -08:00

70 lines
2.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any
import numpy as np
from pydantic import ConfigDict
from bytelatent.data.iterators.abstract_iterator import (
PydanticIteratorState,
StatefulIterator,
)
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
class SamplingIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
rng_state: dict[str, Any]
source_to_weight: dict[str, float]
source_to_iterator_state: dict[str, SequenceIteratorState]
def build(self) -> "SamplingIterator":
return SamplingIterator(
rng_state=self.rng_state,
source_to_weight=self.source_to_weight,
source_to_iterator={
source: state.build()
for source, state in self.source_to_iterator_state.items()
},
)
class SamplingIterator(StatefulIterator):
def __init__(
self,
*,
rng_state: dict[str, Any],
source_to_weight: dict[str, float],
source_to_iterator: dict[str, StatefulIterator],
):
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
self.source_to_weight = source_to_weight
self.source_to_iterator = source_to_iterator
def get_state(self) -> SamplingIteratorState:
return SamplingIteratorState(
rng_state=self.rng.bit_generator.state,
source_to_weight=self.source_to_weight,
source_to_iterator_state={
source: iterator.get_state()
for source, iterator in self.source_to_iterator.items()
},
)
def create_iter(self):
n_sources = len(self.source_to_weight)
possible_sources = []
weights = []
for source, w in self.source_to_weight.items():
possible_sources.append(source)
weights.append(w)
source_to_python_iter = {
source: self.source_to_iterator[source].create_iter()
for source in possible_sources
}
while True:
norm_weights = np.array(weights) / np.array(weights).sum()
source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
yield next(source_to_python_iter[source_choice])