mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
112 lines
4 KiB
Python
112 lines
4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
from typing import Any, Generator
|
|
|
|
import torch
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from bytelatent.data.data_types import BltExample
|
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
|
from bytelatent.data.iterators.arrow_iterator import (
|
|
ArrowFileIterator,
|
|
ArrowFileIteratorState,
|
|
)
|
|
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
|
|
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
|
|
|
|
|
class PreprocessIteratorState(BaseModel, IteratorState):
|
|
model_config = ConfigDict(extra="forbid")
|
|
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
|
|
add_tokens: bool
|
|
add_patches: bool
|
|
tokenizer_args: TokenizerArgs
|
|
patcher_args: PatcherArgs
|
|
|
|
def build(self):
|
|
arrow_iterator = self.arrow_file_iterator_state.build()
|
|
return PreprocessIterator(
|
|
arrow_iterator,
|
|
patcher_args=self.patcher_args,
|
|
tokenizer_args=self.tokenizer_args,
|
|
add_tokens=self.add_tokens,
|
|
add_patches=self.add_patches,
|
|
)
|
|
|
|
|
|
class PreprocessIterator(StatefulIterator):
|
|
"""
|
|
Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
|
|
preprocessing like tokenization and patching
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
arrow_iterator: ArrowFileIterator,
|
|
*,
|
|
patcher_args: PatcherArgs,
|
|
tokenizer_args: TokenizerArgs,
|
|
add_tokens: bool = True,
|
|
add_patches: bool = True,
|
|
):
|
|
self.arrow_iterator = arrow_iterator
|
|
self.tokenizer_args = tokenizer_args
|
|
self.patcher_args = patcher_args
|
|
self.add_tokens = add_tokens
|
|
self.add_patches = add_patches
|
|
self.tokenizer: BltTokenizer | None = None
|
|
self.patcher: Patcher | None = None
|
|
|
|
def get_state(self) -> PreprocessIteratorState:
|
|
"""
|
|
The only state to maintain here is from arrow, there
|
|
isn't any internal state on this iterator.
|
|
"""
|
|
return PreprocessIteratorState(
|
|
arrow_file_iterator_state=self.arrow_iterator.get_state(),
|
|
tokenizer_args=self.tokenizer_args,
|
|
patcher_args=self.patcher_args,
|
|
add_tokens=self.add_tokens,
|
|
add_patches=self.add_patches,
|
|
)
|
|
|
|
def create_iter(self) -> Generator[BltExample, Any, None]:
|
|
if self.tokenizer is None and self.add_tokens:
|
|
self.tokenizer = self.tokenizer_args.build()
|
|
if self.patcher is None and self.add_patches:
|
|
self.patcher = self.patcher_args.build()
|
|
|
|
example_iter = self.arrow_iterator.create_iter()
|
|
for example in example_iter:
|
|
if self.add_tokens:
|
|
tokens = self.tokenizer.encode(example.text)
|
|
else:
|
|
tokens = example.tokens
|
|
if (
|
|
self.patcher is not None
|
|
and self.patcher.patching_mode == PatchingModeEnum.entropy
|
|
):
|
|
assert (
|
|
example.entropies is not None
|
|
), "For patching, entropies cannot be None"
|
|
entropies = torch.tensor(example.entropies).unsqueeze(0)
|
|
else:
|
|
entropies = None
|
|
if self.patcher is None:
|
|
patch_lengths = None
|
|
else:
|
|
patch_lengths = self.patcher.patch(
|
|
torch.tensor(tokens).unsqueeze(0),
|
|
include_next_token=False,
|
|
entropies=entropies,
|
|
)[0][0].tolist()
|
|
yield BltExample(
|
|
sample_id=example.sample_id,
|
|
text=example.text,
|
|
tokens=tokens,
|
|
mask=[True] * len(tokens),
|
|
patch_lengths=patch_lengths,
|
|
entropies=example.entropies,
|
|
)
|