# Copyright (c) Meta Platforms, Inc. and affiliates. import numpy as np import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from bytelatent.constants import BLT_DATA from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState ENTROPY_MODEL = "transformer_100m" ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow") ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow") def test_basic_arrow_file(): dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow") n_head = 1000 head_df = dataset.head(n_head).to_pandas() initial_state = ArrowFileIteratorState( file_path=None, num_workers=1, worker_id=0, preprocess_dir=None, entropy_model_name=ENTROPY_MODEL, dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() assert start_state.row_num == initial_state.row_num sample_id = None for example in arrow_file.create_iter(): sample_id = example.sample_id assert head_df.iloc[0]["sample_id"] == sample_id break assert arrow_file.get_state().row_num == 1 arrow_file = initial_state.build() for example in arrow_file.create_iter(): assert example.sample_id == sample_id assert head_df.iloc[0]["sample_id"] == sample_id break # Test resume far enough in to be past the batch size of 100 resumed_state = ArrowFileIteratorState( file_path=None, num_workers=1, worker_id=0, preprocess_dir=None, entropy_model_name=ENTROPY_MODEL, dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): assert example.sample_id == head_df.iloc[251]["sample_id"] assert arrow_file.get_state().row_num == 252 break world_rank = 1 world_size = 4 # Test World Size and Rank rank_state = ArrowFileIteratorState( file_path=None, num_workers=world_size, worker_id=world_rank, preprocess_dir=None, entropy_model_name=ENTROPY_MODEL, dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] for i in range(n_head): if i % world_size == world_rank: expected_ids.append(head_df.iloc[i]["sample_id"]) print(len(expected_ids)) i = 0 for example in arrow_file.create_iter(): assert example.sample_id == expected_ids[i] i += 1 if i >= len(expected_ids): break