# Copyright (c) Meta Platforms, Inc. and affiliates. from typing import Any import numpy as np from pydantic import BaseModel, ConfigDict from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState class SamplingIteratorState(BaseModel): model_config = ConfigDict(extra="forbid") rng_state: dict[str, Any] source_to_weight: dict[str, float] source_to_iterator_state: dict[str, SequenceIteratorState] def build(self) -> "SamplingIterator": return SamplingIterator( rng_state=self.rng_state, source_to_weight=self.source_to_weight, source_to_iterator={ source: state.build() for source, state in self.source_to_iterator_state.items() }, ) class SamplingIterator(StatefulIterator): def __init__( self, *, rng_state: dict[str, Any], source_to_weight: dict[str, float], source_to_iterator: dict[str, StatefulIterator], ): self.rng = np.random.default_rng() self.rng.bit_generator.state = rng_state self.source_to_weight = source_to_weight self.source_to_iterator = source_to_iterator def get_state(self) -> SamplingIteratorState: return SamplingIteratorState( rng_state=self.rng.bit_generator.state, source_to_weight=self.source_to_weight, source_to_iterator_state={ source: iterator.get_state() for source, iterator in self.source_to_iterator.items() }, ) def create_iter(self): n_sources = len(self.source_to_weight) possible_sources = [] weights = [] for source, w in self.source_to_weight.items(): possible_sources.append(source) weights.append(w) source_to_python_iter = { source: self.source_to_iterator[source].create_iter() for source in possible_sources } while True: norm_weights = np.array(weights) / np.array(weights).sum() source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)] yield next(source_to_python_iter[source_choice])