Allow ArrowIterator to read from json

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-06 17:43:10 +00:00
parent afedb16598
commit 0e9421af07
4 changed files with 105 additions and 58 deletions

View file

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any
import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
@ -53,6 +53,33 @@ def parse_args(args_cls):
return pydantic_args
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks
def distribute_data_to_rank(
*,
dataset_path: str,
@ -62,9 +89,10 @@ def distribute_data_to_rank(
rank: int,
world_size: int,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
dataset_path, world_size, file_pattern, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []

View file

@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__)
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
arrow_batch_size: int = 100
s3_profile: str | None
filesystem_type: str | None = None
file_format: str
def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files: list[str] | None = None,
s3_profile: str | None = None,
filesystem_type: str | None = None,
file_format: str = "arrow",
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
self.arrow_batch_size = arrow_batch_size
self.s3_profile = s3_profile
self.filesystem_type = filesystem_type
self.file_format = file_format
self.fs = None
if self.filesystem_type is not None:
if self.filesystem_type == "file":
self.fs = fsspec.filesystem("file")
elif self.filesystem_type == "s3":
self.fs = fsspec.filesystem("s3", profile=s3_profile)
else:
raise ValueError("Unknown filesystem")
logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None:
# Prepare arrow shards
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
def create_iter(
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
else:
filesystem = None
self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem
self.dataset_files, format=self.file_format, filesystem=filesystem
)
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
if self.file_format == "arrow":
leftover_sample_ids = batch_columns["sample_id"][
curr_remaining:
]
leftover_entropies = batch_columns["entropies"][curr_remaining:]
leftover_texts = batch_columns["text"][curr_remaining:]
elif self.file_format == "json":
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
curr_remaining:
]
leftover_entropies = None
leftover_texts = get_text(batch_columns)[curr_remaining:]
else:
raise ValueError(f"Unknown file format: {self.file_format}")
batch_columns["sample_id"] = leftover_sample_ids
batch_columns["entropies"] = leftover_entropies
batch_columns["text"] = leftover_texts
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks

View file

@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
def get_id_key(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
return "sample_id"
elif "title" in doc:
sample_id = doc["title"]
return "title"
elif "qid" in doc:
sample_id = doc["qid"]
return "qid"
elif "paper_id" in doc:
sample_id = doc["paper_id"]
return "paper_id"
elif "path" in doc:
sample_id = doc["path"]
return "path"
elif "url" in doc:
sample_id = doc["url"]
return "url"
elif "id" in doc:
sample_id = doc["id"]
return "id"
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict):

View file

@ -4,10 +4,10 @@ import json
import os
import shutil
import subprocess
from pydantic import BaseModel
from typing import Any, Dict
from omegaconf import OmegaConf
from pydantic import BaseModel
class StoolArgs(BaseModel):