mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
163 lines
5 KiB
Python
163 lines
5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import pandas as pd
|
|
from pydantic import BaseModel
|
|
|
|
from bytelatent.constants import BLT_DATA
|
|
from bytelatent.data.data_types import BltExample
|
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
|
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
|
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
|
|
|
|
|
class BltTestIteratorState(BaseModel, IteratorState):
|
|
position: int
|
|
total: int
|
|
|
|
def build(self):
|
|
blt_iter = BltTestIteratorState(total=self.total)
|
|
blt_iter.position = self.position
|
|
return blt_iter
|
|
|
|
|
|
class BltTestIterator(StatefulIterator):
|
|
def __init__(self, total: int):
|
|
self.position = 0
|
|
self.total = total
|
|
|
|
def get_state(self):
|
|
return BltTestIteratorState(position=self.position, total=self.total)
|
|
|
|
def create_iter(self):
|
|
for i in range(self.total):
|
|
self.position += 1
|
|
yield BltExample(
|
|
sample_id=f"test_{i}",
|
|
text=f"This is some test {i} text.",
|
|
tokens=None,
|
|
mask=None,
|
|
entropies=None,
|
|
patch_lengths=None,
|
|
)
|
|
|
|
|
|
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
|
|
position: int
|
|
total: int
|
|
|
|
def build(self):
|
|
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
|
blt_iter.position = self.position
|
|
return blt_iter
|
|
|
|
|
|
class BltTestWithEntropiesIterator(StatefulIterator):
|
|
def __init__(self, total: int):
|
|
self.position = 0
|
|
self.total = total
|
|
|
|
def get_state(self):
|
|
return BltTestIteratorState(position=self.position, total=self.total)
|
|
|
|
def create_iter(self):
|
|
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
|
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
|
tokens = df["token_ids"].tolist()
|
|
entropies = df["entropies"].tolist()
|
|
# BOS and EOS
|
|
assert len(tokens) == len(text) + 2
|
|
for i in range(self.total):
|
|
self.position += 1
|
|
yield BltExample(
|
|
sample_id=f"test_{i}",
|
|
text=text,
|
|
tokens=tokens,
|
|
mask=[True] * len(tokens),
|
|
entropies=entropies,
|
|
patch_lengths=None,
|
|
)
|
|
|
|
|
|
def test_preprocess_iter():
|
|
total = 3
|
|
tokenizer_args = TokenizerArgs(
|
|
name="blt",
|
|
init_kwargs={
|
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
|
},
|
|
)
|
|
for mode in [
|
|
PatchingModeEnum.bpe,
|
|
PatchingModeEnum.space,
|
|
]:
|
|
data_it = BltTestIterator(total)
|
|
patcher_args = PatcherArgs(patching_mode=mode)
|
|
example_it = PreprocessIterator(
|
|
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
|
)
|
|
count = 0
|
|
for example in example_it.create_iter():
|
|
assert isinstance(example.tokens, list)
|
|
assert isinstance(example.tokens[0], int)
|
|
# BOS and EOS
|
|
assert len(example.tokens) == len(example.text) + 2
|
|
assert example.mask is not None
|
|
assert len(example.tokens) == len(example.mask)
|
|
count += 1
|
|
|
|
assert count == total
|
|
|
|
|
|
def test_non_entropy_patch_iter():
|
|
total = 3
|
|
tokenizer_args = TokenizerArgs(
|
|
name="blt",
|
|
init_kwargs={
|
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
|
},
|
|
)
|
|
for mode in [
|
|
PatchingModeEnum.bpe,
|
|
PatchingModeEnum.space,
|
|
]:
|
|
patcher_args = PatcherArgs(patching_mode=mode)
|
|
data_it = BltTestIterator(total)
|
|
example_it = PreprocessIterator(
|
|
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
|
)
|
|
|
|
count = 0
|
|
for example in example_it.create_iter():
|
|
assert isinstance(example.patch_lengths, list)
|
|
assert isinstance(example.patch_lengths[0], int)
|
|
assert len(example.tokens) == sum(example.patch_lengths)
|
|
count += 1
|
|
|
|
assert count == total
|
|
|
|
|
|
def test_entropy_patch_iter():
|
|
total = 2
|
|
patcher_args = PatcherArgs(
|
|
patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
|
|
)
|
|
tokenizer_args = TokenizerArgs(
|
|
name="blt",
|
|
init_kwargs={
|
|
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
|
},
|
|
)
|
|
data_it = BltTestWithEntropiesIterator(total)
|
|
example_it = PreprocessIterator(
|
|
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
|
)
|
|
|
|
count = 0
|
|
for example in example_it.create_iter():
|
|
assert isinstance(example.patch_lengths, list)
|
|
assert isinstance(example.patch_lengths[0], int)
|
|
assert len(example.tokens) == sum(example.patch_lengths)
|
|
count += 1
|
|
|
|
assert count == total
|