mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
6ffeb66b53
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
|
|
# pyarrow needs the initialization from this import
|
|
import pyarrow.dataset # pyright: ignore
|
|
|
|
from bytelatent.constants import BLT_DATA
|
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
|
|
|
ENTROPY_MODEL = "transformer_100m"
|
|
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
|
ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
|
|
|
|
|
|
def test_basic_arrow_file():
|
|
dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
|
|
n_head = 1000
|
|
head_df = dataset.head(n_head).to_pandas()
|
|
|
|
initial_state = ArrowFileIteratorState(
|
|
file_path=None,
|
|
num_workers=1,
|
|
worker_id=0,
|
|
preprocess_dir=None,
|
|
entropy_model_name=ENTROPY_MODEL,
|
|
dataset_files=[ARROW_TEST_DATA_1],
|
|
row_num=0,
|
|
arrow_batch_size=100,
|
|
s3_profile=None,
|
|
)
|
|
arrow_file = initial_state.build()
|
|
start_state = arrow_file.get_state()
|
|
assert start_state.row_num == initial_state.row_num
|
|
|
|
sample_id = None
|
|
for example in arrow_file.create_iter():
|
|
sample_id = example.sample_id
|
|
assert head_df.iloc[0]["sample_id"] == sample_id
|
|
break
|
|
|
|
assert arrow_file.get_state().row_num == 1
|
|
arrow_file = initial_state.build()
|
|
for example in arrow_file.create_iter():
|
|
assert example.sample_id == sample_id
|
|
assert head_df.iloc[0]["sample_id"] == sample_id
|
|
break
|
|
|
|
# Test resume far enough in to be past the batch size of 100
|
|
resumed_state = ArrowFileIteratorState(
|
|
file_path=None,
|
|
num_workers=1,
|
|
worker_id=0,
|
|
preprocess_dir=None,
|
|
entropy_model_name=ENTROPY_MODEL,
|
|
dataset_files=[ARROW_TEST_DATA_1],
|
|
row_num=251,
|
|
arrow_batch_size=100,
|
|
s3_profile=None,
|
|
)
|
|
arrow_file = resumed_state.build()
|
|
for example in arrow_file.create_iter():
|
|
assert example.sample_id == head_df.iloc[251]["sample_id"]
|
|
assert arrow_file.get_state().row_num == 252
|
|
break
|
|
|
|
world_rank = 1
|
|
world_size = 4
|
|
# Test World Size and Rank
|
|
rank_state = ArrowFileIteratorState(
|
|
file_path=None,
|
|
num_workers=world_size,
|
|
worker_id=world_rank,
|
|
preprocess_dir=None,
|
|
entropy_model_name=ENTROPY_MODEL,
|
|
dataset_files=[ARROW_TEST_DATA_1],
|
|
row_num=0,
|
|
arrow_batch_size=100,
|
|
s3_profile=None,
|
|
)
|
|
arrow_file = rank_state.build()
|
|
expected_ids = []
|
|
for i in range(n_head):
|
|
if i % world_size == world_rank:
|
|
expected_ids.append(head_df.iloc[i]["sample_id"])
|
|
print(len(expected_ids))
|
|
i = 0
|
|
for example in arrow_file.create_iter():
|
|
assert example.sample_id == expected_ids[i]
|
|
i += 1
|
|
if i >= len(expected_ids):
|
|
break
|