# Copyright (c) Meta Platforms, Inc. and affiliates. from typing import Any, Generator import torch from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import BltExample from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, ArrowFileIteratorState, ) from bytelatent.data.iterators.looping_iterator import LoopingIteratorState from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from bytelatent.tokenizers.build_tokenizer import TokenizerArgs class PreprocessIteratorState(BaseModel, IteratorState): model_config = ConfigDict(extra="forbid") arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState add_tokens: bool add_patches: bool tokenizer_args: TokenizerArgs patcher_args: PatcherArgs def build(self): arrow_iterator = self.arrow_file_iterator_state.build() return PreprocessIterator( arrow_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, add_tokens=self.add_tokens, add_patches=self.add_patches, ) class PreprocessIterator(StatefulIterator): """ Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require preprocessing like tokenization and patching """ def __init__( self, arrow_iterator: ArrowFileIterator, *, patcher_args: PatcherArgs, tokenizer_args: TokenizerArgs, add_tokens: bool = True, add_patches: bool = True, ): self.arrow_iterator = arrow_iterator self.tokenizer_args = tokenizer_args self.patcher_args = patcher_args self.add_tokens = add_tokens self.add_patches = add_patches self.tokenizer: BltTokenizer | None = None self.patcher: Patcher | None = None def get_state(self) -> PreprocessIteratorState: """ The only state to maintain here is from arrow, there isn't any internal state on this iterator. """ return PreprocessIteratorState( arrow_file_iterator_state=self.arrow_iterator.get_state(), tokenizer_args=self.tokenizer_args, patcher_args=self.patcher_args, add_tokens=self.add_tokens, add_patches=self.add_patches, ) def create_iter(self) -> Generator[BltExample, Any, None]: if self.tokenizer is None and self.add_tokens: self.tokenizer = self.tokenizer_args.build() if self.patcher is None and self.add_patches: self.patcher = self.patcher_args.build() example_iter = self.arrow_iterator.create_iter() for example in example_iter: if self.add_tokens: tokens = self.tokenizer.encode(example.text) else: tokens = example.tokens if ( self.patcher is not None and self.patcher.patching_mode == PatchingModeEnum.entropy ): assert ( example.entropies is not None ), "For patching, entropies cannot be None" entropies = torch.tensor(example.entropies).unsqueeze(0) else: entropies = None if self.patcher is None: patch_lengths = None else: patch_lengths = self.patcher.patch( torch.tensor(tokens).unsqueeze(0), include_next_token=False, entropies=entropies, )[0][0].tolist() yield BltExample( sample_id=example.sample_id, text=example.text, tokens=tokens, mask=[True] * len(tokens), patch_lengths=patch_lengths, entropies=example.entropies, )