mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 21:42:14 +00:00
- Create a common class to use in all inheritance for states - Add a limit iterator that we can use in evals - Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json' - Make EvalArgs valid - Move testing iterators to a common directory to allow usage in multiple test files - Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly) Test Plan: - `pytest bytelatent` - `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
import pandas as pd
|
|
from pydantic import ConfigDict
|
|
|
|
from bytelatent.data.data_types import BltExample
|
|
from bytelatent.data.iterators.abstract_iterator import (
|
|
PydanticIteratorState,
|
|
StatefulIterator,
|
|
)
|
|
|
|
|
|
class BltTestIteratorState(PydanticIteratorState):
|
|
model_config = ConfigDict(extra="forbid")
|
|
position: int
|
|
total: int
|
|
|
|
def build(self):
|
|
blt_iter = BltTestIteratorState(total=self.total)
|
|
blt_iter.position = self.position
|
|
return blt_iter
|
|
|
|
|
|
class BltTestIterator(StatefulIterator):
|
|
def __init__(self, total: int):
|
|
self.position = 0
|
|
self.total = total
|
|
|
|
def get_state(self):
|
|
return BltTestIteratorState(position=self.position, total=self.total)
|
|
|
|
def create_iter(self):
|
|
for i in range(self.total):
|
|
self.position += 1
|
|
yield BltExample(
|
|
sample_id=f"test_{i}",
|
|
text=f"This is some test {i} text.",
|
|
tokens=None,
|
|
mask=None,
|
|
entropies=None,
|
|
patch_lengths=None,
|
|
)
|
|
|
|
|
|
class BltTestWithEntropiesIteratorState(PydanticIteratorState):
|
|
model_config = ConfigDict(extra="forbid")
|
|
position: int
|
|
total: int
|
|
|
|
def build(self):
|
|
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
|
blt_iter.position = self.position
|
|
return blt_iter
|
|
|
|
|
|
class BltTestWithEntropiesIterator(StatefulIterator):
|
|
def __init__(self, total: int):
|
|
self.position = 0
|
|
self.total = total
|
|
|
|
def get_state(self):
|
|
return BltTestIteratorState(position=self.position, total=self.total)
|
|
|
|
def create_iter(self):
|
|
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
|
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
|
tokens = df["token_ids"].tolist()
|
|
entropies = df["entropies"].tolist()
|
|
# BOS and EOS
|
|
assert len(tokens) == len(text) + 2
|
|
for i in range(self.total):
|
|
self.position += 1
|
|
yield BltExample(
|
|
sample_id=f"test_{i}",
|
|
text=text,
|
|
tokens=tokens,
|
|
mask=[True] * len(tokens),
|
|
entropies=entropies,
|
|
patch_lengths=None,
|
|
)
|