blt/bytelatent/data/iterators/sampling_iterator.py
2024-12-12 15:32:30 -08:00

67 lines
2.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
class SamplingIteratorState(BaseModel):
model_config = ConfigDict(extra="forbid")
rng_state: dict[str, Any]
source_to_weight: dict[str, float]
source_to_iterator_state: dict[str, SequenceIteratorState]
def build(self) -> "SamplingIterator":
return SamplingIterator(
rng_state=self.rng_state,
source_to_weight=self.source_to_weight,
source_to_iterator={
source: state.build()
for source, state in self.source_to_iterator_state.items()
},
)
class SamplingIterator(StatefulIterator):
def __init__(
self,
*,
rng_state: dict[str, Any],
source_to_weight: dict[str, float],
source_to_iterator: dict[str, StatefulIterator],
):
self.rng = np.random.default_rng()
self.rng.bit_generator.state = rng_state
self.source_to_weight = source_to_weight
self.source_to_iterator = source_to_iterator
def get_state(self) -> SamplingIteratorState:
return SamplingIteratorState(
rng_state=self.rng.bit_generator.state,
source_to_weight=self.source_to_weight,
source_to_iterator_state={
source: iterator.get_state()
for source, iterator in self.source_to_iterator.items()
},
)
def create_iter(self):
n_sources = len(self.source_to_weight)
possible_sources = []
weights = []
for source, w in self.source_to_weight.items():
possible_sources.append(source)
weights.append(w)
source_to_python_iter = {
source: self.source_to_iterator[source].create_iter()
for source in possible_sources
}
while True:
norm_weights = np.array(weights) / np.array(weights).sum()
source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
yield next(source_to_python_iter[source_choice])