mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
- Create a common class to use in all inheritance for states - Add a limit iterator that we can use in evals - Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json' - Make EvalArgs valid - Move testing iterators to a common directory to allow usage in multiple test files - Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly) Test Plan: - `pytest bytelatent` - `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
|
|
# pyarrow needs the initialization from this import
|
|
import pyarrow.dataset # pyright: ignore
|
|
|
|
from bytelatent.constants import BLT_DATA
|
|
from bytelatent.data.iterators.arrow_iterator import (
|
|
ArrowFileIterator,
|
|
ArrowFileIteratorState,
|
|
)
|
|
|
|
ENTROPY_MODEL = "transformer_100m"
|
|
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
|
ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
|
|
|
|
|
|
def test_basic_arrow_file():
|
|
dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
|
|
n_head = 1000
|
|
head_df = dataset.head(n_head).to_pandas()
|
|
|
|
initial_state = ArrowFileIteratorState(
|
|
file_path=None,
|
|
num_workers=1,
|
|
worker_id=0,
|
|
preprocess_dir=None,
|
|
entropy_model_name=ENTROPY_MODEL,
|
|
dataset_files=[ARROW_TEST_DATA_1],
|
|
row_num=0,
|
|
arrow_batch_size=100,
|
|
s3_profile=None,
|
|
file_format="arrow",
|
|
)
|
|
arrow_file = initial_state.build()
|
|
start_state = arrow_file.get_state()
|
|
assert start_state.row_num == initial_state.row_num
|
|
|
|
sample_id = None
|
|
for example in arrow_file.create_iter():
|
|
sample_id = example.sample_id
|
|
assert head_df.iloc[0]["sample_id"] == sample_id
|
|
break
|
|
|
|
assert arrow_file.get_state().row_num == 1
|
|
arrow_file = initial_state.build()
|
|
for example in arrow_file.create_iter():
|
|
assert example.sample_id == sample_id
|
|
assert head_df.iloc[0]["sample_id"] == sample_id
|
|
break
|
|
|
|
# Test resume far enough in to be past the batch size of 100
|
|
resumed_state = ArrowFileIteratorState(
|
|
file_path=None,
|
|
num_workers=1,
|
|
worker_id=0,
|
|
preprocess_dir=None,
|
|
entropy_model_name=ENTROPY_MODEL,
|
|
dataset_files=[ARROW_TEST_DATA_1],
|
|
row_num=251,
|
|
arrow_batch_size=100,
|
|
s3_profile=None,
|
|
file_format="arrow",
|
|
)
|
|
arrow_file = resumed_state.build()
|
|
for example in arrow_file.create_iter():
|
|
assert example.sample_id == head_df.iloc[251]["sample_id"]
|
|
assert arrow_file.get_state().row_num == 252
|
|
break
|
|
|
|
world_rank = 1
|
|
world_size = 4
|
|
# Test World Size and Rank
|
|
rank_state = ArrowFileIteratorState(
|
|
file_path=None,
|
|
num_workers=world_size,
|
|
worker_id=world_rank,
|
|
preprocess_dir=None,
|
|
entropy_model_name=ENTROPY_MODEL,
|
|
dataset_files=[ARROW_TEST_DATA_1],
|
|
row_num=0,
|
|
arrow_batch_size=100,
|
|
s3_profile=None,
|
|
file_format="arrow",
|
|
)
|
|
arrow_file = rank_state.build()
|
|
expected_ids = []
|
|
for i in range(n_head):
|
|
if i % world_size == world_rank:
|
|
expected_ids.append(head_df.iloc[i]["sample_id"])
|
|
print(len(expected_ids))
|
|
i = 0
|
|
for example in arrow_file.create_iter():
|
|
assert example.sample_id == expected_ids[i]
|
|
i += 1
|
|
if i >= len(expected_ids):
|
|
break
|
|
|
|
|
|
def test_read_jsonl_from_arrow():
|
|
arrow_iterator = ArrowFileIterator(
|
|
file_path="fixtures/test_docs.jsonl",
|
|
num_workers=1,
|
|
worker_id=0,
|
|
preprocess_dir=None,
|
|
entropy_model_name=None,
|
|
file_format="json",
|
|
arrow_batch_size=100,
|
|
)
|
|
iterator = arrow_iterator.create_iter()
|
|
for i, example in enumerate(iterator):
|
|
assert example.sample_id == str(i)
|
|
assert example.text == f"test_{i}"
|