blt/bytelatent/data/iterators/arrow_iterator.py

217 lines
8 KiB
Python
Raw Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import re
from logging import getLogger
from pathlib import Path
from typing import Any, Generator
import pyarrow as pa
# pyarrow needs the initialization from this import
import pyarrow.dataset # pyright: ignore
from pydantic import BaseModel, ConfigDict
from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
logger = getLogger(__name__)
class ArrowFileIteratorState(BaseModel, IteratorState):
model_config = ConfigDict(extra="forbid")
file_path: str | None
row_num: int
num_workers: int
worker_id: int
preprocess_dir: str | None
dataset_files: list[str] | None
entropy_model_name: str | None
arrow_batch_size: int = 100
def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
file_path=self.file_path,
worker_id=self.worker_id,
num_workers=self.num_workers,
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
return arrow_file
def shard_sort_key(file: str | Path):
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
shard_number = int(match.group(1))
return shard_number
class ArrowFileIterator(StatefulIterator):
def __init__(
self,
*,
file_path: str | None,
worker_id: int,
num_workers: int,
preprocess_dir: str | None,
entropy_model_name: str | None,
arrow_batch_size: int,
dataset_files: list[str] | None = None,
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
raise ByteLatentError("file_path and dataset_files cannot both be None")
self.row_num = 0
self.iter_id = 0
self.batch_iterator = None
self.batch_to_consume = None
self.dataset = None
self.file_path = file_path
self.worker_id = worker_id
self.num_workers = num_workers
self.preprocess_dir = preprocess_dir
self.entropy_model_name = entropy_model_name
self.arrow_batch_size = arrow_batch_size
if dataset_files is None:
# Prepare arrow shards
jsonl_file = Path(file_path)
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
assert parts is not None
dataset = parts.group(1)
data_dir = Path(preprocess_dir) / dataset / entropy_model_name
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
for s in shard_files:
if not (data_dir / f"{s.name}.complete").exists():
raise ValueError(f"Missing .complete for input file: {s}")
shard_files = sorted(shard_files, key=shard_sort_key)
if len(shard_files) == 0:
raise ByteLatentError(
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
)
self.dataset_files = [str(f) for f in shard_files]
else:
self.preprocess_dir = None
self.dataset_files = dataset_files
def get_state(self) -> ArrowFileIteratorState:
return ArrowFileIteratorState(
file_path=self.file_path,
row_num=self.row_num,
worker_id=self.worker_id,
num_workers=self.num_workers,
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
dataset_files=self.dataset_files,
)
def create_iter(
self,
) -> Generator[BltExample, Any, None]:
if self.dataset is None:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
)
self.iter_id += 1
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
text=texts[i],
tokens=None,
mask=None,
patch_lengths=None,
)
self.row_num += 1
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
text=texts[i],
tokens=None,
mask=None,
patch_lengths=None,
)
self.row_num += 1
if (self.row_num - 1) % self.num_workers == self.worker_id:
yield out
def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
)
if target_row_num is None or target_row_num == 0:
self.row_num = 0
self.dataset = None
self.batch_iterator = None
self.batch_to_consume = None
else:
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
)
curr_remaining = target_row_num
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
# We are exactly at the end of the batch,
# so the next batch is the right spot
break
else:
curr_remaining -= len(batch)
self.row_num = target_row_num
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
):
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks