blt/bytelatent/data/iterators/preprocess_iterator.py
2024-12-12 15:32:30 -08:00

112 lines
4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any, Generator
import torch
from pydantic import BaseModel, ConfigDict
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
ArrowFileIteratorState,
)
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class PreprocessIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
add_tokens: bool
add_patches: bool
tokenizer_args: TokenizerArgs
patcher_args: PatcherArgs
def build(self):
arrow_iterator = self.arrow_file_iterator_state.build()
return PreprocessIterator(
arrow_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
add_tokens=self.add_tokens,
add_patches=self.add_patches,
)
class PreprocessIterator(StatefulIterator):
"""
Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
preprocessing like tokenization and patching
"""
def __init__(
self,
arrow_iterator: ArrowFileIterator,
*,
patcher_args: PatcherArgs,
tokenizer_args: TokenizerArgs,
add_tokens: bool = True,
add_patches: bool = True,
):
self.arrow_iterator = arrow_iterator
self.tokenizer_args = tokenizer_args
self.patcher_args = patcher_args
self.add_tokens = add_tokens
self.add_patches = add_patches
self.tokenizer: BltTokenizer | None = None
self.patcher: Patcher | None = None
def get_state(self) -> PreprocessIteratorState:
"""
The only state to maintain here is from arrow, there
isn't any internal state on this iterator.
"""
return PreprocessIteratorState(
arrow_file_iterator_state=self.arrow_iterator.get_state(),
tokenizer_args=self.tokenizer_args,
patcher_args=self.patcher_args,
add_tokens=self.add_tokens,
add_patches=self.add_patches,
)
def create_iter(self) -> Generator[BltExample, Any, None]:
if self.tokenizer is None and self.add_tokens:
self.tokenizer = self.tokenizer_args.build()
if self.patcher is None and self.add_patches:
self.patcher = self.patcher_args.build()
example_iter = self.arrow_iterator.create_iter()
for example in example_iter:
if self.add_tokens:
tokens = self.tokenizer.encode(example.text)
else:
tokens = example.tokens
if (
self.patcher is not None
and self.patcher.patching_mode == PatchingModeEnum.entropy
):
assert (
example.entropies is not None
), "For patching, entropies cannot be None"
entropies = torch.tensor(example.entropies).unsqueeze(0)
else:
entropies = None
if self.patcher is None:
patch_lengths = None
else:
patch_lengths = self.patcher.patch(
torch.tensor(tokens).unsqueeze(0),
include_next_token=False,
entropies=entropies,
)[0][0].tolist()
yield BltExample(
sample_id=example.sample_id,
text=example.text,
tokens=tokens,
mask=[True] * len(tokens),
patch_lengths=patch_lengths,
entropies=example.entropies,
)