From 936d9437be43c24bfe4718388f6c28d8d50427cd Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 09:57:22 -0800 Subject: [PATCH] Allow ArrowIterator to read from json (#45) Summary: Currently, arrow iterator can only read arrow files. However, the pyarrow library can read other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines, so we can read from the original source data, and simply omit the entropy column when doing so Test Plan: Run train script until dataloader starts --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/stool.py | 2 +- 4 files changed, 105 insertions(+), 58 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/stool.py b/bytelatent/stool.py index b156177..b47ddc7 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,10 +4,10 @@ import json import os import shutil import subprocess -from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf +from pydantic import BaseModel class StoolArgs(BaseModel):