# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict

from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState


class SamplingIteratorState(BaseModel):
    model_config = ConfigDict(extra="forbid")
    rng_state: dict[str, Any]
    source_to_weight: dict[str, float]
    source_to_iterator_state: dict[str, SequenceIteratorState]

    def build(self) -> "SamplingIterator":
        return SamplingIterator(
            rng_state=self.rng_state,
            source_to_weight=self.source_to_weight,
            source_to_iterator={
                source: state.build()
                for source, state in self.source_to_iterator_state.items()
            },
        )


class SamplingIterator(StatefulIterator):
    def __init__(
        self,
        *,
        rng_state: dict[str, Any],
        source_to_weight: dict[str, float],
        source_to_iterator: dict[str, StatefulIterator],
    ):
        self.rng = np.random.default_rng()
        self.rng.bit_generator.state = rng_state
        self.source_to_weight = source_to_weight
        self.source_to_iterator = source_to_iterator

    def get_state(self) -> SamplingIteratorState:
        return SamplingIteratorState(
            rng_state=self.rng.bit_generator.state,
            source_to_weight=self.source_to_weight,
            source_to_iterator_state={
                source: iterator.get_state()
                for source, iterator in self.source_to_iterator.items()
            },
        )

    def create_iter(self):
        n_sources = len(self.source_to_weight)
        possible_sources = []
        weights = []
        for source, w in self.source_to_weight.items():
            possible_sources.append(source)
            weights.append(w)

        source_to_python_iter = {
            source: self.source_to_iterator[source].create_iter()
            for source in possible_sources
        }
        while True:
            norm_weights = np.array(weights) / np.array(weights).sum()
            source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
            yield next(source_to_python_iter[source_choice])