blt/bytelatent/data/iterators/test_arrow_iterator.py
Pedro Rodriguez 6ffeb66b53
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run
Changes for training entropy model and correcting attention in local models (#25)
Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
2025-01-17 14:23:01 -08:00

93 lines
2.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import numpy as np
import pyarrow as pa
# pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore
from bytelatent.constants import BLT_DATA
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
ENTROPY_MODEL = "transformer_100m"
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
def test_basic_arrow_file():
dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
n_head = 1000
head_df = dataset.head(n_head).to_pandas()
initial_state = ArrowFileIteratorState(
file_path=None,
num_workers=1,
worker_id=0,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
start_state = arrow_file.get_state()
assert start_state.row_num == initial_state.row_num
sample_id = None
for example in arrow_file.create_iter():
sample_id = example.sample_id
assert head_df.iloc[0]["sample_id"] == sample_id
break
assert arrow_file.get_state().row_num == 1
arrow_file = initial_state.build()
for example in arrow_file.create_iter():
assert example.sample_id == sample_id
assert head_df.iloc[0]["sample_id"] == sample_id
break
# Test resume far enough in to be past the batch size of 100
resumed_state = ArrowFileIteratorState(
file_path=None,
num_workers=1,
worker_id=0,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA_1],
row_num=251,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = resumed_state.build()
for example in arrow_file.create_iter():
assert example.sample_id == head_df.iloc[251]["sample_id"]
assert arrow_file.get_state().row_num == 252
break
world_rank = 1
world_size = 4
# Test World Size and Rank
rank_state = ArrowFileIteratorState(
file_path=None,
num_workers=world_size,
worker_id=world_rank,
preprocess_dir=None,
entropy_model_name=ENTROPY_MODEL,
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = rank_state.build()
expected_ids = []
for i in range(n_head):
if i % world_size == world_rank:
expected_ids.append(head_df.iloc[i]["sample_id"])
print(len(expected_ids))
i = 0
for example in arrow_file.create_iter():
assert example.sample_id == expected_ids[i]
i += 1
if i >= len(expected_ids):
break